Compare commits

...

27 Commits

Author SHA1 Message Date
Bruce MacDonald
da0190f8e4 Update quickstart.mdx 2026-02-11 12:16:31 -08:00
Bruce MacDonald
1f2594f50f Update quickstart.mdx 2026-02-11 12:15:02 -08:00
Bruce MacDonald
3fb0958e01 docs: update quickstart for tui 2026-02-11 12:08:55 -08:00
Parth Sareen
f08427c138 cmd: TUI UX improvements (#14198) 2026-02-11 10:18:41 -08:00
Maternion
2dbb000908 update context length format. 2026-02-10 17:06:05 -08:00
Maternion
c980e19995 Fix formatting of context length notes in documentation 2026-02-10 17:06:05 -08:00
Maternion
6162374ca9 Update context-length.mdx 2026-02-10 17:06:05 -08:00
Patrick Devine
44bdd9a2ef Add MLX runner with GLM4-MoE-Lite model support (#14185)
This change adds a new MLX based runner which includes:

  * Method-based MLX bindings
  * Subprocess-based MLX runner (x/mlxrunner)
  * KV cache with tree management
  * A basic sampler

The GLM4-MoE-Lite model has been ported to use the new bindings.

---------

Co-authored-by: Michael Yang <git@mxy.ng>
2026-02-10 14:57:57 -08:00
Michael
db493d6e5e docs: update broken links on FAQ and quick cleanup (#14194)
docs: update broken links on FAQ and quick cleanup
2026-02-10 16:52:20 -05:00
Bruce MacDonald
75695f16a5 docs: integration overview (#13831)
Group integrations into high-level types
2026-02-10 11:41:09 -08:00
Patrick Devine
a0407d07fa safetensors quantization for mlx (#14184)
This change includes:
  - changes to the safetensors metadata format
  - changes to the create command to properly create the blobs with the new format
  - changes to load the new format
  - fixes ollama show to properly show each tensor
2026-02-10 11:29:17 -08:00
Jeffrey Morgan
9ec733e527 cmd: make 'ollama login' and 'ollama logout' aliases for 'ollama signin' and 'ollama signout' respectively (#14144) 2026-02-09 19:12:42 -08:00
Parth Sareen
5ef04dab52 cmd: ollama launch pi (#14084) 2026-02-09 19:07:41 -08:00
Daniel Hiltgen
aea316f1e9 win: add curl-style install script (#14178)
This adds a new powershell install script suitable for running via

  irm https://ollama.com/install.ps1 | iex

If you download the script and run '-?' it reports basic usage
information, as well as usage examples for common customization
options.  The script is signed as part of the release process
to ensure it can run on a typically configured Windows system.

This does not include doc updates - we can merge those after a release
ships to avoid user confusion.
2026-02-09 15:28:11 -08:00
Patrick Devine
235ba3df5c cmd: ollama menu and launch improvements (#14038) 2026-02-09 11:30:16 -08:00
Jeffrey Morgan
099a0f18ef build: fix Dockerfile mlx directory (#14131) 2026-02-06 17:08:53 -08:00
Richard Lyons
fff696ee31 docs: increased RAM requirement for parallelism 2026-02-06 15:49:39 -08:00
Jeffrey Morgan
2e3ce6eab3 anthropic: do not count image tokens for now (#14127) 2026-02-06 15:33:18 -08:00
Parth Sareen
9e2003f88a cmd/config: offer to pull missing models instead of erroring (#14113) 2026-02-06 10:19:47 -08:00
Parth Sareen
42e1d49fbe cmd: fix context limits for droid and add qwen3-coder-next ctx (#14112) 2026-02-05 22:29:53 -08:00
Michael Yang
814630ca60 Revert "move tokenizers to separate package (#13825)" (#14111) 2026-02-05 20:49:08 -08:00
Parth Sareen
87cf187774 cmd: set claude code env vars on launch (#14109)
Set ANTHROPIC_DEFAULT_OPUS_MODEL, ANTHROPIC_DEFAULT_SONNET_MODEL,
ANTHROPIC_DEFAULT_HAIKU_MODEL, and CLAUDE_CODE_SUBAGENT_MODEL when
launching Claude Code so all model tiers route through Ollama.
2026-02-05 19:04:53 -08:00
Michael Yang
6ddd8862cd chore: move x/mlxrunner into x/imagegen (#14100) 2026-02-05 18:25:56 -08:00
Michael Yang
f1373193dc move tokenizers to separate package (#13825) 2026-02-05 17:44:11 -08:00
Parth Sareen
8a4b77f9da cmd: set context limits for cloud models in opencode (#14107) 2026-02-05 16:36:46 -08:00
Parth Sareen
5f53fe7884 cmd: ollama launch improvements (#14099) 2026-02-05 15:08:17 -08:00
Bruce MacDonald
7ab4ca0e7f scripts: add macOS support to install.sh (#14060)
Allow installing Ollama on MacOS directly from the command line. This is in line with other CLI tools and results in a more streamlined experience when the user is looking to use the CLI specifically.
2026-02-05 14:59:01 -08:00
199 changed files with 25891 additions and 9692 deletions

View File

@@ -337,6 +337,7 @@ jobs:
name: bundles-windows
path: |
dist/*.zip
dist/*.ps1
dist/OllamaSetup.exe
linux-build:
@@ -514,6 +515,9 @@ jobs:
- name: Log dist contents
run: |
ls -l dist/
- name: Copy install scripts to dist
run: |
cp scripts/install.sh dist/install.sh
- name: Generate checksum file
run: find . -type f -not -name 'sha256sum.txt' | xargs sha256sum | tee sha256sum.txt
working-directory: dist
@@ -536,7 +540,7 @@ jobs:
- name: Upload release artifacts
run: |
pids=()
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg ; do
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg dist/*.ps1 dist/*.sh ; do
echo "Uploading $payload"
gh release upload ${GITHUB_REF_NAME} $payload --clobber &
pids[$!]=$!

22
.github/workflows/test-install.yaml vendored Normal file
View File

@@ -0,0 +1,22 @@
name: test-install
on:
pull_request:
paths:
- 'scripts/install.sh'
- '.github/workflows/test-install.yaml'
jobs:
test:
strategy:
matrix:
os: [ubuntu-latest, macos-latest]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
- name: Run install script
run: sh ./scripts/install.sh
env:
OLLAMA_NO_START: 1 # do not start app
- name: Verify ollama is available
run: ollama --version

View File

@@ -182,7 +182,7 @@ option(MLX_ENGINE "Enable MLX backend" OFF)
if(MLX_ENGINE)
message(STATUS "Setting up MLX (this takes a while...)")
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/ml/backend/mlx)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/imagegen/mlx)
# Find CUDA toolkit if MLX is built with CUDA support
find_package(CUDAToolkit)
@@ -216,4 +216,4 @@ if(MLX_ENGINE)
COMPONENT MLX)
endif()
endif()
endif()
endif()

View File

@@ -147,7 +147,7 @@ ARG PARALLEL
WORKDIR /go/src/github.com/ollama/ollama
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
COPY x/ml/backend/mlx x/ml/backend/mlx
COPY x/imagegen/mlx x/imagegen/mlx
COPY go.mod go.sum .
COPY MLX_VERSION .
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local

View File

@@ -518,24 +518,26 @@ func mapStopReason(reason string, hasToolCalls bool) string {
// StreamConverter manages state for converting Ollama streaming responses to Anthropic format
type StreamConverter struct {
ID string
Model string
firstWrite bool
contentIndex int
inputTokens int
outputTokens int
thinkingStarted bool
thinkingDone bool
textStarted bool
toolCallsSent map[string]bool
ID string
Model string
firstWrite bool
contentIndex int
inputTokens int
outputTokens int
estimatedInputTokens int // Estimated tokens from request (used when actual metrics are 0)
thinkingStarted bool
thinkingDone bool
textStarted bool
toolCallsSent map[string]bool
}
func NewStreamConverter(id, model string) *StreamConverter {
func NewStreamConverter(id, model string, estimatedInputTokens int) *StreamConverter {
return &StreamConverter{
ID: id,
Model: model,
firstWrite: true,
toolCallsSent: make(map[string]bool),
ID: id,
Model: model,
firstWrite: true,
estimatedInputTokens: estimatedInputTokens,
toolCallsSent: make(map[string]bool),
}
}
@@ -551,7 +553,11 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
if c.firstWrite {
c.firstWrite = false
// Use actual metrics if available, otherwise use estimate
c.inputTokens = r.Metrics.PromptEvalCount
if c.inputTokens == 0 && c.estimatedInputTokens > 0 {
c.inputTokens = c.estimatedInputTokens
}
events = append(events, StreamEvent{
Event: "message_start",
@@ -779,3 +785,117 @@ func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
}
return args
}
// CountTokensRequest represents an Anthropic count_tokens request
type CountTokensRequest struct {
Model string `json:"model"`
Messages []MessageParam `json:"messages"`
System any `json:"system,omitempty"`
Tools []Tool `json:"tools,omitempty"`
Thinking *ThinkingConfig `json:"thinking,omitempty"`
}
// EstimateInputTokens estimates input tokens from a MessagesRequest (reuses CountTokensRequest logic)
func EstimateInputTokens(req MessagesRequest) int {
return estimateTokens(CountTokensRequest{
Model: req.Model,
Messages: req.Messages,
System: req.System,
Tools: req.Tools,
Thinking: req.Thinking,
})
}
// CountTokensResponse represents an Anthropic count_tokens response
type CountTokensResponse struct {
InputTokens int `json:"input_tokens"`
}
// estimateTokens returns a rough estimate of tokens (len/4).
// TODO: Replace with actual tokenization via Tokenize API for accuracy.
// Current len/4 heuristic is a rough approximation (~4 chars/token average).
func estimateTokens(req CountTokensRequest) int {
var totalLen int
// Count system prompt
if req.System != nil {
totalLen += countAnyContent(req.System)
}
// Count messages
for _, msg := range req.Messages {
// Count role (always present)
totalLen += len(msg.Role)
// Count content
contentLen := countAnyContent(msg.Content)
totalLen += contentLen
}
for _, tool := range req.Tools {
totalLen += len(tool.Name) + len(tool.Description) + len(tool.InputSchema)
}
// Return len/4 as rough token estimate, minimum 1 if there's any content
tokens := totalLen / 4
if tokens == 0 && (len(req.Messages) > 0 || req.System != nil) {
tokens = 1
}
return tokens
}
func countAnyContent(content any) int {
if content == nil {
return 0
}
switch c := content.(type) {
case string:
return len(c)
case []any:
total := 0
for _, block := range c {
total += countContentBlock(block)
}
return total
default:
if data, err := json.Marshal(content); err == nil {
return len(data)
}
return 0
}
}
func countContentBlock(block any) int {
blockMap, ok := block.(map[string]any)
if !ok {
if s, ok := block.(string); ok {
return len(s)
}
return 0
}
total := 0
blockType, _ := blockMap["type"].(string)
if text, ok := blockMap["text"].(string); ok {
total += len(text)
}
if thinking, ok := blockMap["thinking"].(string); ok {
total += len(thinking)
}
if blockType == "tool_use" {
if data, err := json.Marshal(blockMap); err == nil {
total += len(data)
}
}
if blockType == "tool_result" {
if data, err := json.Marshal(blockMap); err == nil {
total += len(data)
}
}
return total
}

View File

@@ -321,8 +321,6 @@ func TestFromMessagesRequest_WithThinking(t *testing.T) {
}
}
// TestFromMessagesRequest_ThinkingOnlyBlock verifies that messages containing only
// a thinking block (no text, images, or tool calls) are preserved and not dropped.
func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
@@ -605,7 +603,7 @@ func TestGenerateMessageID(t *testing.T) {
}
func TestStreamConverter_Basic(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
conv := NewStreamConverter("msg_123", "test-model", 0)
// First chunk
resp1 := api.ChatResponse{
@@ -678,7 +676,7 @@ func TestStreamConverter_Basic(t *testing.T) {
}
func TestStreamConverter_WithToolCalls(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
conv := NewStreamConverter("msg_123", "test-model", 0)
resp := api.ChatResponse{
Model: "test-model",
@@ -731,7 +729,7 @@ func TestStreamConverter_WithToolCalls(t *testing.T) {
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
// Test that unmarshalable arguments (like channels) are handled gracefully
// and don't cause a panic or corrupt stream
conv := NewStreamConverter("msg_123", "test-model")
conv := NewStreamConverter("msg_123", "test-model", 0)
// Create a channel which cannot be JSON marshaled
unmarshalable := make(chan int)
@@ -778,7 +776,7 @@ func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
// Test that valid tool calls still work when mixed with invalid ones
conv := NewStreamConverter("msg_123", "test-model")
conv := NewStreamConverter("msg_123", "test-model", 0)
unmarshalable := make(chan int)
badArgs := api.NewToolCallFunctionArguments()
@@ -842,10 +840,6 @@ func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
}
}
// TestContentBlockJSON_EmptyFieldsPresent verifies that empty text and thinking fields
// are serialized in JSON output. The Anthropic SDK requires these fields to be present
// (even when empty) in content_block_start events to properly accumulate streaming deltas.
// Without these fields, the SDK throws: "TypeError: unsupported operand type(s) for +=: 'NoneType' and 'str'"
func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
tests := []struct {
name string
@@ -899,11 +893,9 @@ func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
}
}
// TestStreamConverter_ContentBlockStartIncludesEmptyFields verifies that content_block_start
// events include the required empty fields for SDK compatibility.
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
t.Run("text block start includes empty text", func(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
conv := NewStreamConverter("msg_123", "test-model", 0)
resp := api.ChatResponse{
Model: "test-model",
@@ -937,7 +929,7 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
})
t.Run("thinking block start includes empty thinking", func(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
conv := NewStreamConverter("msg_123", "test-model", 0)
resp := api.ChatResponse{
Model: "test-model",
@@ -969,3 +961,105 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
}
})
}
func TestEstimateTokens_SimpleMessage(t *testing.T) {
req := CountTokensRequest{
Model: "test-model",
Messages: []MessageParam{
{Role: "user", Content: "Hello, world!"},
},
}
tokens := estimateTokens(req)
// "user" (4) + "Hello, world!" (13) = 17 chars / 4 = 4 tokens
if tokens < 1 {
t.Errorf("expected at least 1 token, got %d", tokens)
}
// Sanity check: shouldn't be wildly off
if tokens > 10 {
t.Errorf("expected fewer than 10 tokens for short message, got %d", tokens)
}
}
func TestEstimateTokens_WithSystemPrompt(t *testing.T) {
req := CountTokensRequest{
Model: "test-model",
System: "You are a helpful assistant.",
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
tokens := estimateTokens(req)
// System prompt adds to count
if tokens < 5 {
t.Errorf("expected at least 5 tokens with system prompt, got %d", tokens)
}
}
func TestEstimateTokens_WithTools(t *testing.T) {
req := CountTokensRequest{
Model: "test-model",
Messages: []MessageParam{
{Role: "user", Content: "What's the weather?"},
},
Tools: []Tool{
{
Name: "get_weather",
Description: "Get the current weather for a location",
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}}}`),
},
},
}
tokens := estimateTokens(req)
// Tools add significant content
if tokens < 10 {
t.Errorf("expected at least 10 tokens with tools, got %d", tokens)
}
}
func TestEstimateTokens_WithThinking(t *testing.T) {
req := CountTokensRequest{
Model: "test-model",
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "thinking",
"thinking": "Let me think about this carefully...",
},
map[string]any{
"type": "text",
"text": "Here is my response.",
},
},
},
},
}
tokens := estimateTokens(req)
// Thinking content should be counted
if tokens < 10 {
t.Errorf("expected at least 10 tokens with thinking content, got %d", tokens)
}
}
func TestEstimateTokens_EmptyContent(t *testing.T) {
req := CountTokensRequest{
Model: "test-model",
Messages: []MessageParam{},
}
tokens := estimateTokens(req)
if tokens != 0 {
t.Errorf("expected 0 tokens for empty content, got %d", tokens)
}
}

View File

@@ -466,3 +466,25 @@ func (c *Client) Whoami(ctx context.Context) (*UserResponse, error) {
}
return &resp, nil
}
// AliasRequest is the request body for creating or updating a model alias.
type AliasRequest struct {
Alias string `json:"alias"`
Target string `json:"target"`
PrefixMatching bool `json:"prefix_matching,omitempty"`
}
// SetAliasExperimental creates or updates a model alias via the experimental aliases API.
func (c *Client) SetAliasExperimental(ctx context.Context, req *AliasRequest) error {
return c.do(ctx, http.MethodPost, "/api/experimental/aliases", req, nil)
}
// AliasDeleteRequest is the request body for deleting a model alias.
type AliasDeleteRequest struct {
Alias string `json:"alias"`
}
// DeleteAliasExperimental deletes a model alias via the experimental aliases API.
func (c *Client) DeleteAliasExperimental(ctx context.Context, req *AliasDeleteRequest) error {
return c.do(ctx, http.MethodDelete, "/api/experimental/aliases", req, nil)
}

13
cmd/background_unix.go Normal file
View File

@@ -0,0 +1,13 @@
//go:build !windows
package cmd
import "syscall"
// backgroundServerSysProcAttr returns SysProcAttr for running the server in the background on Unix.
// Setpgid prevents the server from being killed when the parent process exits.
func backgroundServerSysProcAttr() *syscall.SysProcAttr {
return &syscall.SysProcAttr{
Setpgid: true,
}
}

12
cmd/background_windows.go Normal file
View File

@@ -0,0 +1,12 @@
package cmd
import "syscall"
// backgroundServerSysProcAttr returns SysProcAttr for running the server in the background on Windows.
// CREATE_NO_WINDOW (0x08000000) prevents a console window from appearing.
func backgroundServerSysProcAttr() *syscall.SysProcAttr {
return &syscall.SysProcAttr{
CreationFlags: 0x08000000,
HideWindow: true,
}
}

View File

@@ -15,6 +15,7 @@ import (
"net"
"net/http"
"os"
"os/exec"
"os/signal"
"path/filepath"
"runtime"
@@ -37,6 +38,7 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/config"
"github.com/ollama/ollama/cmd/tui"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/parser"
@@ -53,6 +55,49 @@ import (
"github.com/ollama/ollama/x/imagegen"
)
func init() {
// Override default selectors to use Bubbletea TUI instead of raw terminal I/O.
config.DefaultSingleSelector = func(title string, items []config.ModelItem) (string, error) {
tuiItems := make([]tui.SelectItem, len(items))
for i, item := range items {
tuiItems[i] = tui.SelectItem{Name: item.Name, Description: item.Description, Recommended: item.Recommended}
}
result, err := tui.SelectSingle(title, tuiItems)
if errors.Is(err, tui.ErrCancelled) {
return "", config.ErrCancelled
}
return result, err
}
config.DefaultMultiSelector = func(title string, items []config.ModelItem, preChecked []string) ([]string, error) {
tuiItems := make([]tui.SelectItem, len(items))
for i, item := range items {
tuiItems[i] = tui.SelectItem{Name: item.Name, Description: item.Description, Recommended: item.Recommended}
}
result, err := tui.SelectMultiple(title, tuiItems, preChecked)
if errors.Is(err, tui.ErrCancelled) {
return nil, config.ErrCancelled
}
return result, err
}
config.DefaultSignIn = func(modelName, signInURL string) (string, error) {
userName, err := tui.RunSignIn(modelName, signInURL)
if errors.Is(err, tui.ErrCancelled) {
return "", config.ErrCancelled
}
return userName, err
}
config.DefaultConfirmPrompt = func(prompt string) (bool, error) {
ok, err := tui.RunConfirm(prompt)
if errors.Is(err, tui.ErrCancelled) {
return false, config.ErrCancelled
}
return ok, err
}
}
const ConnectInstructions = "If your browser did not open, navigate to:\n %s\n\n"
// ensureThinkingSupport emits a warning if the model does not advertise thinking support
@@ -1763,7 +1808,7 @@ func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
return err
}
if err := startApp(cmd.Context(), client); err != nil {
return fmt.Errorf("ollama server not responding - %w", err)
return err
}
}
return nil
@@ -1804,6 +1849,197 @@ Environment Variables:
cmd.SetUsageTemplate(cmd.UsageTemplate() + envUsage)
}
// ensureServerRunning checks if the ollama server is running and starts it in the background if not.
func ensureServerRunning(ctx context.Context) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
// Check if server is already running
if err := client.Heartbeat(ctx); err == nil {
return nil // server is already running
}
// Server not running, start it in the background
exe, err := os.Executable()
if err != nil {
return fmt.Errorf("could not find executable: %w", err)
}
serverCmd := exec.CommandContext(ctx, exe, "serve")
serverCmd.Env = os.Environ()
serverCmd.SysProcAttr = backgroundServerSysProcAttr()
if err := serverCmd.Start(); err != nil {
return fmt.Errorf("failed to start server: %w", err)
}
// Wait for the server to be ready
for {
time.Sleep(500 * time.Millisecond)
if err := client.Heartbeat(ctx); err == nil {
return nil // server has started
}
}
}
// runInteractiveTUI runs the main interactive TUI menu.
func runInteractiveTUI(cmd *cobra.Command) {
// Ensure the server is running before showing the TUI
if err := ensureServerRunning(cmd.Context()); err != nil {
fmt.Fprintf(os.Stderr, "Error starting server: %v\n", err)
return
}
// Selector adapters for tui
singleSelector := func(title string, items []config.ModelItem) (string, error) {
tuiItems := make([]tui.SelectItem, len(items))
for i, item := range items {
tuiItems[i] = tui.SelectItem{Name: item.Name, Description: item.Description, Recommended: item.Recommended}
}
result, err := tui.SelectSingle(title, tuiItems)
if errors.Is(err, tui.ErrCancelled) {
return "", config.ErrCancelled
}
return result, err
}
multiSelector := func(title string, items []config.ModelItem, preChecked []string) ([]string, error) {
tuiItems := make([]tui.SelectItem, len(items))
for i, item := range items {
tuiItems[i] = tui.SelectItem{Name: item.Name, Description: item.Description, Recommended: item.Recommended}
}
result, err := tui.SelectMultiple(title, tuiItems, preChecked)
if errors.Is(err, tui.ErrCancelled) {
return nil, config.ErrCancelled
}
return result, err
}
for {
result, err := tui.Run()
if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
return
}
runModel := func(modelName string) {
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
return
}
if err := config.ShowOrPull(cmd.Context(), client, modelName); err != nil {
if errors.Is(err, config.ErrCancelled) {
return
}
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
return
}
_ = config.SetLastModel(modelName)
opts := runOptions{
Model: modelName,
WordWrap: os.Getenv("TERM") == "xterm-256color",
Options: map[string]any{},
ShowConnect: true,
}
if err := loadOrUnloadModel(cmd, &opts); err != nil {
fmt.Fprintf(os.Stderr, "Error loading model: %v\n", err)
return
}
if err := generateInteractive(cmd, opts); err != nil {
fmt.Fprintf(os.Stderr, "Error running model: %v\n", err)
}
}
launchIntegration := func(name string) bool {
// If not configured or model no longer exists, prompt for model selection
configuredModel := config.IntegrationModel(name)
if configuredModel == "" || !config.ModelExists(cmd.Context(), configuredModel) {
err := config.ConfigureIntegrationWithSelectors(cmd.Context(), name, singleSelector, multiSelector)
if errors.Is(err, config.ErrCancelled) {
return false // Return to main menu
}
if err != nil {
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", name, err)
return true
}
}
if err := config.LaunchIntegration(name); err != nil {
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", name, err)
}
return true
}
switch result.Selection {
case tui.SelectionNone:
// User quit
return
case tui.SelectionRunModel:
_ = config.SetLastSelection("run")
if modelName := config.LastModel(); modelName != "" {
runModel(modelName)
} else {
modelName, err := config.SelectModelWithSelector(cmd.Context(), singleSelector)
if errors.Is(err, config.ErrCancelled) {
continue // Return to main menu
}
if err != nil {
fmt.Fprintf(os.Stderr, "Error selecting model: %v\n", err)
continue
}
runModel(modelName)
}
case tui.SelectionChangeRunModel:
_ = config.SetLastSelection("run")
// Use model from modal if selected, otherwise show picker
modelName := result.Model
if modelName == "" {
var err error
modelName, err = config.SelectModelWithSelector(cmd.Context(), singleSelector)
if errors.Is(err, config.ErrCancelled) {
continue // Return to main menu
}
if err != nil {
fmt.Fprintf(os.Stderr, "Error selecting model: %v\n", err)
continue
}
}
runModel(modelName)
case tui.SelectionIntegration:
_ = config.SetLastSelection(result.Integration)
if !launchIntegration(result.Integration) {
continue // Return to main menu
}
case tui.SelectionChangeIntegration:
_ = config.SetLastSelection(result.Integration)
// Use model from modal if selected, otherwise show picker
if result.Model != "" {
// Model already selected from modal - save and launch
if err := config.SaveIntegrationModel(result.Integration, result.Model); err != nil {
fmt.Fprintf(os.Stderr, "Error saving config: %v\n", err)
continue
}
if err := config.LaunchIntegrationWithModel(result.Integration, result.Model); err != nil {
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
}
} else {
err := config.ConfigureIntegrationWithSelectors(cmd.Context(), result.Integration, singleSelector, multiSelector)
if errors.Is(err, config.ErrCancelled) {
continue // Return to main menu
}
if err != nil {
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", result.Integration, err)
continue
}
if err := config.LaunchIntegration(result.Integration); err != nil {
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
}
}
}
}
}
func NewCLI() *cobra.Command {
log.SetFlags(log.LstdFlags | log.Lshortfile)
cobra.EnableCommandSorting = false
@@ -1826,11 +2062,13 @@ func NewCLI() *cobra.Command {
return
}
cmd.Print(cmd.UsageString())
runInteractiveTUI(cmd)
},
}
rootCmd.Flags().BoolP("version", "v", false, "Show version information")
rootCmd.Flags().Bool("verbose", false, "Show timings for response")
rootCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
createCmd := &cobra.Command{
Use: "create MODEL",
@@ -1934,6 +2172,15 @@ func NewCLI() *cobra.Command {
RunE: SigninHandler,
}
loginCmd := &cobra.Command{
Use: "login",
Short: "Sign in to ollama.com",
Hidden: true,
Args: cobra.ExactArgs(0),
PreRunE: checkServerHeartbeat,
RunE: SigninHandler,
}
signoutCmd := &cobra.Command{
Use: "signout",
Short: "Sign out from ollama.com",
@@ -1942,6 +2189,15 @@ func NewCLI() *cobra.Command {
RunE: SignoutHandler,
}
logoutCmd := &cobra.Command{
Use: "logout",
Short: "Sign out from ollama.com",
Hidden: true,
Args: cobra.ExactArgs(0),
PreRunE: checkServerHeartbeat,
RunE: SignoutHandler,
}
listCmd := &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
@@ -2038,13 +2294,15 @@ func NewCLI() *cobra.Command {
pullCmd,
pushCmd,
signinCmd,
loginCmd,
signoutCmd,
logoutCmd,
listCmd,
psCmd,
copyCmd,
deleteCmd,
runnerCmd,
config.LaunchCmd(checkServerHeartbeat),
config.LaunchCmd(checkServerHeartbeat, runInteractiveTUI),
)
return rootCmd

View File

@@ -1,18 +1,23 @@
package config
import (
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
)
// Claude implements Runner for Claude Code integration
// Claude implements Runner and AliasConfigurer for Claude Code integration
type Claude struct{}
// Compile-time check that Claude implements AliasConfigurer
var _ AliasConfigurer = (*Claude)(nil)
func (c *Claude) String() string { return "Claude Code" }
func (c *Claude) args(model string, extra []string) []string {
@@ -53,10 +58,135 @@ func (c *Claude) Run(model string, args []string) error {
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = append(os.Environ(),
env := append(os.Environ(),
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
"ANTHROPIC_API_KEY=",
"ANTHROPIC_AUTH_TOKEN=ollama",
)
env = append(env, c.modelEnvVars(model)...)
cmd.Env = env
return cmd.Run()
}
// modelEnvVars returns Claude Code env vars that route all model tiers through Ollama.
func (c *Claude) modelEnvVars(model string) []string {
primary := model
fast := model
if cfg, err := loadIntegration("claude"); err == nil && cfg.Aliases != nil {
if p := cfg.Aliases["primary"]; p != "" {
primary = p
}
if f := cfg.Aliases["fast"]; f != "" {
fast = f
}
}
return []string{
"ANTHROPIC_DEFAULT_OPUS_MODEL=" + primary,
"ANTHROPIC_DEFAULT_SONNET_MODEL=" + primary,
"ANTHROPIC_DEFAULT_HAIKU_MODEL=" + fast,
"CLAUDE_CODE_SUBAGENT_MODEL=" + primary,
}
}
// ConfigureAliases sets up model aliases for Claude Code.
// model: the model to use (if empty, user will be prompted to select)
// aliases: existing alias configuration to preserve/update
// Cloud-only: subagent routing (fast model) is gated to cloud models only until
// there is a better strategy for prompt caching on local models.
func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAliases map[string]string, force bool) (map[string]string, bool, error) {
aliases := make(map[string]string)
for k, v := range existingAliases {
aliases[k] = v
}
if model != "" {
aliases["primary"] = model
}
if !force && aliases["primary"] != "" {
client, _ := api.ClientFromEnvironment()
if isCloudModel(ctx, client, aliases["primary"]) {
if isCloudModel(ctx, client, aliases["fast"]) {
return aliases, false, nil
}
} else {
delete(aliases, "fast")
return aliases, false, nil
}
}
items, existingModels, cloudModels, client, err := listModels(ctx)
if err != nil {
return nil, false, err
}
fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n\n", ansiBold, ansiReset)
if aliases["primary"] == "" || force {
primary, err := DefaultSingleSelector("Select model:", items)
if err != nil {
return nil, false, err
}
if err := pullIfNeeded(ctx, client, existingModels, primary); err != nil {
return nil, false, err
}
if err := ensureAuth(ctx, client, cloudModels, []string{primary}); err != nil {
return nil, false, err
}
aliases["primary"] = primary
}
if isCloudModel(ctx, client, aliases["primary"]) {
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
aliases["fast"] = aliases["primary"]
}
} else {
delete(aliases, "fast")
}
return aliases, true, nil
}
// SetAliases syncs the configured aliases to the Ollama server using prefix matching.
// Cloud-only: for local models (fast is empty), we delete any existing aliases to
// prevent stale routing to a previous cloud model.
func (c *Claude) SetAliases(ctx context.Context, aliases map[string]string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
prefixes := []string{"claude-sonnet-", "claude-haiku-"}
if aliases["fast"] == "" {
for _, prefix := range prefixes {
_ = client.DeleteAliasExperimental(ctx, &api.AliasDeleteRequest{Alias: prefix})
}
return nil
}
prefixAliases := map[string]string{
"claude-sonnet-": aliases["primary"],
"claude-haiku-": aliases["fast"],
}
var errs []string
for prefix, target := range prefixAliases {
req := &api.AliasRequest{
Alias: prefix,
Target: target,
PrefixMatching: true,
}
if err := client.SetAliasExperimental(ctx, req); err != nil {
errs = append(errs, prefix)
}
}
if len(errs) > 0 {
return fmt.Errorf("failed to set aliases: %v", errs)
}
return nil
}

View File

@@ -5,6 +5,7 @@ import (
"path/filepath"
"runtime"
"slices"
"strings"
"testing"
)
@@ -103,3 +104,95 @@ func TestClaudeArgs(t *testing.T) {
})
}
}
func TestClaudeModelEnvVars(t *testing.T) {
c := &Claude{}
envMap := func(envs []string) map[string]string {
m := make(map[string]string)
for _, e := range envs {
k, v, _ := strings.Cut(e, "=")
m[k] = v
}
return m
}
t.Run("falls back to model param when no aliases saved", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
got := envMap(c.modelEnvVars("llama3.2"))
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2" {
t.Errorf("OPUS = %q, want llama3.2", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
}
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "llama3.2" {
t.Errorf("SONNET = %q, want llama3.2", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
}
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "llama3.2" {
t.Errorf("HAIKU = %q, want llama3.2", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
}
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2" {
t.Errorf("SUBAGENT = %q, want llama3.2", got["CLAUDE_CODE_SUBAGENT_MODEL"])
}
})
t.Run("uses primary alias for opus sonnet and subagent", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
saveIntegration("claude", []string{"qwen3:8b"})
saveAliases("claude", map[string]string{"primary": "qwen3:8b"})
got := envMap(c.modelEnvVars("qwen3:8b"))
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "qwen3:8b" {
t.Errorf("OPUS = %q, want qwen3:8b", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
}
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "qwen3:8b" {
t.Errorf("SONNET = %q, want qwen3:8b", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
}
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "qwen3:8b" {
t.Errorf("HAIKU = %q, want qwen3:8b (no fast alias)", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
}
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "qwen3:8b" {
t.Errorf("SUBAGENT = %q, want qwen3:8b", got["CLAUDE_CODE_SUBAGENT_MODEL"])
}
})
t.Run("uses fast alias for haiku", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
saveIntegration("claude", []string{"llama3.2:70b"})
saveAliases("claude", map[string]string{
"primary": "llama3.2:70b",
"fast": "llama3.2:8b",
})
got := envMap(c.modelEnvVars("llama3.2:70b"))
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2:70b" {
t.Errorf("OPUS = %q, want llama3.2:70b", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
}
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "llama3.2:70b" {
t.Errorf("SONNET = %q, want llama3.2:70b", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
}
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "llama3.2:8b" {
t.Errorf("HAIKU = %q, want llama3.2:8b", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
}
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2:70b" {
t.Errorf("SUBAGENT = %q, want llama3.2:70b", got["CLAUDE_CODE_SUBAGENT_MODEL"])
}
})
t.Run("alias primary overrides model param", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
saveIntegration("claude", []string{"saved-model"})
saveAliases("claude", map[string]string{"primary": "saved-model"})
got := envMap(c.modelEnvVars("different-model"))
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "saved-model" {
t.Errorf("OPUS = %q, want saved-model", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
}
})
}

View File

@@ -3,21 +3,26 @@
package config
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"os"
"path/filepath"
"strings"
"github.com/ollama/ollama/api"
)
type integration struct {
Models []string `json:"models"`
Models []string `json:"models"`
Aliases map[string]string `json:"aliases,omitempty"`
}
type config struct {
Integrations map[string]*integration `json:"integrations"`
Integrations map[string]*integration `json:"integrations"`
LastModel string `json:"last_model,omitempty"`
LastSelection string `json:"last_selection,omitempty"` // "run" or integration name
}
func configPath() (string, error) {
@@ -53,7 +58,6 @@ func migrateConfig() (bool, error) {
var js json.RawMessage
if err := json.Unmarshal(oldData, &js); err != nil {
slog.Warn("legacy config has invalid JSON, skipping migration", "path", oldPath, "error", err)
return false, nil
}
@@ -72,7 +76,6 @@ func migrateConfig() (bool, error) {
_ = os.Remove(oldPath)
_ = os.Remove(filepath.Dir(oldPath)) // clean up empty directory
slog.Info("migrated config", "from", oldPath, "to", newPath)
return true, nil
}
@@ -133,13 +136,89 @@ func saveIntegration(appName string, models []string) error {
return err
}
cfg.Integrations[strings.ToLower(appName)] = &integration{
Models: models,
key := strings.ToLower(appName)
existing := cfg.Integrations[key]
var aliases map[string]string
if existing != nil && existing.Aliases != nil {
aliases = existing.Aliases
}
cfg.Integrations[key] = &integration{
Models: models,
Aliases: aliases,
}
return save(cfg)
}
// IntegrationModel returns the first configured model for an integration, or empty string if not configured.
func IntegrationModel(appName string) string {
ic, err := loadIntegration(appName)
if err != nil || len(ic.Models) == 0 {
return ""
}
return ic.Models[0]
}
// LastModel returns the last model that was run, or empty string if none.
func LastModel() string {
cfg, err := load()
if err != nil {
return ""
}
return cfg.LastModel
}
// SetLastModel saves the last model that was run.
func SetLastModel(model string) error {
cfg, err := load()
if err != nil {
return err
}
cfg.LastModel = model
return save(cfg)
}
// LastSelection returns the last menu selection ("run" or integration name), or empty string if none.
func LastSelection() string {
cfg, err := load()
if err != nil {
return ""
}
return cfg.LastSelection
}
// SetLastSelection saves the last menu selection ("run" or integration name).
func SetLastSelection(selection string) error {
cfg, err := load()
if err != nil {
return err
}
cfg.LastSelection = selection
return save(cfg)
}
// ModelExists checks if a model exists on the Ollama server.
func ModelExists(ctx context.Context, name string) bool {
if name == "" {
return false
}
client, err := api.ClientFromEnvironment()
if err != nil {
return false
}
models, err := client.List(ctx)
if err != nil {
return false
}
for _, m := range models.Models {
if m.Name == name || strings.HasPrefix(m.Name, name+":") {
return true
}
}
return false
}
func loadIntegration(appName string) (*integration, error) {
cfg, err := load()
if err != nil {
@@ -154,6 +233,29 @@ func loadIntegration(appName string) (*integration, error) {
return ic, nil
}
func saveAliases(appName string, aliases map[string]string) error {
if appName == "" {
return errors.New("app name cannot be empty")
}
cfg, err := load()
if err != nil {
return err
}
key := strings.ToLower(appName)
existing := cfg.Integrations[key]
if existing == nil {
existing = &integration{}
}
// Replace aliases entirely (not merge) so deletions are persisted
existing.Aliases = aliases
cfg.Integrations[key] = existing
return save(cfg)
}
func listIntegrations() ([]integration, error) {
cfg, err := load()
if err != nil {

View File

@@ -0,0 +1,677 @@
package config
import (
"context"
"errors"
"os"
"path/filepath"
"testing"
)
func TestSetAliases_CloudModel(t *testing.T) {
// Test the SetAliases logic by checking the alias map behavior
aliases := map[string]string{
"primary": "kimi-k2.5:cloud",
"fast": "kimi-k2.5:cloud",
}
// Verify fast is set (cloud model behavior)
if aliases["fast"] == "" {
t.Error("cloud model should have fast alias set")
}
if aliases["fast"] != aliases["primary"] {
t.Errorf("fast should equal primary for auto-set, got fast=%q primary=%q", aliases["fast"], aliases["primary"])
}
}
func TestSetAliases_LocalModel(t *testing.T) {
aliases := map[string]string{
"primary": "llama3.2:latest",
}
// Simulate local model behavior: fast should be empty
delete(aliases, "fast")
if aliases["fast"] != "" {
t.Error("local model should have empty fast alias")
}
}
func TestSaveAliases_ReplacesNotMerges(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// First save with both primary and fast
initial := map[string]string{
"primary": "cloud-model",
"fast": "cloud-model",
}
if err := saveAliases("claude", initial); err != nil {
t.Fatalf("failed to save initial aliases: %v", err)
}
// Verify both are saved
loaded, err := loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if loaded.Aliases["fast"] != "cloud-model" {
t.Errorf("expected fast=cloud-model, got %q", loaded.Aliases["fast"])
}
// Now save without fast (simulating switch to local model)
updated := map[string]string{
"primary": "local-model",
// fast intentionally missing
}
if err := saveAliases("claude", updated); err != nil {
t.Fatalf("failed to save updated aliases: %v", err)
}
// Verify fast is GONE (not merged/preserved)
loaded, err = loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load after update: %v", err)
}
if loaded.Aliases["fast"] != "" {
t.Errorf("fast should be removed after saving without it, got %q", loaded.Aliases["fast"])
}
if loaded.Aliases["primary"] != "local-model" {
t.Errorf("primary should be updated to local-model, got %q", loaded.Aliases["primary"])
}
}
func TestSaveAliases_PreservesModels(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// First save integration with models
if err := saveIntegration("claude", []string{"model1", "model2"}); err != nil {
t.Fatalf("failed to save integration: %v", err)
}
// Then update aliases
aliases := map[string]string{"primary": "new-model"}
if err := saveAliases("claude", aliases); err != nil {
t.Fatalf("failed to save aliases: %v", err)
}
// Verify models are preserved
loaded, err := loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if len(loaded.Models) != 2 || loaded.Models[0] != "model1" {
t.Errorf("models should be preserved, got %v", loaded.Models)
}
}
// TestSaveAliases_EmptyMap clears all aliases
func TestSaveAliases_EmptyMap(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Save with aliases
if err := saveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil {
t.Fatalf("failed to save: %v", err)
}
// Save empty map
if err := saveAliases("claude", map[string]string{}); err != nil {
t.Fatalf("failed to save empty: %v", err)
}
loaded, err := loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if len(loaded.Aliases) != 0 {
t.Errorf("aliases should be empty, got %v", loaded.Aliases)
}
}
// TestSaveAliases_NilMap handles nil gracefully
func TestSaveAliases_NilMap(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Save with aliases first
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
t.Fatalf("failed to save: %v", err)
}
// Save nil map - should clear aliases
if err := saveAliases("claude", nil); err != nil {
t.Fatalf("failed to save nil: %v", err)
}
loaded, err := loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if len(loaded.Aliases) > 0 {
t.Errorf("aliases should be nil or empty, got %v", loaded.Aliases)
}
}
// TestSaveAliases_EmptyAppName returns error
func TestSaveAliases_EmptyAppName(t *testing.T) {
err := saveAliases("", map[string]string{"primary": "model"})
if err == nil {
t.Error("expected error for empty app name")
}
}
func TestSaveAliases_CaseInsensitive(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
if err := saveAliases("Claude", map[string]string{"primary": "model1"}); err != nil {
t.Fatalf("failed to save: %v", err)
}
// Load with different case
loaded, err := loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if loaded.Aliases["primary"] != "model1" {
t.Errorf("expected primary=model1, got %q", loaded.Aliases["primary"])
}
// Update with different case
if err := saveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil {
t.Fatalf("failed to update: %v", err)
}
loaded, err = loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load after update: %v", err)
}
if loaded.Aliases["primary"] != "model2" {
t.Errorf("expected primary=model2, got %q", loaded.Aliases["primary"])
}
}
// TestSaveAliases_CreatesIntegration creates integration if it doesn't exist
func TestSaveAliases_CreatesIntegration(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Save aliases for non-existent integration
if err := saveAliases("newintegration", map[string]string{"primary": "model"}); err != nil {
t.Fatalf("failed to save: %v", err)
}
loaded, err := loadIntegration("newintegration")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if loaded.Aliases["primary"] != "model" {
t.Errorf("expected primary=model, got %q", loaded.Aliases["primary"])
}
}
func TestConfigureAliases_AliasMap(t *testing.T) {
t.Run("cloud model auto-sets fast to primary", func(t *testing.T) {
aliases := make(map[string]string)
aliases["primary"] = "cloud-model"
// Simulate cloud model behavior
isCloud := true
if isCloud {
if aliases["fast"] == "" {
aliases["fast"] = aliases["primary"]
}
}
if aliases["fast"] != "cloud-model" {
t.Errorf("expected fast=cloud-model, got %q", aliases["fast"])
}
})
t.Run("cloud model preserves custom fast", func(t *testing.T) {
aliases := map[string]string{
"primary": "cloud-model",
"fast": "custom-fast-model",
}
// Simulate cloud model behavior - should preserve existing fast
isCloud := true
if isCloud {
if aliases["fast"] == "" {
aliases["fast"] = aliases["primary"]
}
}
if aliases["fast"] != "custom-fast-model" {
t.Errorf("expected fast=custom-fast-model (preserved), got %q", aliases["fast"])
}
})
t.Run("local model clears fast", func(t *testing.T) {
aliases := map[string]string{
"primary": "local-model",
"fast": "should-be-cleared",
}
// Simulate local model behavior
isCloud := false
if !isCloud {
delete(aliases, "fast")
}
if aliases["fast"] != "" {
t.Errorf("expected fast to be cleared, got %q", aliases["fast"])
}
})
t.Run("switching cloud to local clears fast", func(t *testing.T) {
// Start with cloud config
aliases := map[string]string{
"primary": "cloud-model",
"fast": "cloud-model",
}
// Switch to local
aliases["primary"] = "local-model"
isCloud := false
if !isCloud {
delete(aliases, "fast")
}
if aliases["fast"] != "" {
t.Errorf("fast should be cleared when switching to local, got %q", aliases["fast"])
}
if aliases["primary"] != "local-model" {
t.Errorf("primary should be updated, got %q", aliases["primary"])
}
})
t.Run("switching local to cloud sets fast", func(t *testing.T) {
// Start with local config (no fast)
aliases := map[string]string{
"primary": "local-model",
}
// Switch to cloud
aliases["primary"] = "cloud-model"
isCloud := true
if isCloud {
if aliases["fast"] == "" {
aliases["fast"] = aliases["primary"]
}
}
if aliases["fast"] != "cloud-model" {
t.Errorf("fast should be set when switching to cloud, got %q", aliases["fast"])
}
})
}
func TestSetAliases_PrefixMapping(t *testing.T) {
// This tests the expected mapping without needing a real client
aliases := map[string]string{
"primary": "my-cloud-model",
"fast": "my-fast-model",
}
expectedMappings := map[string]string{
"claude-sonnet-": aliases["primary"],
"claude-haiku-": aliases["fast"],
}
if expectedMappings["claude-sonnet-"] != "my-cloud-model" {
t.Errorf("claude-sonnet- should map to primary")
}
if expectedMappings["claude-haiku-"] != "my-fast-model" {
t.Errorf("claude-haiku- should map to fast")
}
}
func TestSetAliases_LocalDeletesPrefixes(t *testing.T) {
aliases := map[string]string{
"primary": "local-model",
// fast is empty/missing - indicates local model
}
prefixesToDelete := []string{"claude-sonnet-", "claude-haiku-"}
// Verify the logic: when fast is empty, we should delete
if aliases["fast"] != "" {
t.Error("fast should be empty for local model")
}
// Verify we have the right prefixes to delete
if len(prefixesToDelete) != 2 {
t.Errorf("expected 2 prefixes to delete, got %d", len(prefixesToDelete))
}
}
// TestAtomicUpdate_ServerFailsConfigNotSaved simulates atomic update behavior
func TestAtomicUpdate_ServerFailsConfigNotSaved(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Simulate: server fails, config should NOT be saved
serverErr := errors.New("server unavailable")
if serverErr == nil {
t.Error("config should NOT be saved when server fails")
}
}
// TestAtomicUpdate_ServerSucceedsConfigSaved simulates successful atomic update
func TestAtomicUpdate_ServerSucceedsConfigSaved(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Simulate: server succeeds, config should be saved
var serverErr error
if serverErr != nil {
t.Fatal("server should succeed")
}
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
t.Fatalf("saveAliases failed: %v", err)
}
// Verify it was actually saved
loaded, err := loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if loaded.Aliases["primary"] != "model" {
t.Errorf("expected primary=model, got %q", loaded.Aliases["primary"])
}
}
func TestConfigFile_PreservesUnknownFields(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Write config with extra fields
configPath := filepath.Join(tmpDir, ".ollama", "config.json")
os.MkdirAll(filepath.Dir(configPath), 0o755)
// Note: Our config struct only has Integrations, so top-level unknown fields
// won't be preserved by our current implementation. This test documents that.
initialConfig := `{
"integrations": {
"claude": {
"models": ["model1"],
"aliases": {"primary": "model1"},
"unknownField": "should be lost"
}
},
"topLevelUnknown": "will be lost"
}`
os.WriteFile(configPath, []byte(initialConfig), 0o644)
// Update aliases
if err := saveAliases("claude", map[string]string{"primary": "model2"}); err != nil {
t.Fatalf("failed to save: %v", err)
}
// Read raw file to check
data, _ := os.ReadFile(configPath)
content := string(data)
// models should be preserved
if !contains(content, "model1") {
t.Error("models should be preserved")
}
// primary should be updated
if !contains(content, "model2") {
t.Error("primary should be updated to model2")
}
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr))
}
func containsHelper(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
func TestClaudeImplementsAliasConfigurer(t *testing.T) {
c := &Claude{}
var _ AliasConfigurer = c // Compile-time check
}
func TestModelNameEdgeCases(t *testing.T) {
testCases := []struct {
name string
model string
}{
{"simple", "llama3.2"},
{"with tag", "llama3.2:latest"},
{"with cloud tag", "kimi-k2.5:cloud"},
{"with namespace", "library/llama3.2"},
{"with dots", "glm-4.7-flash"},
{"with numbers", "qwen3:8b"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
aliases := map[string]string{"primary": tc.model}
if err := saveAliases("claude", aliases); err != nil {
t.Fatalf("failed to save model %q: %v", tc.model, err)
}
loaded, err := loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if loaded.Aliases["primary"] != tc.model {
t.Errorf("expected primary=%q, got %q", tc.model, loaded.Aliases["primary"])
}
})
}
}
func TestSwitchingScenarios(t *testing.T) {
t.Run("cloud to local removes fast", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Initial cloud config
if err := saveAliases("claude", map[string]string{
"primary": "cloud-model",
"fast": "cloud-model",
}); err != nil {
t.Fatal(err)
}
// Switch to local (no fast)
if err := saveAliases("claude", map[string]string{
"primary": "local-model",
}); err != nil {
t.Fatal(err)
}
loaded, _ := loadIntegration("claude")
if loaded.Aliases["fast"] != "" {
t.Errorf("fast should be removed, got %q", loaded.Aliases["fast"])
}
if loaded.Aliases["primary"] != "local-model" {
t.Errorf("primary should be local-model, got %q", loaded.Aliases["primary"])
}
})
t.Run("local to cloud adds fast", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Initial local config
if err := saveAliases("claude", map[string]string{
"primary": "local-model",
}); err != nil {
t.Fatal(err)
}
// Switch to cloud (with fast)
if err := saveAliases("claude", map[string]string{
"primary": "cloud-model",
"fast": "cloud-model",
}); err != nil {
t.Fatal(err)
}
loaded, _ := loadIntegration("claude")
if loaded.Aliases["fast"] != "cloud-model" {
t.Errorf("fast should be cloud-model, got %q", loaded.Aliases["fast"])
}
})
t.Run("cloud to different cloud updates both", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Initial cloud config
if err := saveAliases("claude", map[string]string{
"primary": "cloud-model-1",
"fast": "cloud-model-1",
}); err != nil {
t.Fatal(err)
}
// Switch to different cloud
if err := saveAliases("claude", map[string]string{
"primary": "cloud-model-2",
"fast": "cloud-model-2",
}); err != nil {
t.Fatal(err)
}
loaded, _ := loadIntegration("claude")
if loaded.Aliases["primary"] != "cloud-model-2" {
t.Errorf("primary should be cloud-model-2, got %q", loaded.Aliases["primary"])
}
if loaded.Aliases["fast"] != "cloud-model-2" {
t.Errorf("fast should be cloud-model-2, got %q", loaded.Aliases["fast"])
}
})
}
func TestToolCapabilityFiltering(t *testing.T) {
t.Run("all models checked for tool capability", func(t *testing.T) {
// Both cloud and local models are checked for tool capability via Show API
// Only models with "tools" in capabilities are included
m := modelInfo{Name: "tool-model", Remote: false, ToolCapable: true}
if !m.ToolCapable {
t.Error("tool capable model should be marked as such")
}
})
t.Run("modelInfo includes ToolCapable field", func(t *testing.T) {
m := modelInfo{Name: "test", Remote: true, ToolCapable: true}
if !m.ToolCapable {
t.Error("ToolCapable field should be accessible")
}
})
}
func TestIsCloudModel_RequiresClient(t *testing.T) {
t.Run("nil client always returns false", func(t *testing.T) {
// isCloudModel now only uses Show API, no suffix detection
if isCloudModel(context.Background(), nil, "model:cloud") {
t.Error("nil client should return false regardless of suffix")
}
if isCloudModel(context.Background(), nil, "local-model") {
t.Error("nil client should return false")
}
})
}
func TestModelsAndAliasesMustStayInSync(t *testing.T) {
t.Run("saveAliases followed by saveIntegration keeps them in sync", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Save aliases with one model
if err := saveAliases("claude", map[string]string{"primary": "model-a"}); err != nil {
t.Fatal(err)
}
// Save integration with same model (this is the pattern we use)
if err := saveIntegration("claude", []string{"model-a"}); err != nil {
t.Fatal(err)
}
loaded, _ := loadIntegration("claude")
if loaded.Aliases["primary"] != loaded.Models[0] {
t.Errorf("aliases.primary (%q) != models[0] (%q)", loaded.Aliases["primary"], loaded.Models[0])
}
})
t.Run("out of sync config is detectable", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Simulate out-of-sync state (like manual edit or bug)
if err := saveIntegration("claude", []string{"old-model"}); err != nil {
t.Fatal(err)
}
if err := saveAliases("claude", map[string]string{"primary": "new-model"}); err != nil {
t.Fatal(err)
}
loaded, _ := loadIntegration("claude")
// They should be different (this is the bug state)
if loaded.Models[0] == loaded.Aliases["primary"] {
t.Error("expected out-of-sync state for this test")
}
// The fix: when updating aliases, also update models
if err := saveIntegration("claude", []string{loaded.Aliases["primary"]}); err != nil {
t.Fatal(err)
}
loaded, _ = loadIntegration("claude")
if loaded.Models[0] != loaded.Aliases["primary"] {
t.Errorf("after fix: models[0] (%q) should equal aliases.primary (%q)",
loaded.Models[0], loaded.Aliases["primary"])
}
})
t.Run("updating primary alias updates models too", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Initial state
if err := saveIntegration("claude", []string{"initial-model"}); err != nil {
t.Fatal(err)
}
if err := saveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil {
t.Fatal(err)
}
// Update aliases AND models together
newAliases := map[string]string{"primary": "updated-model"}
if err := saveAliases("claude", newAliases); err != nil {
t.Fatal(err)
}
if err := saveIntegration("claude", []string{newAliases["primary"]}); err != nil {
t.Fatal(err)
}
loaded, _ := loadIntegration("claude")
if loaded.Models[0] != "updated-model" {
t.Errorf("models[0] should be updated-model, got %q", loaded.Models[0])
}
if loaded.Aliases["primary"] != "updated-model" {
t.Errorf("aliases.primary should be updated-model, got %q", loaded.Aliases["primary"])
}
})
}

View File

@@ -46,6 +46,53 @@ func TestIntegrationConfig(t *testing.T) {
}
})
t.Run("save and load aliases", func(t *testing.T) {
models := []string{"llama3.2"}
if err := saveIntegration("claude", models); err != nil {
t.Fatal(err)
}
aliases := map[string]string{
"primary": "llama3.2:70b",
"fast": "llama3.2:8b",
}
if err := saveAliases("claude", aliases); err != nil {
t.Fatal(err)
}
config, err := loadIntegration("claude")
if err != nil {
t.Fatal(err)
}
if config.Aliases == nil {
t.Fatal("expected aliases to be saved")
}
for k, v := range aliases {
if config.Aliases[k] != v {
t.Errorf("alias %s: expected %s, got %s", k, v, config.Aliases[k])
}
}
})
t.Run("saveIntegration preserves aliases", func(t *testing.T) {
if err := saveIntegration("claude", []string{"model-a"}); err != nil {
t.Fatal(err)
}
if err := saveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
t.Fatal(err)
}
if err := saveIntegration("claude", []string{"model-b"}); err != nil {
t.Fatal(err)
}
config, err := loadIntegration("claude")
if err != nil {
t.Fatal(err)
}
if config.Aliases["primary"] != "model-a" {
t.Errorf("expected aliases to be preserved, got %v", config.Aliases)
}
})
t.Run("defaultModel returns first model", func(t *testing.T) {
saveIntegration("codex", []string{"model-a", "model-b"})

View File

@@ -1,6 +1,7 @@
package config
import (
"context"
"encoding/json"
"fmt"
"os"
@@ -8,6 +9,7 @@ import (
"path/filepath"
"slices"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
)
@@ -112,9 +114,17 @@ func (d *Droid) Edit(models []string) error {
}
// Build new Ollama model entries with sequential indices (0, 1, 2, ...)
client, _ := api.ClientFromEnvironment()
var newModels []any
var defaultModelID string
for i, model := range models {
maxOutput := 64000
if isCloudModel(context.Background(), client, model) {
if l, ok := lookupCloudModelLimit(model); ok {
maxOutput = l.Output
}
}
modelID := fmt.Sprintf("custom:%s-%d", model, i)
newModels = append(newModels, modelEntry{
Model: model,
@@ -122,7 +132,7 @@ func (d *Droid) Edit(models []string) error {
BaseURL: envconfig.Host().String() + "/v1",
APIKey: "ollama",
Provider: "generic-chat-completion-api",
MaxOutputTokens: 64000,
MaxOutputTokens: maxOutput,
SupportsImages: false,
ID: modelID,
Index: i,

View File

@@ -1251,6 +1251,55 @@ func TestDroidEdit_LargeNumberOfModels(t *testing.T) {
}
}
func TestDroidEdit_LocalModelDefaultMaxOutput(t *testing.T) {
d := &Droid{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
settingsDir := filepath.Join(tmpDir, ".factory")
settingsPath := filepath.Join(settingsDir, "settings.json")
if err := d.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(settingsPath)
var settings map[string]any
json.Unmarshal(data, &settings)
models := settings["customModels"].([]any)
entry := models[0].(map[string]any)
if entry["maxOutputTokens"] != float64(64000) {
t.Errorf("local model maxOutputTokens = %v, want 64000", entry["maxOutputTokens"])
}
}
func TestDroidEdit_CloudModelLimitsUsed(t *testing.T) {
// Verify that every cloud model in cloudModelLimits has a valid output
// value that would be used for maxOutputTokens when isCloudModel returns true.
// :cloud suffix stripping must also work since that's how users specify them.
for name, expected := range cloudModelLimits {
t.Run(name, func(t *testing.T) {
l, ok := lookupCloudModelLimit(name)
if !ok {
t.Fatalf("lookupCloudModelLimit(%q) returned false", name)
}
if l.Output != expected.Output {
t.Errorf("output = %d, want %d", l.Output, expected.Output)
}
// Also verify :cloud suffix lookup
cloudName := name + ":cloud"
l2, ok := lookupCloudModelLimit(cloudName)
if !ok {
t.Fatalf("lookupCloudModelLimit(%q) returned false", cloudName)
}
if l2.Output != expected.Output {
t.Errorf(":cloud output = %d, want %d", l2.Output, expected.Output)
}
})
}
}
func TestDroidEdit_ArraysWithMixedTypes(t *testing.T) {
d := &Droid{}
tmpDir := t.TempDir()

View File

@@ -39,6 +39,15 @@ type Editor interface {
Models() []string
}
// AliasConfigurer can configure model aliases (e.g., for subagent routing).
// Integrations like Claude and Codex use this to route model requests to local models.
type AliasConfigurer interface {
// ConfigureAliases prompts the user to configure aliases and returns the updated map.
ConfigureAliases(ctx context.Context, primaryModel string, existing map[string]string, force bool) (map[string]string, bool, error)
// SetAliases syncs the configured aliases to the server
SetAliases(ctx context.Context, aliases map[string]string) error
}
// integrations is the registry of available integrations.
var integrations = map[string]Runner{
"claude": &Claude{},
@@ -48,30 +57,268 @@ var integrations = map[string]Runner{
"droid": &Droid{},
"opencode": &OpenCode{},
"openclaw": &Openclaw{},
"pi": &Pi{},
}
// recommendedModels are shown when the user has no models or as suggestions.
// Order matters: local models first, then cloud models.
var recommendedModels = []selectItem{
{Name: "glm-4.7-flash", Description: "Recommended (requires ~25GB VRAM)"},
{Name: "qwen3:8b", Description: "Recommended (requires ~11GB VRAM)"},
{Name: "glm-4.7:cloud", Description: "Recommended"},
{Name: "kimi-k2.5:cloud", Description: "Recommended"},
var recommendedModels = []ModelItem{
{Name: "kimi-k2.5:cloud", Description: "Multimodal reasoning with subagents", Recommended: true},
{Name: "glm-4.7:cloud", Description: "Reasoning and code generation", Recommended: true},
{Name: "glm-4.7-flash", Description: "Reasoning and code generation locally", Recommended: true},
{Name: "qwen3:8b", Description: "Efficient all-purpose assistant", Recommended: true},
}
// recommendedVRAM maps local recommended models to their approximate VRAM requirement.
var recommendedVRAM = map[string]string{
"glm-4.7-flash": "~25GB",
"qwen3:8b": "~11GB",
}
// integrationAliases are hidden from the interactive selector but work as CLI arguments.
var integrationAliases = map[string]bool{
"clawdbot": true,
"moltbot": true,
"pi": true,
}
// integrationInstallHints maps integration names to install URLs.
var integrationInstallHints = map[string]string{
"claude": "https://code.claude.com/docs/en/quickstart",
"openclaw": "https://docs.openclaw.ai",
"codex": "https://developers.openai.com/codex/cli/",
"droid": "https://docs.factory.ai/cli/getting-started/quickstart",
"opencode": "https://opencode.ai",
}
// hyperlink wraps text in an OSC 8 terminal hyperlink so it is cmd+clickable.
func hyperlink(url, text string) string {
return fmt.Sprintf("\033]8;;%s\033\\%s\033]8;;\033\\", url, text)
}
// IntegrationInfo contains display information about a registered integration.
type IntegrationInfo struct {
Name string // registry key, e.g. "claude"
DisplayName string // human-readable, e.g. "Claude Code"
Description string // short description, e.g. "Anthropic's agentic coding tool"
}
// integrationDescriptions maps integration names to short descriptions.
var integrationDescriptions = map[string]string{
"claude": "Anthropic's coding tool with subagents",
"codex": "OpenAI's open-source coding agent",
"openclaw": "Personal AI with 100+ skills",
"droid": "Factory's coding agent across terminal and IDEs",
"opencode": "Anomaly's open-source coding agent",
}
// ListIntegrationInfos returns all non-alias registered integrations, sorted by name.
func ListIntegrationInfos() []IntegrationInfo {
var result []IntegrationInfo
for name, r := range integrations {
if integrationAliases[name] {
continue
}
result = append(result, IntegrationInfo{
Name: name,
DisplayName: r.String(),
Description: integrationDescriptions[name],
})
}
slices.SortFunc(result, func(a, b IntegrationInfo) int {
return strings.Compare(a.Name, b.Name)
})
return result
}
// IntegrationInstallHint returns a user-friendly install hint for the given integration,
// or an empty string if none is available. The URL is wrapped in an OSC 8 hyperlink
// so it is cmd+clickable in supported terminals.
func IntegrationInstallHint(name string) string {
url := integrationInstallHints[name]
if url == "" {
return ""
}
return "Install from " + hyperlink(url, url)
}
// IsIntegrationInstalled checks if an integration binary is installed.
func IsIntegrationInstalled(name string) bool {
switch name {
case "claude":
c := &Claude{}
_, err := c.findPath()
return err == nil
case "openclaw":
if _, err := exec.LookPath("openclaw"); err == nil {
return true
}
if _, err := exec.LookPath("clawdbot"); err == nil {
return true
}
return false
case "codex":
_, err := exec.LookPath("codex")
return err == nil
case "droid":
_, err := exec.LookPath("droid")
return err == nil
case "opencode":
_, err := exec.LookPath("opencode")
return err == nil
default:
return true // Assume installed for unknown integrations
}
}
// SelectModel lets the user select a model to run.
// ModelItem represents a model for selection.
type ModelItem struct {
Name string
Description string
Recommended bool
}
// SingleSelector is a function type for single item selection.
type SingleSelector func(title string, items []ModelItem) (string, error)
// MultiSelector is a function type for multi item selection.
type MultiSelector func(title string, items []ModelItem, preChecked []string) ([]string, error)
// SelectModelWithSelector prompts the user to select a model using the provided selector.
func SelectModelWithSelector(ctx context.Context, selector SingleSelector) (string, error) {
client, err := api.ClientFromEnvironment()
if err != nil {
return "", err
}
models, err := client.List(ctx)
if err != nil {
return "", err
}
var existing []modelInfo
for _, m := range models.Models {
existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""})
}
lastModel := LastModel()
var preChecked []string
if lastModel != "" {
preChecked = []string{lastModel}
}
items, _, existingModels, cloudModels := buildModelList(existing, preChecked, lastModel)
if len(items) == 0 {
return "", fmt.Errorf("no models available, run 'ollama pull <model>' first")
}
selected, err := selector("Select model to run:", items)
if err != nil {
return "", err
}
// If the selected model isn't installed, pull it first
if !existingModels[selected] {
msg := fmt.Sprintf("Download %s?", selected)
if ok, err := confirmPrompt(msg); err != nil {
return "", err
} else if !ok {
return "", errCancelled
}
fmt.Fprintf(os.Stderr, "\n")
if err := pullModel(ctx, client, selected); err != nil {
return "", fmt.Errorf("failed to pull %s: %w", selected, err)
}
}
// If it's a cloud model, ensure user is signed in
if cloudModels[selected] {
user, err := client.Whoami(ctx)
if err == nil && user != nil && user.Name != "" {
return selected, nil
}
var aErr api.AuthorizationError
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
return "", err
}
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", selected))
if err != nil || !yes {
return "", fmt.Errorf("%s requires sign in", selected)
}
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
// Auto-open browser (best effort, fail silently)
switch runtime.GOOS {
case "darwin":
_ = exec.Command("open", aErr.SigninURL).Start()
case "linux":
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
case "windows":
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
}
spinnerFrames := []string{"|", "/", "-", "\\"}
frame := 0
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
fmt.Fprintf(os.Stderr, "\r\033[K")
return "", ctx.Err()
case <-ticker.C:
frame++
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
// poll every 10th frame (~2 seconds)
if frame%10 == 0 {
u, err := client.Whoami(ctx)
if err == nil && u != nil && u.Name != "" {
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
return selected, nil
}
}
}
}
}
return selected, nil
}
func SelectModel(ctx context.Context) (string, error) {
return SelectModelWithSelector(ctx, DefaultSingleSelector)
}
// DefaultSingleSelector is the default single-select implementation.
var DefaultSingleSelector SingleSelector
// DefaultMultiSelector is the default multi-select implementation.
var DefaultMultiSelector MultiSelector
// DefaultSignIn provides a TUI-based sign-in flow.
// When set, ensureAuth uses it instead of plain text prompts.
// Returns the signed-in username or an error.
var DefaultSignIn func(modelName, signInURL string) (string, error)
func selectIntegration() (string, error) {
if DefaultSingleSelector == nil {
return "", fmt.Errorf("no selector configured")
}
if len(integrations) == 0 {
return "", fmt.Errorf("no integrations available")
}
names := slices.Sorted(maps.Keys(integrations))
var items []selectItem
var items []ModelItem
for _, name := range names {
if integrationAliases[name] {
continue
@@ -81,14 +328,14 @@ func selectIntegration() (string, error) {
if conn, err := loadIntegration(name); err == nil && len(conn.Models) > 0 {
description = fmt.Sprintf("%s (%s)", r.String(), conn.Models[0])
}
items = append(items, selectItem{Name: name, Description: description})
items = append(items, ModelItem{Name: name, Description: description})
}
return selectPrompt("Select integration:", items)
return DefaultSingleSelector("Select integration:", items)
}
// selectModels lets the user select models for an integration
func selectModels(ctx context.Context, name, current string) ([]string, error) {
// selectModelsWithSelectors lets the user select models for an integration using provided selectors.
func selectModelsWithSelectors(ctx context.Context, name, current string, single SingleSelector, multi MultiSelector) ([]string, error) {
r, ok := integrations[name]
if !ok {
return nil, fmt.Errorf("unknown integration: %s", name)
@@ -124,12 +371,16 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
var selected []string
if _, ok := r.(Editor); ok {
selected, err = multiSelectPrompt(fmt.Sprintf("Select models for %s:", r), items, preChecked)
selected, err = multi(fmt.Sprintf("Select models for %s:", r), items, preChecked)
if err != nil {
return nil, err
}
} else {
model, err := selectPrompt(fmt.Sprintf("Select model for %s:", r), items)
prompt := fmt.Sprintf("Select model for %s:", r)
if _, ok := r.(AliasConfigurer); ok {
prompt = fmt.Sprintf("Select Primary model for %s:", r)
}
model, err := single(prompt, items)
if err != nil {
return nil, err
}
@@ -157,73 +408,157 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
}
}
if err := ensureAuth(ctx, client, cloudModels, selected); err != nil {
return nil, err
}
return selected, nil
}
func pullIfNeeded(ctx context.Context, client *api.Client, existingModels map[string]bool, model string) error {
if existingModels[model] {
return nil
}
msg := fmt.Sprintf("Download %s?", model)
if ok, err := confirmPrompt(msg); err != nil {
return err
} else if !ok {
return errCancelled
}
fmt.Fprintf(os.Stderr, "\n")
if err := pullModel(ctx, client, model); err != nil {
return fmt.Errorf("failed to pull %s: %w", model, err)
}
return nil
}
// TODO(parthsareen): pull this out to tui package
// ShowOrPull checks if a model exists via client.Show and offers to pull it if not found.
func ShowOrPull(ctx context.Context, client *api.Client, model string) error {
if _, err := client.Show(ctx, &api.ShowRequest{Model: model}); err == nil {
return nil
}
if ok, err := confirmPrompt(fmt.Sprintf("Download %s?", model)); err != nil {
return err
} else if !ok {
return errCancelled
}
fmt.Fprintf(os.Stderr, "\n")
return pullModel(ctx, client, model)
}
func listModels(ctx context.Context) ([]ModelItem, map[string]bool, map[string]bool, *api.Client, error) {
client, err := api.ClientFromEnvironment()
if err != nil {
return nil, nil, nil, nil, err
}
models, err := client.List(ctx)
if err != nil {
return nil, nil, nil, nil, err
}
var existing []modelInfo
for _, m := range models.Models {
existing = append(existing, modelInfo{
Name: m.Name,
Remote: m.RemoteModel != "",
})
}
items, _, existingModels, cloudModels := buildModelList(existing, nil, "")
if len(items) == 0 {
return nil, nil, nil, nil, fmt.Errorf("no models available, run 'ollama pull <model>' first")
}
return items, existingModels, cloudModels, client, nil
}
func OpenBrowser(url string) {
switch runtime.GOOS {
case "darwin":
_ = exec.Command("open", url).Start()
case "linux":
_ = exec.Command("xdg-open", url).Start()
case "windows":
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
}
}
func ensureAuth(ctx context.Context, client *api.Client, cloudModels map[string]bool, selected []string) error {
var selectedCloudModels []string
for _, m := range selected {
if cloudModels[m] {
selectedCloudModels = append(selectedCloudModels, m)
}
}
if len(selectedCloudModels) > 0 {
// ensure user is signed in
user, err := client.Whoami(ctx)
if err == nil && user != nil && user.Name != "" {
return selected, nil
if len(selectedCloudModels) == 0 {
return nil
}
user, err := client.Whoami(ctx)
if err == nil && user != nil && user.Name != "" {
return nil
}
var aErr api.AuthorizationError
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
return err
}
modelList := strings.Join(selectedCloudModels, ", ")
if DefaultSignIn != nil {
_, err := DefaultSignIn(modelList, aErr.SigninURL)
if err != nil {
return fmt.Errorf("%s requires sign in", modelList)
}
return nil
}
var aErr api.AuthorizationError
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
return nil, err
}
// Fallback: plain text sign-in flow
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
if err != nil || !yes {
return fmt.Errorf("%s requires sign in", modelList)
}
modelList := strings.Join(selectedCloudModels, ", ")
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
if err != nil || !yes {
return nil, fmt.Errorf("%s requires sign in", modelList)
}
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
OpenBrowser(aErr.SigninURL)
// TODO(parthsareen): extract into auth package for cmd
// Auto-open browser (best effort, fail silently)
switch runtime.GOOS {
case "darwin":
_ = exec.Command("open", aErr.SigninURL).Start()
case "linux":
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
case "windows":
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
}
spinnerFrames := []string{"|", "/", "-", "\\"}
frame := 0
spinnerFrames := []string{"|", "/", "-", "\\"}
frame := 0
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
fmt.Fprintf(os.Stderr, "\r\033[K")
return ctx.Err()
case <-ticker.C:
frame++
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
for {
select {
case <-ctx.Done():
fmt.Fprintf(os.Stderr, "\r\033[K")
return nil, ctx.Err()
case <-ticker.C:
frame++
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
// poll every 10th frame (~2 seconds)
if frame%10 == 0 {
u, err := client.Whoami(ctx)
if err == nil && u != nil && u.Name != "" {
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
return selected, nil
}
// poll every 10th frame (~2 seconds)
if frame%10 == 0 {
u, err := client.Whoami(ctx)
if err == nil && u != nil && u.Name != "" {
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
return nil
}
}
}
}
}
return selected, nil
// selectModels lets the user select models for an integration using default selectors.
func selectModels(ctx context.Context, name, current string) ([]string, error) {
return selectModelsWithSelectors(ctx, name, current, DefaultSingleSelector, DefaultMultiSelector)
}
func runIntegration(name, modelName string, args []string) error {
@@ -231,19 +566,151 @@ func runIntegration(name, modelName string, args []string) error {
if !ok {
return fmt.Errorf("unknown integration: %s", name)
}
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", r, modelName)
return r.Run(modelName, args)
}
// syncAliases syncs aliases to server and saves locally for an AliasConfigurer.
func syncAliases(ctx context.Context, client *api.Client, ac AliasConfigurer, name, model string, existing map[string]string) error {
aliases := make(map[string]string)
for k, v := range existing {
aliases[k] = v
}
aliases["primary"] = model
if isCloudModel(ctx, client, model) {
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
aliases["fast"] = model
}
} else {
delete(aliases, "fast")
}
if err := ac.SetAliases(ctx, aliases); err != nil {
return err
}
return saveAliases(name, aliases)
}
// LaunchIntegration launches the named integration using saved config or prompts for setup.
func LaunchIntegration(name string) error {
r, ok := integrations[name]
if !ok {
return fmt.Errorf("unknown integration: %s", name)
}
// Try to use saved config
if ic, err := loadIntegration(name); err == nil && len(ic.Models) > 0 {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
if err := ShowOrPull(context.Background(), client, ic.Models[0]); err != nil {
return err
}
return runIntegration(name, ic.Models[0], nil)
}
// No saved config - prompt user to run setup
return fmt.Errorf("%s is not configured. Run 'ollama launch %s' to set it up", r, name)
}
// LaunchIntegrationWithModel launches the named integration with the specified model.
func LaunchIntegrationWithModel(name, modelName string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
if err := ShowOrPull(context.Background(), client, modelName); err != nil {
return err
}
return runIntegration(name, modelName, nil)
}
// SaveIntegrationModel saves the model for an integration.
func SaveIntegrationModel(name, modelName string) error {
// Load existing models and prepend the new one
var models []string
if existing, err := loadIntegration(name); err == nil && len(existing.Models) > 0 {
models = existing.Models
// Remove the model if it already exists
for i, m := range models {
if m == modelName {
models = append(models[:i], models[i+1:]...)
break
}
}
}
// Prepend the new model
models = append([]string{modelName}, models...)
return saveIntegration(name, models)
}
// ConfigureIntegrationWithSelectors allows the user to select/change the model for an integration using custom selectors.
func ConfigureIntegrationWithSelectors(ctx context.Context, name string, single SingleSelector, multi MultiSelector) error {
r, ok := integrations[name]
if !ok {
return fmt.Errorf("unknown integration: %s", name)
}
models, err := selectModelsWithSelectors(ctx, name, "", single, multi)
if errors.Is(err, errCancelled) {
return errCancelled
}
if err != nil {
return err
}
if editor, isEditor := r.(Editor); isEditor {
paths := editor.Paths()
if len(paths) > 0 {
fmt.Fprintf(os.Stderr, "This will modify your %s configuration:\n", r)
for _, p := range paths {
fmt.Fprintf(os.Stderr, " %s\n", p)
}
fmt.Fprintf(os.Stderr, "Backups will be saved to %s/\n\n", backupDir())
if ok, _ := confirmPrompt("Proceed?"); !ok {
return nil
}
}
if err := editor.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
}
if err := saveIntegration(name, models); err != nil {
return fmt.Errorf("failed to save: %w", err)
}
if len(models) == 1 {
fmt.Fprintf(os.Stderr, "Configured %s with %s\n", r, models[0])
} else {
fmt.Fprintf(os.Stderr, "Configured %s with %d models (default: %s)\n", r, len(models), models[0])
}
return nil
}
// ConfigureIntegration allows the user to select/change the model for an integration.
func ConfigureIntegration(ctx context.Context, name string) error {
return ConfigureIntegrationWithSelectors(ctx, name, DefaultSingleSelector, DefaultMultiSelector)
}
// LaunchCmd returns the cobra command for launching integrations.
func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error) *cobra.Command {
// The runTUI callback is called when no arguments are provided (alias for main TUI).
func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error, runTUI func(cmd *cobra.Command)) *cobra.Command {
var modelFlag string
var configFlag bool
cmd := &cobra.Command{
Use: "launch [INTEGRATION] [-- [EXTRA_ARGS...]]",
Short: "Launch an integration with Ollama",
Long: `Launch an integration configured with Ollama models.
Short: "Launch the Ollama menu or an integration",
Long: `Launch the Ollama interactive menu, or directly launch a specific integration.
Without arguments, this is equivalent to running 'ollama' directly.
Supported integrations:
claude Claude Code
@@ -262,6 +729,12 @@ Examples:
Args: cobra.ArbitraryArgs,
PreRunE: checkServerHeartbeat,
RunE: func(cmd *cobra.Command, args []string) error {
// No args and no flags - show the full TUI (same as bare 'ollama')
if len(args) == 0 && modelFlag == "" && !configFlag {
runTUI(cmd)
return nil
}
// Extract integration name and args to pass through using -- separator
var name string
var passArgs []string
@@ -302,9 +775,102 @@ Examples:
return fmt.Errorf("unknown integration: %s", name)
}
if !configFlag && modelFlag == "" {
if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 {
return runIntegration(name, config.Models[0], passArgs)
// Handle AliasConfigurer integrations (claude, codex)
if ac, ok := r.(AliasConfigurer); ok {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
// Validate --model flag if provided
if modelFlag != "" {
if err := ShowOrPull(cmd.Context(), client, modelFlag); err != nil {
if errors.Is(err, errCancelled) {
return nil
}
return err
}
}
var model string
var existingAliases map[string]string
// Load saved config
if cfg, err := loadIntegration(name); err == nil {
existingAliases = cfg.Aliases
if len(cfg.Models) > 0 {
model = cfg.Models[0]
// AliasConfigurer integrations use single model; sanitize if multiple
if len(cfg.Models) > 1 {
_ = saveIntegration(name, []string{model})
}
}
}
// --model flag overrides saved model
if modelFlag != "" {
model = modelFlag
}
// Validate saved model still exists
if model != "" && modelFlag == "" {
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Model: model}); err != nil {
fmt.Fprintf(os.Stderr, "%sConfigured model %q not found%s\n\n", ansiGray, model, ansiReset)
if err := ShowOrPull(cmd.Context(), client, model); err != nil {
model = ""
}
}
}
// If no valid model or --config flag, show picker
if model == "" || configFlag {
aliases, _, err := ac.ConfigureAliases(cmd.Context(), model, existingAliases, configFlag)
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
model = aliases["primary"]
existingAliases = aliases
}
// Ensure cloud models are authenticated
if isCloudModel(cmd.Context(), client, model) {
if err := ensureAuth(cmd.Context(), client, map[string]bool{model: true}, []string{model}); err != nil {
return err
}
}
// Sync aliases and save
if err := syncAliases(cmd.Context(), client, ac, name, model, existingAliases); err != nil {
fmt.Fprintf(os.Stderr, "%sWarning: Could not sync aliases: %v%s\n", ansiGray, err, ansiReset)
}
if err := saveIntegration(name, []string{model}); err != nil {
return fmt.Errorf("failed to save: %w", err)
}
// Launch (unless --config without confirmation)
if configFlag {
if launch, _ := confirmPrompt(fmt.Sprintf("Launch %s now?", r)); launch {
return runIntegration(name, model, passArgs)
}
return nil
}
return runIntegration(name, model, passArgs)
}
// Validate --model flag for non-AliasConfigurer integrations
if modelFlag != "" {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
if err := ShowOrPull(cmd.Context(), client, modelFlag); err != nil {
if errors.Is(err, errCancelled) {
return nil
}
return err
}
}
@@ -318,6 +884,8 @@ Examples:
}
}
}
} else if saved, err := loadIntegration(name); err == nil && len(saved.Models) > 0 && !configFlag {
return runIntegration(name, saved.Models[0], passArgs)
} else {
var err error
models, err = selectModels(cmd.Context(), name, "")
@@ -380,20 +948,23 @@ Examples:
}
type modelInfo struct {
Name string
Remote bool
Name string
Remote bool
ToolCapable bool
}
// buildModelList merges existing models with recommendations, sorts them, and returns
// the ordered items along with maps of existing and cloud model names.
func buildModelList(existing []modelInfo, preChecked []string, current string) (items []selectItem, orderedChecked []string, existingModels, cloudModels map[string]bool) {
func buildModelList(existing []modelInfo, preChecked []string, current string) (items []ModelItem, orderedChecked []string, existingModels, cloudModels map[string]bool) {
existingModels = make(map[string]bool)
cloudModels = make(map[string]bool)
recommended := make(map[string]bool)
var hasLocalModel, hasCloudModel bool
recDesc := make(map[string]string)
for _, rec := range recommendedModels {
recommended[rec.Name] = true
recDesc[rec.Name] = rec.Description
}
for _, m := range existing {
@@ -406,10 +977,7 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
}
displayName := strings.TrimSuffix(m.Name, ":latest")
existingModels[displayName] = true
item := selectItem{Name: displayName}
if recommended[displayName] {
item.Description = "recommended"
}
item := ModelItem{Name: displayName, Recommended: recommended[displayName], Description: recDesc[displayName]}
items = append(items, item)
}
@@ -418,7 +986,7 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
continue
}
items = append(items, rec)
if isCloudModel(rec.Name) {
if strings.HasSuffix(rec.Name, ":cloud") {
cloudModels[rec.Name] = true
}
}
@@ -446,31 +1014,76 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
for i := range items {
if !existingModels[items[i].Name] {
notInstalled[items[i].Name] = true
var parts []string
if items[i].Description != "" {
items[i].Description += ", install?"
} else {
items[i].Description = "install?"
parts = append(parts, items[i].Description)
}
if vram := recommendedVRAM[items[i].Name]; vram != "" {
parts = append(parts, vram)
}
parts = append(parts, "install?")
items[i].Description = strings.Join(parts, ", ")
}
}
// Build a recommended rank map to preserve ordering within tiers.
recRank := make(map[string]int)
for i, rec := range recommendedModels {
recRank[rec.Name] = i + 1 // 1-indexed; 0 means not recommended
}
onlyLocal := hasLocalModel && !hasCloudModel
if hasLocalModel || hasCloudModel {
slices.SortStableFunc(items, func(a, b selectItem) int {
slices.SortStableFunc(items, func(a, b ModelItem) int {
ac, bc := checked[a.Name], checked[b.Name]
aNew, bNew := notInstalled[a.Name], notInstalled[b.Name]
aRec, bRec := recRank[a.Name] > 0, recRank[b.Name] > 0
aCloud, bCloud := cloudModels[a.Name], cloudModels[b.Name]
// Checked/pre-selected always first
if ac != bc {
if ac {
return -1
}
return 1
}
if !ac && !bc && aNew != bNew {
// Recommended above non-recommended
if aRec != bRec {
if aRec {
return -1
}
return 1
}
// Both recommended
if aRec && bRec {
if aCloud != bCloud {
if onlyLocal {
// Local before cloud when only local installed
if aCloud {
return 1
}
return -1
}
// Cloud before local in mixed case
if aCloud {
return -1
}
return 1
}
return recRank[a.Name] - recRank[b.Name]
}
// Both non-recommended: installed before not-installed
if aNew != bNew {
if aNew {
return 1
}
return -1
}
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
})
}
@@ -478,8 +1091,45 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
return items, preChecked, existingModels, cloudModels
}
func isCloudModel(name string) bool {
return strings.HasSuffix(name, ":cloud")
// isCloudModel checks if a model is a cloud model using the Show API.
func isCloudModel(ctx context.Context, client *api.Client, name string) bool {
if client == nil {
return false
}
resp, err := client.Show(ctx, &api.ShowRequest{Model: name})
if err != nil {
return false
}
return resp.RemoteModel != ""
}
// GetModelItems returns a list of model items including recommendations for the TUI.
// It includes all locally available models plus recommended models that aren't installed.
func GetModelItems(ctx context.Context) ([]ModelItem, map[string]bool) {
client, err := api.ClientFromEnvironment()
if err != nil {
return nil, nil
}
models, err := client.List(ctx)
if err != nil {
return nil, nil
}
var existing []modelInfo
for _, m := range models.Models {
existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""})
}
lastModel := LastModel()
var preChecked []string
if lastModel != "" {
preChecked = []string{lastModel}
}
items, _, existingModels, _ := buildModelList(existing, preChecked, lastModel)
return items, existingModels
}
func pullModel(ctx context.Context, client *api.Client, model string) error {

View File

@@ -1,12 +1,18 @@
package config
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"slices"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/spf13/cobra"
)
@@ -88,8 +94,8 @@ func TestLaunchCmd(t *testing.T) {
mockCheck := func(cmd *cobra.Command, args []string) error {
return nil
}
cmd := LaunchCmd(mockCheck)
mockTUI := func(cmd *cobra.Command) {}
cmd := LaunchCmd(mockCheck, mockTUI)
t.Run("command structure", func(t *testing.T) {
if cmd.Use != "launch [INTEGRATION] [-- [EXTRA_ARGS...]]" {
@@ -122,6 +128,75 @@ func TestLaunchCmd(t *testing.T) {
})
}
func TestLaunchCmd_TUICallback(t *testing.T) {
mockCheck := func(cmd *cobra.Command, args []string) error {
return nil
}
t.Run("no args calls TUI", func(t *testing.T) {
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{})
_ = cmd.Execute()
if !tuiCalled {
t.Error("TUI callback should be called when no args provided")
}
})
t.Run("integration arg bypasses TUI", func(t *testing.T) {
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{"claude"})
// Will error because claude isn't configured, but that's OK
_ = cmd.Execute()
if tuiCalled {
t.Error("TUI callback should NOT be called when integration arg provided")
}
})
t.Run("--model flag bypasses TUI", func(t *testing.T) {
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{"--model", "test-model"})
// Will error because no integration specified, but that's OK
_ = cmd.Execute()
if tuiCalled {
t.Error("TUI callback should NOT be called when --model flag provided")
}
})
t.Run("--config flag bypasses TUI", func(t *testing.T) {
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{"--config"})
// Will error because no integration specified, but that's OK
_ = cmd.Execute()
if tuiCalled {
t.Error("TUI callback should NOT be called when --config flag provided")
}
})
}
func TestRunIntegration_UnknownIntegration(t *testing.T) {
err := runIntegration("unknown-integration", "model", nil)
if err == nil {
@@ -162,7 +237,7 @@ func TestHasLocalModel_DocumentsHeuristic(t *testing.T) {
func TestLaunchCmd_NilHeartbeat(t *testing.T) {
// This should not panic - cmd creation should work even with nil
cmd := LaunchCmd(nil)
cmd := LaunchCmd(nil, nil)
if cmd == nil {
t.Fatal("LaunchCmd returned nil")
}
@@ -297,27 +372,18 @@ func TestParseArgs(t *testing.T) {
}
func TestIsCloudModel(t *testing.T) {
tests := []struct {
name string
want bool
}{
{"glm-4.7:cloud", true},
{"kimi-k2.5:cloud", true},
{"glm-4.7-flash", false},
{"glm-4.7-flash:latest", false},
{"cloud-model", false},
{"model:cloudish", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := isCloudModel(tt.name); got != tt.want {
t.Errorf("isCloudModel(%q) = %v, want %v", tt.name, got, tt.want)
// isCloudModel now only uses Show API, so nil client always returns false
t.Run("nil client returns false", func(t *testing.T) {
models := []string{"glm-4.7:cloud", "kimi-k2.5:cloud", "local-model"}
for _, model := range models {
if isCloudModel(context.Background(), nil, model) {
t.Errorf("isCloudModel(%q) with nil client should return false", model)
}
})
}
}
})
}
func names(items []selectItem) []string {
func names(items []ModelItem) []string {
var out []string
for _, item := range items {
out = append(out, item.Name)
@@ -328,7 +394,7 @@ func names(items []selectItem) []string {
func TestBuildModelList_NoExistingModels(t *testing.T) {
items, _, _, _ := buildModelList(nil, nil, "")
want := []string{"glm-4.7-flash", "qwen3:8b", "glm-4.7:cloud", "kimi-k2.5:cloud"}
want := []string{"kimi-k2.5:cloud", "glm-4.7:cloud", "glm-4.7-flash", "qwen3:8b"}
if diff := cmp.Diff(want, names(items)); diff != "" {
t.Errorf("with no existing models, items should be recommended in order (-want +got):\n%s", diff)
}
@@ -349,9 +415,10 @@ func TestBuildModelList_OnlyLocalModels_CloudRecsAtBottom(t *testing.T) {
items, _, _, _ := buildModelList(existing, nil, "")
got := names(items)
want := []string{"llama3.2", "qwen2.5", "glm-4.7-flash", "glm-4.7:cloud", "kimi-k2.5:cloud", "qwen3:8b"}
// Recommended pinned at top (local recs first, then cloud recs when only-local), then installed non-recs
want := []string{"glm-4.7-flash", "qwen3:8b", "kimi-k2.5:cloud", "glm-4.7:cloud", "llama3.2", "qwen2.5"}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("cloud recs should be at bottom (-want +got):\n%s", diff)
t.Errorf("recs pinned at top, local recs before cloud recs (-want +got):\n%s", diff)
}
}
@@ -364,9 +431,10 @@ func TestBuildModelList_BothCloudAndLocal_RegularSort(t *testing.T) {
items, _, _, _ := buildModelList(existing, nil, "")
got := names(items)
want := []string{"glm-4.7:cloud", "llama3.2", "glm-4.7-flash", "kimi-k2.5:cloud", "qwen3:8b"}
// All recs pinned at top (cloud before local in mixed case), then non-recs
want := []string{"kimi-k2.5:cloud", "glm-4.7:cloud", "glm-4.7-flash", "qwen3:8b", "llama3.2"}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("mixed models should be alphabetical (-want +got):\n%s", diff)
t.Errorf("recs pinned at top, cloud recs first in mixed case (-want +got):\n%s", diff)
}
}
@@ -417,9 +485,10 @@ func TestBuildModelList_ExistingCloudModelsNotPushedToBottom(t *testing.T) {
// glm-4.7-flash and glm-4.7:cloud are installed so they sort normally;
// kimi-k2.5:cloud and qwen3:8b are not installed so they go to the bottom
want := []string{"glm-4.7-flash", "glm-4.7:cloud", "kimi-k2.5:cloud", "qwen3:8b"}
// All recs: cloud first in mixed case, then local, in rec order within each
want := []string{"kimi-k2.5:cloud", "glm-4.7:cloud", "glm-4.7-flash", "qwen3:8b"}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("existing cloud models should sort normally (-want +got):\n%s", diff)
t.Errorf("all recs, cloud first in mixed case (-want +got):\n%s", diff)
}
}
@@ -434,9 +503,10 @@ func TestBuildModelList_HasRecommendedCloudModel_OnlyNonInstalledAtBottom(t *tes
// kimi-k2.5:cloud is installed so it sorts normally;
// the rest of the recommendations are not installed so they go to the bottom
want := []string{"kimi-k2.5:cloud", "llama3.2", "glm-4.7-flash", "glm-4.7:cloud", "qwen3:8b"}
// All recs pinned at top (cloud first in mixed case), then non-recs
want := []string{"kimi-k2.5:cloud", "glm-4.7:cloud", "glm-4.7-flash", "qwen3:8b", "llama3.2"}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("only non-installed models should be at bottom (-want +got):\n%s", diff)
t.Errorf("recs pinned at top, cloud first in mixed case (-want +got):\n%s", diff)
}
for _, item := range items {
@@ -509,3 +579,588 @@ func TestBuildModelList_ReturnsExistingAndCloudMaps(t *testing.T) {
t.Error("llama3.2 should not be in cloudModels")
}
}
func TestBuildModelList_RecommendedFieldSet(t *testing.T) {
existing := []modelInfo{
{Name: "glm-4.7-flash", Remote: false},
{Name: "llama3.2:latest", Remote: false},
}
items, _, _, _ := buildModelList(existing, nil, "")
for _, item := range items {
switch item.Name {
case "glm-4.7-flash", "qwen3:8b", "glm-4.7:cloud", "kimi-k2.5:cloud":
if !item.Recommended {
t.Errorf("%q should have Recommended=true", item.Name)
}
case "llama3.2":
if item.Recommended {
t.Errorf("%q should have Recommended=false", item.Name)
}
}
}
}
func TestBuildModelList_MixedCase_CloudRecsFirst(t *testing.T) {
existing := []modelInfo{
{Name: "llama3.2:latest", Remote: false},
{Name: "glm-4.7:cloud", Remote: true},
}
items, _, _, _ := buildModelList(existing, nil, "")
got := names(items)
// Cloud recs should sort before local recs in mixed case
cloudIdx := slices.Index(got, "glm-4.7:cloud")
localIdx := slices.Index(got, "glm-4.7-flash")
if cloudIdx > localIdx {
t.Errorf("cloud recs should be before local recs in mixed case, got %v", got)
}
}
func TestBuildModelList_OnlyLocal_LocalRecsFirst(t *testing.T) {
existing := []modelInfo{
{Name: "llama3.2:latest", Remote: false},
}
items, _, _, _ := buildModelList(existing, nil, "")
got := names(items)
// Local recs should sort before cloud recs in only-local case
localIdx := slices.Index(got, "glm-4.7-flash")
cloudIdx := slices.Index(got, "glm-4.7:cloud")
if localIdx > cloudIdx {
t.Errorf("local recs should be before cloud recs in only-local case, got %v", got)
}
}
func TestBuildModelList_RecsAboveNonRecs(t *testing.T) {
existing := []modelInfo{
{Name: "llama3.2:latest", Remote: false},
{Name: "custom-model", Remote: false},
}
items, _, _, _ := buildModelList(existing, nil, "")
got := names(items)
// All recommended models should appear before non-recommended installed models
lastRecIdx := -1
firstNonRecIdx := len(got)
for i, name := range got {
isRec := name == "glm-4.7-flash" || name == "qwen3:8b" || name == "glm-4.7:cloud" || name == "kimi-k2.5:cloud"
if isRec && i > lastRecIdx {
lastRecIdx = i
}
if !isRec && i < firstNonRecIdx {
firstNonRecIdx = i
}
}
if lastRecIdx > firstNonRecIdx {
t.Errorf("all recs should be above non-recs, got %v", got)
}
}
func TestBuildModelList_CheckedBeforeRecs(t *testing.T) {
existing := []modelInfo{
{Name: "llama3.2:latest", Remote: false},
{Name: "glm-4.7:cloud", Remote: true},
}
items, _, _, _ := buildModelList(existing, []string{"llama3.2"}, "")
got := names(items)
if got[0] != "llama3.2" {
t.Errorf("checked model should be first even before recs, got %v", got)
}
}
func TestEditorIntegration_SavedConfigSkipsSelection(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Save a config for opencode so it looks like a previous launch
if err := saveIntegration("opencode", []string{"llama3.2"}); err != nil {
t.Fatal(err)
}
// Verify loadIntegration returns the saved models
saved, err := loadIntegration("opencode")
if err != nil {
t.Fatal(err)
}
if len(saved.Models) == 0 {
t.Fatal("expected saved models")
}
if saved.Models[0] != "llama3.2" {
t.Errorf("expected llama3.2, got %s", saved.Models[0])
}
}
func TestAliasConfigurerInterface(t *testing.T) {
t.Run("claude implements AliasConfigurer", func(t *testing.T) {
claude := &Claude{}
if _, ok := interface{}(claude).(AliasConfigurer); !ok {
t.Error("Claude should implement AliasConfigurer")
}
})
t.Run("codex does not implement AliasConfigurer", func(t *testing.T) {
codex := &Codex{}
if _, ok := interface{}(codex).(AliasConfigurer); ok {
t.Error("Codex should not implement AliasConfigurer")
}
})
}
func TestShowOrPull_ModelExists(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{"model":"test-model"}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
err := ShowOrPull(context.Background(), client, "test-model")
if err != nil {
t.Errorf("showOrPull should return nil when model exists, got: %v", err)
}
}
func TestShowOrPull_ModelNotFound_NoTerminal(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, `{"error":"model not found"}`)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
// confirmPrompt will fail in test (no terminal), so showOrPull should return an error
err := ShowOrPull(context.Background(), client, "missing-model")
if err == nil {
t.Error("showOrPull should return error when model not found and no terminal available")
}
}
func TestShowOrPull_ShowCalledWithCorrectModel(t *testing.T) {
var receivedModel string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
var req api.ShowRequest
if err := json.NewDecoder(r.Body).Decode(&req); err == nil {
receivedModel = req.Model
}
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{"model":"%s"}`, receivedModel)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
_ = ShowOrPull(context.Background(), client, "qwen3:8b")
if receivedModel != "qwen3:8b" {
t.Errorf("expected Show to be called with %q, got %q", "qwen3:8b", receivedModel)
}
}
func TestShowOrPull_ModelNotFound_ConfirmYes_Pulls(t *testing.T) {
// Set up hook so confirmPrompt doesn't need a terminal
oldHook := DefaultConfirmPrompt
DefaultConfirmPrompt = func(prompt string) (bool, error) {
if !strings.Contains(prompt, "missing-model") {
t.Errorf("expected prompt to contain model name, got %q", prompt)
}
return true, nil
}
defer func() { DefaultConfirmPrompt = oldHook }()
var pullCalled bool
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/show":
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, `{"error":"model not found"}`)
case "/api/pull":
pullCalled = true
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{"status":"success"}`)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
err := ShowOrPull(context.Background(), client, "missing-model")
if err != nil {
t.Errorf("ShowOrPull should succeed after pull, got: %v", err)
}
if !pullCalled {
t.Error("expected pull to be called when user confirms download")
}
}
func TestShowOrPull_ModelNotFound_ConfirmNo_Cancelled(t *testing.T) {
oldHook := DefaultConfirmPrompt
DefaultConfirmPrompt = func(prompt string) (bool, error) {
return false, ErrCancelled
}
defer func() { DefaultConfirmPrompt = oldHook }()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/show":
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, `{"error":"model not found"}`)
case "/api/pull":
t.Error("pull should not be called when user declines")
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
err := ShowOrPull(context.Background(), client, "missing-model")
if err == nil {
t.Error("ShowOrPull should return error when user declines")
}
}
func TestConfirmPrompt_DelegatesToHook(t *testing.T) {
oldHook := DefaultConfirmPrompt
var hookCalled bool
DefaultConfirmPrompt = func(prompt string) (bool, error) {
hookCalled = true
if prompt != "test prompt?" {
t.Errorf("expected prompt %q, got %q", "test prompt?", prompt)
}
return true, nil
}
defer func() { DefaultConfirmPrompt = oldHook }()
ok, err := confirmPrompt("test prompt?")
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if !ok {
t.Error("expected true from hook")
}
if !hookCalled {
t.Error("expected DefaultConfirmPrompt hook to be called")
}
}
func TestEnsureAuth_NoCloudModels(t *testing.T) {
// ensureAuth should be a no-op when no cloud models are selected
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Error("no API calls expected when no cloud models selected")
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
err := ensureAuth(context.Background(), client, map[string]bool{}, []string{"local-model"})
if err != nil {
t.Errorf("ensureAuth should return nil for non-cloud models, got: %v", err)
}
}
func TestEnsureAuth_CloudModelFilteredCorrectly(t *testing.T) {
// ensureAuth should only care about models in cloudModels map
var whoamiCalled bool
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/me" {
whoamiCalled = true
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{"name":"testuser"}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cloudModels := map[string]bool{"cloud-model:cloud": true}
selected := []string{"cloud-model:cloud", "local-model"}
err := ensureAuth(context.Background(), client, cloudModels, selected)
if err != nil {
t.Errorf("ensureAuth should succeed when user is authenticated, got: %v", err)
}
if !whoamiCalled {
t.Error("expected whoami to be called for cloud model")
}
}
func TestEnsureAuth_SkipsWhenNoCloudSelected(t *testing.T) {
var whoamiCalled bool
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/me" {
whoamiCalled = true
}
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
// cloudModels has entries but none are in selected
cloudModels := map[string]bool{"cloud-model:cloud": true}
selected := []string{"local-model"}
err := ensureAuth(context.Background(), client, cloudModels, selected)
if err != nil {
t.Errorf("expected nil error, got: %v", err)
}
if whoamiCalled {
t.Error("whoami should not be called when no cloud models are selected")
}
}
func TestHyperlink(t *testing.T) {
tests := []struct {
name string
url string
text string
wantURL string
wantText string
}{
{
name: "basic link",
url: "https://example.com",
text: "click here",
wantURL: "https://example.com",
wantText: "click here",
},
{
name: "url with path",
url: "https://example.com/docs/install",
text: "install docs",
wantURL: "https://example.com/docs/install",
wantText: "install docs",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := hyperlink(tt.url, tt.text)
// Should contain OSC 8 escape sequences
if !strings.Contains(got, "\033]8;;") {
t.Error("should contain OSC 8 open sequence")
}
if !strings.Contains(got, tt.wantURL) {
t.Errorf("should contain URL %q", tt.wantURL)
}
if !strings.Contains(got, tt.wantText) {
t.Errorf("should contain text %q", tt.wantText)
}
// Should have closing OSC 8 sequence
wantSuffix := "\033]8;;\033\\"
if !strings.HasSuffix(got, wantSuffix) {
t.Error("should end with OSC 8 close sequence")
}
})
}
}
func TestIntegrationInstallHint(t *testing.T) {
tests := []struct {
name string
input string
wantEmpty bool
wantURL string
}{
{
name: "claude has hint",
input: "claude",
wantURL: "https://code.claude.com/docs/en/quickstart",
},
{
name: "codex has hint",
input: "codex",
wantURL: "https://developers.openai.com/codex/cli/",
},
{
name: "openclaw has hint",
input: "openclaw",
wantURL: "https://docs.openclaw.ai",
},
{
name: "unknown has no hint",
input: "unknown",
wantEmpty: true,
},
{
name: "empty name has no hint",
input: "",
wantEmpty: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IntegrationInstallHint(tt.input)
if tt.wantEmpty {
if got != "" {
t.Errorf("expected empty hint, got %q", got)
}
return
}
if !strings.Contains(got, "Install from") {
t.Errorf("hint should start with 'Install from', got %q", got)
}
if !strings.Contains(got, tt.wantURL) {
t.Errorf("hint should contain URL %q, got %q", tt.wantURL, got)
}
// Should be a clickable hyperlink
if !strings.Contains(got, "\033]8;;") {
t.Error("hint URL should be wrapped in OSC 8 hyperlink")
}
})
}
}
func TestListIntegrationInfos(t *testing.T) {
infos := ListIntegrationInfos()
t.Run("excludes aliases", func(t *testing.T) {
for _, info := range infos {
if integrationAliases[info.Name] {
t.Errorf("alias %q should not appear in ListIntegrationInfos", info.Name)
}
}
})
t.Run("sorted by name", func(t *testing.T) {
for i := 1; i < len(infos); i++ {
if infos[i-1].Name >= infos[i].Name {
t.Errorf("not sorted: %q >= %q", infos[i-1].Name, infos[i].Name)
}
}
})
t.Run("all fields populated", func(t *testing.T) {
for _, info := range infos {
if info.Name == "" {
t.Error("Name should not be empty")
}
if info.DisplayName == "" {
t.Errorf("DisplayName for %q should not be empty", info.Name)
}
}
})
t.Run("includes known integrations", func(t *testing.T) {
known := map[string]bool{"claude": false, "codex": false, "opencode": false}
for _, info := range infos {
if _, ok := known[info.Name]; ok {
known[info.Name] = true
}
}
for name, found := range known {
if !found {
t.Errorf("expected %q in ListIntegrationInfos", name)
}
}
})
}
func TestBuildModelList_Descriptions(t *testing.T) {
t.Run("installed recommended has base description", func(t *testing.T) {
existing := []modelInfo{
{Name: "qwen3:8b", Remote: false},
}
items, _, _, _ := buildModelList(existing, nil, "")
for _, item := range items {
if item.Name == "qwen3:8b" {
if strings.HasSuffix(item.Description, "install?") {
t.Errorf("installed model should not have 'install?' suffix, got %q", item.Description)
}
if item.Description == "" {
t.Error("installed recommended model should have a description")
}
return
}
}
t.Error("qwen3:8b not found in items")
})
t.Run("not-installed local rec has VRAM in description", func(t *testing.T) {
items, _, _, _ := buildModelList(nil, nil, "")
for _, item := range items {
if item.Name == "qwen3:8b" {
if !strings.Contains(item.Description, "~11GB") {
t.Errorf("not-installed qwen3:8b should show VRAM hint, got %q", item.Description)
}
return
}
}
t.Error("qwen3:8b not found in items")
})
t.Run("installed local rec omits VRAM", func(t *testing.T) {
existing := []modelInfo{
{Name: "qwen3:8b", Remote: false},
}
items, _, _, _ := buildModelList(existing, nil, "")
for _, item := range items {
if item.Name == "qwen3:8b" {
if strings.Contains(item.Description, "~11GB") {
t.Errorf("installed qwen3:8b should not show VRAM hint, got %q", item.Description)
}
return
}
}
t.Error("qwen3:8b not found in items")
})
}
func TestLaunchIntegration_UnknownIntegration(t *testing.T) {
err := LaunchIntegration("nonexistent-integration")
if err == nil {
t.Fatal("expected error for unknown integration")
}
if !strings.Contains(err.Error(), "unknown integration") {
t.Errorf("error should mention 'unknown integration', got: %v", err)
}
}
func TestLaunchIntegration_NotConfigured(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Claude is a known integration but not configured in temp dir
err := LaunchIntegration("claude")
if err == nil {
t.Fatal("expected error when integration is not configured")
}
if !strings.Contains(err.Error(), "not configured") {
t.Errorf("error should mention 'not configured', got: %v", err)
}
}

View File

@@ -17,8 +17,6 @@ type Openclaw struct{}
func (c *Openclaw) String() string { return "OpenClaw" }
const ansiGreen = "\033[32m"
func (c *Openclaw) Run(model string, args []string) error {
bin := "openclaw"
if _, err := exec.LookPath(bin); err != nil {

View File

@@ -1,6 +1,7 @@
package config
import (
"context"
"encoding/json"
"fmt"
"maps"
@@ -10,12 +11,53 @@ import (
"slices"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
)
// OpenCode implements Runner and Editor for OpenCode integration
type OpenCode struct{}
// cloudModelLimit holds context and output token limits for a cloud model.
type cloudModelLimit struct {
Context int
Output int
}
// cloudModelLimits maps cloud model base names to their token limits.
// TODO(parthsareen): grab context/output limits from model info instead of hardcoding
var cloudModelLimits = map[string]cloudModelLimit{
"cogito-2.1:671b": {Context: 163_840, Output: 65_536},
"deepseek-v3.1:671b": {Context: 163_840, Output: 163_840},
"deepseek-v3.2": {Context: 163_840, Output: 65_536},
"glm-4.6": {Context: 202_752, Output: 131_072},
"glm-4.7": {Context: 202_752, Output: 131_072},
"gpt-oss:120b": {Context: 131_072, Output: 131_072},
"gpt-oss:20b": {Context: 131_072, Output: 131_072},
"kimi-k2:1t": {Context: 262_144, Output: 262_144},
"kimi-k2.5": {Context: 262_144, Output: 262_144},
"kimi-k2-thinking": {Context: 262_144, Output: 262_144},
"nemotron-3-nano:30b": {Context: 1_048_576, Output: 131_072},
"qwen3-coder:480b": {Context: 262_144, Output: 65_536},
"qwen3-coder-next": {Context: 262_144, Output: 32_768},
"qwen3-next:80b": {Context: 262_144, Output: 32_768},
}
// lookupCloudModelLimit returns the token limits for a cloud model.
// It tries the exact name first, then strips the ":cloud" suffix.
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
if l, ok := cloudModelLimits[name]; ok {
return l, true
}
base := strings.TrimSuffix(name, ":cloud")
if base != name {
if l, ok := cloudModelLimits[base]; ok {
return l, true
}
}
return cloudModelLimit{}, false
}
func (o *OpenCode) String() string { return "OpenCode" }
func (o *OpenCode) Run(model string, args []string) error {
@@ -113,6 +155,8 @@ func (o *OpenCode) Edit(modelList []string) error {
}
}
client, _ := api.ClientFromEnvironment()
for _, model := range modelList {
if existing, ok := models[model].(map[string]any); ok {
// migrate existing models without _launch marker
@@ -122,12 +166,29 @@ func (o *OpenCode) Edit(modelList []string) error {
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
}
}
if isCloudModel(context.Background(), client, model) {
if l, ok := lookupCloudModelLimit(model); ok {
existing["limit"] = map[string]any{
"context": l.Context,
"output": l.Output,
}
}
}
continue
}
models[model] = map[string]any{
entry := map[string]any{
"name": model,
"_launch": true,
}
if isCloudModel(context.Background(), client, model) {
if l, ok := lookupCloudModelLimit(model); ok {
entry["limit"] = map[string]any{
"context": l.Context,
"output": l.Output,
}
}
}
models[model] = entry
}
ollama["models"] = models

View File

@@ -2,6 +2,7 @@ package config
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"testing"
@@ -495,6 +496,166 @@ func TestOpenCodeEdit_SpecialCharsInModelName(t *testing.T) {
}
}
func readOpenCodeModel(t *testing.T, configPath, model string) map[string]any {
t.Helper()
data, err := os.ReadFile(configPath)
if err != nil {
t.Fatal(err)
}
var cfg map[string]any
json.Unmarshal(data, &cfg)
provider := cfg["provider"].(map[string]any)
ollama := provider["ollama"].(map[string]any)
models := ollama["models"].(map[string]any)
entry, ok := models[model].(map[string]any)
if !ok {
t.Fatalf("model %s not found in config", model)
}
return entry
}
func TestOpenCodeEdit_LocalModelNoLimit(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configPath := filepath.Join(tmpDir, ".config", "opencode", "opencode.json")
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
entry := readOpenCodeModel(t, configPath, "llama3.2")
if entry["limit"] != nil {
t.Errorf("local model should not have limit set, got %v", entry["limit"])
}
}
func TestOpenCodeEdit_PreservesUserLimit(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
// Set up a model with a user-configured limit
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{
"provider": {
"ollama": {
"models": {
"llama3.2": {
"name": "llama3.2",
"_launch": true,
"limit": {"context": 8192, "output": 4096}
}
}
}
}
}`), 0o644)
// Re-edit should preserve the user's limit (not delete it)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
entry := readOpenCodeModel(t, configPath, "llama3.2")
limit, ok := entry["limit"].(map[string]any)
if !ok {
t.Fatal("user-configured limit was removed")
}
if limit["context"] != float64(8192) {
t.Errorf("context limit changed: got %v, want 8192", limit["context"])
}
if limit["output"] != float64(4096) {
t.Errorf("output limit changed: got %v, want 4096", limit["output"])
}
}
func TestOpenCodeEdit_CloudModelLimitStructure(t *testing.T) {
// Verify that when a cloud model entry has limits set (as Edit would do),
// the structure matches what opencode expects and re-edit preserves them.
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
expected := cloudModelLimits["glm-4.7"]
// Simulate a cloud model that already has the limit set by a previous Edit
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(fmt.Sprintf(`{
"provider": {
"ollama": {
"models": {
"glm-4.7:cloud": {
"name": "glm-4.7:cloud",
"_launch": true,
"limit": {"context": %d, "output": %d}
}
}
}
}
}`, expected.Context, expected.Output)), 0o644)
// Re-edit should preserve the cloud model limit
if err := o.Edit([]string{"glm-4.7:cloud"}); err != nil {
t.Fatal(err)
}
entry := readOpenCodeModel(t, configPath, "glm-4.7:cloud")
limit, ok := entry["limit"].(map[string]any)
if !ok {
t.Fatal("cloud model limit was removed on re-edit")
}
if limit["context"] != float64(expected.Context) {
t.Errorf("context = %v, want %d", limit["context"], expected.Context)
}
if limit["output"] != float64(expected.Output) {
t.Errorf("output = %v, want %d", limit["output"], expected.Output)
}
}
func TestLookupCloudModelLimit(t *testing.T) {
tests := []struct {
name string
wantOK bool
wantContext int
wantOutput int
}{
{"glm-4.7", true, 202_752, 131_072},
{"glm-4.7:cloud", true, 202_752, 131_072},
{"kimi-k2.5", true, 262_144, 262_144},
{"kimi-k2.5:cloud", true, 262_144, 262_144},
{"deepseek-v3.2", true, 163_840, 65_536},
{"deepseek-v3.2:cloud", true, 163_840, 65_536},
{"qwen3-coder:480b", true, 262_144, 65_536},
{"qwen3-coder-next:cloud", true, 262_144, 32_768},
{"llama3.2", false, 0, 0},
{"unknown-model:cloud", false, 0, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
l, ok := lookupCloudModelLimit(tt.name)
if ok != tt.wantOK {
t.Errorf("lookupCloudModelLimit(%q) ok = %v, want %v", tt.name, ok, tt.wantOK)
}
if ok {
if l.Context != tt.wantContext {
t.Errorf("context = %d, want %d", l.Context, tt.wantContext)
}
if l.Output != tt.wantOutput {
t.Errorf("output = %d, want %d", l.Output, tt.wantOutput)
}
}
})
}
}
func TestOpenCodeModels_NoConfig(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()

237
cmd/config/pi.go Normal file
View File

@@ -0,0 +1,237 @@
package config
import (
"context"
"encoding/json"
"fmt"
"net/http"
"os"
"os/exec"
"path/filepath"
"slices"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
)
// Pi implements Runner and Editor for Pi (Pi Coding Agent) integration
type Pi struct{}
func (p *Pi) String() string { return "Pi" }
func (p *Pi) Run(model string, args []string) error {
if _, err := exec.LookPath("pi"); err != nil {
return fmt.Errorf("pi is not installed, install with: npm install -g @mariozechner/pi-coding-agent")
}
// Call Edit() to ensure config is up-to-date before launch
models := []string{model}
if config, err := loadIntegration("pi"); err == nil && len(config.Models) > 0 {
models = config.Models
}
if err := p.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("pi", args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func (p *Pi) Paths() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
var paths []string
modelsPath := filepath.Join(home, ".pi", "agent", "models.json")
if _, err := os.Stat(modelsPath); err == nil {
paths = append(paths, modelsPath)
}
settingsPath := filepath.Join(home, ".pi", "agent", "settings.json")
if _, err := os.Stat(settingsPath); err == nil {
paths = append(paths, settingsPath)
}
return paths
}
func (p *Pi) Edit(models []string) error {
if len(models) == 0 {
return nil
}
home, err := os.UserHomeDir()
if err != nil {
return err
}
configPath := filepath.Join(home, ".pi", "agent", "models.json")
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
config := make(map[string]any)
if data, err := os.ReadFile(configPath); err == nil {
_ = json.Unmarshal(data, &config)
}
providers, ok := config["providers"].(map[string]any)
if !ok {
providers = make(map[string]any)
}
ollama, ok := providers["ollama"].(map[string]any)
if !ok {
ollama = map[string]any{
"baseUrl": envconfig.Host().String() + "/v1",
"api": "openai-completions",
"apiKey": "ollama",
}
}
existingModels, ok := ollama["models"].([]any)
if !ok {
existingModels = make([]any, 0)
}
// Build set of selected models to track which need to be added
selectedSet := make(map[string]bool, len(models))
for _, m := range models {
selectedSet[m] = true
}
// Build new models list:
// 1. Keep user-managed models (no _launch marker) - untouched
// 2. Keep ollama-managed models (_launch marker) that are still selected
// 3. Add new ollama-managed models
var newModels []any
for _, m := range existingModels {
if modelObj, ok := m.(map[string]any); ok {
if id, ok := modelObj["id"].(string); ok {
// User-managed model (no _launch marker) - always preserve
if !isPiOllamaModel(modelObj) {
newModels = append(newModels, m)
} else if selectedSet[id] {
// Ollama-managed and still selected - keep it
newModels = append(newModels, m)
selectedSet[id] = false
}
}
}
}
// Add newly selected models that weren't already in the list
client := api.NewClient(envconfig.Host(), http.DefaultClient)
ctx := context.Background()
for _, model := range models {
if selectedSet[model] {
newModels = append(newModels, createConfig(ctx, client, model))
}
}
ollama["models"] = newModels
providers["ollama"] = ollama
config["providers"] = providers
configData, err := json.MarshalIndent(config, "", " ")
if err != nil {
return err
}
if err := writeWithBackup(configPath, configData); err != nil {
return err
}
// Update settings.json with default provider and model
settingsPath := filepath.Join(home, ".pi", "agent", "settings.json")
settings := make(map[string]any)
if data, err := os.ReadFile(settingsPath); err == nil {
_ = json.Unmarshal(data, &settings)
}
settings["defaultProvider"] = "ollama"
settings["defaultModel"] = models[0]
settingsData, err := json.MarshalIndent(settings, "", " ")
if err != nil {
return err
}
return writeWithBackup(settingsPath, settingsData)
}
func (p *Pi) Models() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
configPath := filepath.Join(home, ".pi", "agent", "models.json")
config, err := readJSONFile(configPath)
if err != nil {
return nil
}
providers, _ := config["providers"].(map[string]any)
ollama, _ := providers["ollama"].(map[string]any)
models, _ := ollama["models"].([]any)
var result []string
for _, m := range models {
if modelObj, ok := m.(map[string]any); ok {
if id, ok := modelObj["id"].(string); ok {
result = append(result, id)
}
}
}
slices.Sort(result)
return result
}
// isPiOllamaModel reports whether a model config entry is managed by ollama launch
func isPiOllamaModel(cfg map[string]any) bool {
if v, ok := cfg["_launch"].(bool); ok && v {
return true
}
return false
}
// createConfig builds Pi model config with capability detection
func createConfig(ctx context.Context, client *api.Client, modelID string) map[string]any {
cfg := map[string]any{
"id": modelID,
"_launch": true,
}
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelID})
if err != nil {
return cfg
}
// Set input types based on vision capability
if slices.Contains(resp.Capabilities, model.CapabilityVision) {
cfg["input"] = []string{"text", "image"}
} else {
cfg["input"] = []string{"text"}
}
// Set reasoning based on thinking capability
if slices.Contains(resp.Capabilities, model.CapabilityThinking) {
cfg["reasoning"] = true
}
// Extract context window from ModelInfo
for key, val := range resp.ModelInfo {
if strings.HasSuffix(key, ".context_length") {
if ctxLen, ok := val.(float64); ok && ctxLen > 0 {
cfg["contextWindow"] = int(ctxLen)
}
break
}
}
return cfg
}

830
cmd/config/pi_test.go Normal file
View File

@@ -0,0 +1,830 @@
package config
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/types/model"
)
func TestPiIntegration(t *testing.T) {
pi := &Pi{}
t.Run("String", func(t *testing.T) {
if got := pi.String(); got != "Pi" {
t.Errorf("String() = %q, want %q", got, "Pi")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = pi
})
t.Run("implements Editor", func(t *testing.T) {
var _ Editor = pi
})
}
func TestPiPaths(t *testing.T) {
pi := &Pi{}
t.Run("returns empty when no config exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
paths := pi.Paths()
if len(paths) != 0 {
t.Errorf("Paths() = %v, want empty", paths)
}
})
t.Run("returns path when config exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".pi", "agent")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
configPath := filepath.Join(configDir, "models.json")
if err := os.WriteFile(configPath, []byte("{}"), 0o644); err != nil {
t.Fatal(err)
}
paths := pi.Paths()
if len(paths) != 1 || paths[0] != configPath {
t.Errorf("Paths() = %v, want [%s]", paths, configPath)
}
})
}
func TestPiEdit(t *testing.T) {
// Mock Ollama server for createConfig calls during Edit
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":[],"model_info":{}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
pi := &Pi{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".pi", "agent")
configPath := filepath.Join(configDir, "models.json")
cleanup := func() {
os.RemoveAll(configDir)
}
readConfig := func() map[string]any {
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
return cfg
}
t.Run("returns nil for empty models", func(t *testing.T) {
if err := pi.Edit([]string{}); err != nil {
t.Errorf("Edit([]) error = %v, want nil", err)
}
})
t.Run("creates config with models", func(t *testing.T) {
cleanup()
models := []string{"llama3.2", "qwen3:8b"}
if err := pi.Edit(models); err != nil {
t.Fatalf("Edit() error = %v", err)
}
cfg := readConfig()
providers, ok := cfg["providers"].(map[string]any)
if !ok {
t.Error("Config missing providers")
}
ollama, ok := providers["ollama"].(map[string]any)
if !ok {
t.Error("Providers missing ollama")
}
modelsArray, ok := ollama["models"].([]any)
if !ok || len(modelsArray) != 2 {
t.Errorf("Expected 2 models, got %v", modelsArray)
}
if ollama["baseUrl"] == nil {
t.Error("Missing baseUrl")
}
if ollama["api"] != "openai-completions" {
t.Errorf("Expected api=openai-completions, got %v", ollama["api"])
}
if ollama["apiKey"] != "ollama" {
t.Errorf("Expected apiKey=ollama, got %v", ollama["apiKey"])
}
})
t.Run("updates existing config preserving ollama provider settings", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
existingConfig := `{
"providers": {
"ollama": {
"baseUrl": "http://custom:8080/v1",
"api": "custom-api",
"apiKey": "custom-key",
"models": [
{"id": "old-model", "_launch": true}
]
}
}
}`
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
t.Fatal(err)
}
models := []string{"new-model"}
if err := pi.Edit(models); err != nil {
t.Fatalf("Edit() error = %v", err)
}
cfg := readConfig()
providers := cfg["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
if ollama["baseUrl"] != "http://custom:8080/v1" {
t.Errorf("Custom baseUrl not preserved, got %v", ollama["baseUrl"])
}
if ollama["api"] != "custom-api" {
t.Errorf("Custom api not preserved, got %v", ollama["api"])
}
if ollama["apiKey"] != "custom-key" {
t.Errorf("Custom apiKey not preserved, got %v", ollama["apiKey"])
}
modelsArray := ollama["models"].([]any)
if len(modelsArray) != 1 {
t.Errorf("Expected 1 model after update, got %d", len(modelsArray))
} else {
modelEntry := modelsArray[0].(map[string]any)
if modelEntry["id"] != "new-model" {
t.Errorf("Expected new-model, got %v", modelEntry["id"])
}
// Verify _launch marker is present
if modelEntry["_launch"] != true {
t.Errorf("Expected _launch marker to be true")
}
}
})
t.Run("replaces old models with new ones", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
// Old models must have _launch marker to be managed by us
existingConfig := `{
"providers": {
"ollama": {
"baseUrl": "http://localhost:11434/v1",
"api": "openai-completions",
"apiKey": "ollama",
"models": [
{"id": "old-model-1", "_launch": true},
{"id": "old-model-2", "_launch": true}
]
}
}
}`
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
t.Fatal(err)
}
newModels := []string{"new-model-1", "new-model-2"}
if err := pi.Edit(newModels); err != nil {
t.Fatalf("Edit() error = %v", err)
}
cfg := readConfig()
providers := cfg["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelsArray := ollama["models"].([]any)
if len(modelsArray) != 2 {
t.Errorf("Expected 2 models, got %d", len(modelsArray))
}
modelIDs := make(map[string]bool)
for _, m := range modelsArray {
modelObj := m.(map[string]any)
id := modelObj["id"].(string)
modelIDs[id] = true
}
if !modelIDs["new-model-1"] || !modelIDs["new-model-2"] {
t.Errorf("Expected new models, got %v", modelIDs)
}
if modelIDs["old-model-1"] || modelIDs["old-model-2"] {
t.Errorf("Old models should have been removed, got %v", modelIDs)
}
})
t.Run("handles partial overlap in model list", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
// Models must have _launch marker to be managed
existingConfig := `{
"providers": {
"ollama": {
"baseUrl": "http://localhost:11434/v1",
"api": "openai-completions",
"apiKey": "ollama",
"models": [
{"id": "keep-model", "_launch": true},
{"id": "remove-model", "_launch": true}
]
}
}
}`
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
t.Fatal(err)
}
newModels := []string{"keep-model", "add-model"}
if err := pi.Edit(newModels); err != nil {
t.Fatalf("Edit() error = %v", err)
}
cfg := readConfig()
providers := cfg["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelsArray := ollama["models"].([]any)
if len(modelsArray) != 2 {
t.Errorf("Expected 2 models, got %d", len(modelsArray))
}
modelIDs := make(map[string]bool)
for _, m := range modelsArray {
modelObj := m.(map[string]any)
id := modelObj["id"].(string)
modelIDs[id] = true
}
if !modelIDs["keep-model"] || !modelIDs["add-model"] {
t.Errorf("Expected keep-model and add-model, got %v", modelIDs)
}
if modelIDs["remove-model"] {
t.Errorf("remove-model should have been removed")
}
})
t.Run("handles corrupt config gracefully", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
if err := os.WriteFile(configPath, []byte("{invalid json}"), 0o644); err != nil {
t.Fatal(err)
}
models := []string{"test-model"}
if err := pi.Edit(models); err != nil {
t.Fatalf("Edit() should not fail with corrupt config, got %v", err)
}
data, err := os.ReadFile(configPath)
if err != nil {
t.Fatalf("Failed to read config: %v", err)
}
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Fatalf("Config should be valid after Edit, got parse error: %v", err)
}
providers := cfg["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelsArray := ollama["models"].([]any)
if len(modelsArray) != 1 {
t.Errorf("Expected 1 model, got %d", len(modelsArray))
}
})
// CRITICAL SAFETY TEST: verifies we don't stomp on user configs
t.Run("preserves user-managed models without _launch marker", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
// User has manually configured models in ollama provider (no _launch marker)
existingConfig := `{
"providers": {
"ollama": {
"baseUrl": "http://localhost:11434/v1",
"api": "openai-completions",
"apiKey": "ollama",
"models": [
{"id": "user-model-1"},
{"id": "user-model-2", "customField": "preserved"},
{"id": "ollama-managed", "_launch": true}
]
}
}
}`
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
t.Fatal(err)
}
// Add a new ollama-managed model
newModels := []string{"new-ollama-model"}
if err := pi.Edit(newModels); err != nil {
t.Fatalf("Edit() error = %v", err)
}
cfg := readConfig()
providers := cfg["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelsArray := ollama["models"].([]any)
// Should have: new-ollama-model (managed) + 2 user models (preserved)
if len(modelsArray) != 3 {
t.Errorf("Expected 3 models (1 new managed + 2 preserved user models), got %d", len(modelsArray))
}
modelIDs := make(map[string]map[string]any)
for _, m := range modelsArray {
modelObj := m.(map[string]any)
id := modelObj["id"].(string)
modelIDs[id] = modelObj
}
// Verify new model has _launch marker
if m, ok := modelIDs["new-ollama-model"]; !ok {
t.Errorf("new-ollama-model should be present")
} else if m["_launch"] != true {
t.Errorf("new-ollama-model should have _launch marker")
}
// Verify user models are preserved
if _, ok := modelIDs["user-model-1"]; !ok {
t.Errorf("user-model-1 should be preserved")
}
if _, ok := modelIDs["user-model-2"]; !ok {
t.Errorf("user-model-2 should be preserved")
} else if modelIDs["user-model-2"]["customField"] != "preserved" {
t.Errorf("user-model-2 customField should be preserved")
}
// Verify old ollama-managed model is removed (not in new list)
if _, ok := modelIDs["ollama-managed"]; ok {
t.Errorf("ollama-managed should be removed (old ollama model not in new selection)")
}
})
t.Run("updates settings.json with default provider and model", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
// Create existing settings with other fields
settingsPath := filepath.Join(configDir, "settings.json")
existingSettings := `{
"theme": "dark",
"customSetting": "value",
"defaultProvider": "anthropic",
"defaultModel": "claude-3"
}`
if err := os.WriteFile(settingsPath, []byte(existingSettings), 0o644); err != nil {
t.Fatal(err)
}
models := []string{"llama3.2"}
if err := pi.Edit(models); err != nil {
t.Fatalf("Edit() error = %v", err)
}
data, err := os.ReadFile(settingsPath)
if err != nil {
t.Fatalf("Failed to read settings: %v", err)
}
var settings map[string]any
if err := json.Unmarshal(data, &settings); err != nil {
t.Fatalf("Failed to parse settings: %v", err)
}
// Verify defaultProvider is set to ollama
if settings["defaultProvider"] != "ollama" {
t.Errorf("defaultProvider = %v, want ollama", settings["defaultProvider"])
}
// Verify defaultModel is set to first model
if settings["defaultModel"] != "llama3.2" {
t.Errorf("defaultModel = %v, want llama3.2", settings["defaultModel"])
}
// Verify other fields are preserved
if settings["theme"] != "dark" {
t.Errorf("theme = %v, want dark (preserved)", settings["theme"])
}
if settings["customSetting"] != "value" {
t.Errorf("customSetting = %v, want value (preserved)", settings["customSetting"])
}
})
t.Run("creates settings.json if it does not exist", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
models := []string{"qwen3:8b"}
if err := pi.Edit(models); err != nil {
t.Fatalf("Edit() error = %v", err)
}
settingsPath := filepath.Join(configDir, "settings.json")
data, err := os.ReadFile(settingsPath)
if err != nil {
t.Fatalf("settings.json should be created: %v", err)
}
var settings map[string]any
if err := json.Unmarshal(data, &settings); err != nil {
t.Fatalf("Failed to parse settings: %v", err)
}
if settings["defaultProvider"] != "ollama" {
t.Errorf("defaultProvider = %v, want ollama", settings["defaultProvider"])
}
if settings["defaultModel"] != "qwen3:8b" {
t.Errorf("defaultModel = %v, want qwen3:8b", settings["defaultModel"])
}
})
t.Run("handles corrupt settings.json gracefully", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
// Create corrupt settings
settingsPath := filepath.Join(configDir, "settings.json")
if err := os.WriteFile(settingsPath, []byte("{invalid"), 0o644); err != nil {
t.Fatal(err)
}
models := []string{"test-model"}
if err := pi.Edit(models); err != nil {
t.Fatalf("Edit() should not fail with corrupt settings, got %v", err)
}
data, err := os.ReadFile(settingsPath)
if err != nil {
t.Fatalf("Failed to read settings: %v", err)
}
var settings map[string]any
if err := json.Unmarshal(data, &settings); err != nil {
t.Fatalf("settings.json should be valid after Edit, got parse error: %v", err)
}
if settings["defaultProvider"] != "ollama" {
t.Errorf("defaultProvider = %v, want ollama", settings["defaultProvider"])
}
if settings["defaultModel"] != "test-model" {
t.Errorf("defaultModel = %v, want test-model", settings["defaultModel"])
}
})
}
func TestPiModels(t *testing.T) {
pi := &Pi{}
t.Run("returns nil when no config exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
models := pi.Models()
if models != nil {
t.Errorf("Models() = %v, want nil", models)
}
})
t.Run("returns models from config", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".pi", "agent")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
config := `{
"providers": {
"ollama": {
"models": [
{"id": "llama3.2"},
{"id": "qwen3:8b"}
]
}
}
}`
configPath := filepath.Join(configDir, "models.json")
if err := os.WriteFile(configPath, []byte(config), 0o644); err != nil {
t.Fatal(err)
}
models := pi.Models()
if len(models) != 2 {
t.Errorf("Models() returned %d models, want 2", len(models))
}
if models[0] != "llama3.2" || models[1] != "qwen3:8b" {
t.Errorf("Models() = %v, want [llama3.2 qwen3:8b] (sorted)", models)
}
})
t.Run("returns sorted models", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".pi", "agent")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
config := `{
"providers": {
"ollama": {
"models": [
{"id": "z-model"},
{"id": "a-model"},
{"id": "m-model"}
]
}
}
}`
configPath := filepath.Join(configDir, "models.json")
if err := os.WriteFile(configPath, []byte(config), 0o644); err != nil {
t.Fatal(err)
}
models := pi.Models()
if models[0] != "a-model" || models[1] != "m-model" || models[2] != "z-model" {
t.Errorf("Models() = %v, want [a-model m-model z-model] (sorted)", models)
}
})
t.Run("returns nil when models array is missing", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".pi", "agent")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
config := `{
"providers": {
"ollama": {}
}
}`
configPath := filepath.Join(configDir, "models.json")
if err := os.WriteFile(configPath, []byte(config), 0o644); err != nil {
t.Fatal(err)
}
models := pi.Models()
if models != nil {
t.Errorf("Models() = %v, want nil when models array is missing", models)
}
})
t.Run("handles corrupt config gracefully", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".pi", "agent")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
configPath := filepath.Join(configDir, "models.json")
if err := os.WriteFile(configPath, []byte("{invalid json}"), 0o644); err != nil {
t.Fatal(err)
}
models := pi.Models()
if models != nil {
t.Errorf("Models() = %v, want nil for corrupt config", models)
}
})
}
func TestIsPiOllamaModel(t *testing.T) {
tests := []struct {
name string
cfg map[string]any
want bool
}{
{"with _launch true", map[string]any{"id": "m", "_launch": true}, true},
{"with _launch false", map[string]any{"id": "m", "_launch": false}, false},
{"without _launch", map[string]any{"id": "m"}, false},
{"with _launch non-bool", map[string]any{"id": "m", "_launch": "yes"}, false},
{"empty map", map[string]any{}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := isPiOllamaModel(tt.cfg); got != tt.want {
t.Errorf("isPiOllamaModel(%v) = %v, want %v", tt.cfg, got, tt.want)
}
})
}
}
func TestCreateConfig(t *testing.T) {
t.Run("sets vision input when model has vision capability", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":["vision"],"model_info":{}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg := createConfig(context.Background(), client, "llava:7b")
if cfg["id"] != "llava:7b" {
t.Errorf("id = %v, want llava:7b", cfg["id"])
}
if cfg["_launch"] != true {
t.Error("expected _launch = true")
}
input, ok := cfg["input"].([]string)
if !ok || len(input) != 2 || input[0] != "text" || input[1] != "image" {
t.Errorf("input = %v, want [text image]", cfg["input"])
}
})
t.Run("sets text-only input when model lacks vision", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":["completion"],"model_info":{}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg := createConfig(context.Background(), client, "llama3.2")
input, ok := cfg["input"].([]string)
if !ok || len(input) != 1 || input[0] != "text" {
t.Errorf("input = %v, want [text]", cfg["input"])
}
if _, ok := cfg["reasoning"]; ok {
t.Error("reasoning should not be set for non-thinking model")
}
})
t.Run("sets reasoning when model has thinking capability", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":["thinking"],"model_info":{}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg := createConfig(context.Background(), client, "qwq")
if cfg["reasoning"] != true {
t.Error("expected reasoning = true for thinking model")
}
})
t.Run("extracts context window from model info", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":[],"model_info":{"llama.context_length":131072}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg := createConfig(context.Background(), client, "llama3.2")
if cfg["contextWindow"] != 131072 {
t.Errorf("contextWindow = %v, want 131072", cfg["contextWindow"])
}
})
t.Run("handles all capabilities together", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":["vision","thinking"],"model_info":{"qwen3.context_length":32768}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg := createConfig(context.Background(), client, "qwen3-vision")
input := cfg["input"].([]string)
if len(input) != 2 || input[0] != "text" || input[1] != "image" {
t.Errorf("input = %v, want [text image]", input)
}
if cfg["reasoning"] != true {
t.Error("expected reasoning = true")
}
if cfg["contextWindow"] != 32768 {
t.Errorf("contextWindow = %v, want 32768", cfg["contextWindow"])
}
})
t.Run("returns minimal config when show fails", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, `{"error":"model not found"}`)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg := createConfig(context.Background(), client, "missing-model")
if cfg["id"] != "missing-model" {
t.Errorf("id = %v, want missing-model", cfg["id"])
}
if cfg["_launch"] != true {
t.Error("expected _launch = true")
}
// Should not have capability fields
if _, ok := cfg["input"]; ok {
t.Error("input should not be set when show fails")
}
if _, ok := cfg["reasoning"]; ok {
t.Error("reasoning should not be set when show fails")
}
if _, ok := cfg["contextWindow"]; ok {
t.Error("contextWindow should not be set when show fails")
}
})
t.Run("skips zero context length", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":[],"model_info":{"llama.context_length":0}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg := createConfig(context.Background(), client, "test-model")
if _, ok := cfg["contextWindow"]; ok {
t.Error("contextWindow should not be set for zero value")
}
})
}
// Ensure Capability constants used in createConfig match expected values
func TestPiCapabilityConstants(t *testing.T) {
if model.CapabilityVision != "vision" {
t.Errorf("CapabilityVision = %q, want %q", model.CapabilityVision, "vision")
}
if model.CapabilityThinking != "thinking" {
t.Errorf("CapabilityThinking = %q, want %q", model.CapabilityThinking, "thinking")
}
}

View File

@@ -3,474 +3,34 @@ package config
import (
"errors"
"fmt"
"io"
"os"
"strings"
"golang.org/x/term"
)
// ANSI escape sequences for terminal formatting.
const (
ansiHideCursor = "\033[?25l"
ansiShowCursor = "\033[?25h"
ansiBold = "\033[1m"
ansiReset = "\033[0m"
ansiGray = "\033[37m"
ansiClearDown = "\033[J"
ansiBold = "\033[1m"
ansiReset = "\033[0m"
ansiGray = "\033[37m"
ansiGreen = "\033[32m"
)
const maxDisplayedItems = 10
// ErrCancelled is returned when the user cancels a selection.
var ErrCancelled = errors.New("cancelled")
var errCancelled = errors.New("cancelled")
// errCancelled is kept as an alias for backward compatibility within the package.
var errCancelled = ErrCancelled
type selectItem struct {
Name string
Description string
}
type inputEvent int
const (
eventNone inputEvent = iota
eventEnter
eventEscape
eventUp
eventDown
eventTab
eventBackspace
eventChar
)
type selectState struct {
items []selectItem
filter string
selected int
scrollOffset int
}
func newSelectState(items []selectItem) *selectState {
return &selectState{items: items}
}
func (s *selectState) filtered() []selectItem {
return filterItems(s.items, s.filter)
}
func (s *selectState) handleInput(event inputEvent, char byte) (done bool, result string, err error) {
filtered := s.filtered()
switch event {
case eventEnter:
if len(filtered) > 0 && s.selected < len(filtered) {
return true, filtered[s.selected].Name, nil
}
case eventEscape:
return true, "", errCancelled
case eventBackspace:
if len(s.filter) > 0 {
s.filter = s.filter[:len(s.filter)-1]
s.selected = 0
s.scrollOffset = 0
}
case eventUp:
if s.selected > 0 {
s.selected--
if s.selected < s.scrollOffset {
s.scrollOffset = s.selected
}
}
case eventDown:
if s.selected < len(filtered)-1 {
s.selected++
if s.selected >= s.scrollOffset+maxDisplayedItems {
s.scrollOffset = s.selected - maxDisplayedItems + 1
}
}
case eventChar:
s.filter += string(char)
s.selected = 0
s.scrollOffset = 0
}
return false, "", nil
}
type multiSelectState struct {
items []selectItem
itemIndex map[string]int
filter string
highlighted int
scrollOffset int
checked map[int]bool
checkOrder []int
focusOnButton bool
}
func newMultiSelectState(items []selectItem, preChecked []string) *multiSelectState {
s := &multiSelectState{
items: items,
itemIndex: make(map[string]int, len(items)),
checked: make(map[int]bool),
}
for i, item := range items {
s.itemIndex[item.Name] = i
}
for _, name := range preChecked {
if idx, ok := s.itemIndex[name]; ok {
s.checked[idx] = true
s.checkOrder = append(s.checkOrder, idx)
}
}
return s
}
func (s *multiSelectState) filtered() []selectItem {
return filterItems(s.items, s.filter)
}
func (s *multiSelectState) toggleItem() {
filtered := s.filtered()
if len(filtered) == 0 || s.highlighted >= len(filtered) {
return
}
item := filtered[s.highlighted]
origIdx := s.itemIndex[item.Name]
if s.checked[origIdx] {
delete(s.checked, origIdx)
for i, idx := range s.checkOrder {
if idx == origIdx {
s.checkOrder = append(s.checkOrder[:i], s.checkOrder[i+1:]...)
break
}
}
} else {
s.checked[origIdx] = true
s.checkOrder = append(s.checkOrder, origIdx)
}
}
func (s *multiSelectState) handleInput(event inputEvent, char byte) (done bool, result []string, err error) {
filtered := s.filtered()
switch event {
case eventEnter:
if s.focusOnButton && len(s.checkOrder) > 0 {
var res []string
for _, idx := range s.checkOrder {
res = append(res, s.items[idx].Name)
}
return true, res, nil
} else if !s.focusOnButton {
s.toggleItem()
}
case eventTab:
if len(s.checkOrder) > 0 {
s.focusOnButton = !s.focusOnButton
}
case eventEscape:
return true, nil, errCancelled
case eventBackspace:
if len(s.filter) > 0 {
s.filter = s.filter[:len(s.filter)-1]
s.highlighted = 0
s.scrollOffset = 0
s.focusOnButton = false
}
case eventUp:
if s.focusOnButton {
s.focusOnButton = false
} else if s.highlighted > 0 {
s.highlighted--
if s.highlighted < s.scrollOffset {
s.scrollOffset = s.highlighted
}
}
case eventDown:
if s.focusOnButton {
s.focusOnButton = false
} else if s.highlighted < len(filtered)-1 {
s.highlighted++
if s.highlighted >= s.scrollOffset+maxDisplayedItems {
s.scrollOffset = s.highlighted - maxDisplayedItems + 1
}
}
case eventChar:
s.filter += string(char)
s.highlighted = 0
s.scrollOffset = 0
s.focusOnButton = false
}
return false, nil, nil
}
func (s *multiSelectState) selectedCount() int {
return len(s.checkOrder)
}
// Terminal I/O handling
type terminalState struct {
fd int
oldState *term.State
}
func enterRawMode() (*terminalState, error) {
fd := int(os.Stdin.Fd())
oldState, err := term.MakeRaw(fd)
if err != nil {
return nil, err
}
fmt.Fprint(os.Stderr, ansiHideCursor)
return &terminalState{fd: fd, oldState: oldState}, nil
}
func (t *terminalState) restore() {
fmt.Fprint(os.Stderr, ansiShowCursor)
term.Restore(t.fd, t.oldState)
}
func clearLines(n int) {
if n > 0 {
fmt.Fprintf(os.Stderr, "\033[%dA", n)
fmt.Fprint(os.Stderr, ansiClearDown)
}
}
func parseInput(r io.Reader) (inputEvent, byte, error) {
buf := make([]byte, 3)
n, err := r.Read(buf)
if err != nil {
return 0, 0, err
}
switch {
case n == 1 && buf[0] == 13:
return eventEnter, 0, nil
case n == 1 && (buf[0] == 3 || buf[0] == 27):
return eventEscape, 0, nil
case n == 1 && buf[0] == 9:
return eventTab, 0, nil
case n == 1 && buf[0] == 127:
return eventBackspace, 0, nil
case n == 3 && buf[0] == 27 && buf[1] == 91 && buf[2] == 65:
return eventUp, 0, nil
case n == 3 && buf[0] == 27 && buf[1] == 91 && buf[2] == 66:
return eventDown, 0, nil
case n == 1 && buf[0] >= 32 && buf[0] < 127:
return eventChar, buf[0], nil
}
return eventNone, 0, nil
}
// Rendering
func renderSelect(w io.Writer, prompt string, s *selectState) int {
filtered := s.filtered()
if s.filter == "" {
fmt.Fprintf(w, "%s %sType to filter...%s\r\n", prompt, ansiGray, ansiReset)
} else {
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
}
lineCount := 1
if len(filtered) == 0 {
fmt.Fprintf(w, " %s(no matches)%s\r\n", ansiGray, ansiReset)
lineCount++
} else {
displayCount := min(len(filtered), maxDisplayedItems)
for i := range displayCount {
idx := s.scrollOffset + i
if idx >= len(filtered) {
break
}
item := filtered[idx]
prefix := " "
if idx == s.selected {
prefix = " " + ansiBold + "> "
}
if item.Description != "" {
fmt.Fprintf(w, "%s%s%s %s- %s%s\r\n", prefix, item.Name, ansiReset, ansiGray, item.Description, ansiReset)
} else {
fmt.Fprintf(w, "%s%s%s\r\n", prefix, item.Name, ansiReset)
}
lineCount++
}
if remaining := len(filtered) - s.scrollOffset - displayCount; remaining > 0 {
fmt.Fprintf(w, " %s... and %d more%s\r\n", ansiGray, remaining, ansiReset)
lineCount++
}
}
return lineCount
}
func renderMultiSelect(w io.Writer, prompt string, s *multiSelectState) int {
filtered := s.filtered()
if s.filter == "" {
fmt.Fprintf(w, "%s %sType to filter...%s\r\n", prompt, ansiGray, ansiReset)
} else {
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
}
lineCount := 1
if len(filtered) == 0 {
fmt.Fprintf(w, " %s(no matches)%s\r\n", ansiGray, ansiReset)
lineCount++
} else {
displayCount := min(len(filtered), maxDisplayedItems)
for i := range displayCount {
idx := s.scrollOffset + i
if idx >= len(filtered) {
break
}
item := filtered[idx]
origIdx := s.itemIndex[item.Name]
checkbox := "[ ]"
if s.checked[origIdx] {
checkbox = "[x]"
}
prefix := " "
suffix := ""
if idx == s.highlighted && !s.focusOnButton {
prefix = "> "
}
if len(s.checkOrder) > 0 && s.checkOrder[0] == origIdx {
suffix = " " + ansiGray + "(default)" + ansiReset
}
desc := ""
if item.Description != "" {
desc = " " + ansiGray + "- " + item.Description + ansiReset
}
if idx == s.highlighted && !s.focusOnButton {
fmt.Fprintf(w, " %s%s %s %s%s%s%s\r\n", ansiBold, prefix, checkbox, item.Name, ansiReset, desc, suffix)
} else {
fmt.Fprintf(w, " %s %s %s%s%s\r\n", prefix, checkbox, item.Name, desc, suffix)
}
lineCount++
}
if remaining := len(filtered) - s.scrollOffset - displayCount; remaining > 0 {
fmt.Fprintf(w, " %s... and %d more%s\r\n", ansiGray, remaining, ansiReset)
lineCount++
}
}
fmt.Fprintf(w, "\r\n")
lineCount++
count := s.selectedCount()
switch {
case count == 0:
fmt.Fprintf(w, " %sSelect at least one model.%s\r\n", ansiGray, ansiReset)
case s.focusOnButton:
fmt.Fprintf(w, " %s> [ Continue ]%s %s(%d selected)%s\r\n", ansiBold, ansiReset, ansiGray, count, ansiReset)
default:
fmt.Fprintf(w, " %s[ Continue ] (%d selected) - press Tab%s\r\n", ansiGray, count, ansiReset)
}
lineCount++
return lineCount
}
// selectPrompt prompts the user to select a single item from a list.
func selectPrompt(prompt string, items []selectItem) (string, error) {
if len(items) == 0 {
return "", fmt.Errorf("no items to select from")
}
ts, err := enterRawMode()
if err != nil {
return "", err
}
defer ts.restore()
state := newSelectState(items)
var lastLineCount int
render := func() {
clearLines(lastLineCount)
lastLineCount = renderSelect(os.Stderr, prompt, state)
}
render()
for {
event, char, err := parseInput(os.Stdin)
if err != nil {
return "", err
}
done, result, err := state.handleInput(event, char)
if done {
clearLines(lastLineCount)
if err != nil {
return "", err
}
return result, nil
}
render()
}
}
// multiSelectPrompt prompts the user to select multiple items from a list.
func multiSelectPrompt(prompt string, items []selectItem, preChecked []string) ([]string, error) {
if len(items) == 0 {
return nil, fmt.Errorf("no items to select from")
}
ts, err := enterRawMode()
if err != nil {
return nil, err
}
defer ts.restore()
state := newMultiSelectState(items, preChecked)
var lastLineCount int
render := func() {
clearLines(lastLineCount)
lastLineCount = renderMultiSelect(os.Stderr, prompt, state)
}
render()
for {
event, char, err := parseInput(os.Stdin)
if err != nil {
return nil, err
}
done, result, err := state.handleInput(event, char)
if done {
clearLines(lastLineCount)
if err != nil {
return nil, err
}
return result, nil
}
render()
}
}
// DefaultConfirmPrompt provides a TUI-based confirmation prompt.
// When set, confirmPrompt delegates to it instead of using raw terminal I/O.
var DefaultConfirmPrompt func(prompt string) (bool, error)
func confirmPrompt(prompt string) (bool, error) {
if DefaultConfirmPrompt != nil {
return DefaultConfirmPrompt(prompt)
}
fd := int(os.Stdin.Fd())
oldState, err := term.MakeRaw(fd)
if err != nil {
@@ -496,17 +56,3 @@ func confirmPrompt(prompt string) (bool, error) {
}
}
}
func filterItems(items []selectItem, filter string) []selectItem {
if filter == "" {
return items
}
var result []selectItem
filterLower := strings.ToLower(filter)
for _, item := range items {
if strings.Contains(strings.ToLower(item.Name), filterLower) {
result = append(result, item)
}
}
return result
}

View File

@@ -1,651 +1,9 @@
package config
import (
"bytes"
"strings"
"testing"
)
func TestFilterItems(t *testing.T) {
items := []selectItem{
{Name: "llama3.2:latest"},
{Name: "qwen2.5:7b"},
{Name: "deepseek-v3:cloud"},
{Name: "GPT-OSS:20b"},
}
t.Run("EmptyFilter_ReturnsAllItems", func(t *testing.T) {
result := filterItems(items, "")
if len(result) != len(items) {
t.Errorf("expected %d items, got %d", len(items), len(result))
}
})
t.Run("CaseInsensitive_UppercaseFilterMatchesLowercase", func(t *testing.T) {
result := filterItems(items, "LLAMA")
if len(result) != 1 || result[0].Name != "llama3.2:latest" {
t.Errorf("expected llama3.2:latest, got %v", result)
}
})
t.Run("CaseInsensitive_LowercaseFilterMatchesUppercase", func(t *testing.T) {
result := filterItems(items, "gpt")
if len(result) != 1 || result[0].Name != "GPT-OSS:20b" {
t.Errorf("expected GPT-OSS:20b, got %v", result)
}
})
t.Run("PartialMatch", func(t *testing.T) {
result := filterItems(items, "deep")
if len(result) != 1 || result[0].Name != "deepseek-v3:cloud" {
t.Errorf("expected deepseek-v3:cloud, got %v", result)
}
})
t.Run("NoMatch_ReturnsEmpty", func(t *testing.T) {
result := filterItems(items, "nonexistent")
if len(result) != 0 {
t.Errorf("expected 0 items, got %d", len(result))
}
})
}
func TestSelectState(t *testing.T) {
items := []selectItem{
{Name: "item1"},
{Name: "item2"},
{Name: "item3"},
}
t.Run("InitialState", func(t *testing.T) {
s := newSelectState(items)
if s.selected != 0 {
t.Errorf("expected selected=0, got %d", s.selected)
}
if s.filter != "" {
t.Errorf("expected empty filter, got %q", s.filter)
}
if s.scrollOffset != 0 {
t.Errorf("expected scrollOffset=0, got %d", s.scrollOffset)
}
})
t.Run("Enter_SelectsCurrentItem", func(t *testing.T) {
s := newSelectState(items)
done, result, err := s.handleInput(eventEnter, 0)
if !done || result != "item1" || err != nil {
t.Errorf("expected (true, item1, nil), got (%v, %v, %v)", done, result, err)
}
})
t.Run("Enter_WithFilter_SelectsFilteredItem", func(t *testing.T) {
s := newSelectState(items)
s.filter = "item3"
done, result, err := s.handleInput(eventEnter, 0)
if !done || result != "item3" || err != nil {
t.Errorf("expected (true, item3, nil), got (%v, %v, %v)", done, result, err)
}
})
t.Run("Enter_EmptyFilteredList_DoesNothing", func(t *testing.T) {
s := newSelectState(items)
s.filter = "nonexistent"
done, result, err := s.handleInput(eventEnter, 0)
if done || result != "" || err != nil {
t.Errorf("expected (false, '', nil), got (%v, %v, %v)", done, result, err)
}
})
t.Run("Escape_ReturnsCancelledError", func(t *testing.T) {
s := newSelectState(items)
done, result, err := s.handleInput(eventEscape, 0)
if !done || result != "" || err != errCancelled {
t.Errorf("expected (true, '', errCancelled), got (%v, %v, %v)", done, result, err)
}
})
t.Run("Down_MovesSelection", func(t *testing.T) {
s := newSelectState(items)
s.handleInput(eventDown, 0)
if s.selected != 1 {
t.Errorf("expected selected=1, got %d", s.selected)
}
})
t.Run("Down_AtBottom_StaysAtBottom", func(t *testing.T) {
s := newSelectState(items)
s.selected = 2
s.handleInput(eventDown, 0)
if s.selected != 2 {
t.Errorf("expected selected=2 (stayed at bottom), got %d", s.selected)
}
})
t.Run("Up_MovesSelection", func(t *testing.T) {
s := newSelectState(items)
s.selected = 2
s.handleInput(eventUp, 0)
if s.selected != 1 {
t.Errorf("expected selected=1, got %d", s.selected)
}
})
t.Run("Up_AtTop_StaysAtTop", func(t *testing.T) {
s := newSelectState(items)
s.handleInput(eventUp, 0)
if s.selected != 0 {
t.Errorf("expected selected=0 (stayed at top), got %d", s.selected)
}
})
t.Run("Char_AppendsToFilter", func(t *testing.T) {
s := newSelectState(items)
s.handleInput(eventChar, 'i')
s.handleInput(eventChar, 't')
s.handleInput(eventChar, 'e')
s.handleInput(eventChar, 'm')
s.handleInput(eventChar, '2')
if s.filter != "item2" {
t.Errorf("expected filter='item2', got %q", s.filter)
}
filtered := s.filtered()
if len(filtered) != 1 || filtered[0].Name != "item2" {
t.Errorf("expected [item2], got %v", filtered)
}
})
t.Run("Char_ResetsSelectionToZero", func(t *testing.T) {
s := newSelectState(items)
s.selected = 2
s.handleInput(eventChar, 'x')
if s.selected != 0 {
t.Errorf("expected selected=0 after typing, got %d", s.selected)
}
})
t.Run("Backspace_RemovesLastFilterChar", func(t *testing.T) {
s := newSelectState(items)
s.filter = "test"
s.handleInput(eventBackspace, 0)
if s.filter != "tes" {
t.Errorf("expected filter='tes', got %q", s.filter)
}
})
t.Run("Backspace_EmptyFilter_DoesNothing", func(t *testing.T) {
s := newSelectState(items)
s.handleInput(eventBackspace, 0)
if s.filter != "" {
t.Errorf("expected filter='', got %q", s.filter)
}
})
t.Run("Backspace_ResetsSelectionToZero", func(t *testing.T) {
s := newSelectState(items)
s.filter = "test"
s.selected = 2
s.handleInput(eventBackspace, 0)
if s.selected != 0 {
t.Errorf("expected selected=0 after backspace, got %d", s.selected)
}
})
t.Run("Scroll_DownPastVisibleItems_ScrollsViewport", func(t *testing.T) {
// maxDisplayedItems is 10, so with 15 items we need to scroll
manyItems := make([]selectItem, 15)
for i := range manyItems {
manyItems[i] = selectItem{Name: string(rune('a' + i))}
}
s := newSelectState(manyItems)
// move down 12 times (past the 10-item viewport)
for range 12 {
s.handleInput(eventDown, 0)
}
if s.selected != 12 {
t.Errorf("expected selected=12, got %d", s.selected)
}
if s.scrollOffset != 3 {
t.Errorf("expected scrollOffset=3 (12-10+1), got %d", s.scrollOffset)
}
})
t.Run("Scroll_UpPastScrollOffset_ScrollsViewport", func(t *testing.T) {
manyItems := make([]selectItem, 15)
for i := range manyItems {
manyItems[i] = selectItem{Name: string(rune('a' + i))}
}
s := newSelectState(manyItems)
s.selected = 5
s.scrollOffset = 5
s.handleInput(eventUp, 0)
if s.selected != 4 {
t.Errorf("expected selected=4, got %d", s.selected)
}
if s.scrollOffset != 4 {
t.Errorf("expected scrollOffset=4, got %d", s.scrollOffset)
}
})
}
func TestMultiSelectState(t *testing.T) {
items := []selectItem{
{Name: "item1"},
{Name: "item2"},
{Name: "item3"},
}
t.Run("InitialState_NoPrechecked", func(t *testing.T) {
s := newMultiSelectState(items, nil)
if s.highlighted != 0 {
t.Errorf("expected highlighted=0, got %d", s.highlighted)
}
if s.selectedCount() != 0 {
t.Errorf("expected 0 selected, got %d", s.selectedCount())
}
if s.focusOnButton {
t.Error("expected focusOnButton=false initially")
}
})
t.Run("InitialState_WithPrechecked", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item2", "item3"})
if s.selectedCount() != 2 {
t.Errorf("expected 2 selected, got %d", s.selectedCount())
}
if !s.checked[1] || !s.checked[2] {
t.Error("expected item2 and item3 to be checked")
}
})
t.Run("Prechecked_PreservesSelectionOrder", func(t *testing.T) {
// order matters: first checked = default model
s := newMultiSelectState(items, []string{"item3", "item1"})
if len(s.checkOrder) != 2 {
t.Fatalf("expected 2 in checkOrder, got %d", len(s.checkOrder))
}
if s.checkOrder[0] != 2 || s.checkOrder[1] != 0 {
t.Errorf("expected checkOrder=[2,0] (item3 first), got %v", s.checkOrder)
}
})
t.Run("Prechecked_IgnoresInvalidNames", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1", "nonexistent"})
if s.selectedCount() != 1 {
t.Errorf("expected 1 selected (nonexistent ignored), got %d", s.selectedCount())
}
})
t.Run("Toggle_ChecksUncheckedItem", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.toggleItem()
if !s.checked[0] {
t.Error("expected item1 to be checked after toggle")
}
})
t.Run("Toggle_UnchecksCheckedItem", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
s.toggleItem()
if s.checked[0] {
t.Error("expected item1 to be unchecked after toggle")
}
})
t.Run("Toggle_RemovesFromCheckOrder", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1", "item2", "item3"})
s.highlighted = 1 // toggle item2
s.toggleItem()
if len(s.checkOrder) != 2 {
t.Fatalf("expected 2 in checkOrder, got %d", len(s.checkOrder))
}
// should be [0, 2] (item1, item3) with item2 removed
if s.checkOrder[0] != 0 || s.checkOrder[1] != 2 {
t.Errorf("expected checkOrder=[0,2], got %v", s.checkOrder)
}
})
t.Run("Enter_TogglesWhenNotOnButton", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.handleInput(eventEnter, 0)
if !s.checked[0] {
t.Error("expected item1 to be checked after enter")
}
})
t.Run("Enter_OnButton_ReturnsSelection", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item2", "item1"})
s.focusOnButton = true
done, result, err := s.handleInput(eventEnter, 0)
if !done || err != nil {
t.Errorf("expected done=true, err=nil, got done=%v, err=%v", done, err)
}
// result should preserve selection order
if len(result) != 2 || result[0] != "item2" || result[1] != "item1" {
t.Errorf("expected [item2, item1], got %v", result)
}
})
t.Run("Enter_OnButton_EmptySelection_DoesNothing", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.focusOnButton = true
done, result, err := s.handleInput(eventEnter, 0)
if done || result != nil || err != nil {
t.Errorf("expected (false, nil, nil), got (%v, %v, %v)", done, result, err)
}
})
t.Run("Tab_SwitchesToButton_WhenHasSelection", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
s.handleInput(eventTab, 0)
if !s.focusOnButton {
t.Error("expected focus on button after tab")
}
})
t.Run("Tab_DoesNothing_WhenNoSelection", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.handleInput(eventTab, 0)
if s.focusOnButton {
t.Error("tab should not focus button when nothing selected")
}
})
t.Run("Tab_TogglesButtonFocus", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
s.handleInput(eventTab, 0)
if !s.focusOnButton {
t.Error("expected focus on button after first tab")
}
s.handleInput(eventTab, 0)
if s.focusOnButton {
t.Error("expected focus back on list after second tab")
}
})
t.Run("Escape_ReturnsCancelledError", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
done, result, err := s.handleInput(eventEscape, 0)
if !done || result != nil || err != errCancelled {
t.Errorf("expected (true, nil, errCancelled), got (%v, %v, %v)", done, result, err)
}
})
t.Run("IsDefault_TrueForFirstChecked", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item2", "item1"})
if !(len(s.checkOrder) > 0 && s.checkOrder[0] == 1) {
t.Error("expected item2 (idx 1) to be default (first checked)")
}
if len(s.checkOrder) > 0 && s.checkOrder[0] == 0 {
t.Error("expected item1 (idx 0) to NOT be default")
}
})
t.Run("IsDefault_FalseWhenNothingChecked", func(t *testing.T) {
s := newMultiSelectState(items, nil)
if len(s.checkOrder) > 0 && s.checkOrder[0] == 0 {
t.Error("expected isDefault=false when nothing checked")
}
})
t.Run("Down_MovesHighlight", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.handleInput(eventDown, 0)
if s.highlighted != 1 {
t.Errorf("expected highlighted=1, got %d", s.highlighted)
}
})
t.Run("Up_MovesHighlight", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.highlighted = 1
s.handleInput(eventUp, 0)
if s.highlighted != 0 {
t.Errorf("expected highlighted=0, got %d", s.highlighted)
}
})
t.Run("Arrow_ReturnsFocusFromButton", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
s.focusOnButton = true
s.handleInput(eventDown, 0)
if s.focusOnButton {
t.Error("expected focus to return to list on arrow key")
}
})
t.Run("Char_AppendsToFilter", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.handleInput(eventChar, 'x')
if s.filter != "x" {
t.Errorf("expected filter='x', got %q", s.filter)
}
})
t.Run("Char_ResetsHighlightAndScroll", func(t *testing.T) {
manyItems := make([]selectItem, 15)
for i := range manyItems {
manyItems[i] = selectItem{Name: string(rune('a' + i))}
}
s := newMultiSelectState(manyItems, nil)
s.highlighted = 10
s.scrollOffset = 5
s.handleInput(eventChar, 'x')
if s.highlighted != 0 {
t.Errorf("expected highlighted=0, got %d", s.highlighted)
}
if s.scrollOffset != 0 {
t.Errorf("expected scrollOffset=0, got %d", s.scrollOffset)
}
})
t.Run("Backspace_RemovesLastFilterChar", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.filter = "test"
s.handleInput(eventBackspace, 0)
if s.filter != "tes" {
t.Errorf("expected filter='tes', got %q", s.filter)
}
})
t.Run("Backspace_RemovesFocusFromButton", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
s.filter = "x"
s.focusOnButton = true
s.handleInput(eventBackspace, 0)
if s.focusOnButton {
t.Error("expected focusOnButton=false after backspace")
}
})
}
func TestParseInput(t *testing.T) {
t.Run("Enter", func(t *testing.T) {
event, char, err := parseInput(bytes.NewReader([]byte{13}))
if err != nil || event != eventEnter || char != 0 {
t.Errorf("expected (eventEnter, 0, nil), got (%v, %v, %v)", event, char, err)
}
})
t.Run("Escape", func(t *testing.T) {
event, _, err := parseInput(bytes.NewReader([]byte{27}))
if err != nil || event != eventEscape {
t.Errorf("expected eventEscape, got %v", event)
}
})
t.Run("CtrlC_TreatedAsEscape", func(t *testing.T) {
event, _, err := parseInput(bytes.NewReader([]byte{3}))
if err != nil || event != eventEscape {
t.Errorf("expected eventEscape for Ctrl+C, got %v", event)
}
})
t.Run("Tab", func(t *testing.T) {
event, _, err := parseInput(bytes.NewReader([]byte{9}))
if err != nil || event != eventTab {
t.Errorf("expected eventTab, got %v", event)
}
})
t.Run("Backspace", func(t *testing.T) {
event, _, err := parseInput(bytes.NewReader([]byte{127}))
if err != nil || event != eventBackspace {
t.Errorf("expected eventBackspace, got %v", event)
}
})
t.Run("UpArrow", func(t *testing.T) {
event, _, err := parseInput(bytes.NewReader([]byte{27, 91, 65}))
if err != nil || event != eventUp {
t.Errorf("expected eventUp, got %v", event)
}
})
t.Run("DownArrow", func(t *testing.T) {
event, _, err := parseInput(bytes.NewReader([]byte{27, 91, 66}))
if err != nil || event != eventDown {
t.Errorf("expected eventDown, got %v", event)
}
})
t.Run("PrintableChars", func(t *testing.T) {
tests := []struct {
name string
char byte
}{
{"lowercase", 'a'},
{"uppercase", 'Z'},
{"digit", '5'},
{"space", ' '},
{"tilde", '~'},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
event, char, err := parseInput(bytes.NewReader([]byte{tt.char}))
if err != nil || event != eventChar || char != tt.char {
t.Errorf("expected (eventChar, %q), got (%v, %q)", tt.char, event, char)
}
})
}
})
}
func TestRenderSelect(t *testing.T) {
items := []selectItem{
{Name: "item1", Description: "first item"},
{Name: "item2"},
}
t.Run("ShowsPromptAndItems", func(t *testing.T) {
s := newSelectState(items)
var buf bytes.Buffer
lineCount := renderSelect(&buf, "Select:", s)
output := buf.String()
if !strings.Contains(output, "Select:") {
t.Error("expected prompt in output")
}
if !strings.Contains(output, "item1") {
t.Error("expected item1 in output")
}
if !strings.Contains(output, "first item") {
t.Error("expected description in output")
}
if !strings.Contains(output, "item2") {
t.Error("expected item2 in output")
}
if lineCount != 3 { // 1 prompt + 2 items
t.Errorf("expected 3 lines, got %d", lineCount)
}
})
t.Run("EmptyFilteredList_ShowsNoMatches", func(t *testing.T) {
s := newSelectState(items)
s.filter = "xyz"
var buf bytes.Buffer
renderSelect(&buf, "Select:", s)
if !strings.Contains(buf.String(), "no matches") {
t.Error("expected 'no matches' message")
}
})
t.Run("LongList_ShowsRemainingCount", func(t *testing.T) {
manyItems := make([]selectItem, 15)
for i := range manyItems {
manyItems[i] = selectItem{Name: string(rune('a' + i))}
}
s := newSelectState(manyItems)
var buf bytes.Buffer
renderSelect(&buf, "Select:", s)
// 15 items - 10 displayed = 5 more
if !strings.Contains(buf.String(), "5 more") {
t.Error("expected '5 more' indicator")
}
})
}
func TestRenderMultiSelect(t *testing.T) {
items := []selectItem{
{Name: "item1"},
{Name: "item2"},
}
t.Run("ShowsCheckboxes", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
var buf bytes.Buffer
renderMultiSelect(&buf, "Select:", s)
output := buf.String()
if !strings.Contains(output, "[x]") {
t.Error("expected checked checkbox [x]")
}
if !strings.Contains(output, "[ ]") {
t.Error("expected unchecked checkbox [ ]")
}
})
t.Run("ShowsDefaultMarker", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
var buf bytes.Buffer
renderMultiSelect(&buf, "Select:", s)
if !strings.Contains(buf.String(), "(default)") {
t.Error("expected (default) marker for first checked item")
}
})
t.Run("ShowsSelectedCount", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1", "item2"})
var buf bytes.Buffer
renderMultiSelect(&buf, "Select:", s)
if !strings.Contains(buf.String(), "2 selected") {
t.Error("expected '2 selected' in output")
}
})
t.Run("NoSelection_ShowsHelperText", func(t *testing.T) {
s := newMultiSelectState(items, nil)
var buf bytes.Buffer
renderMultiSelect(&buf, "Select:", s)
if !strings.Contains(buf.String(), "Select at least one") {
t.Error("expected 'Select at least one' helper text")
}
})
}
func TestErrCancelled(t *testing.T) {
t.Run("NotNil", func(t *testing.T) {
if errCancelled == nil {
@@ -659,255 +17,3 @@ func TestErrCancelled(t *testing.T) {
}
})
}
// Edge case tests for selector.go
// TestSelectState_SingleItem verifies that single item list works without crash.
// List with only one item should still work.
func TestSelectState_SingleItem(t *testing.T) {
items := []selectItem{{Name: "only-one"}}
s := newSelectState(items)
// Down should do nothing (already at bottom)
s.handleInput(eventDown, 0)
if s.selected != 0 {
t.Errorf("down on single item: expected selected=0, got %d", s.selected)
}
// Up should do nothing (already at top)
s.handleInput(eventUp, 0)
if s.selected != 0 {
t.Errorf("up on single item: expected selected=0, got %d", s.selected)
}
// Enter should select the only item
done, result, err := s.handleInput(eventEnter, 0)
if !done || result != "only-one" || err != nil {
t.Errorf("enter on single item: expected (true, 'only-one', nil), got (%v, %q, %v)", done, result, err)
}
}
// TestSelectState_ExactlyMaxItems verifies boundary condition at maxDisplayedItems.
// List with exactly maxDisplayedItems items should not scroll.
func TestSelectState_ExactlyMaxItems(t *testing.T) {
items := make([]selectItem, maxDisplayedItems)
for i := range items {
items[i] = selectItem{Name: string(rune('a' + i))}
}
s := newSelectState(items)
// Move to last item
for range maxDisplayedItems - 1 {
s.handleInput(eventDown, 0)
}
if s.selected != maxDisplayedItems-1 {
t.Errorf("expected selected=%d, got %d", maxDisplayedItems-1, s.selected)
}
// Should not scroll when exactly at max
if s.scrollOffset != 0 {
t.Errorf("expected scrollOffset=0 for exactly maxDisplayedItems, got %d", s.scrollOffset)
}
// One more down should do nothing
s.handleInput(eventDown, 0)
if s.selected != maxDisplayedItems-1 {
t.Errorf("down at max: expected selected=%d, got %d", maxDisplayedItems-1, s.selected)
}
}
// TestFilterItems_RegexSpecialChars verifies that filter is literal, not regex.
// User typing "model.v1" shouldn't match "modelsv1".
func TestFilterItems_RegexSpecialChars(t *testing.T) {
items := []selectItem{
{Name: "model.v1"},
{Name: "modelsv1"},
{Name: "model-v1"},
}
// Filter with dot should only match literal dot
result := filterItems(items, "model.v1")
if len(result) != 1 {
t.Errorf("expected 1 exact match, got %d", len(result))
}
if len(result) > 0 && result[0].Name != "model.v1" {
t.Errorf("expected 'model.v1', got %s", result[0].Name)
}
// Other regex special chars should be literal too
items2 := []selectItem{
{Name: "test[0]"},
{Name: "test0"},
{Name: "test(1)"},
}
result2 := filterItems(items2, "test[0]")
if len(result2) != 1 || result2[0].Name != "test[0]" {
t.Errorf("expected only 'test[0]', got %v", result2)
}
}
// TestMultiSelectState_DuplicateNames documents handling of duplicate item names.
// itemIndex uses name as key - duplicates cause collision. This documents
// the current behavior: the last index for a duplicate name is stored
func TestMultiSelectState_DuplicateNames(t *testing.T) {
// Duplicate names - this is an edge case that shouldn't happen in practice
items := []selectItem{
{Name: "duplicate"},
{Name: "duplicate"},
{Name: "unique"},
}
s := newMultiSelectState(items, nil)
// DOCUMENTED BEHAVIOR: itemIndex maps name to LAST index
// When there are duplicates, only the last occurrence's index is stored
if s.itemIndex["duplicate"] != 1 {
t.Errorf("itemIndex should map 'duplicate' to last index (1), got %d", s.itemIndex["duplicate"])
}
// Toggle item at highlighted=0 (first "duplicate")
// Due to name collision, toggleItem uses itemIndex["duplicate"] = 1
// So it actually toggles the SECOND duplicate item, not the first
s.toggleItem()
// This documents the potentially surprising behavior:
// We toggled at highlighted=0, but itemIndex lookup returned 1
if !s.checked[1] {
t.Error("toggle should check index 1 (due to name collision in itemIndex)")
}
if s.checked[0] {
t.Log("Note: index 0 is NOT checked, even though highlighted=0 (name collision behavior)")
}
}
// TestSelectState_FilterReducesBelowSelection verifies selection resets when filter reduces list.
// Prevents index-out-of-bounds on next keystroke
func TestSelectState_FilterReducesBelowSelection(t *testing.T) {
items := []selectItem{
{Name: "apple"},
{Name: "banana"},
{Name: "cherry"},
}
s := newSelectState(items)
s.selected = 2 // Select "cherry"
// Type a filter that removes cherry from results
s.handleInput(eventChar, 'a') // Filter to "a" - matches "apple" and "banana"
// Selection should reset to 0
if s.selected != 0 {
t.Errorf("expected selected=0 after filter, got %d", s.selected)
}
filtered := s.filtered()
if len(filtered) != 2 {
t.Errorf("expected 2 filtered items, got %d", len(filtered))
}
}
// TestFilterItems_UnicodeCharacters verifies filtering works with UTF-8.
// Model names might contain unicode characters
func TestFilterItems_UnicodeCharacters(t *testing.T) {
items := []selectItem{
{Name: "llama-日本語"},
{Name: "模型-chinese"},
{Name: "émoji-🦙"},
{Name: "regular-model"},
}
t.Run("filter japanese", func(t *testing.T) {
result := filterItems(items, "日本")
if len(result) != 1 || result[0].Name != "llama-日本語" {
t.Errorf("expected llama-日本語, got %v", result)
}
})
t.Run("filter chinese", func(t *testing.T) {
result := filterItems(items, "模型")
if len(result) != 1 || result[0].Name != "模型-chinese" {
t.Errorf("expected 模型-chinese, got %v", result)
}
})
t.Run("filter emoji", func(t *testing.T) {
result := filterItems(items, "🦙")
if len(result) != 1 || result[0].Name != "émoji-🦙" {
t.Errorf("expected émoji-🦙, got %v", result)
}
})
t.Run("filter accented char", func(t *testing.T) {
result := filterItems(items, "émoji")
if len(result) != 1 || result[0].Name != "émoji-🦙" {
t.Errorf("expected émoji-🦙, got %v", result)
}
})
}
// TestMultiSelectState_FilterReducesBelowHighlight verifies highlight resets when filter reduces list.
func TestMultiSelectState_FilterReducesBelowHighlight(t *testing.T) {
items := []selectItem{
{Name: "apple"},
{Name: "banana"},
{Name: "cherry"},
}
s := newMultiSelectState(items, nil)
s.highlighted = 2 // Highlight "cherry"
// Type a filter that removes cherry
s.handleInput(eventChar, 'a')
if s.highlighted != 0 {
t.Errorf("expected highlighted=0 after filter, got %d", s.highlighted)
}
}
// TestMultiSelectState_EmptyItems verifies handling of empty item list.
// Empty list should be handled gracefully.
func TestMultiSelectState_EmptyItems(t *testing.T) {
s := newMultiSelectState([]selectItem{}, nil)
// Toggle should not panic on empty list
s.toggleItem()
if s.selectedCount() != 0 {
t.Errorf("expected 0 selected for empty list, got %d", s.selectedCount())
}
// Render should handle empty list
var buf bytes.Buffer
lineCount := renderMultiSelect(&buf, "Select:", s)
if lineCount == 0 {
t.Error("renderMultiSelect should produce output even for empty list")
}
if !strings.Contains(buf.String(), "no matches") {
t.Error("expected 'no matches' for empty list")
}
}
// TestSelectState_RenderWithDescriptions verifies rendering items with descriptions.
func TestSelectState_RenderWithDescriptions(t *testing.T) {
items := []selectItem{
{Name: "item1", Description: "First item description"},
{Name: "item2", Description: ""},
{Name: "item3", Description: "Third item"},
}
s := newSelectState(items)
var buf bytes.Buffer
renderSelect(&buf, "Select:", s)
output := buf.String()
if !strings.Contains(output, "First item description") {
t.Error("expected description to be rendered")
}
if !strings.Contains(output, "item2") {
t.Error("expected item without description to be rendered")
}
}

View File

@@ -10,19 +10,21 @@ import (
"github.com/ollama/ollama/api"
)
var errNotRunning = errors.New("could not connect to ollama server, run 'ollama serve' to start it")
func startApp(ctx context.Context, client *api.Client) error {
exe, err := os.Executable()
if err != nil {
return err
return errNotRunning
}
link, err := os.Readlink(exe)
if err != nil {
return err
return errNotRunning
}
r := regexp.MustCompile(`^.*/Ollama\s?\d*.app`)
m := r.FindStringSubmatch(link)
if len(m) != 1 {
return errors.New("could not find ollama app")
return errNotRunning
}
if err := exec.Command("/usr/bin/open", "-j", "-a", m[0], "--args", "--fast-startup").Run(); err != nil {
return err

109
cmd/tui/confirm.go Normal file
View File

@@ -0,0 +1,109 @@
package tui
import (
"fmt"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
var (
confirmActiveStyle = lipgloss.NewStyle().
Bold(true).
Background(lipgloss.AdaptiveColor{Light: "254", Dark: "236"})
confirmInactiveStyle = lipgloss.NewStyle().
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"})
)
type confirmModel struct {
prompt string
yes bool
confirmed bool
cancelled bool
width int
}
func (m confirmModel) Init() tea.Cmd {
return nil
}
func (m confirmModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.WindowSizeMsg:
wasSet := m.width > 0
m.width = msg.Width
if wasSet {
return m, tea.EnterAltScreen
}
return m, nil
case tea.KeyMsg:
switch msg.String() {
case "ctrl+c", "esc", "n":
m.cancelled = true
return m, tea.Quit
case "y":
m.yes = true
m.confirmed = true
return m, tea.Quit
case "enter":
m.confirmed = true
return m, tea.Quit
case "left", "h":
m.yes = true
case "right", "l":
m.yes = false
case "tab":
m.yes = !m.yes
}
}
return m, nil
}
func (m confirmModel) View() string {
if m.confirmed || m.cancelled {
return ""
}
var yesBtn, noBtn string
if m.yes {
yesBtn = confirmActiveStyle.Render(" Yes ")
noBtn = confirmInactiveStyle.Render(" No ")
} else {
yesBtn = confirmInactiveStyle.Render(" Yes ")
noBtn = confirmActiveStyle.Render(" No ")
}
s := selectorTitleStyle.Render(m.prompt) + "\n\n"
s += " " + yesBtn + " " + noBtn + "\n\n"
s += selectorHelpStyle.Render("←/→ navigate • enter confirm • esc cancel")
if m.width > 0 {
return lipgloss.NewStyle().MaxWidth(m.width).Render(s)
}
return s
}
// RunConfirm shows a bubbletea yes/no confirmation prompt.
// Returns true if the user confirmed, false if cancelled.
func RunConfirm(prompt string) (bool, error) {
m := confirmModel{
prompt: prompt,
yes: true, // default to yes
}
p := tea.NewProgram(m)
finalModel, err := p.Run()
if err != nil {
return false, fmt.Errorf("error running confirm: %w", err)
}
fm := finalModel.(confirmModel)
if fm.cancelled {
return false, ErrCancelled
}
return fm.yes, nil
}

208
cmd/tui/confirm_test.go Normal file
View File

@@ -0,0 +1,208 @@
package tui
import (
"strings"
"testing"
tea "github.com/charmbracelet/bubbletea"
)
func TestConfirmModel_DefaultsToYes(t *testing.T) {
m := confirmModel{prompt: "Download test?", yes: true}
if !m.yes {
t.Error("should default to yes")
}
}
func TestConfirmModel_View_ContainsPrompt(t *testing.T) {
m := confirmModel{prompt: "Download qwen3:8b?", yes: true}
got := m.View()
if !strings.Contains(got, "Download qwen3:8b?") {
t.Error("should contain the prompt text")
}
}
func TestConfirmModel_View_ContainsButtons(t *testing.T) {
m := confirmModel{prompt: "Download?", yes: true}
got := m.View()
if !strings.Contains(got, "Yes") {
t.Error("should contain Yes button")
}
if !strings.Contains(got, "No") {
t.Error("should contain No button")
}
}
func TestConfirmModel_View_ContainsHelp(t *testing.T) {
m := confirmModel{prompt: "Download?", yes: true}
got := m.View()
if !strings.Contains(got, "enter confirm") {
t.Error("should contain help text")
}
}
func TestConfirmModel_View_ClearsAfterConfirm(t *testing.T) {
m := confirmModel{prompt: "Download?", confirmed: true}
if m.View() != "" {
t.Error("View should return empty string after confirmation")
}
}
func TestConfirmModel_View_ClearsAfterCancel(t *testing.T) {
m := confirmModel{prompt: "Download?", cancelled: true}
if m.View() != "" {
t.Error("View should return empty string after cancellation")
}
}
func TestConfirmModel_EnterConfirmsYes(t *testing.T) {
m := confirmModel{prompt: "Download?", yes: true}
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyEnter})
fm := updated.(confirmModel)
if !fm.confirmed {
t.Error("enter should set confirmed=true")
}
if !fm.yes {
t.Error("enter with yes selected should keep yes=true")
}
if cmd == nil {
t.Error("enter should return tea.Quit")
}
}
func TestConfirmModel_EnterConfirmsNo(t *testing.T) {
m := confirmModel{prompt: "Download?", yes: false}
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyEnter})
fm := updated.(confirmModel)
if !fm.confirmed {
t.Error("enter should set confirmed=true")
}
if fm.yes {
t.Error("enter with no selected should keep yes=false")
}
if cmd == nil {
t.Error("enter should return tea.Quit")
}
}
func TestConfirmModel_EscCancels(t *testing.T) {
m := confirmModel{prompt: "Download?", yes: true}
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyEsc})
fm := updated.(confirmModel)
if !fm.cancelled {
t.Error("esc should set cancelled=true")
}
if cmd == nil {
t.Error("esc should return tea.Quit")
}
}
func TestConfirmModel_CtrlCCancels(t *testing.T) {
m := confirmModel{prompt: "Download?", yes: true}
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyCtrlC})
fm := updated.(confirmModel)
if !fm.cancelled {
t.Error("ctrl+c should set cancelled=true")
}
if cmd == nil {
t.Error("ctrl+c should return tea.Quit")
}
}
func TestConfirmModel_NCancels(t *testing.T) {
m := confirmModel{prompt: "Download?", yes: true}
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'n'}})
fm := updated.(confirmModel)
if !fm.cancelled {
t.Error("'n' should set cancelled=true")
}
if cmd == nil {
t.Error("'n' should return tea.Quit")
}
}
func TestConfirmModel_YConfirmsYes(t *testing.T) {
m := confirmModel{prompt: "Download?", yes: false}
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'y'}})
fm := updated.(confirmModel)
if !fm.confirmed {
t.Error("'y' should set confirmed=true")
}
if !fm.yes {
t.Error("'y' should set yes=true")
}
if cmd == nil {
t.Error("'y' should return tea.Quit")
}
}
func TestConfirmModel_ArrowKeysNavigate(t *testing.T) {
m := confirmModel{prompt: "Download?", yes: true}
// Right moves to No
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'l'}})
fm := updated.(confirmModel)
if fm.yes {
t.Error("right/l should move to No")
}
if fm.confirmed || fm.cancelled {
t.Error("navigation should not confirm or cancel")
}
// Left moves back to Yes
updated, _ = fm.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'h'}})
fm = updated.(confirmModel)
if !fm.yes {
t.Error("left/h should move to Yes")
}
}
func TestConfirmModel_TabToggles(t *testing.T) {
m := confirmModel{prompt: "Download?", yes: true}
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyTab})
fm := updated.(confirmModel)
if fm.yes {
t.Error("tab should toggle from Yes to No")
}
updated, _ = fm.Update(tea.KeyMsg{Type: tea.KeyTab})
fm = updated.(confirmModel)
if !fm.yes {
t.Error("tab should toggle from No to Yes")
}
}
func TestConfirmModel_WindowSizeUpdatesWidth(t *testing.T) {
m := confirmModel{prompt: "Download?"}
updated, _ := m.Update(tea.WindowSizeMsg{Width: 100, Height: 40})
fm := updated.(confirmModel)
if fm.width != 100 {
t.Errorf("expected width 100, got %d", fm.width)
}
}
func TestConfirmModel_ResizeEntersAltScreen(t *testing.T) {
m := confirmModel{prompt: "Download?", width: 80}
_, cmd := m.Update(tea.WindowSizeMsg{Width: 100, Height: 40})
if cmd == nil {
t.Error("resize (width already set) should return a command")
}
}
func TestConfirmModel_InitialWindowSizeNoAltScreen(t *testing.T) {
m := confirmModel{prompt: "Download?"}
_, cmd := m.Update(tea.WindowSizeMsg{Width: 80, Height: 40})
if cmd != nil {
t.Error("initial WindowSizeMsg should not return a command")
}
}
func TestConfirmModel_ViewMaxWidth(t *testing.T) {
m := confirmModel{prompt: "Download?", yes: true, width: 40}
got := m.View()
// Just ensure it doesn't panic and returns content
if got == "" {
t.Error("View with width set should still return content")
}
}

654
cmd/tui/selector.go Normal file
View File

@@ -0,0 +1,654 @@
package tui
import (
"errors"
"fmt"
"strings"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
var (
selectorTitleStyle = lipgloss.NewStyle().
Bold(true)
selectorItemStyle = lipgloss.NewStyle().
PaddingLeft(4)
selectorSelectedItemStyle = lipgloss.NewStyle().
PaddingLeft(2).
Bold(true).
Background(lipgloss.AdaptiveColor{Light: "254", Dark: "236"})
selectorDescStyle = lipgloss.NewStyle().
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"})
selectorDescLineStyle = selectorDescStyle.
PaddingLeft(6)
selectorFilterStyle = lipgloss.NewStyle().
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).
Italic(true)
selectorInputStyle = lipgloss.NewStyle().
Foreground(lipgloss.AdaptiveColor{Light: "235", Dark: "252"})
selectorCheckboxStyle = lipgloss.NewStyle().
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"})
selectorCheckboxCheckedStyle = lipgloss.NewStyle().
Bold(true)
selectorDefaultTagStyle = lipgloss.NewStyle().
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).
Italic(true)
selectorHelpStyle = lipgloss.NewStyle().
Foreground(lipgloss.AdaptiveColor{Light: "244", Dark: "244"})
selectorMoreStyle = lipgloss.NewStyle().
PaddingLeft(6).
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).
Italic(true)
sectionHeaderStyle = lipgloss.NewStyle().
PaddingLeft(2).
Bold(true).
Foreground(lipgloss.AdaptiveColor{Light: "240", Dark: "249"})
)
const maxSelectorItems = 10
// ErrCancelled is returned when the user cancels the selection.
var ErrCancelled = errors.New("cancelled")
type SelectItem struct {
Name string
Description string
Recommended bool
}
// selectorModel is the bubbletea model for single selection.
type selectorModel struct {
title string
items []SelectItem
filter string
cursor int
scrollOffset int
selected string
cancelled bool
helpText string
width int
}
func (m selectorModel) filteredItems() []SelectItem {
if m.filter == "" {
return m.items
}
filterLower := strings.ToLower(m.filter)
var result []SelectItem
for _, item := range m.items {
if strings.Contains(strings.ToLower(item.Name), filterLower) {
result = append(result, item)
}
}
return result
}
func (m selectorModel) Init() tea.Cmd {
return nil
}
// otherStart returns the index of the first non-recommended item in the filtered list.
// When filtering, all items scroll together so this returns 0.
func (m selectorModel) otherStart() int {
if m.filter != "" {
return 0
}
filtered := m.filteredItems()
for i, item := range filtered {
if !item.Recommended {
return i
}
}
return len(filtered)
}
// updateNavigation handles navigation keys (up/down/pgup/pgdown/filter/backspace).
// It does NOT handle Enter, Esc, or CtrlC. This is used by both the standalone
// selector and the TUI modal (which intercepts Enter/Esc for its own logic).
func (m *selectorModel) updateNavigation(msg tea.KeyMsg) {
filtered := m.filteredItems()
otherStart := m.otherStart()
switch msg.Type {
case tea.KeyUp:
if m.cursor > 0 {
m.cursor--
m.updateScroll(otherStart)
}
case tea.KeyDown:
if m.cursor < len(filtered)-1 {
m.cursor++
m.updateScroll(otherStart)
}
case tea.KeyPgUp:
m.cursor -= maxSelectorItems
if m.cursor < 0 {
m.cursor = 0
}
m.updateScroll(otherStart)
case tea.KeyPgDown:
m.cursor += maxSelectorItems
if m.cursor >= len(filtered) {
m.cursor = len(filtered) - 1
}
m.updateScroll(otherStart)
case tea.KeyBackspace:
if len(m.filter) > 0 {
m.filter = m.filter[:len(m.filter)-1]
m.cursor = 0
m.scrollOffset = 0
}
case tea.KeyRunes:
m.filter += string(msg.Runes)
m.cursor = 0
m.scrollOffset = 0
}
}
// updateScroll adjusts scrollOffset based on cursor position.
// When not filtering, scrollOffset is relative to the "More" (non-recommended) section.
// When filtering, it's relative to the full filtered list.
func (m *selectorModel) updateScroll(otherStart int) {
if m.filter != "" {
if m.cursor < m.scrollOffset {
m.scrollOffset = m.cursor
}
if m.cursor >= m.scrollOffset+maxSelectorItems {
m.scrollOffset = m.cursor - maxSelectorItems + 1
}
return
}
// Cursor is in recommended section — reset "More" scroll to top
if m.cursor < otherStart {
m.scrollOffset = 0
return
}
// Cursor is in "More" section — scroll relative to others
posInOthers := m.cursor - otherStart
maxOthers := maxSelectorItems - otherStart
if maxOthers < 3 {
maxOthers = 3
}
if posInOthers < m.scrollOffset {
m.scrollOffset = posInOthers
}
if posInOthers >= m.scrollOffset+maxOthers {
m.scrollOffset = posInOthers - maxOthers + 1
}
}
func (m selectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.WindowSizeMsg:
wasSet := m.width > 0
m.width = msg.Width
if wasSet {
return m, tea.EnterAltScreen
}
return m, nil
case tea.KeyMsg:
switch msg.Type {
case tea.KeyCtrlC, tea.KeyEsc:
m.cancelled = true
return m, tea.Quit
case tea.KeyEnter:
filtered := m.filteredItems()
if len(filtered) > 0 && m.cursor < len(filtered) {
m.selected = filtered[m.cursor].Name
}
return m, tea.Quit
default:
m.updateNavigation(msg)
}
}
return m, nil
}
func (m selectorModel) renderItem(s *strings.Builder, item SelectItem, idx int) {
if idx == m.cursor {
s.WriteString(selectorSelectedItemStyle.Render("▸ " + item.Name))
} else {
s.WriteString(selectorItemStyle.Render(item.Name))
}
s.WriteString("\n")
if item.Description != "" {
s.WriteString(selectorDescLineStyle.Render(item.Description))
s.WriteString("\n")
}
}
// renderContent renders the selector content (title, items, help text) without
// checking the cancelled/selected state. This is used by both View() (standalone mode)
// and by the TUI modal which embeds a selectorModel.
func (m selectorModel) renderContent() string {
var s strings.Builder
s.WriteString(selectorTitleStyle.Render(m.title))
s.WriteString(" ")
if m.filter == "" {
s.WriteString(selectorFilterStyle.Render("Type to filter..."))
} else {
s.WriteString(selectorInputStyle.Render(m.filter))
}
s.WriteString("\n\n")
filtered := m.filteredItems()
if len(filtered) == 0 {
s.WriteString(selectorItemStyle.Render(selectorDescStyle.Render("(no matches)")))
s.WriteString("\n")
} else if m.filter != "" {
s.WriteString(sectionHeaderStyle.Render("Top Results"))
s.WriteString("\n")
displayCount := min(len(filtered), maxSelectorItems)
for i := range displayCount {
idx := m.scrollOffset + i
if idx >= len(filtered) {
break
}
m.renderItem(&s, filtered[idx], idx)
}
if remaining := len(filtered) - m.scrollOffset - displayCount; remaining > 0 {
s.WriteString(selectorMoreStyle.Render(fmt.Sprintf("... and %d more", remaining)))
s.WriteString("\n")
}
} else {
// Split into pinned recommended and scrollable others
var recItems, otherItems []int
for i, item := range filtered {
if item.Recommended {
recItems = append(recItems, i)
} else {
otherItems = append(otherItems, i)
}
}
// Always render all recommended items (pinned)
if len(recItems) > 0 {
s.WriteString(sectionHeaderStyle.Render("Recommended"))
s.WriteString("\n")
for _, idx := range recItems {
m.renderItem(&s, filtered[idx], idx)
}
}
if len(otherItems) > 0 {
s.WriteString("\n")
s.WriteString(sectionHeaderStyle.Render("More"))
s.WriteString("\n")
maxOthers := maxSelectorItems - len(recItems)
if maxOthers < 3 {
maxOthers = 3
}
displayCount := min(len(otherItems), maxOthers)
for i := range displayCount {
idx := m.scrollOffset + i
if idx >= len(otherItems) {
break
}
m.renderItem(&s, filtered[otherItems[idx]], otherItems[idx])
}
if remaining := len(otherItems) - m.scrollOffset - displayCount; remaining > 0 {
s.WriteString(selectorMoreStyle.Render(fmt.Sprintf("... and %d more", remaining)))
s.WriteString("\n")
}
}
}
s.WriteString("\n")
help := "↑/↓ navigate • enter select • esc cancel"
if m.helpText != "" {
help = m.helpText
}
s.WriteString(selectorHelpStyle.Render(help))
return s.String()
}
func (m selectorModel) View() string {
if m.cancelled || m.selected != "" {
return ""
}
s := m.renderContent()
if m.width > 0 {
return lipgloss.NewStyle().MaxWidth(m.width).Render(s)
}
return s
}
func SelectSingle(title string, items []SelectItem) (string, error) {
if len(items) == 0 {
return "", fmt.Errorf("no items to select from")
}
m := selectorModel{
title: title,
items: items,
}
p := tea.NewProgram(m)
finalModel, err := p.Run()
if err != nil {
return "", fmt.Errorf("error running selector: %w", err)
}
fm := finalModel.(selectorModel)
if fm.cancelled {
return "", ErrCancelled
}
return fm.selected, nil
}
// multiSelectorModel is the bubbletea model for multi selection.
type multiSelectorModel struct {
title string
items []SelectItem
itemIndex map[string]int
filter string
cursor int
scrollOffset int
checked map[int]bool
checkOrder []int
cancelled bool
confirmed bool
width int
}
func newMultiSelectorModel(title string, items []SelectItem, preChecked []string) multiSelectorModel {
m := multiSelectorModel{
title: title,
items: items,
itemIndex: make(map[string]int, len(items)),
checked: make(map[int]bool),
}
for i, item := range items {
m.itemIndex[item.Name] = i
}
for _, name := range preChecked {
if idx, ok := m.itemIndex[name]; ok {
m.checked[idx] = true
m.checkOrder = append(m.checkOrder, idx)
}
}
return m
}
func (m multiSelectorModel) filteredItems() []SelectItem {
if m.filter == "" {
return m.items
}
filterLower := strings.ToLower(m.filter)
var result []SelectItem
for _, item := range m.items {
if strings.Contains(strings.ToLower(item.Name), filterLower) {
result = append(result, item)
}
}
return result
}
func (m *multiSelectorModel) toggleItem() {
filtered := m.filteredItems()
if len(filtered) == 0 || m.cursor >= len(filtered) {
return
}
item := filtered[m.cursor]
origIdx := m.itemIndex[item.Name]
if m.checked[origIdx] {
delete(m.checked, origIdx)
for i, idx := range m.checkOrder {
if idx == origIdx {
m.checkOrder = append(m.checkOrder[:i], m.checkOrder[i+1:]...)
break
}
}
} else {
m.checked[origIdx] = true
m.checkOrder = append(m.checkOrder, origIdx)
}
}
func (m multiSelectorModel) selectedCount() int {
return len(m.checkOrder)
}
func (m multiSelectorModel) Init() tea.Cmd {
return nil
}
func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.WindowSizeMsg:
wasSet := m.width > 0
m.width = msg.Width
if wasSet {
return m, tea.EnterAltScreen
}
return m, nil
case tea.KeyMsg:
filtered := m.filteredItems()
switch msg.Type {
case tea.KeyCtrlC, tea.KeyEsc:
m.cancelled = true
return m, tea.Quit
case tea.KeyEnter:
if len(m.checkOrder) > 0 {
m.confirmed = true
return m, tea.Quit
}
case tea.KeySpace:
m.toggleItem()
case tea.KeyUp:
if m.cursor > 0 {
m.cursor--
if m.cursor < m.scrollOffset {
m.scrollOffset = m.cursor
}
}
case tea.KeyDown:
if m.cursor < len(filtered)-1 {
m.cursor++
if m.cursor >= m.scrollOffset+maxSelectorItems {
m.scrollOffset = m.cursor - maxSelectorItems + 1
}
}
case tea.KeyPgUp:
m.cursor -= maxSelectorItems
if m.cursor < 0 {
m.cursor = 0
}
m.scrollOffset -= maxSelectorItems
if m.scrollOffset < 0 {
m.scrollOffset = 0
}
case tea.KeyPgDown:
m.cursor += maxSelectorItems
if m.cursor >= len(filtered) {
m.cursor = len(filtered) - 1
}
if m.cursor >= m.scrollOffset+maxSelectorItems {
m.scrollOffset = m.cursor - maxSelectorItems + 1
}
case tea.KeyBackspace:
if len(m.filter) > 0 {
m.filter = m.filter[:len(m.filter)-1]
m.cursor = 0
m.scrollOffset = 0
}
case tea.KeyRunes:
m.filter += string(msg.Runes)
m.cursor = 0
m.scrollOffset = 0
}
}
return m, nil
}
func (m multiSelectorModel) View() string {
if m.cancelled || m.confirmed {
return ""
}
var s strings.Builder
s.WriteString(selectorTitleStyle.Render(m.title))
s.WriteString(" ")
if m.filter == "" {
s.WriteString(selectorFilterStyle.Render("Type to filter..."))
} else {
s.WriteString(selectorInputStyle.Render(m.filter))
}
s.WriteString("\n\n")
filtered := m.filteredItems()
if len(filtered) == 0 {
s.WriteString(selectorItemStyle.Render(selectorDescStyle.Render("(no matches)")))
s.WriteString("\n")
} else {
displayCount := min(len(filtered), maxSelectorItems)
shownRecHeader := false
prevWasRec := false
for i := range displayCount {
idx := m.scrollOffset + i
if idx >= len(filtered) {
break
}
item := filtered[idx]
origIdx := m.itemIndex[item.Name]
if m.filter == "" {
if item.Recommended && !shownRecHeader {
s.WriteString(sectionHeaderStyle.Render("Recommended"))
s.WriteString("\n")
shownRecHeader = true
} else if !item.Recommended && prevWasRec {
s.WriteString("\n")
}
prevWasRec = item.Recommended
}
var checkbox string
if m.checked[origIdx] {
checkbox = selectorCheckboxCheckedStyle.Render("[x]")
} else {
checkbox = selectorCheckboxStyle.Render("[ ]")
}
var line string
if idx == m.cursor {
line = selectorSelectedItemStyle.Render("▸ ") + checkbox + " " + selectorSelectedItemStyle.Render(item.Name)
} else {
line = " " + checkbox + " " + item.Name
}
if len(m.checkOrder) > 0 && m.checkOrder[0] == origIdx {
line += " " + selectorDefaultTagStyle.Render("(default)")
}
s.WriteString(line)
s.WriteString("\n")
}
if remaining := len(filtered) - m.scrollOffset - displayCount; remaining > 0 {
s.WriteString(selectorMoreStyle.Render(fmt.Sprintf("... and %d more", remaining)))
s.WriteString("\n")
}
}
s.WriteString("\n")
count := m.selectedCount()
if count == 0 {
s.WriteString(selectorDescStyle.Render(" Select at least one model."))
} else {
s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count)))
}
s.WriteString("\n\n")
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • enter confirm • esc cancel"))
result := s.String()
if m.width > 0 {
return lipgloss.NewStyle().MaxWidth(m.width).Render(result)
}
return result
}
func SelectMultiple(title string, items []SelectItem, preChecked []string) ([]string, error) {
if len(items) == 0 {
return nil, fmt.Errorf("no items to select from")
}
m := newMultiSelectorModel(title, items, preChecked)
p := tea.NewProgram(m)
finalModel, err := p.Run()
if err != nil {
return nil, fmt.Errorf("error running selector: %w", err)
}
fm := finalModel.(multiSelectorModel)
if fm.cancelled {
return nil, ErrCancelled
}
if !fm.confirmed {
return nil, ErrCancelled
}
var result []string
for _, idx := range fm.checkOrder {
result = append(result, fm.items[idx].Name)
}
return result, nil
}

410
cmd/tui/selector_test.go Normal file
View File

@@ -0,0 +1,410 @@
package tui
import (
"strings"
"testing"
tea "github.com/charmbracelet/bubbletea"
)
func items(names ...string) []SelectItem {
var out []SelectItem
for _, n := range names {
out = append(out, SelectItem{Name: n})
}
return out
}
func recItems(names ...string) []SelectItem {
var out []SelectItem
for _, n := range names {
out = append(out, SelectItem{Name: n, Recommended: true})
}
return out
}
func mixedItems() []SelectItem {
return []SelectItem{
{Name: "rec-a", Recommended: true},
{Name: "rec-b", Recommended: true},
{Name: "other-1"},
{Name: "other-2"},
{Name: "other-3"},
{Name: "other-4"},
{Name: "other-5"},
{Name: "other-6"},
{Name: "other-7"},
{Name: "other-8"},
{Name: "other-9"},
{Name: "other-10"},
}
}
func TestFilteredItems(t *testing.T) {
tests := []struct {
name string
items []SelectItem
filter string
want []string
}{
{
name: "no filter returns all",
items: items("alpha", "beta", "gamma"),
filter: "",
want: []string{"alpha", "beta", "gamma"},
},
{
name: "filter matches substring",
items: items("llama3.2", "qwen3:8b", "llama2"),
filter: "llama",
want: []string{"llama3.2", "llama2"},
},
{
name: "filter is case insensitive",
items: items("Qwen3:8b", "llama3.2"),
filter: "QWEN",
want: []string{"Qwen3:8b"},
},
{
name: "no matches",
items: items("alpha", "beta"),
filter: "zzz",
want: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := selectorModel{items: tt.items, filter: tt.filter}
got := m.filteredItems()
var gotNames []string
for _, item := range got {
gotNames = append(gotNames, item.Name)
}
if len(gotNames) != len(tt.want) {
t.Fatalf("got %v, want %v", gotNames, tt.want)
}
for i := range tt.want {
if gotNames[i] != tt.want[i] {
t.Errorf("index %d: got %q, want %q", i, gotNames[i], tt.want[i])
}
}
})
}
}
func TestOtherStart(t *testing.T) {
tests := []struct {
name string
items []SelectItem
filter string
want int
}{
{
name: "all recommended",
items: recItems("a", "b", "c"),
want: 3,
},
{
name: "none recommended",
items: items("a", "b"),
want: 0,
},
{
name: "mixed",
items: []SelectItem{
{Name: "rec-a", Recommended: true},
{Name: "rec-b", Recommended: true},
{Name: "other-1"},
{Name: "other-2"},
},
want: 2,
},
{
name: "empty",
items: nil,
want: 0,
},
{
name: "filtering returns 0",
items: []SelectItem{
{Name: "rec-a", Recommended: true},
{Name: "other-1"},
},
filter: "rec",
want: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := selectorModel{items: tt.items, filter: tt.filter}
if got := m.otherStart(); got != tt.want {
t.Errorf("otherStart() = %d, want %d", got, tt.want)
}
})
}
}
func TestUpdateScroll(t *testing.T) {
tests := []struct {
name string
cursor int
offset int
otherStart int
filter string
wantOffset int
}{
{
name: "cursor in recommended resets scroll",
cursor: 1,
offset: 5,
otherStart: 3,
wantOffset: 0,
},
{
name: "cursor at start of others",
cursor: 2,
offset: 0,
otherStart: 2,
wantOffset: 0,
},
{
name: "cursor scrolls down in others",
cursor: 12,
offset: 0,
otherStart: 2,
wantOffset: 3, // posInOthers=10, maxOthers=8, 10-8+1=3
},
{
name: "cursor scrolls up in others",
cursor: 4,
offset: 5,
otherStart: 2,
wantOffset: 2, // posInOthers=2 < offset=5
},
{
name: "filter mode standard scroll down",
cursor: 12,
offset: 0,
filter: "x",
otherStart: 0,
wantOffset: 3, // 12 - 10 + 1 = 3
},
{
name: "filter mode standard scroll up",
cursor: 2,
offset: 5,
filter: "x",
otherStart: 0,
wantOffset: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := selectorModel{
cursor: tt.cursor,
scrollOffset: tt.offset,
filter: tt.filter,
}
m.updateScroll(tt.otherStart)
if m.scrollOffset != tt.wantOffset {
t.Errorf("scrollOffset = %d, want %d", m.scrollOffset, tt.wantOffset)
}
})
}
}
func TestRenderContent_SectionHeaders(t *testing.T) {
m := selectorModel{
title: "Pick:",
items: []SelectItem{
{Name: "rec-a", Recommended: true},
{Name: "other-1"},
},
}
content := m.renderContent()
if !strings.Contains(content, "Recommended") {
t.Error("should contain 'Recommended' header")
}
if !strings.Contains(content, "More") {
t.Error("should contain 'More' header")
}
}
func TestRenderContent_FilteredHeader(t *testing.T) {
m := selectorModel{
title: "Pick:",
items: items("alpha", "beta", "alphabet"),
filter: "alpha",
}
content := m.renderContent()
if !strings.Contains(content, "Top Results") {
t.Error("filtered view should contain 'Top Results' header")
}
if strings.Contains(content, "Recommended") {
t.Error("filtered view should not contain 'Recommended' header")
}
}
func TestRenderContent_NoMatches(t *testing.T) {
m := selectorModel{
title: "Pick:",
items: items("alpha"),
filter: "zzz",
}
content := m.renderContent()
if !strings.Contains(content, "(no matches)") {
t.Error("should show '(no matches)' when filter has no results")
}
}
func TestRenderContent_SelectedItemIndicator(t *testing.T) {
m := selectorModel{
title: "Pick:",
items: items("alpha", "beta"),
cursor: 0,
}
content := m.renderContent()
if !strings.Contains(content, "▸") {
t.Error("selected item should have ▸ indicator")
}
}
func TestRenderContent_Description(t *testing.T) {
m := selectorModel{
title: "Pick:",
items: []SelectItem{
{Name: "alpha", Description: "the first letter"},
},
}
content := m.renderContent()
if !strings.Contains(content, "the first letter") {
t.Error("should render item description")
}
}
func TestRenderContent_PinnedRecommended(t *testing.T) {
m := selectorModel{
title: "Pick:",
items: mixedItems(),
// cursor deep in "More" section
cursor: 8,
scrollOffset: 3,
}
content := m.renderContent()
// Recommended items should always be visible (pinned)
if !strings.Contains(content, "rec-a") {
t.Error("recommended items should always be rendered (pinned)")
}
if !strings.Contains(content, "rec-b") {
t.Error("recommended items should always be rendered (pinned)")
}
}
func TestRenderContent_MoreOverflowIndicator(t *testing.T) {
m := selectorModel{
title: "Pick:",
items: mixedItems(), // 2 rec + 10 other = 12 total, maxSelectorItems=10
}
content := m.renderContent()
if !strings.Contains(content, "... and") {
t.Error("should show overflow indicator when more items than visible")
}
}
func TestUpdateNavigation_CursorBounds(t *testing.T) {
m := selectorModel{
items: items("a", "b", "c"),
cursor: 0,
}
// Up at top stays at 0
m.updateNavigation(keyMsg(KeyUp))
if m.cursor != 0 {
t.Errorf("cursor should stay at 0 when pressing up at top, got %d", m.cursor)
}
// Down moves to 1
m.updateNavigation(keyMsg(KeyDown))
if m.cursor != 1 {
t.Errorf("cursor should be 1 after down, got %d", m.cursor)
}
// Down to end
m.updateNavigation(keyMsg(KeyDown))
m.updateNavigation(keyMsg(KeyDown))
if m.cursor != 2 {
t.Errorf("cursor should be 2 at bottom, got %d", m.cursor)
}
}
func TestUpdateNavigation_FilterResetsState(t *testing.T) {
m := selectorModel{
items: items("alpha", "beta"),
cursor: 1,
scrollOffset: 5,
}
m.updateNavigation(runeMsg('x'))
if m.filter != "x" {
t.Errorf("filter should be 'x', got %q", m.filter)
}
if m.cursor != 0 {
t.Errorf("cursor should reset to 0 on filter, got %d", m.cursor)
}
if m.scrollOffset != 0 {
t.Errorf("scrollOffset should reset to 0 on filter, got %d", m.scrollOffset)
}
}
func TestUpdateNavigation_Backspace(t *testing.T) {
m := selectorModel{
items: items("alpha"),
filter: "abc",
cursor: 1,
}
m.updateNavigation(keyMsg(KeyBackspace))
if m.filter != "ab" {
t.Errorf("filter should be 'ab' after backspace, got %q", m.filter)
}
if m.cursor != 0 {
t.Errorf("cursor should reset to 0 on backspace, got %d", m.cursor)
}
}
// Key message helpers for testing
type keyType = int
const (
KeyUp keyType = iota
KeyDown keyType = iota
KeyBackspace keyType = iota
)
func keyMsg(k keyType) tea.KeyMsg {
switch k {
case KeyUp:
return tea.KeyMsg{Type: tea.KeyUp}
case KeyDown:
return tea.KeyMsg{Type: tea.KeyDown}
case KeyBackspace:
return tea.KeyMsg{Type: tea.KeyBackspace}
default:
return tea.KeyMsg{}
}
}
func runeMsg(r rune) tea.KeyMsg {
return tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{r}}
}

128
cmd/tui/signin.go Normal file
View File

@@ -0,0 +1,128 @@
package tui
import (
"fmt"
"strings"
"time"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/ollama/ollama/cmd/config"
)
type signInModel struct {
modelName string
signInURL string
spinner int
width int
userName string
cancelled bool
}
func (m signInModel) Init() tea.Cmd {
return tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
return signInTickMsg{}
})
}
func (m signInModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.WindowSizeMsg:
wasSet := m.width > 0
m.width = msg.Width
if wasSet {
return m, tea.EnterAltScreen
}
return m, nil
case tea.KeyMsg:
switch msg.Type {
case tea.KeyCtrlC, tea.KeyEsc:
m.cancelled = true
return m, tea.Quit
}
case signInTickMsg:
m.spinner++
if m.spinner%5 == 0 {
return m, tea.Batch(
tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
return signInTickMsg{}
}),
checkSignIn,
)
}
return m, tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
return signInTickMsg{}
})
case signInCheckMsg:
if msg.signedIn {
m.userName = msg.userName
return m, tea.Quit
}
}
return m, nil
}
func (m signInModel) View() string {
if m.userName != "" {
return ""
}
return renderSignIn(m.modelName, m.signInURL, m.spinner, m.width)
}
func renderSignIn(modelName, signInURL string, spinner, width int) string {
spinnerFrames := []string{"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"}
frame := spinnerFrames[spinner%len(spinnerFrames)]
urlColor := lipgloss.NewStyle().
Foreground(lipgloss.Color("117"))
urlWrap := lipgloss.NewStyle().PaddingLeft(2)
if width > 4 {
urlWrap = urlWrap.Width(width - 4)
}
var s strings.Builder
fmt.Fprintf(&s, "To use %s, please sign in.\n\n", selectorSelectedItemStyle.Render(modelName))
// Wrap in OSC 8 hyperlink so the entire URL is clickable even when wrapped.
// Padding is outside the hyperlink so spaces don't get underlined.
link := fmt.Sprintf("\033]8;;%s\033\\%s\033]8;;\033\\", signInURL, urlColor.Render(signInURL))
s.WriteString("Navigate to:\n")
s.WriteString(urlWrap.Render(link))
s.WriteString("\n\n")
s.WriteString(lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).Render(
frame + " Waiting for sign in to complete..."))
s.WriteString("\n\n")
s.WriteString(selectorHelpStyle.Render("esc cancel"))
return lipgloss.NewStyle().PaddingLeft(2).Render(s.String())
}
// RunSignIn shows a bubbletea sign-in dialog and polls until the user signs in or cancels.
func RunSignIn(modelName, signInURL string) (string, error) {
config.OpenBrowser(signInURL)
m := signInModel{
modelName: modelName,
signInURL: signInURL,
}
p := tea.NewProgram(m)
finalModel, err := p.Run()
if err != nil {
return "", fmt.Errorf("error running sign-in: %w", err)
}
fm := finalModel.(signInModel)
if fm.cancelled {
return "", ErrCancelled
}
return fm.userName, nil
}

175
cmd/tui/signin_test.go Normal file
View File

@@ -0,0 +1,175 @@
package tui
import (
"strings"
"testing"
tea "github.com/charmbracelet/bubbletea"
)
func TestRenderSignIn_ContainsModelName(t *testing.T) {
got := renderSignIn("glm-4.7:cloud", "https://example.com/signin", 0, 80)
if !strings.Contains(got, "glm-4.7:cloud") {
t.Error("should contain model name")
}
if !strings.Contains(got, "please sign in") {
t.Error("should contain sign-in prompt")
}
}
func TestRenderSignIn_ContainsURL(t *testing.T) {
url := "https://ollama.com/connect?key=abc123"
got := renderSignIn("test:cloud", url, 0, 120)
if !strings.Contains(got, url) {
t.Errorf("should contain URL %q", url)
}
}
func TestRenderSignIn_OSC8Hyperlink(t *testing.T) {
url := "https://ollama.com/connect?key=abc123"
got := renderSignIn("test:cloud", url, 0, 120)
// Should contain OSC 8 open sequence with the URL
osc8Open := "\033]8;;" + url + "\033\\"
if !strings.Contains(got, osc8Open) {
t.Error("should contain OSC 8 open sequence with URL")
}
// Should contain OSC 8 close sequence
osc8Close := "\033]8;;\033\\"
if !strings.Contains(got, osc8Close) {
t.Error("should contain OSC 8 close sequence")
}
}
func TestRenderSignIn_ContainsSpinner(t *testing.T) {
got := renderSignIn("test:cloud", "https://example.com", 0, 80)
if !strings.Contains(got, "Waiting for sign in to complete") {
t.Error("should contain waiting message")
}
if !strings.Contains(got, "⠋") {
t.Error("should contain first spinner frame at spinner=0")
}
}
func TestRenderSignIn_SpinnerAdvances(t *testing.T) {
got0 := renderSignIn("test:cloud", "https://example.com", 0, 80)
got1 := renderSignIn("test:cloud", "https://example.com", 1, 80)
if got0 == got1 {
t.Error("different spinner values should produce different output")
}
}
func TestRenderSignIn_ContainsEscHelp(t *testing.T) {
got := renderSignIn("test:cloud", "https://example.com", 0, 80)
if !strings.Contains(got, "esc cancel") {
t.Error("should contain esc cancel help text")
}
}
func TestSignInModel_EscCancels(t *testing.T) {
m := signInModel{
modelName: "test:cloud",
signInURL: "https://example.com",
}
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyEsc})
fm := updated.(signInModel)
if !fm.cancelled {
t.Error("esc should set cancelled=true")
}
if cmd == nil {
t.Error("esc should return tea.Quit")
}
}
func TestSignInModel_CtrlCCancels(t *testing.T) {
m := signInModel{
modelName: "test:cloud",
signInURL: "https://example.com",
}
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyCtrlC})
fm := updated.(signInModel)
if !fm.cancelled {
t.Error("ctrl+c should set cancelled=true")
}
if cmd == nil {
t.Error("ctrl+c should return tea.Quit")
}
}
func TestSignInModel_SignedInQuitsClean(t *testing.T) {
m := signInModel{
modelName: "test:cloud",
signInURL: "https://example.com",
}
updated, cmd := m.Update(signInCheckMsg{signedIn: true, userName: "alice"})
fm := updated.(signInModel)
if fm.userName != "alice" {
t.Errorf("expected userName 'alice', got %q", fm.userName)
}
if cmd == nil {
t.Error("successful sign-in should return tea.Quit")
}
}
func TestSignInModel_SignedInViewClears(t *testing.T) {
m := signInModel{
modelName: "test:cloud",
signInURL: "https://example.com",
userName: "alice",
}
got := m.View()
if got != "" {
t.Errorf("View should return empty string after sign-in, got %q", got)
}
}
func TestSignInModel_NotSignedInContinues(t *testing.T) {
m := signInModel{
modelName: "test:cloud",
signInURL: "https://example.com",
}
updated, _ := m.Update(signInCheckMsg{signedIn: false})
fm := updated.(signInModel)
if fm.userName != "" {
t.Error("should not set userName when not signed in")
}
if fm.cancelled {
t.Error("should not cancel when check returns not signed in")
}
}
func TestSignInModel_WindowSizeUpdatesWidth(t *testing.T) {
m := signInModel{
modelName: "test:cloud",
signInURL: "https://example.com",
}
updated, _ := m.Update(tea.WindowSizeMsg{Width: 120, Height: 40})
fm := updated.(signInModel)
if fm.width != 120 {
t.Errorf("expected width 120, got %d", fm.width)
}
}
func TestSignInModel_TickAdvancesSpinner(t *testing.T) {
m := signInModel{
modelName: "test:cloud",
signInURL: "https://example.com",
spinner: 0,
}
updated, cmd := m.Update(signInTickMsg{})
fm := updated.(signInModel)
if fm.spinner != 1 {
t.Errorf("expected spinner=1, got %d", fm.spinner)
}
if cmd == nil {
t.Error("tick should return a command")
}
}

603
cmd/tui/tui.go Normal file
View File

@@ -0,0 +1,603 @@
package tui
import (
"context"
"errors"
"fmt"
"strings"
"time"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/config"
"github.com/ollama/ollama/version"
)
var (
versionStyle = lipgloss.NewStyle().
Foreground(lipgloss.AdaptiveColor{Light: "243", Dark: "250"})
menuItemStyle = lipgloss.NewStyle().
PaddingLeft(2)
menuSelectedItemStyle = lipgloss.NewStyle().
Bold(true).
Background(lipgloss.AdaptiveColor{Light: "254", Dark: "236"})
menuDescStyle = selectorDescStyle.
PaddingLeft(4)
greyedStyle = menuItemStyle.
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"})
greyedSelectedStyle = menuSelectedItemStyle.
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"})
modelStyle = lipgloss.NewStyle().
Foreground(lipgloss.AdaptiveColor{Light: "243", Dark: "250"})
notInstalledStyle = lipgloss.NewStyle().
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).
Italic(true)
)
type menuItem struct {
title string
description string
integration string // integration name for loading model config, empty if not an integration
isRunModel bool
isOthers bool
}
var mainMenuItems = []menuItem{
{
title: "Run a model",
description: "Start an interactive chat with a model",
isRunModel: true,
},
{
title: "Launch Claude Code",
description: "Agentic coding across large codebases",
integration: "claude",
},
{
title: "Launch Codex",
description: "OpenAI's open-source coding agent",
integration: "codex",
},
{
title: "Launch OpenClaw",
description: "Personal AI with 100+ skills",
integration: "openclaw",
},
}
var othersMenuItem = menuItem{
title: "More...",
description: "Show additional integrations",
isOthers: true,
}
// getOtherIntegrations dynamically builds the "Others" list from the integration
// registry, excluding any integrations already present in the pinned mainMenuItems.
func getOtherIntegrations() []menuItem {
pinned := map[string]bool{
"run": true, // not an integration but in the pinned list
}
for _, item := range mainMenuItems {
if item.integration != "" {
pinned[item.integration] = true
}
}
var others []menuItem
for _, info := range config.ListIntegrationInfos() {
if pinned[info.Name] {
continue
}
desc := info.Description
if desc == "" {
desc = "Open " + info.DisplayName + " integration"
}
others = append(others, menuItem{
title: "Launch " + info.DisplayName,
description: desc,
integration: info.Name,
})
}
return others
}
type model struct {
items []menuItem
cursor int
quitting bool
selected bool
changeModel bool
showOthers bool
availableModels map[string]bool
err error
showingModal bool
modalSelector selectorModel
modalItems []SelectItem
showingSignIn bool
signInURL string
signInModel string
signInSpinner int
signInFromModal bool // true if sign-in was triggered from modal (not main menu)
width int // terminal width from WindowSizeMsg
statusMsg string // temporary status message shown near help text
}
type signInTickMsg struct{}
type signInCheckMsg struct {
signedIn bool
userName string
}
type clearStatusMsg struct{}
func (m *model) modelExists(name string) bool {
if m.availableModels == nil || name == "" {
return false
}
if m.availableModels[name] {
return true
}
// Check for prefix match (e.g., "llama2" matches "llama2:latest")
for modelName := range m.availableModels {
if strings.HasPrefix(modelName, name+":") {
return true
}
}
return false
}
func (m *model) buildModalItems() []SelectItem {
modelItems, _ := config.GetModelItems(context.Background())
var items []SelectItem
for _, item := range modelItems {
items = append(items, SelectItem{Name: item.Name, Description: item.Description, Recommended: item.Recommended})
}
return items
}
func (m *model) openModelModal() {
m.modalItems = m.buildModalItems()
m.modalSelector = selectorModel{
title: "Select model:",
items: m.modalItems,
helpText: "↑/↓ navigate • enter select • ← back",
}
m.showingModal = true
}
func isCloudModel(name string) bool {
return strings.HasSuffix(name, ":cloud")
}
// checkCloudSignIn checks if a cloud model needs sign-in.
// Returns a command to start sign-in if needed, or nil if already signed in.
func (m *model) checkCloudSignIn(modelName string, fromModal bool) tea.Cmd {
if modelName == "" || !isCloudModel(modelName) {
return nil
}
client, err := api.ClientFromEnvironment()
if err != nil {
return nil
}
user, err := client.Whoami(context.Background())
if err == nil && user != nil && user.Name != "" {
return nil
}
var aErr api.AuthorizationError
if errors.As(err, &aErr) && aErr.SigninURL != "" {
return m.startSignIn(modelName, aErr.SigninURL, fromModal)
}
return nil
}
// startSignIn initiates the sign-in flow for a cloud model.
// fromModal indicates if this was triggered from the model picker modal.
func (m *model) startSignIn(modelName, signInURL string, fromModal bool) tea.Cmd {
m.showingModal = false
m.showingSignIn = true
m.signInURL = signInURL
m.signInModel = modelName
m.signInSpinner = 0
m.signInFromModal = fromModal
config.OpenBrowser(signInURL)
return tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
return signInTickMsg{}
})
}
func checkSignIn() tea.Msg {
client, err := api.ClientFromEnvironment()
if err != nil {
return signInCheckMsg{signedIn: false}
}
user, err := client.Whoami(context.Background())
if err == nil && user != nil && user.Name != "" {
return signInCheckMsg{signedIn: true, userName: user.Name}
}
return signInCheckMsg{signedIn: false}
}
func (m *model) loadAvailableModels() {
m.availableModels = make(map[string]bool)
client, err := api.ClientFromEnvironment()
if err != nil {
return
}
models, err := client.List(context.Background())
if err != nil {
return
}
for _, mdl := range models.Models {
m.availableModels[mdl.Name] = true
}
}
func (m *model) buildItems() {
others := getOtherIntegrations()
m.items = make([]menuItem, 0, len(mainMenuItems)+1+len(others))
m.items = append(m.items, mainMenuItems...)
if m.showOthers {
m.items = append(m.items, others...)
} else {
m.items = append(m.items, othersMenuItem)
}
}
func isOthersIntegration(name string) bool {
for _, item := range getOtherIntegrations() {
if item.integration == name {
return true
}
}
return false
}
func initialModel() model {
m := model{
cursor: 0,
}
m.loadAvailableModels()
lastSelection := config.LastSelection()
if isOthersIntegration(lastSelection) {
m.showOthers = true
}
m.buildItems()
if lastSelection != "" {
for i, item := range m.items {
if lastSelection == "run" && item.isRunModel {
m.cursor = i
break
} else if item.integration == lastSelection {
m.cursor = i
break
}
}
}
return m
}
func (m model) Init() tea.Cmd {
return nil
}
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if wmsg, ok := msg.(tea.WindowSizeMsg); ok {
wasSet := m.width > 0
m.width = wmsg.Width
if wasSet {
return m, tea.EnterAltScreen
}
return m, nil
}
if _, ok := msg.(clearStatusMsg); ok {
m.statusMsg = ""
return m, nil
}
if m.showingSignIn {
switch msg := msg.(type) {
case tea.KeyMsg:
switch msg.Type {
case tea.KeyCtrlC, tea.KeyEsc:
m.showingSignIn = false
if m.signInFromModal {
m.showingModal = true
}
return m, nil
}
case signInTickMsg:
m.signInSpinner++
// Check sign-in status every 5th tick (~1 second)
if m.signInSpinner%5 == 0 {
return m, tea.Batch(
tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
return signInTickMsg{}
}),
checkSignIn,
)
}
return m, tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
return signInTickMsg{}
})
case signInCheckMsg:
if msg.signedIn {
if m.signInFromModal {
m.modalSelector.selected = m.signInModel
m.changeModel = true
} else {
m.selected = true
}
m.quitting = true
return m, tea.Quit
}
}
return m, nil
}
if m.showingModal {
switch msg := msg.(type) {
case tea.KeyMsg:
switch msg.Type {
case tea.KeyCtrlC, tea.KeyEsc, tea.KeyLeft:
m.showingModal = false
return m, nil
case tea.KeyEnter:
filtered := m.modalSelector.filteredItems()
if len(filtered) > 0 && m.modalSelector.cursor < len(filtered) {
m.modalSelector.selected = filtered[m.modalSelector.cursor].Name
}
if m.modalSelector.selected != "" {
if cmd := m.checkCloudSignIn(m.modalSelector.selected, true); cmd != nil {
return m, cmd
}
m.changeModel = true
m.quitting = true
return m, tea.Quit
}
return m, nil
default:
// Delegate navigation (up/down/pgup/pgdown/filter/backspace) to selectorModel
m.modalSelector.updateNavigation(msg)
}
}
return m, nil
}
switch msg := msg.(type) {
case tea.KeyMsg:
switch msg.String() {
case "ctrl+c", "q", "esc":
m.quitting = true
return m, tea.Quit
case "up", "k":
if m.cursor > 0 {
m.cursor--
}
// Auto-collapse "Others" when cursor moves back into pinned items
if m.showOthers && m.cursor < len(mainMenuItems) {
m.showOthers = false
m.buildItems()
}
case "down", "j":
if m.cursor < len(m.items)-1 {
m.cursor++
}
// Auto-expand "Others..." when cursor lands on it
if m.cursor < len(m.items) && m.items[m.cursor].isOthers && !m.showOthers {
m.showOthers = true
m.buildItems()
// cursor now points at the first "other" integration
}
case "enter", " ":
item := m.items[m.cursor]
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) {
return m, nil
}
var configuredModel string
if item.isRunModel {
configuredModel = config.LastModel()
} else if item.integration != "" {
configuredModel = config.IntegrationModel(item.integration)
}
if cmd := m.checkCloudSignIn(configuredModel, false); cmd != nil {
return m, cmd
}
m.selected = true
m.quitting = true
return m, tea.Quit
case "right", "l":
item := m.items[m.cursor]
if item.integration != "" || item.isRunModel {
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) {
return m, nil
}
m.openModelModal()
}
}
}
return m, nil
}
func (m model) View() string {
if m.quitting {
return ""
}
if m.showingSignIn {
return m.renderSignInDialog()
}
if m.showingModal {
return m.renderModal()
}
s := selectorTitleStyle.Render("Ollama "+versionStyle.Render(version.Version)) + "\n\n"
for i, item := range m.items {
cursor := ""
style := menuItemStyle
isInstalled := true
if item.integration != "" {
isInstalled = config.IsIntegrationInstalled(item.integration)
}
if m.cursor == i {
cursor = "▸ "
if isInstalled {
style = menuSelectedItemStyle
} else {
style = greyedSelectedStyle
}
} else if !isInstalled && item.integration != "" {
style = greyedStyle
}
title := item.title
var modelSuffix string
if item.integration != "" {
if !isInstalled {
title += " " + notInstalledStyle.Render("(not installed)")
} else if m.cursor == i {
if mdl := config.IntegrationModel(item.integration); mdl != "" && m.modelExists(mdl) {
modelSuffix = " " + modelStyle.Render("("+mdl+")")
}
}
} else if item.isRunModel && m.cursor == i {
if mdl := config.LastModel(); mdl != "" && m.modelExists(mdl) {
modelSuffix = " " + modelStyle.Render("("+mdl+")")
}
}
s += style.Render(cursor+title) + modelSuffix + "\n"
desc := item.description
if !isInstalled && item.integration != "" && m.cursor == i {
if hint := config.IntegrationInstallHint(item.integration); hint != "" {
desc = hint
} else {
desc = "not installed"
}
}
s += menuDescStyle.Render(desc) + "\n\n"
}
if m.statusMsg != "" {
s += "\n" + lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "124", Dark: "210"}).Render(m.statusMsg) + "\n"
}
s += "\n" + selectorHelpStyle.Render("↑/↓ navigate • enter launch • → change model • esc quit")
if m.width > 0 {
return lipgloss.NewStyle().MaxWidth(m.width).Render(s)
}
return s
}
func (m model) renderModal() string {
modalStyle := lipgloss.NewStyle().
PaddingBottom(1).
PaddingRight(2)
s := modalStyle.Render(m.modalSelector.renderContent())
if m.width > 0 {
return lipgloss.NewStyle().MaxWidth(m.width).Render(s)
}
return s
}
func (m model) renderSignInDialog() string {
return renderSignIn(m.signInModel, m.signInURL, m.signInSpinner, m.width)
}
type Selection int
const (
SelectionNone Selection = iota
SelectionRunModel
SelectionChangeRunModel
SelectionIntegration // Generic integration selection
SelectionChangeIntegration // Generic change model for integration
)
type Result struct {
Selection Selection
Integration string // integration name if applicable
Model string // model name if selected from modal
}
func Run() (Result, error) {
m := initialModel()
p := tea.NewProgram(m)
finalModel, err := p.Run()
if err != nil {
return Result{Selection: SelectionNone}, fmt.Errorf("error running TUI: %w", err)
}
fm := finalModel.(model)
if fm.err != nil {
return Result{Selection: SelectionNone}, fm.err
}
if !fm.selected && !fm.changeModel {
return Result{Selection: SelectionNone}, nil
}
item := fm.items[fm.cursor]
if fm.changeModel {
if item.isRunModel {
return Result{
Selection: SelectionChangeRunModel,
Model: fm.modalSelector.selected,
}, nil
}
return Result{
Selection: SelectionChangeIntegration,
Integration: item.integration,
Model: fm.modalSelector.selected,
}, nil
}
if item.isRunModel {
return Result{Selection: SelectionRunModel}, nil
}
return Result{
Selection: SelectionIntegration,
Integration: item.integration,
}, nil
}

View File

@@ -5,7 +5,10 @@ title: Context length
Context length is the maximum number of tokens that the model has access to in memory.
<Note>
The default context length in Ollama is 4096 tokens.
Ollama defaults to the following context lengths based on VRAM:
- < 24 GiB VRAM: 4k context
- 24-48 GiB VRAM: 32k context
- &gt;= 48 GiB VRAM: 256k context
</Note>
Tasks which require large context like web search, agents, and coding tools should be set to at least 64000 tokens.

View File

@@ -105,21 +105,52 @@
{
"group": "Integrations",
"pages": [
"/integrations/claude-code",
"/integrations/cline",
"/integrations/openclaw",
"/integrations/codex",
"/integrations/droid",
"/integrations/goose",
"/integrations/jetbrains",
"/integrations/marimo",
"/integrations/n8n",
"/integrations/onyx",
"/integrations/opencode",
"/integrations/roo-code",
"/integrations/vscode",
"/integrations/xcode",
"/integrations/zed"
"/integrations/index",
{
"group": "Coding",
"pages": [
"/integrations/claude-code",
"/integrations/codex",
"/integrations/opencode",
"/integrations/droid",
"/integrations/goose"
]
},
{
"group": "Assistants",
"pages": [
"/integrations/openclaw"
]
},
{
"group": "IDEs & Editors",
"pages": [
"/integrations/cline",
"/integrations/jetbrains",
"/integrations/roo-code",
"/integrations/vscode",
"/integrations/xcode",
"/integrations/zed"
]
},
{
"group": "Chat & RAG",
"pages": [
"/integrations/onyx"
]
},
{
"group": "Automation",
"pages": [
"/integrations/n8n"
]
},
{
"group": "Notebooks",
"pages": [
"/integrations/marimo"
]
}
]
},
{

View File

@@ -14,11 +14,11 @@ curl -fsSL https://ollama.com/install.sh | sh
## How can I view the logs?
Review the [Troubleshooting](./troubleshooting) docs for more about using logs.
Review the [Troubleshooting](./troubleshooting.mdx) docs for more about using logs.
## Is my GPU compatible with Ollama?
Please refer to the [GPU docs](./gpu).
Please refer to the [GPU docs](./gpu.mdx).
## How can I specify the context window size?
@@ -66,7 +66,7 @@ llama3:70b bcfb190ca3a7 42 GB 100% GPU 4 minutes from now
```
</Info>
The `Processor` column will show which memory the model was loaded in to:
The `Processor` column will show which memory the model was loaded into:
- `100% GPU` means the model was loaded entirely into the GPU
- `100% CPU` means the model was loaded entirely in system memory
@@ -158,7 +158,7 @@ docker run -d -e HTTPS_PROXY=https://my.proxy.example.com -p 11434:11434 ollama-
## Does Ollama send my prompts and answers back to ollama.com?
No. Ollama runs locally, and conversation data does not leave your machine.
Ollama runs locally. We don't see your prompts or data when you run locally. When using cloud-hosted models, we process your prompts and responses to provide the service but do not store or log that content and never train on it. We collect basic account info and limited usage metadata to provide the service that does not include prompt or response content. We don't sell your data. You can delete your account anytime.
## How can I expose Ollama on my network?
@@ -183,7 +183,7 @@ server {
## How can I use Ollama with ngrok?
Ollama can be accessed using a range of tools for tunneling tools. For example with Ngrok:
Ollama can be accessed using a range of tunneling apps. For example with Ngrok:
```shell
ngrok http 11434 --host-header="localhost:11434"
@@ -240,7 +240,7 @@ GPU acceleration is not available for Docker Desktop in macOS due to the lack of
This can impact both installing Ollama, as well as downloading models.
Open `Control Panel > Networking and Internet > View network status and tasks` and click on `Change adapter settings` on the left panel. Find the `vEthernel (WSL)` adapter, right click and select `Properties`.
Open `Control Panel > Networking and Internet > View network status and tasks` and click on `Change adapter settings` on the left panel. Find the `vEthernet (WSL)` adapter, right click and select `Properties`.
Click on `Configure` and open the `Advanced` tab. Search through each of the properties until you find `Large Send Offload Version 2 (IPv4)` and `Large Send Offload Version 2 (IPv6)`. _Disable_ both of these
properties.
@@ -299,7 +299,7 @@ The `keep_alive` API parameter with the `/api/generate` and `/api/chat` API endp
## How do I manage the maximum number of requests the Ollama server can queue?
If too many requests are sent to the server, it will respond with a 503 error indicating the server is overloaded. You can adjust how many requests may be queue by setting `OLLAMA_MAX_QUEUE`.
If too many requests are sent to the server, it will respond with a 503 error indicating the server is overloaded. You can adjust how many requests may be queued by setting `OLLAMA_MAX_QUEUE`.
## How does Ollama handle concurrent requests?
@@ -312,10 +312,10 @@ Parallel request processing for a given model results in increasing the context
The following server settings may be used to adjust how Ollama handles concurrent requests on most platforms:
- `OLLAMA_MAX_LOADED_MODELS` - The maximum number of models that can be loaded concurrently provided they fit in available memory. The default is 3 \* the number of GPUs or 3 for CPU inference.
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default will auto-select either 4 or 1 based on available memory.
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time, default 1. Required RAM will scale by `OLLAMA_NUM_PARALLEL` * `OLLAMA_CONTEXT_LENGTH`.
- `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6.2 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6.2 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPU's VRAM.
## How does Ollama load models on multiple GPUs?
@@ -382,7 +382,7 @@ ollama signin
Replace &lt;username&gt; with your actual Windows user name.
</Note>
## How can I stop Ollama from starting when I login to my computer
## How can I stop Ollama from starting when I login to my computer?
Ollama for Windows and macOS register as a login item during installation. You can disable this if you prefer not to have Ollama automatically start. Ollama will respect this setting across upgrades, unless you uninstall the application.
@@ -390,4 +390,4 @@ Ollama for Windows and macOS register as a login item during installation. You
- In `Task Manager` go to the `Startup apps` tab, search for `ollama` then click `Disable`
**MacOS**
- Open `Settings` and search for "Login Items", find the `Ollama` entry under "Allow in the Background`, then click the slider to disable.
- Open `Settings` and search for "Login Items", find the `Ollama` entry under `Allow in the Background`, then click the slider to disable.

View File

@@ -0,0 +1,50 @@
---
title: Overview
---
Ollama integrates with a wide range of tools.
## Coding Agents
Coding assistants that can read, modify, and execute code in your projects.
- [Claude Code](/integrations/claude-code)
- [Codex](/integrations/codex)
- [OpenCode](/integrations/opencode)
- [Droid](/integrations/droid)
- [Goose](/integrations/goose)
## Assistants
AI assistants that help with everyday tasks.
- [OpenClaw](/integrations/openclaw)
## IDEs & Editors
Native integrations for popular development environments.
- [VS Code](/integrations/vscode)
- [Cline](/integrations/cline)
- [Roo Code](/integrations/roo-code)
- [JetBrains](/integrations/jetbrains)
- [Xcode](/integrations/xcode)
- [Zed](/integrations/zed)
## Chat & RAG
Chat interfaces and retrieval-augmented generation platforms.
- [Onyx](/integrations/onyx)
## Automation
Workflow automation platforms with AI integration.
- [n8n](/integrations/n8n)
## Notebooks
Interactive computing environments with AI capabilities.
- [marimo](/integrations/marimo)

View File

@@ -2,7 +2,7 @@
title: Quickstart
---
This quickstart will walk your through running your first model with Ollama. To get started, download Ollama on macOS, Windows or Linux.
Ollama is available on macOS, Windows or Linux.
<a
href="https://ollama.com/download"
@@ -12,131 +12,48 @@ This quickstart will walk your through running your first model with Ollama. To
Download Ollama
</a>
## Run a model
## Get Started
<Tabs>
<Tab title="CLI">
Open a terminal and run the command:
```sh
ollama run gemma3
```
</Tab>
<Tab title="cURL">
```sh
ollama pull gemma3
```
Lastly, chat with the model:
```shell
curl http://localhost:11434/api/chat -d '{
"model": "gemma3",
"messages": [{
"role": "user",
"content": "Hello there!"
}],
"stream": false
}'
```
</Tab>
<Tab title="Python">
Start by downloading a model:
```sh
ollama pull gemma3
```
Then install Ollama's Python library:
```sh
pip install ollama
```
Lastly, chat with the model:
```python
from ollama import chat
from ollama import ChatResponse
response: ChatResponse = chat(model='gemma3', messages=[
{
'role': 'user',
'content': 'Why is the sky blue?',
},
])
print(response['message']['content'])
# or access fields directly from the response object
print(response.message.content)
```
</Tab>
<Tab title="JavaScript">
Start by downloading a model:
```
ollama pull gemma3
```
Then install the Ollama JavaScript library:
```
npm i ollama
```
Lastly, chat with the model:
```shell
import ollama from 'ollama'
const response = await ollama.chat({
model: 'gemma3',
messages: [{ role: 'user', content: 'Why is the sky blue?' }],
})
console.log(response.message.content)
```
</Tab>
</Tabs>
See a full list of available models [here](https://ollama.com/models).
## Coding
For coding use cases, we recommend using the `glm-4.7-flash` model.
Note: this model requires 23 GB of VRAM with 64000 tokens context length.
```sh
ollama pull glm-4.7-flash
```
Alternatively, you can use a more powerful cloud model (with full context length):
```sh
ollama pull glm-4.7:cloud
```
Use `ollama launch` to quickly set up a coding tool with Ollama models:
Run `ollama` in your terminal of choice to open the interactive menu:
```sh
ollama launch
ollama
```
### Supported integrations
Navigate with `↑/↓`, press `enter` to launch, `→` to change model, and `esc` to quit.
- [OpenCode](/integrations/opencode) - Open-source coding assistant
- [Claude Code](/integrations/claude-code) - Anthropic's agentic coding tool
- [Codex](/integrations/codex) - OpenAI's coding assistant
- [Droid](/integrations/droid) - Factory's AI coding agent
The menu provides quick access to:
- **Run a model** - Start an interactive chat
- **Launch tools** - Claude Code, Codex, OpenClaw, and more
- **Additional integrations** - Available under "More..."
### Launch with a specific model
## API
Use the [API](/api) to integrate Ollama into your applications:
```sh
ollama launch claude --model glm-4.7-flash
curl http://localhost:11434/api/chat -d '{
"model": "gemma3",
"messages": [{ "role": "user", "content": "Hello!" }]
}'
```
### Configure without launching
See the [API documentation](/api) for Python, JavaScript, and other integrations.
## Coding
Launch coding tools with Ollama models:
```sh
ollama launch claude --config
ollama launch claude
```
```sh
ollama launch codex
```
```sh
ollama launch opencode
```
See [integrations](/integrations) for all supported tools.

23
go.mod
View File

@@ -13,7 +13,7 @@ require (
github.com/mattn/go-sqlite3 v1.14.24
github.com/olekukonko/tablewriter v0.0.5
github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.9.0
github.com/stretchr/testify v1.10.0
github.com/x448/float16 v0.8.4
golang.org/x/sync v0.17.0
golang.org/x/sys v0.37.0
@@ -21,14 +21,18 @@ require (
require (
github.com/agnivade/levenshtein v1.1.1
github.com/charmbracelet/bubbletea v1.3.10
github.com/charmbracelet/lipgloss v1.1.0
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
github.com/dlclark/regexp2 v1.11.4
github.com/emirpasic/gods/v2 v2.0.0-alpha
github.com/mattn/go-runewidth v0.0.14
github.com/mattn/go-runewidth v0.0.16
github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
github.com/tkrajina/typescriptify-golang-structs v0.2.0
github.com/tree-sitter/go-tree-sitter v0.25.0
github.com/tree-sitter/tree-sitter-cpp v0.23.4
github.com/wk8/go-ordered-map/v2 v2.1.8
golang.org/x/image v0.22.0
golang.org/x/mod v0.30.0
@@ -38,22 +42,35 @@ require (
require (
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/buger/jsonparser v1.1.1 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
github.com/charmbracelet/x/ansi v0.10.1 // indirect
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect
github.com/charmbracelet/x/term v0.2.1 // indirect
github.com/chewxy/hm v1.0.0 // indirect
github.com/chewxy/math32 v1.11.0 // indirect
github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/flatbuffers v24.3.25+incompatible // indirect
github.com/kr/text v0.2.0 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect
github.com/mattn/go-pointer v0.0.1 // indirect
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/termenv v0.16.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/tkrajina/go-reflector v0.5.5 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
github.com/xtgo/set v1.0.0 // indirect
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect

67
go.sum
View File

@@ -14,6 +14,8 @@ github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 h1:q4dksr6IC
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs=
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q=
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
@@ -24,6 +26,18 @@ github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs=
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk=
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
github.com/charmbracelet/x/ansi v0.10.1 h1:rL3Koar5XvX0pHGfovN03f5cxLbCF2YvLeyz7D2jVDQ=
github.com/charmbracelet/x/ansi v0.10.1/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE=
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8=
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs=
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
github.com/chewxy/hm v1.0.0 h1:zy/TSv3LV2nD3dwUEQL2VhXeoXbb9QkpmdRAVUFiA6k=
github.com/chewxy/hm v1.0.0/go.mod h1:qg9YI4q6Fkj/whwHR1D+bOGeF7SniIP40VweVepLjg0=
github.com/chewxy/math32 v1.0.0/go.mod h1:Miac6hA1ohdDUTagnvJy/q+aNnEk16qWUdb8ZVhvCN0=
@@ -59,6 +73,8 @@ github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.m
github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk=
github.com/envoyproxy/go-control-plane v0.9.9-0.20210512163311-63b5d3c536b0/go.mod h1:hliV/p42l8fGbc6Y9bQ70uLwIvmJyVE5k4iMKlh8wCQ=
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
github.com/fogleman/gg v1.3.0/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
@@ -148,13 +164,19 @@ github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728 h1:QwWKgMY28TAXaDl+
github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728/go.mod h1:1fEHWurg7pvf5SG6XNE5Q8UZmOwex51Mkx3SLhrW5B4=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
github.com/mattn/go-pointer v0.0.1 h1:n+XhsuGeVO6MEAp7xyEukFINEa+Quek5psIR/ylA6o0=
github.com/mattn/go-pointer v0.0.1/go.mod h1:2zXcozF6qYGgmsG+SeTZz3oAbFLdD3OWqnUbNvJZAlc=
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU=
github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -162,6 +184,12 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
github.com/nlpodyssey/gopickle v0.3.0 h1:BLUE5gxFLyyNOPzlXxt6GoHEMMxD0qhsE4p0CIQyoLw=
github.com/nlpodyssey/gopickle v0.3.0/go.mod h1:f070HJ/yR+eLi5WmM1OXJEGaTpuJEUiib19olXgYha0=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
@@ -182,8 +210,9 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
@@ -206,12 +235,39 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tkrajina/go-reflector v0.5.5 h1:gwoQFNye30Kk7NrExj8zm3zFtrGPqOkzFMLuQZg1DtQ=
github.com/tkrajina/go-reflector v0.5.5/go.mod h1:ECbqLgccecY5kPmPmXg1MrHW585yMcDkVl6IvJe64T4=
github.com/tkrajina/typescriptify-golang-structs v0.2.0 h1:ZedWk82egydDspGTryAatbX0/1NZDQbdiZLoCbOk4f8=
github.com/tkrajina/typescriptify-golang-structs v0.2.0/go.mod h1:sjU00nti/PMEOZb07KljFlR+lJ+RotsC0GBQMv9EKls=
github.com/tree-sitter/go-tree-sitter v0.25.0 h1:sx6kcg8raRFCvc9BnXglke6axya12krCJF5xJ2sftRU=
github.com/tree-sitter/go-tree-sitter v0.25.0/go.mod h1:r77ig7BikoZhHrrsjAnv8RqGti5rtSyvDHPzgTPsUuU=
github.com/tree-sitter/tree-sitter-c v0.23.4 h1:nBPH3FV07DzAD7p0GfNvXM+Y7pNIoPenQWBpvM++t4c=
github.com/tree-sitter/tree-sitter-c v0.23.4/go.mod h1:MkI5dOiIpeN94LNjeCp8ljXN/953JCwAby4bClMr6bw=
github.com/tree-sitter/tree-sitter-cpp v0.23.4 h1:LaWZsiqQKvR65yHgKmnaqA+uz6tlDJTJFCyFIeZU/8w=
github.com/tree-sitter/tree-sitter-cpp v0.23.4/go.mod h1:doqNW64BriC7WBCQ1klf0KmJpdEvfxyXtoEybnBo6v8=
github.com/tree-sitter/tree-sitter-embedded-template v0.23.2 h1:nFkkH6Sbe56EXLmZBqHHcamTpmz3TId97I16EnGy4rg=
github.com/tree-sitter/tree-sitter-embedded-template v0.23.2/go.mod h1:HNPOhN0qF3hWluYLdxWs5WbzP/iE4aaRVPMsdxuzIaQ=
github.com/tree-sitter/tree-sitter-go v0.23.4 h1:yt5KMGnTHS+86pJmLIAZMWxukr8W7Ae1STPvQUuNROA=
github.com/tree-sitter/tree-sitter-go v0.23.4/go.mod h1:Jrx8QqYN0v7npv1fJRH1AznddllYiCMUChtVjxPK040=
github.com/tree-sitter/tree-sitter-html v0.23.2 h1:1UYDV+Yd05GGRhVnTcbP58GkKLSHHZwVaN+lBZV11Lc=
github.com/tree-sitter/tree-sitter-html v0.23.2/go.mod h1:gpUv/dG3Xl/eebqgeYeFMt+JLOY9cgFinb/Nw08a9og=
github.com/tree-sitter/tree-sitter-java v0.23.5 h1:J9YeMGMwXYlKSP3K4Us8CitC6hjtMjqpeOf2GGo6tig=
github.com/tree-sitter/tree-sitter-java v0.23.5/go.mod h1:NRKlI8+EznxA7t1Yt3xtraPk1Wzqh3GAIC46wxvc320=
github.com/tree-sitter/tree-sitter-javascript v0.23.1 h1:1fWupaRC0ArlHJ/QJzsfQ3Ibyopw7ZfQK4xXc40Zveo=
github.com/tree-sitter/tree-sitter-javascript v0.23.1/go.mod h1:lmGD1EJdCA+v0S1u2fFgepMg/opzSg/4pgFym2FPGAs=
github.com/tree-sitter/tree-sitter-json v0.24.8 h1:tV5rMkihgtiOe14a9LHfDY5kzTl5GNUYe6carZBn0fQ=
github.com/tree-sitter/tree-sitter-json v0.24.8/go.mod h1:F351KK0KGvCaYbZ5zxwx/gWWvZhIDl0eMtn+1r+gQbo=
github.com/tree-sitter/tree-sitter-php v0.23.11 h1:iHewsLNDmznh8kgGyfWfujsZxIz1YGbSd2ZTEM0ZiP8=
github.com/tree-sitter/tree-sitter-php v0.23.11/go.mod h1:T/kbfi+UcCywQfUNAJnGTN/fMSUjnwPXA8k4yoIks74=
github.com/tree-sitter/tree-sitter-python v0.23.6 h1:qHnWFR5WhtMQpxBZRwiaU5Hk/29vGju6CVtmvu5Haas=
github.com/tree-sitter/tree-sitter-python v0.23.6/go.mod h1:cpdthSy/Yoa28aJFBscFHlGiU+cnSiSh1kuDVtI8YeM=
github.com/tree-sitter/tree-sitter-ruby v0.23.1 h1:T/NKHUA+iVbHM440hFx+lzVOzS4dV6z8Qw8ai+72bYo=
github.com/tree-sitter/tree-sitter-ruby v0.23.1/go.mod h1:kUS4kCCQloFcdX6sdpr8p6r2rogbM6ZjTox5ZOQy8cA=
github.com/tree-sitter/tree-sitter-rust v0.23.2 h1:6AtoooCW5GqNrRpfnvl0iUhxTAZEovEmLKDbyHlfw90=
github.com/tree-sitter/tree-sitter-rust v0.23.2/go.mod h1:hfeGWic9BAfgTrc7Xf6FaOAguCFJRo3RBbs7QJ6D7MI=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
@@ -220,6 +276,8 @@ github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY=
github.com/xtgo/set v1.0.0/go.mod h1:d3NHzGzSa0NmB2NhFyECA+QdRp29oEn2xbT+TpeFoM8=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
@@ -306,6 +364,7 @@ golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View File

@@ -34,6 +34,7 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/tokenizer"
)
type filteredEnv []string
@@ -116,7 +117,7 @@ type llamaServer struct {
type ollamaServer struct {
llmServer
textProcessor model.TextProcessor // textProcessor handles text encoding/decoding
tokenizer tokenizer.Tokenizer // tokenizer handles text encoding/decoding
}
// LoadModel will load a model from disk. The model must be in the GGML format.
@@ -142,11 +143,11 @@ func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) {
// NewLlamaServer will run a server for the given GPUs
func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
var llamaModel *llama.Model
var textProcessor model.TextProcessor
var tok tokenizer.Tokenizer
var err error
if envconfig.NewEngine() || f.KV().OllamaEngineRequired() {
if len(projectors) == 0 {
textProcessor, err = model.NewTextProcessor(modelPath)
tok, err = model.NewTextProcessor(modelPath)
} else {
err = errors.New("split vision models aren't supported")
}
@@ -155,7 +156,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err)
}
}
if textProcessor == nil {
if tok == nil {
llamaModel, err = llama.LoadModelFromFile(modelPath, llama.ModelParams{VocabOnly: true})
if err != nil {
return nil, err
@@ -211,7 +212,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
kvct := strings.ToLower(envconfig.KvCacheType())
if textProcessor == nil {
if tok == nil {
flashAttention := ml.FlashAttentionAuto
if faUserSet {
if fa {
@@ -261,7 +262,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
gpuLibs := ml.LibraryPaths(gpus)
status := NewStatusWriter(os.Stderr)
cmd, port, err := StartRunner(
textProcessor != nil,
tok != nil,
modelPath,
gpuLibs,
status,
@@ -310,8 +311,8 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
}
}()
if textProcessor != nil {
return &ollamaServer{llmServer: s, textProcessor: textProcessor}, nil
if tok != nil {
return &ollamaServer{llmServer: s, tokenizer: tok}, nil
} else {
return &llamaServer{llmServer: s, ggml: f}, nil
}
@@ -1774,7 +1775,7 @@ func (s *llamaServer) Tokenize(ctx context.Context, content string) ([]int, erro
}
func (s *ollamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
tokens, err := s.textProcessor.Encode(content, false)
tokens, err := s.tokenizer.Encode(content, false)
if err != nil {
return nil, err
}
@@ -1809,7 +1810,7 @@ func (s *ollamaServer) Detokenize(ctx context.Context, tokens []int) (string, er
toks[i] = int32(t)
}
content, err := s.textProcessor.Decode(toks)
content, err := s.tokenizer.Decode(toks)
if err != nil {
return "", err
}

View File

@@ -131,12 +131,15 @@ func AnthropicMessagesMiddleware() gin.HandlerFunc {
messageID := anthropic.GenerateMessageID()
// Estimate input tokens for streaming (actual count not available until generation completes)
estimatedTokens := anthropic.EstimateInputTokens(req)
w := &AnthropicWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: messageID,
model: req.Model,
converter: anthropic.NewStreamConverter(messageID, req.Model),
converter: anthropic.NewStreamConverter(messageID, req.Model, estimatedTokens),
}
if req.Stream {

View File

@@ -1,272 +0,0 @@
package model
import (
"cmp"
"iter"
"slices"
"strings"
"github.com/dlclark/regexp2"
heap "github.com/emirpasic/gods/v2/trees/binaryheap"
"github.com/ollama/ollama/logutil"
)
type BytePairEncoding struct {
vocab *Vocabulary
regexps []*regexp2.Regexp
}
var _ TextProcessor = (*BytePairEncoding)(nil)
func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding {
if len(pretokenizers) == 0 {
// set default byte-level pretokenizer if none provided, e.g.
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44
pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
}
return BytePairEncoding{
vocab: vocab,
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
for _, p := range pretokenizers {
if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
return
}
}
}),
}
}
func (bpe BytePairEncoding) Vocabulary() *Vocabulary {
return bpe.vocab
}
func (bpe BytePairEncoding) Is(id int32, special Special) bool {
return bpe.vocab.Is(id, special)
}
func (bpe *BytePairEncoding) split(s string) iter.Seq[string] {
parts := []string{s}
for _, re := range bpe.regexps {
parts = slices.Collect(func(yield func(string) bool) {
for _, part := range parts {
r := []rune(part)
var offset int
for m, _ := re.FindRunesMatch(r); m != nil; m, _ = re.FindNextMatch(m) {
if offset-m.Index != 0 {
if !yield(string(r[:m.Index])) {
return
}
}
if !yield(m.String()) {
return
}
offset = m.Index + m.Length
}
if offset < len(r) {
if !yield(string(r[offset:])) {
return
}
}
}
})
}
return slices.Values(parts)
}
// fragment is a string fragment and their corresponding token IDs
type fragment struct {
value string
ids []int32
}
// pair is a pair of runes and its rank
type pair struct {
a, b int
rank int
value string
}
type merge struct {
p, n int
runes []rune
}
func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
fragments := []fragment{{value: s}}
for _, special := range bpe.vocab.SpecialVocabulary() {
// TODO: process special tokens concurrently
id := bpe.vocab.Encode(special)
for i := 0; i < len(fragments); i++ {
frag := fragments[i]
if len(frag.ids) > 0 {
continue
}
var middle []fragment
switch i := strings.Index(frag.value, special); {
case i < 0:
middle = append(middle, frag)
case i > 0:
middle = append(middle, fragment{value: frag.value[:i]})
fallthrough
default:
middle = append(middle, fragment{value: special, ids: []int32{id}})
if rest := frag.value[i+len(special):]; rest != "" {
middle = append(middle, fragment{value: rest})
}
}
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
}
}
var ids []int32
for _, frag := range fragments {
if len(frag.ids) > 0 {
ids = append(ids, frag.ids...)
continue
}
for split := range bpe.split(frag.value) {
// TODO: process splits concurrently
var sb strings.Builder
for _, b := range []byte(split) {
r := rune(b)
switch {
case r == 0x00ad:
r = 0x0143
case r <= 0x0020:
r = r + 0x0100
case r >= 0x007f && r <= 0x00a0:
r = r + 0x00a2
}
sb.WriteRune(r)
}
// short circuit if the fragment is in the vocabulary
if id := bpe.vocab.Encode(sb.String()); id >= 0 {
ids = append(ids, id)
continue
}
runes := []rune(sb.String())
merges := make([]merge, len(runes))
for r := range runes {
merges[r] = merge{
p: r - 1,
n: r + 1,
runes: []rune{runes[r]},
}
}
pairwise := func(a, b int) *pair {
if a < 0 || b >= len(runes) {
return nil
}
left, right := string(merges[a].runes), string(merges[b].runes)
rank := bpe.vocab.Merge(left, right)
if rank < 0 {
return nil
}
return &pair{
a: a,
b: b,
rank: rank,
value: left + right,
}
}
pairs := heap.NewWith(func(i, j *pair) int {
return cmp.Compare(i.rank, j.rank)
})
for i := range len(runes) - 1 {
if pair := pairwise(i, i+1); pair != nil {
pairs.Push(pair)
}
}
for !pairs.Empty() {
pair, _ := pairs.Pop()
left, right := merges[pair.a], merges[pair.b]
if len(left.runes) == 0 || len(right.runes) == 0 ||
string(left.runes)+string(right.runes) != pair.value {
continue
}
if id := bpe.vocab.Encode(pair.value); id < 0 {
continue
}
merges[pair.a].runes = append(left.runes, right.runes...)
merges[pair.b].runes = nil
merges[pair.a].n = right.n
if right.n < len(merges) {
merges[right.n].p = pair.a
}
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
pairs.Push(pair)
}
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
pairs.Push(pair)
}
}
for _, merge := range merges {
if len(merge.runes) > 0 {
// TODO: handle the edge case where the rune isn't in the vocabulary
if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 {
ids = append(ids, id)
}
}
}
}
}
if addSpecial {
ids = bpe.vocab.addSpecials(ids)
}
logutil.Trace("encoded", "string", s, "ids", ids)
return ids, nil
}
func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
var sb strings.Builder
for _, id := range ids {
for _, r := range bpe.vocab.Decode(id) {
switch {
case r == 0x0100:
// this produces 0x00 aka NULL
continue
case r == 0x0143:
r = 0x00ad
case r > 0x0100 && r <= 0x0120:
r = r - 0x0100
case r > 0x0120 && r <= 0x0142:
r = r - 0x00a2
}
// NOTE: not using WriteRune here because it writes the UTF-8
// encoding of the rune which is _not_ what we want
if err := sb.WriteByte(byte(r)); err != nil {
return "", err
}
}
}
logutil.Trace("decoded", "string", sb.String(), "from", ids)
return sb.String(), nil
}

View File

@@ -23,6 +23,7 @@ import (
_ "github.com/ollama/ollama/ml/backend"
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
var (
@@ -133,7 +134,7 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
return m, nil
}
func NewTextProcessor(s string) (TextProcessor, error) {
func NewTextProcessor(s string) (tokenizer.Tokenizer, error) {
r, err := os.Open(s)
if err != nil {
return nil, err
@@ -150,7 +151,7 @@ func NewTextProcessor(s string) (TextProcessor, error) {
return nil, err
}
tp, ok := m.(TextProcessor)
tp, ok := m.(tokenizer.Tokenizer)
if !ok {
return nil, ErrUnsupportedTokenizer
}

View File

@@ -10,11 +10,12 @@ import (
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TypeEmbedding *nn.Embedding `gguf:"token_types"`
@@ -129,7 +130,7 @@ func (o Options) headDim() int {
}
func New(c fs.Config) (model.Model, error) {
vocab := &model.Vocabulary{
vocab := &tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -153,17 +154,17 @@ func New(c fs.Config) (model.Model, error) {
},
}
var processor model.TextProcessor
var t tokenizer.Tokenizer
switch c.String("tokenizer.ggml.model", "bert") {
case "bert":
processor = model.NewWordPiece(vocab, true)
t = tokenizer.NewWordPiece(vocab, true)
default:
return nil, model.ErrUnsupportedTokenizer
}
return &Model{
TextProcessor: processor,
Layers: make([]EncoderLayer, c.Uint("block_count")),
Tokenizer: t,
Layers: make([]EncoderLayer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),

View File

@@ -13,6 +13,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -222,7 +223,7 @@ func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -277,8 +278,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,11 +10,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
Sam *samModel `gguf:"s"`
Vision *visionModel `gguf:"v"`
@@ -134,8 +135,8 @@ func init() {
}
m := Model{
TextProcessor: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,6 +10,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -27,7 +28,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
type Model struct {
model.Base
model.SentencePiece
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -43,8 +44,8 @@ const (
func New(c fs.Config) (model.Model, error) {
m := Model{
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Tokenizer: tokenizer.NewSentencePiece(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),

View File

@@ -7,11 +7,12 @@ import (
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type embedModel struct {
model.Base
model.SentencePiece
tokenizer.Tokenizer
*TextModel
poolingType pooling.Type
@@ -31,8 +32,8 @@ func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, erro
func newEmbedModel(c fs.Config) (model.Model, error) {
m := &embedModel{
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Tokenizer: tokenizer.NewSentencePiece(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),

View File

@@ -12,11 +12,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
*VisionModel `gguf:"v"`
*TextModel
@@ -54,7 +55,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
}
func New(c fs.Config) (model.Model, error) {
vocabulary := model.Vocabulary{
vocabulary := tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -70,19 +71,19 @@ func New(c fs.Config) (model.Model, error) {
),
}
var processor model.TextProcessor
var t tokenizer.Tokenizer
switch c.String("tokenizer.ggml.model") {
case "gpt2":
processor = model.NewBytePairEncoding(&vocabulary)
t = tokenizer.NewBytePairEncoding(&vocabulary)
default:
// Previous uploads of Gemma 3 on Ollama did not have token 106
// (i.e. "<end_of_turn>") so we need to add in case it's not already present
vocabulary.EOS = append(vocabulary.EOS, int32(c.Uint("tokenizer.ggml.eot_token_id", 106)))
processor = model.NewSentencePiece(&vocabulary)
t = tokenizer.NewSentencePiece(&vocabulary)
}
m := Model{
TextProcessor: processor,
Tokenizer: t,
ImageProcessor: newImageProcessor(c),
VisionModel: newVisionModel(c),
TextModel: newTextModel(c),

View File

@@ -6,11 +6,12 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.SentencePiece
tokenizer.Tokenizer
*TextModel
}
@@ -23,8 +24,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func New(c fs.Config) (model.Model, error) {
m := Model{
TextModel: newTextModel(c),
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Tokenizer: tokenizer.NewSentencePiece(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),

View File

@@ -10,6 +10,7 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
var ErrOldModelFormat = errors.New("this model uses a weight format that is no longer supported; please re-download it")
@@ -198,7 +199,7 @@ func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -236,8 +237,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
*TextModel
*VisionModel `gguf:"v"`
@@ -37,8 +38,8 @@ func New(c fs.Config) (model.Model, error) {
allEOS := append([]int32{eosTokenID}, eosTokenIDs...)
m := &Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -12,11 +12,12 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Transformer struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TransformerBlocks []TransformerBlock `gguf:"blk"`
@@ -196,8 +197,8 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Optio
func New(c fs.Config) (model.Model, error) {
m := Transformer{
TransformerBlocks: make([]TransformerBlock, c.Uint("block_count")),
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,6 +10,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -59,7 +60,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -78,7 +79,7 @@ func New(c fs.Config) (model.Model, error) {
return nil, model.ErrUnsupportedTokenizer
}
vocabulary := model.Vocabulary{
vocabulary := tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -104,8 +105,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
TextProcessor: model.NewBytePairEncoding(&vocabulary, pretokenizers...),
Layers: make([]Layer, c.Uint("block_count")),
Tokenizer: tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...),
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
headDim: int(c.Uint("attention.key_length")),

View File

@@ -11,6 +11,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -25,7 +26,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -41,8 +42,8 @@ func New(c fs.Config) (model.Model, error) {
return nil, model.ErrUnsupportedModel
}
var processor model.TextProcessor
vocabulary := model.Vocabulary{
var processor tokenizer.Tokenizer
vocabulary := tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -80,16 +81,16 @@ func New(c fs.Config) (model.Model, error) {
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}
}
processor = model.NewBytePairEncoding(&vocabulary, pretokenizers...)
processor = tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...)
case "llama":
processor = model.NewSentencePiece(&vocabulary)
processor = tokenizer.NewSentencePiece(&vocabulary)
default:
return nil, model.ErrUnsupportedTokenizer
}
m := Model{
TextProcessor: processor,
Layers: make([]Layer, c.Uint("block_count")),
Tokenizer: processor,
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),

View File

@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
ImageProcessor
*VisionModel `gguf:"v"`
@@ -33,8 +34,8 @@ func (p *Projector) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor {
func New(c fs.Config) (model.Model, error) {
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
*TextModel
*VisionModel `gguf:"v"`
@@ -28,12 +29,12 @@ type Model struct {
var _ model.MultimodalProcessor = (*Model)(nil)
// Implement TextProcessor interface
var _ model.TextProcessor = (*Model)(nil)
var _ tokenizer.Tokenizer = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
m := &Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
*VisionModel `gguf:"v"`
*TextModel
@@ -32,8 +33,8 @@ const (
func New(c fs.Config) (model.Model, error) {
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TypeEmbedding *nn.Embedding `gguf:"token_types"`
@@ -178,29 +179,6 @@ func New(c fs.Config) (model.Model, error) {
numHeads := int(c.Uint("attention.head_count"))
headDim := hiddenSize / numHeads
processor := model.NewWordPiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.cls_token_id"),
c.Uint("tokenizer.ggml.bos_token_id"),
)),
},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
EOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.separator_token_id"),
c.Uint("tokenizer.ggml.eos_token_id"),
)),
},
},
false,
)
blockCount := int(c.Uint("block_count"))
moeEveryNLayers := int(c.Uint("moe_every_n_layers", 0))
layers := make([]EncoderLayer, blockCount)
@@ -219,8 +197,29 @@ func New(c fs.Config) (model.Model, error) {
}
return &Model{
TextProcessor: processor,
Layers: layers,
Tokenizer: tokenizer.NewWordPiece(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.cls_token_id"),
c.Uint("tokenizer.ggml.bos_token_id"),
)),
},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
EOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.separator_token_id"),
c.Uint("tokenizer.ggml.eos_token_id"),
)),
},
},
false,
),
Layers: layers,
Options: Options{
hiddenSize: hiddenSize,
numHeads: numHeads,

View File

@@ -11,6 +11,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
const (
@@ -33,7 +34,7 @@ type Options struct {
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -44,28 +45,24 @@ type Model struct {
}
func New(c fs.Config) (model.Model, error) {
vocabulary := model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
}
processor := model.NewBytePairEncoding(
&vocabulary,
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
)
m := Model{
TextProcessor: processor,
Layers: make([]Layer, c.Uint("block_count")),
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
),
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),

View File

@@ -13,6 +13,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -92,7 +93,7 @@ func (d DecoderLayer) Forward(ctx ml.Context, hiddenStates, positions, outputs m
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []DecoderLayer `gguf:"blk"`
@@ -139,8 +140,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
Layers: make([]DecoderLayer, c.Uint("block_count")),
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,11 +10,12 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
*TextModel
*VisionModel `gguf:"v"`
@@ -27,8 +28,8 @@ var _ model.MultimodalProcessor = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
m := &Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -7,11 +7,12 @@ import (
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type embedModel struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
*Model
poolingType pooling.Type
@@ -34,8 +35,8 @@ func newEmbed(c fs.Config) (model.Model, error) {
layers[i].MLP = &dense{}
}
m := embedModel{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -12,6 +12,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -159,7 +160,7 @@ func (d *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
@@ -218,8 +219,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -11,6 +11,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
// Options contains model configuration
@@ -207,7 +208,7 @@ func (l *Layer) Forward(ctx ml.Context, layer int, hiddenStates, positions, outp
// Model is the main Qwen3-Next model
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
@@ -353,8 +354,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,11 +10,12 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
*TextModel
*VisionModel `gguf:"v"`
@@ -172,8 +173,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func New(c fs.Config) (model.Model, error) {
m := Model{
TextProcessor: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -1,53 +0,0 @@
package model
import (
"slices"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestWordPiece(t *testing.T) {
wpm := NewWordPiece(
&Vocabulary{
Values: []string{"[UNK]", "[CLS]", "[SEP]", "▁hello", "▁world", "s", "▁!", "▁@", "▁#"},
AddBOS: true,
AddEOS: true,
BOS: []int32{1},
EOS: []int32{2},
},
true, // lowercase
)
ids, err := wpm.Encode("Hello world!", true)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff([]int32{1, 3, 4, 6, 2}, ids); diff != "" {
t.Errorf("unexpected ids (-want +got):\n%s", diff)
}
words, err := wpm.Decode(ids)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff("[CLS] hello world! [SEP]", words); diff != "" {
t.Errorf("unexpected words (-want +got):\n%s", diff)
}
}
func TestWordPieceWords(t *testing.T) {
var wpm WordPiece
basic := slices.Collect(wpm.words("Hey friend! How are you?!?"))
if diff := cmp.Diff([]string{"Hey", "friend", "!", "How", "are", "you", "?", "!", "?"}, basic); diff != "" {
t.Errorf("unexpected words (-want +got):\n%s", diff)
}
chinese := slices.Collect(wpm.words("野口里佳 Noguchi Rika"))
if diff := cmp.Diff([]string{"野", "口", "里", "佳", "Noguchi", "Rika"}, chinese); diff != "" {
t.Errorf("unexpected words (-want +got):\n%s", diff)
}
}

View File

@@ -37,6 +37,7 @@ import (
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/runner/common"
"github.com/ollama/ollama/sample"
"github.com/ollama/ollama/tokenizer"
_ "github.com/ollama/ollama/model/models"
)
@@ -210,9 +211,9 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
}
// calculateLogprobs converts raw logits to log probabilities and finds top K tokens
func calculateLogprobs(logits []float32, selectedToken int32, topK int, textProcessor model.TextProcessor) []llm.Logprob {
func calculateLogprobs(logits []float32, selectedToken int32, topK int, tok tokenizer.Tokenizer) []llm.Logprob {
decoder := func(tokenID int) string {
text, _ := textProcessor.Decode([]int32{int32(tokenID)})
text, _ := tok.Decode([]int32{int32(tokenID)})
return text
}
return common.CalculateLogprobs(logits, int(selectedToken), topK, decoder)
@@ -242,7 +243,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input,
for i, part := range parts {
// text - tokenize
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
tokens, err := s.model.(tokenizer.Tokenizer).Encode(part, i == 0)
if err != nil {
return nil, nil, nil, err
}
@@ -764,7 +765,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
nextBatchTokens[i].Token = token
// if it's an end of sequence token, break
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
if s.model.(tokenizer.Tokenizer).Is(token, tokenizer.SpecialEOS) {
// TODO (jmorganca): we should send this back
// as it's important for the /api/generate context
// seq.responses <- piece
@@ -773,14 +774,14 @@ func (s *Server) computeBatch(activeBatch batchState) {
continue
}
piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
piece, err := s.model.(tokenizer.Tokenizer).Decode([]int32{token})
if err != nil {
panic("failed to decode token")
}
// Calculate logprobs if requested (after EOS check to avoid logprobs for EOS tokens)
if seq.logprobs {
logprobs := calculateLogprobs(logits, token, seq.topLogprobs, s.model.(model.TextProcessor))
logprobs := calculateLogprobs(logits, token, seq.topLogprobs, s.model.(tokenizer.Tokenizer))
seq.pendingLogprobs = append(seq.pendingLogprobs, logprobs...)
}
@@ -878,7 +879,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
var grammar *sample.GrammarSampler
var err error
if req.Grammar != "" {
grammar, err = sample.NewGrammarSampler(s.model.(model.TextProcessor), req.Grammar)
grammar, err = sample.NewGrammarSampler(s.model.(tokenizer.Tokenizer), req.Grammar)
if err != nil {
http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
return

View File

@@ -3,6 +3,7 @@ package runner
import (
"github.com/ollama/ollama/runner/llamarunner"
"github.com/ollama/ollama/runner/ollamarunner"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/mlxrunner"
)
@@ -11,22 +12,15 @@ func Execute(args []string) error {
args = args[1:]
}
var newRunner bool
var mlxRunner bool
if len(args) > 0 && args[0] == "--ollama-engine" {
args = args[1:]
newRunner = true
}
if len(args) > 0 && args[0] == "--mlx-engine" {
args = args[1:]
mlxRunner = true
}
if mlxRunner {
return mlxrunner.Execute(args)
} else if newRunner {
return ollamarunner.Execute(args)
} else {
return llamarunner.Execute(args)
if len(args) > 0 {
switch args[0] {
case "--ollama-engine":
return ollamarunner.Execute(args[1:])
case "--imagegen-engine":
return imagegen.Execute(args[1:])
case "--mlx-engine":
return mlxrunner.Execute(args[1:])
}
}
return llamarunner.Execute(args)
}

View File

@@ -7,7 +7,7 @@ import (
"slices"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/tokenizer"
)
// token represents information about a single token during sampling
@@ -168,15 +168,15 @@ type GrammarSampler struct {
grammar *llama.Grammar
}
func NewGrammarSampler(model model.TextProcessor, grammarStr string) (*GrammarSampler, error) {
vocabIds := make([]uint32, len(model.Vocabulary().Values))
pieces := make([]string, len(model.Vocabulary().Values))
for i := range model.Vocabulary().Values {
pieces[i], _ = model.Decode([]int32{int32(i)})
func NewGrammarSampler(tok tokenizer.Tokenizer, grammarStr string) (*GrammarSampler, error) {
vocabIds := make([]uint32, len(tok.Vocabulary().Values))
pieces := make([]string, len(tok.Vocabulary().Values))
for i := range tok.Vocabulary().Values {
pieces[i], _ = tok.Decode([]int32{int32(i)})
vocabIds[i] = uint32(i)
}
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, model.Vocabulary().EOS)
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, tok.Vocabulary().EOS)
if grammar == nil {
return nil, errors.New("sample: failed to initialize grammar")
}

View File

@@ -8,7 +8,7 @@ import (
"path/filepath"
"testing"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/tokenizer"
)
func TestWeighted(t *testing.T) {
@@ -60,10 +60,10 @@ func TestWeighted(t *testing.T) {
}
}
func modelHelper(t testing.TB) model.BytePairEncoding {
func modelHelper(t testing.TB) tokenizer.Tokenizer {
t.Helper()
f, err := os.Open(filepath.Join("..", "model", "testdata", "llama3.2", "encoder.json"))
f, err := os.Open(filepath.FromSlash("../tokenizer/testdata/llama3.2/encoder.json"))
if err != nil {
t.Fatal(err)
}
@@ -81,8 +81,8 @@ func modelHelper(t testing.TB) model.BytePairEncoding {
merges := make([]string, 0, 1)
// Only need vocab for Grammar Test
return model.NewBytePairEncoding(
&model.Vocabulary{
return tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: tokens,
Types: make([]int32, len(vocab)),
Merges: merges,

View File

@@ -302,12 +302,22 @@ function deps {
}
function sign {
# Copy install.ps1 to dist for release packaging
write-host "Copying install.ps1 to dist"
Copy-Item -Path "${script:SRC_DIR}\scripts\install.ps1" -Destination "${script:SRC_DIR}\dist\install.ps1"
if ("${env:KEY_CONTAINER}") {
write-host "Signing Ollama executables, scripts and libraries"
& "${script:SignTool}" sign /v /fd sha256 /t http://timestamp.digicert.com /f "${script:OLLAMA_CERT}" `
/csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} `
$(get-childitem -path "${script:SRC_DIR}\dist\windows-*" -r -include @('*.exe', '*.dll'))
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
write-host "Signing install.ps1"
& "${script:SignTool}" sign /v /fd sha256 /t http://timestamp.digicert.com /f "${script:OLLAMA_CERT}" `
/csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} `
"${script:SRC_DIR}\dist\install.ps1"
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
} else {
write-host "Signing not enabled"
}

271
scripts/install.ps1 Normal file
View File

@@ -0,0 +1,271 @@
<#
.SYNOPSIS
Install, upgrade, or uninstall Ollama on Windows.
.DESCRIPTION
Downloads and installs Ollama.
Quick install:
irm https://ollama.com/install.ps1 | iex
Specific version:
$env:OLLAMA_VERSION="0.5.7"; irm https://ollama.com/install.ps1 | iex
Custom install directory:
$env:OLLAMA_INSTALL_DIR="D:\Ollama"; irm https://ollama.com/install.ps1 | iex
Uninstall:
$env:OLLAMA_UNINSTALL=1; irm https://ollama.com/install.ps1 | iex
Environment variables:
OLLAMA_VERSION Target version (default: latest stable)
OLLAMA_INSTALL_DIR Custom install directory
OLLAMA_UNINSTALL Set to 1 to uninstall Ollama
OLLAMA_DEBUG Enable verbose output
.EXAMPLE
irm https://ollama.com/install.ps1 | iex
.EXAMPLE
$env:OLLAMA_VERSION = "0.5.7"; irm https://ollama.com/install.ps1 | iex
.LINK
https://ollama.com
#>
$ErrorActionPreference = "Stop"
$ProgressPreference = "SilentlyContinue"
# --------------------------------------------------------------------------
# Configuration from environment variables
# --------------------------------------------------------------------------
$Version = if ($env:OLLAMA_VERSION) { $env:OLLAMA_VERSION } else { "" }
$InstallDir = if ($env:OLLAMA_INSTALL_DIR) { $env:OLLAMA_INSTALL_DIR } else { "" }
$Uninstall = $env:OLLAMA_UNINSTALL -eq "1"
$DebugInstall = [bool]$env:OLLAMA_DEBUG
# --------------------------------------------------------------------------
# Constants
# --------------------------------------------------------------------------
# OLLAMA_DOWNLOAD_URL for developer testing only
$DownloadBaseURL = if ($env:OLLAMA_DOWNLOAD_URL) { $env:OLLAMA_DOWNLOAD_URL.TrimEnd('/') } else { "https://ollama.com/download" }
$InnoSetupUninstallGuid = "{44E83376-CE68-45EB-8FC1-393500EB558C}_is1"
# --------------------------------------------------------------------------
# Helpers
# --------------------------------------------------------------------------
function Write-Status {
param([string]$Message)
if ($DebugInstall) { Write-Host $Message }
}
function Write-Step {
param([string]$Message)
if ($DebugInstall) { Write-Host ">>> $Message" -ForegroundColor Cyan }
}
function Test-Signature {
param([string]$FilePath)
$sig = Get-AuthenticodeSignature -FilePath $FilePath
if ($sig.Status -ne "Valid") {
Write-Status " Signature status: $($sig.Status)"
return $false
}
# Verify it's signed by Ollama Inc. (check exact organization name)
# Anchor with comma/boundary to prevent "O=Not Ollama Inc." from matching
$subject = $sig.SignerCertificate.Subject
if ($subject -notmatch "(^|, )O=Ollama Inc\.(,|$)") {
Write-Status " Unexpected signer: $subject"
return $false
}
Write-Status " Signature valid: $subject"
return $true
}
function Find-InnoSetupInstall {
# Check both HKCU (per-user) and HKLM (per-machine) locations
$possibleKeys = @(
"HKCU:\Software\Microsoft\Windows\CurrentVersion\Uninstall\$InnoSetupUninstallGuid",
"HKLM:\Software\Microsoft\Windows\CurrentVersion\Uninstall\$InnoSetupUninstallGuid",
"HKLM:\Software\WOW6432Node\Microsoft\Windows\CurrentVersion\Uninstall\$InnoSetupUninstallGuid"
)
foreach ($key in $possibleKeys) {
if (Test-Path $key) {
Write-Status " Found install at: $key"
return $key
}
}
return $null
}
function Update-SessionPath {
# Update PATH in current session so 'ollama' works immediately
if ($InstallDir) {
$ollamaDir = $InstallDir
} else {
$ollamaDir = Join-Path $env:LOCALAPPDATA "Programs\Ollama"
}
# Add to PATH if not already present
if (Test-Path $ollamaDir) {
$currentPath = $env:PATH -split ';'
if ($ollamaDir -notin $currentPath) {
$env:PATH = "$ollamaDir;$env:PATH"
Write-Status " Added $ollamaDir to session PATH"
}
}
}
function Invoke-Download {
param(
[string]$Url,
[string]$OutFile
)
Write-Status " Downloading: $Url"
try {
Invoke-WebRequest -Uri $Url -OutFile $OutFile -UseBasicParsing
$size = (Get-Item $OutFile).Length
Write-Status " Downloaded: $([math]::Round($size / 1MB, 1)) MB"
} catch {
if ($_.Exception.Response.StatusCode -eq 404) {
throw "Download failed: not found at $Url"
}
throw "Download failed for ${Url}: $($_.Exception.Message)"
}
}
# --------------------------------------------------------------------------
# Uninstall
# --------------------------------------------------------------------------
function Invoke-Uninstall {
Write-Step "Uninstalling Ollama"
$regKey = Find-InnoSetupInstall
if (-not $regKey) {
Write-Host "Ollama is not installed."
return
}
$uninstallString = (Get-ItemProperty -Path $regKey).UninstallString
if (-not $uninstallString) {
Write-Warning "No uninstall string found in registry"
return
}
# Strip quotes if present
$uninstallExe = $uninstallString -replace '"', ''
Write-Status " Uninstaller: $uninstallExe"
if (-not (Test-Path $uninstallExe)) {
Write-Warning "Uninstaller not found at: $uninstallExe"
return
}
Write-Host "Launching uninstaller..."
# Run with GUI so user can choose whether to keep models
Start-Process -FilePath $uninstallExe -Wait
# Verify removal
if (Find-InnoSetupInstall) {
Write-Warning "Uninstall may not have completed"
} else {
Write-Host "Ollama has been uninstalled."
}
}
# --------------------------------------------------------------------------
# Install
# --------------------------------------------------------------------------
function Invoke-Install {
# Determine installer URL
if ($Version) {
$installerUrl = "$DownloadBaseURL/OllamaSetup.exe?version=$Version"
} else {
$installerUrl = "$DownloadBaseURL/OllamaSetup.exe"
}
# Download installer
Write-Step "Downloading Ollama"
if (-not $DebugInstall) {
Write-Host "Downloading Ollama..."
}
$tempInstaller = Join-Path $env:TEMP "OllamaSetup.exe"
Invoke-Download -Url $installerUrl -OutFile $tempInstaller
# Verify signature
Write-Step "Verifying signature"
if (-not (Test-Signature -FilePath $tempInstaller)) {
Remove-Item $tempInstaller -Force -ErrorAction SilentlyContinue
throw "Installer signature verification failed"
}
# Build installer arguments
$installerArgs = "/VERYSILENT /NORESTART /SUPPRESSMSGBOXES"
if ($InstallDir) {
$installerArgs += " /DIR=`"$InstallDir`""
}
Write-Status " Installer args: $installerArgs"
# Run installer
Write-Step "Installing Ollama"
if (-not $DebugInstall) {
Write-Host "Installing..."
}
# Create upgrade marker so the app starts hidden
# The app checks for this file on startup and removes it after
$markerDir = Join-Path $env:LOCALAPPDATA "Ollama"
$markerFile = Join-Path $markerDir "upgraded"
if (-not (Test-Path $markerDir)) {
New-Item -ItemType Directory -Path $markerDir -Force | Out-Null
}
New-Item -ItemType File -Path $markerFile -Force | Out-Null
Write-Status " Created upgrade marker: $markerFile"
# Start installer and wait for just the installer process (not children)
# Using -Wait would wait for Ollama to exit too, which we don't want
$proc = Start-Process -FilePath $tempInstaller `
-ArgumentList $installerArgs `
-PassThru
$proc.WaitForExit()
if ($proc.ExitCode -ne 0) {
Remove-Item $tempInstaller -Force -ErrorAction SilentlyContinue
throw "Installation failed with exit code $($proc.ExitCode)"
}
# Cleanup
Remove-Item $tempInstaller -Force -ErrorAction SilentlyContinue
# Update PATH in current session so 'ollama' works immediately
Write-Step "Updating session PATH"
Update-SessionPath
Write-Host "Install complete. You can now run 'ollama'."
}
# --------------------------------------------------------------------------
# Main
# --------------------------------------------------------------------------
if ($Uninstall) {
Invoke-Uninstall
} else {
Invoke-Install
}

View File

@@ -1,5 +1,5 @@
#!/bin/sh
# This script installs Ollama on Linux.
# This script installs Ollama on Linux and macOS.
# It detects the current operating system architecture and installs the appropriate version of Ollama.
set -eu
@@ -27,8 +27,7 @@ require() {
echo $MISSING
}
[ "$(uname -s)" = "Linux" ] || error 'This script is intended to run on Linux only.'
OS="$(uname -s)"
ARCH=$(uname -m)
case "$ARCH" in
x86_64) ARCH="amd64" ;;
@@ -36,6 +35,63 @@ case "$ARCH" in
*) error "Unsupported architecture: $ARCH" ;;
esac
VER_PARAM="${OLLAMA_VERSION:+?version=$OLLAMA_VERSION}"
###########################################
# macOS
###########################################
if [ "$OS" = "Darwin" ]; then
NEEDS=$(require curl unzip)
if [ -n "$NEEDS" ]; then
status "ERROR: The following tools are required but missing:"
for NEED in $NEEDS; do
echo " - $NEED"
done
exit 1
fi
DOWNLOAD_URL="https://ollama.com/download/Ollama-darwin.zip${VER_PARAM}"
if pgrep -x Ollama >/dev/null 2>&1; then
status "Stopping running Ollama instance..."
pkill -x Ollama 2>/dev/null || true
sleep 2
fi
if [ -d "/Applications/Ollama.app" ]; then
status "Removing existing Ollama installation..."
rm -rf "/Applications/Ollama.app"
fi
status "Downloading Ollama for macOS..."
curl --fail --show-error --location --progress-bar \
-o "$TEMP_DIR/Ollama-darwin.zip" "$DOWNLOAD_URL"
status "Installing Ollama to /Applications..."
unzip -q "$TEMP_DIR/Ollama-darwin.zip" -d "$TEMP_DIR"
mv "$TEMP_DIR/Ollama.app" "/Applications/"
status "Adding 'ollama' command to PATH (may require password)..."
mkdir -p "/usr/local/bin" 2>/dev/null || sudo mkdir -p "/usr/local/bin"
ln -sf "/Applications/Ollama.app/Contents/Resources/ollama" "/usr/local/bin/ollama" 2>/dev/null || \
sudo ln -sf "/Applications/Ollama.app/Contents/Resources/ollama" "/usr/local/bin/ollama"
if [ -z "${OLLAMA_NO_START:-}" ]; then
status "Starting Ollama..."
open -a Ollama --args hidden
fi
status "Install complete. You can now run 'ollama'."
exit 0
fi
###########################################
# Linux
###########################################
[ "$OS" = "Linux" ] || error 'This script is intended to run on Linux and macOS only.'
IS_WSL2=false
KERN=$(uname -r)
@@ -45,8 +101,6 @@ case "$KERN" in
*) ;;
esac
VER_PARAM="${OLLAMA_VERSION:+?version=$OLLAMA_VERSION}"
SUDO=
if [ "$(id -u)" -ne 0 ]; then
# Running as root, no need for sudo

422
server/aliases.go Normal file
View File

@@ -0,0 +1,422 @@
package server
import (
"encoding/json"
"errors"
"fmt"
"log/slog"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/types/model"
)
const (
serverConfigFilename = "server.json"
serverConfigVersion = 1
)
var errAliasCycle = errors.New("alias cycle detected")
type aliasEntry struct {
Alias string `json:"alias"`
Target string `json:"target"`
PrefixMatching bool `json:"prefix_matching,omitempty"`
}
type serverConfig struct {
Version int `json:"version"`
Aliases []aliasEntry `json:"aliases"`
}
type store struct {
mu sync.RWMutex
path string
entries map[string]aliasEntry // normalized alias -> entry (exact matches)
prefixEntries []aliasEntry // prefix matches, sorted longest-first
}
func createStore(path string) (*store, error) {
store := &store{
path: path,
entries: make(map[string]aliasEntry),
}
if err := store.load(); err != nil {
return nil, err
}
return store, nil
}
func (s *store) load() error {
data, err := os.ReadFile(s.path)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil
}
return err
}
var cfg serverConfig
if err := json.Unmarshal(data, &cfg); err != nil {
return err
}
if cfg.Version != 0 && cfg.Version != serverConfigVersion {
return fmt.Errorf("unsupported router config version %d", cfg.Version)
}
for _, entry := range cfg.Aliases {
targetName := model.ParseName(entry.Target)
if !targetName.IsValid() {
slog.Warn("invalid alias target in router config", "target", entry.Target)
continue
}
canonicalTarget := displayAliasName(targetName)
if entry.PrefixMatching {
// Prefix aliases don't need to be valid model names
alias := strings.TrimSpace(entry.Alias)
if alias == "" {
slog.Warn("empty prefix alias in router config")
continue
}
s.prefixEntries = append(s.prefixEntries, aliasEntry{
Alias: alias,
Target: canonicalTarget,
PrefixMatching: true,
})
} else {
aliasName := model.ParseName(entry.Alias)
if !aliasName.IsValid() {
slog.Warn("invalid alias name in router config", "alias", entry.Alias)
continue
}
canonicalAlias := displayAliasName(aliasName)
s.entries[normalizeAliasKey(aliasName)] = aliasEntry{
Alias: canonicalAlias,
Target: canonicalTarget,
}
}
}
// Sort prefix entries by alias length descending (longest prefix wins)
s.sortPrefixEntriesLocked()
return nil
}
func (s *store) saveLocked() error {
dir := filepath.Dir(s.path)
if err := os.MkdirAll(dir, 0o755); err != nil {
return err
}
// Combine exact and prefix entries
entries := make([]aliasEntry, 0, len(s.entries)+len(s.prefixEntries))
for _, entry := range s.entries {
entries = append(entries, entry)
}
entries = append(entries, s.prefixEntries...)
sort.Slice(entries, func(i, j int) bool {
return strings.Compare(entries[i].Alias, entries[j].Alias) < 0
})
cfg := serverConfig{
Version: serverConfigVersion,
Aliases: entries,
}
f, err := os.CreateTemp(dir, "router-*.json")
if err != nil {
return err
}
enc := json.NewEncoder(f)
enc.SetIndent("", " ")
if err := enc.Encode(cfg); err != nil {
_ = f.Close()
_ = os.Remove(f.Name())
return err
}
if err := f.Close(); err != nil {
_ = os.Remove(f.Name())
return err
}
if err := os.Chmod(f.Name(), 0o644); err != nil {
_ = os.Remove(f.Name())
return err
}
return os.Rename(f.Name(), s.path)
}
func (s *store) ResolveName(name model.Name) (model.Name, bool, error) {
// If a local model exists, do not allow alias shadowing (highest priority).
exists, err := localModelExists(name)
if err != nil {
return name, false, err
}
if exists {
return name, false, nil
}
key := normalizeAliasKey(name)
s.mu.RLock()
entry, exactMatch := s.entries[key]
var prefixMatch *aliasEntry
if !exactMatch {
// Try prefix matching - prefixEntries is sorted longest-first
nameStr := strings.ToLower(displayAliasName(name))
for i := range s.prefixEntries {
prefix := strings.ToLower(s.prefixEntries[i].Alias)
if strings.HasPrefix(nameStr, prefix) {
prefixMatch = &s.prefixEntries[i]
break // First match is longest due to sorting
}
}
}
s.mu.RUnlock()
if !exactMatch && prefixMatch == nil {
return name, false, nil
}
var current string
var visited map[string]struct{}
if exactMatch {
visited = map[string]struct{}{key: {}}
current = entry.Target
} else {
// For prefix match, use the target as-is
visited = map[string]struct{}{}
current = prefixMatch.Target
}
targetKey := normalizeAliasKeyString(current)
for {
targetName := model.ParseName(current)
if !targetName.IsValid() {
return name, false, fmt.Errorf("alias target %q is invalid", current)
}
if _, seen := visited[targetKey]; seen {
return name, false, errAliasCycle
}
visited[targetKey] = struct{}{}
s.mu.RLock()
next, ok := s.entries[targetKey]
s.mu.RUnlock()
if !ok {
return targetName, true, nil
}
current = next.Target
targetKey = normalizeAliasKeyString(current)
}
}
func (s *store) Set(alias, target model.Name, prefixMatching bool) error {
targetKey := normalizeAliasKey(target)
s.mu.Lock()
defer s.mu.Unlock()
if prefixMatching {
// For prefix aliases, we skip cycle detection since prefix matching
// works differently and the target is a specific model
aliasStr := displayAliasName(alias)
// Remove any existing prefix entry with the same alias
for i, e := range s.prefixEntries {
if strings.EqualFold(e.Alias, aliasStr) {
s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...)
break
}
}
s.prefixEntries = append(s.prefixEntries, aliasEntry{
Alias: aliasStr,
Target: displayAliasName(target),
PrefixMatching: true,
})
s.sortPrefixEntriesLocked()
return s.saveLocked()
}
aliasKey := normalizeAliasKey(alias)
if aliasKey == targetKey {
return fmt.Errorf("alias cannot point to itself")
}
visited := map[string]struct{}{aliasKey: {}}
currentKey := targetKey
for {
if _, seen := visited[currentKey]; seen {
return errAliasCycle
}
visited[currentKey] = struct{}{}
next, ok := s.entries[currentKey]
if !ok {
break
}
currentKey = normalizeAliasKeyString(next.Target)
}
s.entries[aliasKey] = aliasEntry{
Alias: displayAliasName(alias),
Target: displayAliasName(target),
}
return s.saveLocked()
}
func (s *store) Delete(alias model.Name) (bool, error) {
aliasKey := normalizeAliasKey(alias)
s.mu.Lock()
defer s.mu.Unlock()
// Try exact match first
if _, ok := s.entries[aliasKey]; ok {
delete(s.entries, aliasKey)
return true, s.saveLocked()
}
// Try prefix entries
aliasStr := displayAliasName(alias)
for i, e := range s.prefixEntries {
if strings.EqualFold(e.Alias, aliasStr) {
s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...)
return true, s.saveLocked()
}
}
return false, nil
}
// DeleteByString deletes an alias by its raw string value, useful for prefix
// aliases that may not be valid model names.
func (s *store) DeleteByString(alias string) (bool, error) {
alias = strings.TrimSpace(alias)
aliasLower := strings.ToLower(alias)
s.mu.Lock()
defer s.mu.Unlock()
// Try prefix entries first (since this is mainly for prefix aliases)
for i, e := range s.prefixEntries {
if strings.EqualFold(e.Alias, alias) {
s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...)
return true, s.saveLocked()
}
}
// Also check exact entries by normalized key
if _, ok := s.entries[aliasLower]; ok {
delete(s.entries, aliasLower)
return true, s.saveLocked()
}
return false, nil
}
func (s *store) List() []aliasEntry {
s.mu.RLock()
defer s.mu.RUnlock()
entries := make([]aliasEntry, 0, len(s.entries)+len(s.prefixEntries))
for _, entry := range s.entries {
entries = append(entries, entry)
}
entries = append(entries, s.prefixEntries...)
sort.Slice(entries, func(i, j int) bool {
return strings.Compare(entries[i].Alias, entries[j].Alias) < 0
})
return entries
}
func normalizeAliasKey(name model.Name) string {
return strings.ToLower(displayAliasName(name))
}
func (s *store) sortPrefixEntriesLocked() {
sort.Slice(s.prefixEntries, func(i, j int) bool {
// Sort by length descending (longest prefix first)
return len(s.prefixEntries[i].Alias) > len(s.prefixEntries[j].Alias)
})
}
func normalizeAliasKeyString(value string) string {
n := model.ParseName(value)
if !n.IsValid() {
return strings.ToLower(strings.TrimSpace(value))
}
return normalizeAliasKey(n)
}
func displayAliasName(n model.Name) string {
display := n.DisplayShortest()
if strings.EqualFold(n.Tag, "latest") {
if idx := strings.LastIndex(display, ":"); idx != -1 {
return display[:idx]
}
}
return display
}
func localModelExists(name model.Name) (bool, error) {
manifests, err := manifest.Manifests(true)
if err != nil {
return false, err
}
needle := name.String()
for existing := range manifests {
if strings.EqualFold(existing.String(), needle) {
return true, nil
}
}
return false, nil
}
func serverConfigPath() string {
home, err := os.UserHomeDir()
if err != nil {
return filepath.Join(".ollama", serverConfigFilename)
}
return filepath.Join(home, ".ollama", serverConfigFilename)
}
func (s *Server) aliasStore() (*store, error) {
s.aliasesOnce.Do(func() {
s.aliases, s.aliasesErr = createStore(serverConfigPath())
})
return s.aliases, s.aliasesErr
}
func (s *Server) resolveAlias(name model.Name) (model.Name, bool, error) {
store, err := s.aliasStore()
if err != nil {
return name, false, err
}
if store == nil {
return name, false, nil
}
return store.ResolveName(name)
}

View File

@@ -22,6 +22,7 @@ import (
"os/signal"
"slices"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
@@ -51,7 +52,7 @@ import (
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
"github.com/ollama/ollama/x/imagegen"
imagegenmanifest "github.com/ollama/ollama/x/imagegen/manifest"
xserver "github.com/ollama/ollama/x/server"
)
@@ -81,6 +82,9 @@ type Server struct {
addr net.Addr
sched *Scheduler
defaultNumCtx int
aliasesOnce sync.Once
aliases *store
aliasesErr error
}
func init() {
@@ -191,9 +195,16 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
resolvedName, _, err := s.resolveAlias(name)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
name = resolvedName
// We cannot currently consolidate this into GetModel because all we'll
// induce infinite recursion given the current code structure.
name, err := getExistingName(name)
name, err = getExistingName(name)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
return
@@ -1095,7 +1106,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
// For image generation models, populate details from imagegen package
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
if info, err := imagegen.GetModelInfo(name.String()); err == nil {
if info, err := imagegenmanifest.GetModelInfo(name.String()); err == nil {
modelDetails.Family = info.Architecture
modelDetails.ParameterSize = format.HumanNumber(uint64(info.ParameterCount))
modelDetails.QuantizationLevel = info.Quantization
@@ -1580,6 +1591,9 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
r.POST("/api/copy", s.CopyHandler)
r.GET("/api/experimental/aliases", s.ListAliasesHandler)
r.POST("/api/experimental/aliases", s.CreateAliasHandler)
r.DELETE("/api/experimental/aliases", s.DeleteAliasHandler)
// Inference
r.GET("/api/ps", s.PsHandler)
@@ -1950,13 +1964,20 @@ func (s *Server) ChatHandler(c *gin.Context) {
return
}
name, err := getExistingName(name)
resolvedName, _, err := s.resolveAlias(name)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
name = resolvedName
name, err = getExistingName(name)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
m, err := GetModel(req.Model)
m, err := GetModel(name.String())
if err != nil {
switch {
case os.IsNotExist(err):

159
server/routes_aliases.go Normal file
View File

@@ -0,0 +1,159 @@
package server
import (
"errors"
"fmt"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/types/model"
)
type aliasListResponse struct {
Aliases []aliasEntry `json:"aliases"`
}
type aliasDeleteRequest struct {
Alias string `json:"alias"`
}
func (s *Server) ListAliasesHandler(c *gin.Context) {
store, err := s.aliasStore()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var aliases []aliasEntry
if store != nil {
aliases = store.List()
}
c.JSON(http.StatusOK, aliasListResponse{Aliases: aliases})
}
func (s *Server) CreateAliasHandler(c *gin.Context) {
var req aliasEntry
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
req.Alias = strings.TrimSpace(req.Alias)
req.Target = strings.TrimSpace(req.Target)
if req.Alias == "" || req.Target == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias and target are required"})
return
}
// Target must always be a valid model name
targetName := model.ParseName(req.Target)
if !targetName.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("target %q is invalid", req.Target)})
return
}
var aliasName model.Name
if req.PrefixMatching {
// For prefix aliases, we still parse the alias to normalize it,
// but we allow any non-empty string since prefix patterns may not be valid model names
aliasName = model.ParseName(req.Alias)
// Even if not valid as a model name, we accept it for prefix matching
} else {
aliasName = model.ParseName(req.Alias)
if !aliasName.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("alias %q is invalid", req.Alias)})
return
}
if normalizeAliasKey(aliasName) == normalizeAliasKey(targetName) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias cannot point to itself"})
return
}
exists, err := localModelExists(aliasName)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if exists {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("alias %q conflicts with existing model", req.Alias)})
return
}
}
store, err := s.aliasStore()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if err := store.Set(aliasName, targetName, req.PrefixMatching); err != nil {
status := http.StatusInternalServerError
if errors.Is(err, errAliasCycle) {
status = http.StatusBadRequest
}
c.AbortWithStatusJSON(status, gin.H{"error": err.Error()})
return
}
resp := aliasEntry{
Alias: displayAliasName(aliasName),
Target: displayAliasName(targetName),
PrefixMatching: req.PrefixMatching,
}
if req.PrefixMatching && !aliasName.IsValid() {
// For prefix aliases that aren't valid model names, use the raw alias
resp.Alias = req.Alias
}
c.JSON(http.StatusOK, resp)
}
func (s *Server) DeleteAliasHandler(c *gin.Context) {
var req aliasDeleteRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
req.Alias = strings.TrimSpace(req.Alias)
if req.Alias == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias is required"})
return
}
store, err := s.aliasStore()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
aliasName := model.ParseName(req.Alias)
var deleted bool
if aliasName.IsValid() {
deleted, err = store.Delete(aliasName)
} else {
// For invalid model names (like prefix aliases), try deleting by raw string
deleted, err = store.DeleteByString(req.Alias)
}
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if !deleted {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("alias %q not found", req.Alias)})
return
}
c.JSON(http.StatusOK, gin.H{"deleted": true})
}

View File

@@ -0,0 +1,426 @@
package server
import (
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"path/filepath"
"testing"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/types/model"
)
func TestAliasShadowingRejected(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Setenv("HOME", t.TempDir())
s := Server{}
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "shadowed-model",
RemoteHost: "example.com",
From: "test",
Info: map[string]any{
"capabilities": []string{"completion"},
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
w = createRequest(t, s.CreateAliasHandler, aliasEntry{Alias: "shadowed-model", Target: "other-model"})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status 400, got %d", w.Code)
}
}
func TestAliasResolvesForChatRemote(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Setenv("HOME", t.TempDir())
var remoteModel string
rs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req api.ChatRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatal(err)
}
remoteModel = req.Model
w.Header().Set("Content-Type", "application/json")
resp := api.ChatResponse{
Model: req.Model,
Done: true,
DoneReason: "load",
}
if err := json.NewEncoder(w).Encode(&resp); err != nil {
t.Fatal(err)
}
}))
defer rs.Close()
p, err := url.Parse(rs.URL)
if err != nil {
t.Fatal(err)
}
t.Setenv("OLLAMA_REMOTES", p.Hostname())
s := Server{}
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "target-model",
RemoteHost: rs.URL,
From: "test",
Info: map[string]any{
"capabilities": []string{"completion"},
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
w = createRequest(t, s.CreateAliasHandler, aliasEntry{Alias: "alias-model", Target: "target-model"})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
w = createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "alias-model",
Messages: []api.Message{{Role: "user", Content: "hi"}},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
var resp api.ChatResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
if resp.Model != "alias-model" {
t.Fatalf("expected response model to be alias-model, got %q", resp.Model)
}
if remoteModel != "test" {
t.Fatalf("expected remote model to be 'test', got %q", remoteModel)
}
}
func TestPrefixAliasBasicMatching(t *testing.T) {
tmpDir := t.TempDir()
store, err := createStore(filepath.Join(tmpDir, "server.json"))
if err != nil {
t.Fatal(err)
}
// Create a prefix alias: "myprefix-" -> "targetmodel"
targetName := model.ParseName("targetmodel")
// Set a prefix alias (using "myprefix-" as the pattern)
store.mu.Lock()
store.prefixEntries = append(store.prefixEntries, aliasEntry{
Alias: "myprefix-",
Target: "targetmodel",
PrefixMatching: true,
})
store.mu.Unlock()
// Test that "myprefix-foo" resolves to "targetmodel"
testName := model.ParseName("myprefix-foo")
resolved, wasResolved, err := store.ResolveName(testName)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !wasResolved {
t.Fatal("expected name to be resolved")
}
if resolved.DisplayShortest() != targetName.DisplayShortest() {
t.Fatalf("expected resolved name to be %q, got %q", targetName.DisplayShortest(), resolved.DisplayShortest())
}
// Test that "otherprefix-foo" does not resolve
otherName := model.ParseName("otherprefix-foo")
_, wasResolved, err = store.ResolveName(otherName)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if wasResolved {
t.Fatal("expected name not to be resolved")
}
// Test that exact alias takes precedence
exactAlias := model.ParseName("myprefix-exact")
exactTarget := model.ParseName("exacttarget")
if err := store.Set(exactAlias, exactTarget, false); err != nil {
t.Fatal(err)
}
resolved, wasResolved, err = store.ResolveName(exactAlias)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !wasResolved {
t.Fatal("expected name to be resolved")
}
if resolved.DisplayShortest() != exactTarget.DisplayShortest() {
t.Fatalf("expected resolved name to be %q (exact match), got %q", exactTarget.DisplayShortest(), resolved.DisplayShortest())
}
}
func TestPrefixAliasLongestMatchWins(t *testing.T) {
tmpDir := t.TempDir()
store, err := createStore(filepath.Join(tmpDir, "server.json"))
if err != nil {
t.Fatal(err)
}
// Add two prefix aliases with overlapping patterns
store.mu.Lock()
store.prefixEntries = []aliasEntry{
{Alias: "abc-", Target: "short-target", PrefixMatching: true},
{Alias: "abc-def-", Target: "long-target", PrefixMatching: true},
}
store.sortPrefixEntriesLocked()
store.mu.Unlock()
// "abc-def-ghi" should match the longer prefix "abc-def-"
testName := model.ParseName("abc-def-ghi")
resolved, wasResolved, err := store.ResolveName(testName)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !wasResolved {
t.Fatal("expected name to be resolved")
}
expectedLongTarget := model.ParseName("long-target")
if resolved.DisplayShortest() != expectedLongTarget.DisplayShortest() {
t.Fatalf("expected resolved name to be %q (longest prefix match), got %q", expectedLongTarget.DisplayShortest(), resolved.DisplayShortest())
}
// "abc-xyz" should match the shorter prefix "abc-"
testName2 := model.ParseName("abc-xyz")
resolved, wasResolved, err = store.ResolveName(testName2)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !wasResolved {
t.Fatal("expected name to be resolved")
}
expectedShortTarget := model.ParseName("short-target")
if resolved.DisplayShortest() != expectedShortTarget.DisplayShortest() {
t.Fatalf("expected resolved name to be %q, got %q", expectedShortTarget.DisplayShortest(), resolved.DisplayShortest())
}
}
func TestPrefixAliasChain(t *testing.T) {
tmpDir := t.TempDir()
store, err := createStore(filepath.Join(tmpDir, "server.json"))
if err != nil {
t.Fatal(err)
}
// Create a chain: prefix "test-" -> "intermediate" -> "final"
intermediate := model.ParseName("intermediate")
final := model.ParseName("final")
// Add prefix alias
store.mu.Lock()
store.prefixEntries = []aliasEntry{
{Alias: "test-", Target: "intermediate", PrefixMatching: true},
}
store.mu.Unlock()
// Add exact alias for the intermediate step
if err := store.Set(intermediate, final, false); err != nil {
t.Fatal(err)
}
// "test-foo" should resolve through the chain to "final"
testName := model.ParseName("test-foo")
resolved, wasResolved, err := store.ResolveName(testName)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !wasResolved {
t.Fatal("expected name to be resolved")
}
if resolved.DisplayShortest() != final.DisplayShortest() {
t.Fatalf("expected resolved name to be %q, got %q", final.DisplayShortest(), resolved.DisplayShortest())
}
}
func TestPrefixAliasCRUD(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Setenv("HOME", t.TempDir())
s := Server{}
// Create a prefix alias via API
w := createRequest(t, s.CreateAliasHandler, aliasEntry{
Alias: "myprefix-",
Target: "llama2",
PrefixMatching: true,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
var createResp aliasEntry
if err := json.NewDecoder(w.Body).Decode(&createResp); err != nil {
t.Fatal(err)
}
if !createResp.PrefixMatching {
t.Fatal("expected prefix_matching to be true in response")
}
// List aliases and verify the prefix alias is included
w = createRequest(t, s.ListAliasesHandler, nil)
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
var listResp aliasListResponse
if err := json.NewDecoder(w.Body).Decode(&listResp); err != nil {
t.Fatal(err)
}
found := false
for _, a := range listResp.Aliases {
if a.PrefixMatching && a.Target == "llama2" {
found = true
break
}
}
if !found {
t.Fatal("expected to find prefix alias in list")
}
// Delete the prefix alias
w = createRequest(t, s.DeleteAliasHandler, aliasDeleteRequest{Alias: "myprefix-"})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
// Verify it's deleted
w = createRequest(t, s.ListAliasesHandler, nil)
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
if err := json.NewDecoder(w.Body).Decode(&listResp); err != nil {
t.Fatal(err)
}
for _, a := range listResp.Aliases {
if a.PrefixMatching {
t.Fatal("expected prefix alias to be deleted")
}
}
}
func TestPrefixAliasCaseInsensitive(t *testing.T) {
tmpDir := t.TempDir()
store, err := createStore(filepath.Join(tmpDir, "server.json"))
if err != nil {
t.Fatal(err)
}
// Add a prefix alias with mixed case
store.mu.Lock()
store.prefixEntries = []aliasEntry{
{Alias: "MyPrefix-", Target: "targetmodel", PrefixMatching: true},
}
store.mu.Unlock()
// Test that matching is case-insensitive
testName := model.ParseName("myprefix-foo")
resolved, wasResolved, err := store.ResolveName(testName)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !wasResolved {
t.Fatal("expected name to be resolved (case-insensitive)")
}
expectedTarget := model.ParseName("targetmodel")
if resolved.DisplayShortest() != expectedTarget.DisplayShortest() {
t.Fatalf("expected resolved name to be %q, got %q", expectedTarget.DisplayShortest(), resolved.DisplayShortest())
}
// Test uppercase request
testName2 := model.ParseName("MYPREFIX-BAR")
_, wasResolved, err = store.ResolveName(testName2)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !wasResolved {
t.Fatal("expected name to be resolved (uppercase)")
}
}
func TestPrefixAliasLocalModelPrecedence(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Setenv("HOME", t.TempDir())
s := Server{}
// Create a local model that would match a prefix alias
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "myprefix-localmodel",
RemoteHost: "example.com",
From: "test",
Info: map[string]any{
"capabilities": []string{"completion"},
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
// Create a prefix alias that would match the local model name
w = createRequest(t, s.CreateAliasHandler, aliasEntry{
Alias: "myprefix-",
Target: "someothermodel",
PrefixMatching: true,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
// Verify that resolving "myprefix-localmodel" returns the local model, not the alias target
store, err := s.aliasStore()
if err != nil {
t.Fatal(err)
}
localModelName := model.ParseName("myprefix-localmodel")
resolved, wasResolved, err := store.ResolveName(localModelName)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if wasResolved {
t.Fatalf("expected local model to take precedence (wasResolved should be false), but got resolved to %q", resolved.DisplayShortest())
}
if resolved.DisplayShortest() != localModelName.DisplayShortest() {
t.Fatalf("expected resolved name to be local model %q, got %q", localModelName.DisplayShortest(), resolved.DisplayShortest())
}
// Also verify that a non-local model matching the prefix DOES resolve to the alias target
nonLocalName := model.ParseName("myprefix-nonexistent")
resolved, wasResolved, err = store.ResolveName(nonLocalName)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !wasResolved {
t.Fatal("expected non-local model to resolve via prefix alias")
}
expectedTarget := model.ParseName("someothermodel")
if resolved.DisplayShortest() != expectedTarget.DisplayShortest() {
t.Fatalf("expected resolved name to be %q, got %q", expectedTarget.DisplayShortest(), resolved.DisplayShortest())
}
}

View File

@@ -21,7 +21,7 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/mlxrunner"
"github.com/ollama/ollama/x/imagegen"
)
type LlmRequest struct {
@@ -567,16 +567,16 @@ iGPUScan:
// This supports both LLM (completion) and image generation models.
func (s *Scheduler) loadMLX(req *LlmRequest) bool {
// Determine mode based on capabilities
var mode mlxrunner.ModelMode
var mode imagegen.ModelMode
if slices.Contains(req.model.Config.Capabilities, "image") {
mode = mlxrunner.ModeImageGen
mode = imagegen.ModeImageGen
} else {
mode = mlxrunner.ModeLLM
mode = imagegen.ModeLLM
}
// Use model name for MLX (it resolves manifests by name, not file path)
modelName := req.model.ShortName
server, err := mlxrunner.NewServer(modelName, mode)
server, err := imagegen.NewServer(modelName, mode)
if err != nil {
req.errCh <- err
return true

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"cmp"
@@ -18,19 +18,19 @@ type BytePairEncoding struct {
regexps []*regexp2.Regexp
}
var _ TextProcessor = (*BytePairEncoding)(nil)
var _ Tokenizer = (*BytePairEncoding)(nil)
func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding {
if len(pretokenizers) == 0 {
func NewBytePairEncoding(vocab *Vocabulary, pretokenizer ...string) BytePairEncoding {
if len(pretokenizer) == 0 {
// set default byte-level pretokenizer if none provided, e.g.
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44
pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
// https://github.com/huggingface/tokenizer/blob/main/tokenizer/src/pre_tokenizer/byte_level.rs#L44
pretokenizer = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
}
return BytePairEncoding{
vocab: vocab,
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
for _, p := range pretokenizers {
for _, p := range pretokenizer {
if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
return
}

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"bufio"
@@ -17,7 +17,7 @@ import (
func llama(t testing.TB) BytePairEncoding {
t.Helper()
f, err := os.Open(filepath.Join("testdata", "llama3.2", "encoder.json"))
f, err := os.Open(filepath.FromSlash("testdata/llama3.2/encoder.json"))
if err != nil {
t.Fatal(err)
}
@@ -43,7 +43,7 @@ func llama(t testing.TB) BytePairEncoding {
}
}
f, err = os.Open(filepath.Join("testdata", "llama3.2", "vocab.bpe"))
f, err = os.Open(filepath.FromSlash("testdata/llama3.2/vocab.bpe"))
if err != nil {
t.Fatal(err)
}

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"container/heap"
@@ -17,7 +17,7 @@ type SentencePiece struct {
vocab *Vocabulary
}
var _ TextProcessor = (*SentencePiece)(nil)
var _ Tokenizer = (*SentencePiece)(nil)
func (spm SentencePiece) Vocabulary() *Vocabulary {
return spm.vocab
@@ -224,7 +224,7 @@ func (spm SentencePiece) Decode(ids []int32) (string, error) {
data := spm.vocab.Decode(id)
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
// For tokenizers that use byte tokens like "<0xEA>"
// For tokenizer that use byte tokens like "<0xEA>"
// convert them to the partial unicode character
// so they are buffered correctly by the runner instead
// of being sent back to the api as "<0xEA>"

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"log/slog"
@@ -15,7 +15,7 @@ import (
func loadSentencePieceVocab(t *testing.T) SentencePiece {
t.Helper()
bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
bts, err := os.ReadFile(filepath.FromSlash("testdata/gemma2/tokenizer.model"))
if err != nil {
t.Fatal(err)
}

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
const (
TOKEN_TYPE_NORMAL = iota + 1
@@ -9,7 +9,7 @@ const (
TOKEN_TYPE_BYTE
)
type TextProcessor interface {
type Tokenizer interface {
Encode(s string, addSpecial bool) ([]int32, error)
Decode([]int32) (string, error)
Is(int32, Special) bool

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"log/slog"

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"testing"

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"fmt"
@@ -32,7 +32,7 @@ var wordPieceReplacer = strings.NewReplacer(
" 're", "'re",
)
// Decode implements TextProcessor.
// Decode implements Tokenizer.
func (wpm WordPiece) Decode(ids []int32) (string, error) {
var sb strings.Builder
for i, id := range ids {
@@ -96,7 +96,7 @@ func (wpm WordPiece) words(s string) iter.Seq[string] {
}
}
// Encode implements TextProcessor.
// Encode implements Tokenizer.
func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
var ids []int32
@@ -151,17 +151,17 @@ func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
return ids, nil
}
// Is implements TextProcessor.
// Is implements Tokenizer.
func (wpm WordPiece) Is(id int32, special Special) bool {
return wpm.vocab.Is(id, special)
}
// Vocabulary implements TextProcessor.
// Vocabulary implements Tokenizer.
func (wpm WordPiece) Vocabulary() *Vocabulary {
return wpm.vocab
}
var _ TextProcessor = (*WordPiece)(nil)
var _ Tokenizer = (*WordPiece)(nil)
func NewWordPiece(vocab *Vocabulary, lowercase bool) WordPiece {
return WordPiece{

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"slices"

View File

@@ -19,6 +19,7 @@ import (
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/create"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// MinOllamaVersion is the minimum Ollama version required for safetensors models.
@@ -35,7 +36,7 @@ type ModelfileConfig struct {
type CreateOptions struct {
ModelName string
ModelDir string
Quantize string // "q4", "q8", "nvfp4", or "mxfp8" for quantization
Quantize string // "int4", "int8", "nvfp4", or "mxfp8" for quantization
Modelfile *ModelfileConfig // template/system/license from Modelfile
}
@@ -94,6 +95,7 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
newLayerCreator(), newTensorLayerCreator(),
newManifestWriter(opts, capabilities, parserName, rendererName),
progressFn,
newPackedTensorLayerCreator(),
)
} else {
err = create.CreateImageGenModel(
@@ -141,60 +143,33 @@ func newTensorLayerCreator() create.QuantizingTensorLayerCreator {
}
}
// createQuantizedLayers quantizes a tensor and returns the resulting layers.
// createQuantizedLayers quantizes a tensor and returns a single combined layer.
// The combined blob contains data, scale, and optional bias tensors with metadata.
func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quantize string) ([]create.LayerInfo, error) {
if !QuantizeSupported() {
return nil, fmt.Errorf("quantization requires MLX support")
}
// Quantize the tensor
qweightData, scalesData, qbiasData, _, _, _, err := quantizeTensor(r, name, dtype, shape, quantize)
// Quantize the tensor into a single combined blob
blobData, err := quantizeTensor(r, name, dtype, shape, quantize)
if err != nil {
return nil, fmt.Errorf("failed to quantize %s: %w", name, err)
}
// Create layer for quantized weight
weightLayer, err := manifest.NewLayer(bytes.NewReader(qweightData), manifest.MediaTypeImageTensor)
// Create single layer for the combined blob
layer, err := manifest.NewLayer(bytes.NewReader(blobData), manifest.MediaTypeImageTensor)
if err != nil {
return nil, err
}
// Create layer for scales
scalesLayer, err := manifest.NewLayer(bytes.NewReader(scalesData), manifest.MediaTypeImageTensor)
if err != nil {
return nil, err
}
layers := []create.LayerInfo{
return []create.LayerInfo{
{
Digest: weightLayer.Digest,
Size: weightLayer.Size,
MediaType: weightLayer.MediaType,
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: name,
},
{
Digest: scalesLayer.Digest,
Size: scalesLayer.Size,
MediaType: scalesLayer.MediaType,
Name: name + "_scale",
},
}
// Add qbiases layer if present (affine mode)
if qbiasData != nil {
qbiasLayer, err := manifest.NewLayer(bytes.NewReader(qbiasData), manifest.MediaTypeImageTensor)
if err != nil {
return nil, err
}
layers = append(layers, create.LayerInfo{
Digest: qbiasLayer.Digest,
Size: qbiasLayer.Size,
MediaType: qbiasLayer.MediaType,
Name: name + "_qbias",
})
}
return layers, nil
}, nil
}
// createUnquantizedLayer creates a single tensor layer without quantization.
@@ -214,6 +189,58 @@ func createUnquantizedLayer(r io.Reader, name string) ([]create.LayerInfo, error
}, nil
}
// newPackedTensorLayerCreator returns a PackedTensorLayerCreator callback for
// creating packed multi-tensor blob layers (used for expert groups).
func newPackedTensorLayerCreator() create.PackedTensorLayerCreator {
return func(groupName string, tensors []create.PackedTensorInput) (create.LayerInfo, error) {
// Check if any tensor in the group needs quantization
hasQuantize := false
for _, t := range tensors {
if t.Quantize != "" {
hasQuantize = true
break
}
}
var blobReader io.Reader
if hasQuantize {
if !QuantizeSupported() {
return create.LayerInfo{}, fmt.Errorf("quantization requires MLX support")
}
blobData, err := quantizePackedGroup(tensors)
if err != nil {
return create.LayerInfo{}, fmt.Errorf("failed to quantize packed group %s: %w", groupName, err)
}
blobReader = bytes.NewReader(blobData)
} else {
// Build unquantized packed blob using streaming reader
// Extract raw tensor data from safetensors-wrapped readers
var tds []*safetensors.TensorData
for _, t := range tensors {
rawData, err := safetensors.ExtractRawFromSafetensors(t.Reader)
if err != nil {
return create.LayerInfo{}, fmt.Errorf("failed to extract tensor %s: %w", t.Name, err)
}
td := safetensors.NewTensorDataFromBytes(t.Name, t.Dtype, t.Shape, rawData)
tds = append(tds, td)
}
blobReader = safetensors.BuildPackedSafetensorsReader(tds)
}
layer, err := manifest.NewLayer(blobReader, manifest.MediaTypeImageTensor)
if err != nil {
return create.LayerInfo{}, err
}
return create.LayerInfo{
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: groupName,
}, nil
}
}
// newManifestWriter returns a ManifestWriter callback for writing the model manifest.
func newManifestWriter(opts CreateOptions, capabilities []string, parserName, rendererName string) create.ManifestWriter {
return func(modelName string, config create.LayerInfo, layers []create.LayerInfo) error {

View File

@@ -3,128 +3,195 @@
package client
import (
"encoding/binary"
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"strconv"
"github.com/ollama/ollama/x/create"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// quantizeTensor loads a tensor from safetensors format, quantizes it,
// and returns safetensors data for the quantized weights, scales, and biases.
// Supported quantization types:
// - "q4": affine 4-bit, group_size=32 (with qbiases)
// - "nvfp4": NVIDIA FP4, group_size=16 (no qbiases, E4M3 scales)
// - "q8": affine 8-bit, group_size=64 (with qbiases)
// - "mxfp8": Microsoft MX FP8, group_size=32 (no qbiases, E4M3 scales)
// Uses MLX's native SaveSafetensors to ensure correct dtype handling (especially uint32 for quantized weights).
func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
// quantizeParams maps quantization type names to MLX quantize parameters.
var quantizeParams = map[string]struct {
groupSize int
bits int
mode string
}{
"int4": {32, 4, "affine"},
"nvfp4": {16, 4, "nvfp4"},
"int8": {64, 8, "affine"},
"mxfp8": {32, 8, "mxfp8"},
}
// loadAndQuantizeArray writes a safetensors reader to a temp file, loads it with MLX,
// quantizes the tensor, and appends the resulting arrays (weight, scale, optional bias)
// to the provided maps. If quantize is empty, the tensor is kept as-is.
// Returns any temp file paths created (caller must clean up) and arrays needing eval.
func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]*mlx.Array) (tmpPath string, toEval []*mlx.Array, nativeHandle *mlx.SafetensorsFile, err error) {
tmpDir := ensureTempDir()
// Read safetensors data to a temp file (LoadSafetensorsNative needs a path)
tmpFile, err := os.CreateTemp(tmpDir, "quant-input-*.safetensors")
tmpFile, err := os.CreateTemp(tmpDir, "quant-*.safetensors")
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to create temp file: %w", err)
return "", nil, nil, fmt.Errorf("failed to create temp file: %w", err)
}
tmpPath := tmpFile.Name()
defer os.Remove(tmpPath)
tmpPath = tmpFile.Name()
if _, err := io.Copy(tmpFile, r); err != nil {
tmpFile.Close()
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to write temp file: %w", err)
return tmpPath, nil, nil, fmt.Errorf("failed to write temp file for %s: %w", name, err)
}
tmpFile.Close()
// Load the tensor using MLX's native loader
st, err := mlx.LoadSafetensorsNative(tmpPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to load safetensors: %w", err)
return tmpPath, nil, nil, fmt.Errorf("failed to load safetensors for %s: %w", name, err)
}
defer st.Free()
// Get the tensor (it's stored as "data" in our minimal safetensors format)
arr := st.Get("data")
// Find the tensor key (may differ from name for single-tensor blobs)
inputKey, err := findSafetensorsKey(tmpPath)
if err != nil {
st.Free()
return tmpPath, nil, nil, fmt.Errorf("failed to read blob header for %s: %w", name, err)
}
arr := st.Get(inputKey)
if arr == nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("tensor 'data' not found in safetensors")
st.Free()
return tmpPath, nil, nil, fmt.Errorf("tensor %q not found in safetensors", inputKey)
}
// Convert to BFloat16 if needed (quantize expects float type)
if quantize == "" {
arr = mlx.Contiguous(arr)
arrays[name] = arr
return tmpPath, []*mlx.Array{arr}, st, nil
}
// Convert to float type if needed (quantize expects float)
if arr.Dtype() != mlx.DtypeBFloat16 && arr.Dtype() != mlx.DtypeFloat32 && arr.Dtype() != mlx.DtypeFloat16 {
arr = mlx.AsType(arr, mlx.DtypeBFloat16)
mlx.Eval(arr)
}
// Quantize based on quantization type
var qweight, scales, qbiases *mlx.Array
switch quantize {
case "q4":
// affine mode: group_size=32, bits=4 (with qbiases for zero-point offset)
qweight, scales, qbiases = mlx.Quantize(arr, 32, 4, "affine")
case "nvfp4":
// NVIDIA FP4: group_size=16, bits=4 (no qbiases, E4M3 scales)
qweight, scales, qbiases = mlx.Quantize(arr, 16, 4, "nvfp4")
case "q8":
// affine mode: group_size=64, bits=8 (with qbiases for zero-point offset)
qweight, scales, qbiases = mlx.Quantize(arr, 64, 8, "affine")
case "mxfp8":
// Microsoft MX FP8: group_size=32, bits=8, E4M3 scales (no qbiases)
qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "mxfp8")
default:
return nil, nil, nil, nil, nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize)
params, ok := quantizeParams[quantize]
if !ok {
st.Free()
return tmpPath, nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize)
}
// Eval and make contiguous for data access
qweight, scales, qbiases := mlx.Quantize(arr, params.groupSize, params.bits, params.mode)
qweight = mlx.Contiguous(qweight)
scales = mlx.Contiguous(scales)
arrays[name] = qweight
arrays[name+".scale"] = scales
toEval = append(toEval, qweight, scales)
if qbiases != nil {
qbiases = mlx.Contiguous(qbiases)
mlx.Eval(qweight, scales, qbiases)
} else {
mlx.Eval(qweight, scales)
arrays[name+".bias"] = qbiases
toEval = append(toEval, qbiases)
}
// Get shapes
qweightShape = qweight.Shape()
scalesShape = scales.Shape()
return tmpPath, toEval, st, nil
}
// Save quantized weight using MLX's native safetensors (correctly handles uint32 dtype)
qweightPath := filepath.Join(tmpDir, "qweight.safetensors")
defer os.Remove(qweightPath)
if err := mlx.SaveSafetensors(qweightPath, map[string]*mlx.Array{"data": qweight}); err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save quantized weight: %w", err)
// quantizeTensor loads a tensor from safetensors format, quantizes it,
// and returns a single combined safetensors blob with the quantized weight, scale, and optional bias.
// Tensor keys use the original tensor name: name, name.scale, name.bias.
// The blob includes __metadata__ with quant_type and group_size.
// Supported quantization types: "int4", "nvfp4", "int8", "mxfp8".
func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quantize string) (blobData []byte, err error) {
arrays := make(map[string]*mlx.Array)
tmpPath, toEval, st, err := loadAndQuantizeArray(r, tensorName, quantize, arrays)
if tmpPath != "" {
defer os.Remove(tmpPath)
}
if st != nil {
defer st.Free()
}
qweightData, err = os.ReadFile(qweightPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read quantized weight: %w", err)
return nil, err
}
// Save scales using MLX's native safetensors
scalesPath := filepath.Join(tmpDir, "scales.safetensors")
defer os.Remove(scalesPath)
if err := mlx.SaveSafetensors(scalesPath, map[string]*mlx.Array{"data": scales}); err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save scales: %w", err)
}
scalesData, err = os.ReadFile(scalesPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read scales: %w", err)
mlx.Eval(toEval...)
// Build metadata for single-tensor blobs
params := quantizeParams[quantize]
metadata := map[string]string{
"quant_type": quantize,
"group_size": strconv.Itoa(params.groupSize),
}
// Affine mode returns qbiases for zero-point offset
if qbiases != nil {
qbiasShape = qbiases.Shape()
qbiasPath := filepath.Join(tmpDir, "qbias.safetensors")
defer os.Remove(qbiasPath)
if err := mlx.SaveSafetensors(qbiasPath, map[string]*mlx.Array{"data": qbiases}); err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save qbiases: %w", err)
tmpDir := ensureTempDir()
outPath := filepath.Join(tmpDir, "combined.safetensors")
defer os.Remove(outPath)
if err := mlx.SaveSafetensorsWithMetadata(outPath, arrays, metadata); err != nil {
return nil, fmt.Errorf("failed to save combined blob: %w", err)
}
return os.ReadFile(outPath)
}
// quantizePackedGroup quantizes multiple tensors and saves them all into a single
// combined safetensors blob. Used for packing expert groups.
// Each tensor may have a different quantization type (mixed-precision).
// Returns the blob bytes. No __metadata__ is added because different tensors
// may use different quantization types.
func quantizePackedGroup(inputs []create.PackedTensorInput) ([]byte, error) {
allArrays := make(map[string]*mlx.Array)
var allToEval []*mlx.Array
var tmpPaths []string
var handles []*mlx.SafetensorsFile
for _, input := range inputs {
tmpPath, toEval, st, err := loadAndQuantizeArray(input.Reader, input.Name, input.Quantize, allArrays)
if tmpPath != "" {
tmpPaths = append(tmpPaths, tmpPath)
}
if st != nil {
handles = append(handles, st)
}
qbiasData, err = os.ReadFile(qbiasPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read qbiases: %w", err)
// Cleanup on error
for _, h := range handles {
h.Free()
}
for _, p := range tmpPaths {
os.Remove(p)
}
return nil, err
}
allToEval = append(allToEval, toEval...)
}
return qweightData, scalesData, qbiasData, qweightShape, scalesShape, qbiasShape, nil
mlx.Eval(allToEval...)
// Free native handles after eval
for _, h := range handles {
h.Free()
}
// Save combined blob (no global metadata for mixed-precision packed blobs)
tmpDir := ensureTempDir()
outPath := filepath.Join(tmpDir, "packed-combined.safetensors")
defer os.Remove(outPath)
if err := mlx.SaveSafetensorsWithMetadata(outPath, allArrays, nil); err != nil {
return nil, fmt.Errorf("failed to save packed blob: %w", err)
}
blobData, err := os.ReadFile(outPath)
if err != nil {
return nil, fmt.Errorf("failed to read packed blob: %w", err)
}
for _, p := range tmpPaths {
os.Remove(p)
}
return blobData, nil
}
// QuantizeSupported returns true if quantization is supported (MLX build)
@@ -138,3 +205,33 @@ func ensureTempDir() string {
os.MkdirAll(tmpDir, 0755)
return tmpDir
}
// findSafetensorsKey reads the first non-metadata tensor key from a safetensors file.
func findSafetensorsKey(path string) (string, error) {
f, err := os.Open(path)
if err != nil {
return "", err
}
defer f.Close()
var headerSize uint64
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
return "", err
}
headerBytes := make([]byte, headerSize)
if _, err := io.ReadFull(f, headerBytes); err != nil {
return "", err
}
var header map[string]json.RawMessage
if err := json.Unmarshal(headerBytes, &header); err != nil {
return "", err
}
for k := range header {
if k != "__metadata__" {
return k, nil
}
}
return "", fmt.Errorf("no tensor found in safetensors header")
}

View File

@@ -5,11 +5,18 @@ package client
import (
"fmt"
"io"
"github.com/ollama/ollama/x/create"
)
// quantizeTensor is not available without MLX
func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quantize string) (blobData []byte, err error) {
return nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
}
// quantizePackedGroup is not available without MLX
func quantizePackedGroup(inputs []create.PackedTensorInput) ([]byte, error) {
return nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
}
// QuantizeSupported returns false when MLX is not available

View File

@@ -6,7 +6,9 @@ import (
"io"
"os"
"path/filepath"
"regexp"
"slices"
"sort"
"strings"
"github.com/ollama/ollama/envconfig"
@@ -228,7 +230,7 @@ type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error)
type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error)
// QuantizingTensorLayerCreator creates tensor layers with optional quantization.
// When quantize is non-empty (e.g., "q8"), returns multiple layers (weight + scales + biases).
// When quantize is non-empty (e.g., "int8"), returns multiple layers (weight + scales + biases).
type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error)
// ManifestWriter writes the manifest file.
@@ -264,19 +266,19 @@ func ShouldQuantize(name, component string) bool {
// ShouldQuantizeTensor returns true if a tensor should be quantized based on name, shape, and quantize type.
// This is a more detailed check that also considers tensor dimensions.
// The quantize parameter specifies the quantization type (e.g., "q4", "nvfp4", "q8", "mxfp8").
// The quantize parameter specifies the quantization type (e.g., "int4", "nvfp4", "int8", "mxfp8").
func ShouldQuantizeTensor(name string, shape []int32, quantize string) bool {
return GetTensorQuantization(name, shape, quantize) != ""
}
// normalizeQuantType converts various quantization type aliases to canonical forms.
// Supports: q4/Q4/int4/INT4/fp4/FP4 -> q4, q8/Q8/int8/INT8/fp8/FP8 -> q8, nvfp4/NVFP4, mxfp8/MXFP8
// Supports: q4/Q4/int4/INT4/fp4/FP4 -> int4, q8/Q8/int8/INT8/fp8/FP8 -> int8, nvfp4/NVFP4, mxfp8/MXFP8
func normalizeQuantType(quantize string) string {
switch strings.ToUpper(quantize) {
case "Q4", "INT4", "FP4":
return "q4"
return "int4"
case "Q8", "INT8", "FP8":
return "q8"
return "int8"
case "NVFP4":
return "nvfp4"
case "MXFP8":
@@ -286,29 +288,12 @@ func normalizeQuantType(quantize string) string {
}
}
// getQuantGroupSize returns the group size for a given quantization type.
// These must match the values used in quantize.go when creating quantized models.
func getQuantGroupSize(quantize string) int {
switch normalizeQuantType(quantize) {
case "nvfp4":
return 16
case "q4":
return 32
case "mxfp8":
return 32
case "q8":
return 64
default:
return 32
}
}
// GetTensorQuantization returns the appropriate quantization type for a tensor.
// Returns "" if the tensor should not be quantized.
// This implements mixed-precision quantization:
// - Attention MLA weights (q_a, q_b, kv_a, kv_b): unquantized (most sensitive)
// - Output projection, gate/up weights: q4 (less sensitive)
// - Down projection weights: q8 (more sensitive, would be Q6 in GGML but no MLX kernel)
// - Output projection, gate/up weights: int4 (less sensitive)
// - Down projection weights: int8 (more sensitive, would be Q6 in GGML but no MLX kernel)
// - Norms, embeddings, biases, routing gates: no quantization
func GetTensorQuantization(name string, shape []int32, quantize string) string {
// Use basic name-based check first
@@ -330,12 +315,12 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string {
quantNorm := normalizeQuantType(quantize)
// MLX quantization requires last dimension to be divisible by group size
// nvfp4: 16, q4/mxfp8: 32, q8: 64
// nvfp4: 16, int4/mxfp8: 32, int8: 64
groupSize := int32(32)
switch quantNorm {
case "nvfp4":
groupSize = 16
case "q8":
case "int8":
groupSize = 64
}
if shape[len(shape)-1]%groupSize != 0 {
@@ -363,13 +348,13 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string {
return "" // No quantization - keep bf16
}
// Down projection weights - use Q8 (would be Q6_K in GGML, but MLX has no Q6 kernel)
// Down projection weights - use INT8 (would be Q6_K in GGML, but MLX has no Q6 kernel)
// mlp.down_proj, mlp.experts.X.down_proj, mlp.shared_experts.down_proj
if strings.Contains(name, "down_proj") {
return "q8"
return "int8"
}
// Output projection, gate/up weights - use requested quantization (Q4)
// Output projection, gate/up weights - use requested quantization (INT4)
// o_proj, gate_proj, up_proj
if strings.Contains(name, "o_proj") ||
strings.Contains(name, "gate_proj") ||
@@ -386,14 +371,69 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string {
return quantNorm
}
// expertGroupRegexp matches expert tensor names and captures the group prefix.
// Matches: model.layers.{L}.mlp.experts.{E}.{proj}.weight (and .scale, .bias suffixes)
// Captures: model.layers.{L}.mlp.experts
var expertGroupRegexp = regexp.MustCompile(`^(model\.layers\.\d+\.mlp\.(?:shared_)?experts)\..*\.weight`)
// ExpertGroupPrefix returns the group prefix for expert tensors that should be packed together.
// For example:
// - "model.layers.1.mlp.experts.0.down_proj.weight" -> "model.layers.1.mlp.experts"
// - "model.layers.1.mlp.shared_experts.down_proj.weight" -> "model.layers.1.mlp.shared_experts"
// - "model.layers.0.mlp.down_proj.weight" -> "" (dense layer, no experts)
// - "model.layers.1.mlp.gate.weight" -> "" (routing gate, not an expert)
func ExpertGroupPrefix(tensorName string) string {
m := expertGroupRegexp.FindStringSubmatch(tensorName)
if m == nil {
return ""
}
return m[1]
}
// PackedTensorInput holds metadata for a tensor that will be packed into a multi-tensor blob.
type PackedTensorInput struct {
Name string
Dtype string
Shape []int32
Quantize string // per-tensor quantization type (may differ within group)
Reader io.Reader // safetensors-wrapped tensor data
}
// PackedTensorLayerCreator creates a single blob layer containing multiple packed tensors.
// groupName is the group prefix (e.g., "model.layers.1.mlp.experts").
type PackedTensorLayerCreator func(groupName string, tensors []PackedTensorInput) (LayerInfo, error)
// CreateSafetensorsModel imports a standard safetensors model from a directory.
// This handles Hugging Face style models with config.json and *.safetensors files.
// Stores each tensor as a separate blob for fine-grained deduplication.
// If quantize is non-empty (e.g., "q8"), eligible tensors will be quantized.
func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
// Expert tensors are packed into per-layer blobs when createPackedLayer is non-nil.
// If quantize is non-empty (e.g., "int8"), eligible tensors will be quantized.
func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string), createPackedLayer ...PackedTensorLayerCreator) error {
var layers []LayerInfo
var configLayer LayerInfo
// Resolve the optional packed layer creator
var packedCreator PackedTensorLayerCreator
if len(createPackedLayer) > 0 {
packedCreator = createPackedLayer[0]
}
// Accumulate expert tensors by group prefix for packing.
// Readers reference file-backed SectionReaders, so we keep extractors
// open until each group is flushed to avoid buffering tensor data in memory.
expertGroups := make(map[string][]PackedTensorInput)
var expertGroupOrder []string
// Track open extractors so we can close them after flushing groups
var openExtractors []*safetensors.TensorExtractor
closeExtractors := func() {
for _, ext := range openExtractors {
ext.Close()
}
openExtractors = nil
}
entries, err := os.ReadDir(modelDir)
if err != nil {
return fmt.Errorf("failed to read directory: %w", err)
@@ -410,6 +450,7 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
// Extract individual tensors from safetensors file
extractor, err := safetensors.OpenForExtraction(stPath)
if err != nil {
closeExtractors()
return fmt.Errorf("failed to open %s: %w", stPath, err)
}
@@ -420,10 +461,14 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
}
fn(fmt.Sprintf("importing %s (%d tensors%s)", entry.Name(), len(tensorNames), quantizeMsg))
// Track whether this extractor has expert tensors that need to stay open
hasExpertTensors := false
for _, tensorName := range tensorNames {
td, err := extractor.GetTensor(tensorName)
if err != nil {
extractor.Close()
closeExtractors()
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
}
@@ -434,20 +479,65 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
quantizeType = GetTensorQuantization(tensorName, td.Shape, quantize)
}
// Store as minimal safetensors format (88 bytes header overhead)
// This enables native mmap loading via mlx_load_safetensors
// createTensorLayer returns multiple layers if quantizing (weight + scales)
newLayers, err := createTensorLayer(td.SafetensorsReader(), tensorName, td.Dtype, td.Shape, quantizeType)
if err != nil {
extractor.Close()
return fmt.Errorf("failed to create layer for %s: %w", tensorName, err)
// Check if this tensor belongs to an expert group for packing
groupPrefix := ""
if packedCreator != nil {
groupPrefix = ExpertGroupPrefix(tensorName)
}
if groupPrefix != "" {
// Accumulate expert tensor for packed blob.
// The Reader uses a file-backed SectionReader, so we must
// keep the extractor open until this group is flushed.
hasExpertTensors = true
if _, exists := expertGroups[groupPrefix]; !exists {
expertGroupOrder = append(expertGroupOrder, groupPrefix)
}
expertGroups[groupPrefix] = append(expertGroups[groupPrefix], PackedTensorInput{
Name: tensorName,
Dtype: td.Dtype,
Shape: td.Shape,
Quantize: quantizeType,
Reader: td.SafetensorsReader(),
})
} else {
// Store as minimal safetensors format (88 bytes header overhead)
// This enables native mmap loading via mlx_load_safetensors
// createTensorLayer returns multiple layers if quantizing (weight + scales)
newLayers, err := createTensorLayer(td.SafetensorsReader(), tensorName, td.Dtype, td.Shape, quantizeType)
if err != nil {
extractor.Close()
closeExtractors()
return fmt.Errorf("failed to create layer for %s: %w", tensorName, err)
}
layers = append(layers, newLayers...)
}
layers = append(layers, newLayers...)
}
extractor.Close()
if hasExpertTensors {
// Keep extractor open - readers still reference its file handle
openExtractors = append(openExtractors, extractor)
} else {
extractor.Close()
}
}
// Process accumulated expert groups into packed blobs, then close extractors
if packedCreator != nil {
sort.Strings(expertGroupOrder)
for _, groupName := range expertGroupOrder {
tensors := expertGroups[groupName]
fn(fmt.Sprintf("packing %s (%d tensors)", groupName, len(tensors)))
layer, err := packedCreator(groupName, tensors)
if err != nil {
closeExtractors()
return fmt.Errorf("failed to create packed layer for %s: %w", groupName, err)
}
layers = append(layers, layer)
}
}
closeExtractors()
// Process all JSON config files
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".json") {
@@ -487,23 +577,6 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
return fmt.Errorf("config.json not found in %s", modelDir)
}
// Create model_index.json with quantization info if quantizing
if quantize != "" {
modelIndex := map[string]any{
"quantization": strings.ToUpper(quantize),
"group_size": getQuantGroupSize(quantize),
}
indexData, err := json.MarshalIndent(modelIndex, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal model_index.json: %w", err)
}
indexLayer, err := createLayer(strings.NewReader(string(indexData)), "application/vnd.ollama.image.json", "model_index.json")
if err != nil {
return fmt.Errorf("failed to create model_index.json layer: %w", err)
}
layers = append(layers, indexLayer)
}
fn(fmt.Sprintf("writing manifest for %s", modelName))
if err := writeManifest(modelName, configLayer, layers); err != nil {

View File

@@ -586,6 +586,39 @@ func TestShouldQuantizeTensor(t *testing.T) {
}
}
func TestExpertGroupPrefix(t *testing.T) {
tests := []struct {
name string
want string
}{
// Expert tensors should return the group prefix
{"model.layers.1.mlp.experts.0.down_proj.weight", "model.layers.1.mlp.experts"},
{"model.layers.1.mlp.experts.63.gate_proj.weight", "model.layers.1.mlp.experts"},
{"model.layers.0.mlp.experts.0.up_proj.weight", "model.layers.0.mlp.experts"},
// Shared expert tensors should return their own group prefix
{"model.layers.1.mlp.shared_experts.down_proj.weight", "model.layers.1.mlp.shared_experts"},
{"model.layers.2.mlp.shared_experts.gate_proj.weight", "model.layers.2.mlp.shared_experts"},
// Non-expert tensors should return empty string
{"model.layers.0.mlp.down_proj.weight", ""}, // dense layer, no experts
{"model.layers.1.mlp.gate.weight", ""}, // routing gate, not an expert
{"model.embed_tokens.weight", ""}, // embedding
{"model.layers.0.self_attn.q_proj.weight", ""}, // attention
{"model.norm.weight", ""}, // norm
{"lm_head.weight", ""}, // output head
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ExpertGroupPrefix(tt.name)
if got != tt.want {
t.Errorf("ExpertGroupPrefix(%q) = %q, want %q", tt.name, got, tt.want)
}
})
}
}
func TestCreateSafetensorsModel_WithQuantize(t *testing.T) {
dir := t.TempDir()
@@ -751,7 +784,7 @@ func TestCreateImageGenModel_WithQuantize(t *testing.T) {
progressFn := func(status string) {}
err := CreateImageGenModel("test-imagegen", dir, "q8", createLayer, createTensorLayer, writeManifest, progressFn)
err := CreateImageGenModel("test-imagegen", dir, "int8", createLayer, createTensorLayer, writeManifest, progressFn)
if err != nil {
t.Fatalf("CreateImageGenModel failed: %v", err)
}

View File

@@ -15,15 +15,15 @@ import (
// CreateImageGenModel imports an image generation model from a directory.
// Stores each tensor as a separate blob for fine-grained deduplication.
// If quantize is specified, linear weights in transformer/text_encoder are quantized.
// Supported quantization types: q4, q8, nvfp4, mxfp8 (or empty for no quantization).
// Supported quantization types: int4, int8, nvfp4, mxfp8 (or empty for no quantization).
// Layer creation and manifest writing are done via callbacks to avoid import cycles.
func CreateImageGenModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
// Validate quantization type
switch quantize {
case "", "q4", "q8", "nvfp4", "mxfp8":
case "", "int4", "int8", "nvfp4", "mxfp8":
// valid
default:
return fmt.Errorf("unsupported quantization type %q: supported types are q4, q8, nvfp4, mxfp8", quantize)
return fmt.Errorf("unsupported quantization type %q: supported types are int4, int8, nvfp4, mxfp8", quantize)
}
var layers []LayerInfo
@@ -214,7 +214,7 @@ func CreateImageGenModel(modelName, modelDir, quantize string, createLayer Layer
// canQuantizeShape returns true if a tensor shape is compatible with MLX quantization.
// MLX requires the last dimension to be divisible by the group size.
// nvfp4: 16, q4/mxfp8: 32, q8: 64
// nvfp4: 16, int4/mxfp8: 32, int8: 64
func canQuantizeShape(shape []int32, quantize string) bool {
if len(shape) < 2 {
return false
@@ -223,7 +223,7 @@ func canQuantizeShape(shape []int32, quantize string) bool {
switch strings.ToUpper(quantize) {
case "NVFP4":
groupSize = 16
case "Q8":
case "INT8":
groupSize = 64
}
return shape[len(shape)-1]%groupSize == 0

Some files were not shown because too many files have changed in this diff Show More