mirror of
https://github.com/ollama/ollama.git
synced 2026-02-05 21:23:43 -05:00
Compare commits
4 Commits
pdevine/la
...
v0.15.5-rc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5f53fe7884 | ||
|
|
7ab4ca0e7f | ||
|
|
e36f389e82 | ||
|
|
c61023f554 |
22
.github/workflows/test-install.yaml
vendored
Normal file
22
.github/workflows/test-install.yaml
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
name: test-install
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'scripts/install.sh'
|
||||
- '.github/workflows/test-install.yaml'
|
||||
|
||||
jobs:
|
||||
test:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest]
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Run install script
|
||||
run: sh ./scripts/install.sh
|
||||
env:
|
||||
OLLAMA_NO_START: 1 # do not start app
|
||||
- name: Verify ollama is available
|
||||
run: ollama --version
|
||||
@@ -518,24 +518,26 @@ func mapStopReason(reason string, hasToolCalls bool) string {
|
||||
|
||||
// StreamConverter manages state for converting Ollama streaming responses to Anthropic format
|
||||
type StreamConverter struct {
|
||||
ID string
|
||||
Model string
|
||||
firstWrite bool
|
||||
contentIndex int
|
||||
inputTokens int
|
||||
outputTokens int
|
||||
thinkingStarted bool
|
||||
thinkingDone bool
|
||||
textStarted bool
|
||||
toolCallsSent map[string]bool
|
||||
ID string
|
||||
Model string
|
||||
firstWrite bool
|
||||
contentIndex int
|
||||
inputTokens int
|
||||
outputTokens int
|
||||
estimatedInputTokens int // Estimated tokens from request (used when actual metrics are 0)
|
||||
thinkingStarted bool
|
||||
thinkingDone bool
|
||||
textStarted bool
|
||||
toolCallsSent map[string]bool
|
||||
}
|
||||
|
||||
func NewStreamConverter(id, model string) *StreamConverter {
|
||||
func NewStreamConverter(id, model string, estimatedInputTokens int) *StreamConverter {
|
||||
return &StreamConverter{
|
||||
ID: id,
|
||||
Model: model,
|
||||
firstWrite: true,
|
||||
toolCallsSent: make(map[string]bool),
|
||||
ID: id,
|
||||
Model: model,
|
||||
firstWrite: true,
|
||||
estimatedInputTokens: estimatedInputTokens,
|
||||
toolCallsSent: make(map[string]bool),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -551,7 +553,11 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
||||
|
||||
if c.firstWrite {
|
||||
c.firstWrite = false
|
||||
// Use actual metrics if available, otherwise use estimate
|
||||
c.inputTokens = r.Metrics.PromptEvalCount
|
||||
if c.inputTokens == 0 && c.estimatedInputTokens > 0 {
|
||||
c.inputTokens = c.estimatedInputTokens
|
||||
}
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "message_start",
|
||||
@@ -779,3 +785,123 @@ func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
// CountTokensRequest represents an Anthropic count_tokens request
|
||||
type CountTokensRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []MessageParam `json:"messages"`
|
||||
System any `json:"system,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
// EstimateInputTokens estimates input tokens from a MessagesRequest (reuses CountTokensRequest logic)
|
||||
func EstimateInputTokens(req MessagesRequest) int {
|
||||
return estimateTokens(CountTokensRequest{
|
||||
Model: req.Model,
|
||||
Messages: req.Messages,
|
||||
System: req.System,
|
||||
Tools: req.Tools,
|
||||
Thinking: req.Thinking,
|
||||
})
|
||||
}
|
||||
|
||||
// CountTokensResponse represents an Anthropic count_tokens response
|
||||
type CountTokensResponse struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
}
|
||||
|
||||
// estimateTokens returns a rough estimate of tokens (len/4).
|
||||
// TODO: Replace with actual tokenization via Tokenize API for accuracy.
|
||||
// Current len/4 heuristic is a rough approximation (~4 chars/token average).
|
||||
func estimateTokens(req CountTokensRequest) int {
|
||||
var totalLen int
|
||||
|
||||
// Count system prompt
|
||||
if req.System != nil {
|
||||
totalLen += countAnyContent(req.System)
|
||||
}
|
||||
|
||||
// Count messages
|
||||
for _, msg := range req.Messages {
|
||||
// Count role (always present)
|
||||
totalLen += len(msg.Role)
|
||||
// Count content
|
||||
contentLen := countAnyContent(msg.Content)
|
||||
totalLen += contentLen
|
||||
}
|
||||
|
||||
for _, tool := range req.Tools {
|
||||
totalLen += len(tool.Name) + len(tool.Description) + len(tool.InputSchema)
|
||||
}
|
||||
|
||||
// Return len/4 as rough token estimate, minimum 1 if there's any content
|
||||
tokens := totalLen / 4
|
||||
if tokens == 0 && (len(req.Messages) > 0 || req.System != nil) {
|
||||
tokens = 1
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
func countAnyContent(content any) int {
|
||||
if content == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
switch c := content.(type) {
|
||||
case string:
|
||||
return len(c)
|
||||
case []any:
|
||||
total := 0
|
||||
for _, block := range c {
|
||||
total += countContentBlock(block)
|
||||
}
|
||||
return total
|
||||
default:
|
||||
if data, err := json.Marshal(content); err == nil {
|
||||
return len(data)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func countContentBlock(block any) int {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
if s, ok := block.(string); ok {
|
||||
return len(s)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
total := 0
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
|
||||
if text, ok := blockMap["text"].(string); ok {
|
||||
total += len(text)
|
||||
}
|
||||
|
||||
if thinking, ok := blockMap["thinking"].(string); ok {
|
||||
total += len(thinking)
|
||||
}
|
||||
|
||||
if blockType == "tool_use" {
|
||||
if data, err := json.Marshal(blockMap); err == nil {
|
||||
total += len(data)
|
||||
}
|
||||
}
|
||||
|
||||
if blockType == "tool_result" {
|
||||
if data, err := json.Marshal(blockMap); err == nil {
|
||||
total += len(data)
|
||||
}
|
||||
}
|
||||
|
||||
if source, ok := blockMap["source"].(map[string]any); ok {
|
||||
if data, ok := source["data"].(string); ok {
|
||||
total += len(data)
|
||||
}
|
||||
}
|
||||
|
||||
return total
|
||||
}
|
||||
|
||||
@@ -321,8 +321,6 @@ func TestFromMessagesRequest_WithThinking(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestFromMessagesRequest_ThinkingOnlyBlock verifies that messages containing only
|
||||
// a thinking block (no text, images, or tool calls) are preserved and not dropped.
|
||||
func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
@@ -605,7 +603,7 @@ func TestGenerateMessageID(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStreamConverter_Basic(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
|
||||
// First chunk
|
||||
resp1 := api.ChatResponse{
|
||||
@@ -678,7 +676,7 @@ func TestStreamConverter_Basic(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStreamConverter_WithToolCalls(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
@@ -731,7 +729,7 @@ func TestStreamConverter_WithToolCalls(t *testing.T) {
|
||||
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
|
||||
// Test that unmarshalable arguments (like channels) are handled gracefully
|
||||
// and don't cause a panic or corrupt stream
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
|
||||
// Create a channel which cannot be JSON marshaled
|
||||
unmarshalable := make(chan int)
|
||||
@@ -778,7 +776,7 @@ func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
|
||||
|
||||
func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
|
||||
// Test that valid tool calls still work when mixed with invalid ones
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
|
||||
unmarshalable := make(chan int)
|
||||
badArgs := api.NewToolCallFunctionArguments()
|
||||
@@ -842,10 +840,6 @@ func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestContentBlockJSON_EmptyFieldsPresent verifies that empty text and thinking fields
|
||||
// are serialized in JSON output. The Anthropic SDK requires these fields to be present
|
||||
// (even when empty) in content_block_start events to properly accumulate streaming deltas.
|
||||
// Without these fields, the SDK throws: "TypeError: unsupported operand type(s) for +=: 'NoneType' and 'str'"
|
||||
func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -899,11 +893,9 @@ func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamConverter_ContentBlockStartIncludesEmptyFields verifies that content_block_start
|
||||
// events include the required empty fields for SDK compatibility.
|
||||
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
||||
t.Run("text block start includes empty text", func(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
@@ -937,7 +929,7 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("thinking block start includes empty thinking", func(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
@@ -969,3 +961,105 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestEstimateTokens_SimpleMessage(t *testing.T) {
|
||||
req := CountTokensRequest{
|
||||
Model: "test-model",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello, world!"},
|
||||
},
|
||||
}
|
||||
|
||||
tokens := estimateTokens(req)
|
||||
|
||||
// "user" (4) + "Hello, world!" (13) = 17 chars / 4 = 4 tokens
|
||||
if tokens < 1 {
|
||||
t.Errorf("expected at least 1 token, got %d", tokens)
|
||||
}
|
||||
// Sanity check: shouldn't be wildly off
|
||||
if tokens > 10 {
|
||||
t.Errorf("expected fewer than 10 tokens for short message, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateTokens_WithSystemPrompt(t *testing.T) {
|
||||
req := CountTokensRequest{
|
||||
Model: "test-model",
|
||||
System: "You are a helpful assistant.",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
}
|
||||
|
||||
tokens := estimateTokens(req)
|
||||
|
||||
// System prompt adds to count
|
||||
if tokens < 5 {
|
||||
t.Errorf("expected at least 5 tokens with system prompt, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateTokens_WithTools(t *testing.T) {
|
||||
req := CountTokensRequest{
|
||||
Model: "test-model",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
},
|
||||
Tools: []Tool{
|
||||
{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather for a location",
|
||||
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}}}`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tokens := estimateTokens(req)
|
||||
|
||||
// Tools add significant content
|
||||
if tokens < 10 {
|
||||
t.Errorf("expected at least 10 tokens with tools, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateTokens_WithThinking(t *testing.T) {
|
||||
req := CountTokensRequest{
|
||||
Model: "test-model",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "thinking",
|
||||
"thinking": "Let me think about this carefully...",
|
||||
},
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "Here is my response.",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tokens := estimateTokens(req)
|
||||
|
||||
// Thinking content should be counted
|
||||
if tokens < 10 {
|
||||
t.Errorf("expected at least 10 tokens with thinking content, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateTokens_EmptyContent(t *testing.T) {
|
||||
req := CountTokensRequest{
|
||||
Model: "test-model",
|
||||
Messages: []MessageParam{},
|
||||
}
|
||||
|
||||
tokens := estimateTokens(req)
|
||||
|
||||
if tokens != 0 {
|
||||
t.Errorf("expected 0 tokens for empty content, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -466,3 +466,25 @@ func (c *Client) Whoami(ctx context.Context) (*UserResponse, error) {
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// AliasRequest is the request body for creating or updating a model alias.
|
||||
type AliasRequest struct {
|
||||
Alias string `json:"alias"`
|
||||
Target string `json:"target"`
|
||||
PrefixMatching bool `json:"prefix_matching,omitempty"`
|
||||
}
|
||||
|
||||
// SetAliasExperimental creates or updates a model alias via the experimental aliases API.
|
||||
func (c *Client) SetAliasExperimental(ctx context.Context, req *AliasRequest) error {
|
||||
return c.do(ctx, http.MethodPost, "/api/experimental/aliases", req, nil)
|
||||
}
|
||||
|
||||
// AliasDeleteRequest is the request body for deleting a model alias.
|
||||
type AliasDeleteRequest struct {
|
||||
Alias string `json:"alias"`
|
||||
}
|
||||
|
||||
// DeleteAliasExperimental deletes a model alias via the experimental aliases API.
|
||||
func (c *Client) DeleteAliasExperimental(ctx context.Context, req *AliasDeleteRequest) error {
|
||||
return c.do(ctx, http.MethodDelete, "/api/experimental/aliases", req, nil)
|
||||
}
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
//go:build !windows
|
||||
|
||||
package cmd
|
||||
|
||||
import "syscall"
|
||||
|
||||
// backgroundServerSysProcAttr returns SysProcAttr for running the server in the background on Unix.
|
||||
// Setpgid prevents the server from being killed when the parent process exits.
|
||||
func backgroundServerSysProcAttr() *syscall.SysProcAttr {
|
||||
return &syscall.SysProcAttr{
|
||||
Setpgid: true,
|
||||
}
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import "syscall"
|
||||
|
||||
// backgroundServerSysProcAttr returns SysProcAttr for running the server in the background on Windows.
|
||||
// CREATE_NO_WINDOW (0x08000000) prevents a console window from appearing.
|
||||
func backgroundServerSysProcAttr() *syscall.SysProcAttr {
|
||||
return &syscall.SysProcAttr{
|
||||
CreationFlags: 0x08000000,
|
||||
HideWindow: true,
|
||||
}
|
||||
}
|
||||
194
cmd/cmd.go
194
cmd/cmd.go
@@ -15,7 +15,6 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
@@ -38,7 +37,6 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/cmd/tui"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/parser"
|
||||
@@ -1765,7 +1763,7 @@ func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
|
||||
return err
|
||||
}
|
||||
if err := startApp(cmd.Context(), client); err != nil {
|
||||
return fmt.Errorf("ollama server not responding - %w", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -1806,190 +1804,6 @@ Environment Variables:
|
||||
cmd.SetUsageTemplate(cmd.UsageTemplate() + envUsage)
|
||||
}
|
||||
|
||||
// ensureServerRunning checks if the ollama server is running and starts it in the background if not.
|
||||
func ensureServerRunning(ctx context.Context) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if server is already running
|
||||
if err := client.Heartbeat(ctx); err == nil {
|
||||
return nil // server is already running
|
||||
}
|
||||
|
||||
// Server not running, start it in the background
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not find executable: %w", err)
|
||||
}
|
||||
|
||||
serverCmd := exec.CommandContext(ctx, exe, "serve")
|
||||
serverCmd.Env = os.Environ()
|
||||
serverCmd.SysProcAttr = backgroundServerSysProcAttr()
|
||||
if err := serverCmd.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start server: %w", err)
|
||||
}
|
||||
|
||||
// Wait for the server to be ready
|
||||
for {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
if err := client.Heartbeat(ctx); err == nil {
|
||||
return nil // server has started
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// runInteractiveTUI runs the main interactive TUI menu.
|
||||
func runInteractiveTUI(cmd *cobra.Command) {
|
||||
// Ensure the server is running before showing the TUI
|
||||
if err := ensureServerRunning(cmd.Context()); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error starting server: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// errSelectionCancelled is returned when user cancels model selection
|
||||
errSelectionCancelled := errors.New("cancelled")
|
||||
|
||||
// Selector adapters for tui
|
||||
singleSelector := func(title string, items []config.ModelItem) (string, error) {
|
||||
tuiItems := make([]tui.SelectItem, len(items))
|
||||
for i, item := range items {
|
||||
tuiItems[i] = tui.SelectItem{Name: item.Name, Description: item.Description}
|
||||
}
|
||||
result, err := tui.SelectSingle(title, tuiItems)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return "", errSelectionCancelled
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
multiSelector := func(title string, items []config.ModelItem, preChecked []string) ([]string, error) {
|
||||
tuiItems := make([]tui.SelectItem, len(items))
|
||||
for i, item := range items {
|
||||
tuiItems[i] = tui.SelectItem{Name: item.Name, Description: item.Description}
|
||||
}
|
||||
result, err := tui.SelectMultiple(title, tuiItems, preChecked)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return nil, errSelectionCancelled
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
for {
|
||||
result, err := tui.Run()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
runModel := func(modelName string) {
|
||||
_ = config.SetLastModel(modelName)
|
||||
opts := runOptions{
|
||||
Model: modelName,
|
||||
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||
Options: map[string]any{},
|
||||
ShowConnect: true,
|
||||
}
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error loading model: %v\n", err)
|
||||
return
|
||||
}
|
||||
if err := generateInteractive(cmd, opts); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error running model: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
launchIntegration := func(name string) bool {
|
||||
// If not configured or model no longer exists, prompt for model selection
|
||||
configuredModel := config.IntegrationModel(name)
|
||||
if configuredModel == "" || !config.ModelExists(cmd.Context(), configuredModel) {
|
||||
err := config.ConfigureIntegrationWithSelectors(cmd.Context(), name, singleSelector, multiSelector)
|
||||
if errors.Is(err, errSelectionCancelled) {
|
||||
return false // Return to main menu
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", name, err)
|
||||
return true
|
||||
}
|
||||
}
|
||||
if err := config.LaunchIntegration(name); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", name, err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
switch result.Selection {
|
||||
case tui.SelectionNone:
|
||||
// User quit
|
||||
return
|
||||
case tui.SelectionRunModel:
|
||||
_ = config.SetLastSelection("run")
|
||||
// Run last model directly if configured and still exists
|
||||
if modelName := config.LastModel(); modelName != "" && config.ModelExists(cmd.Context(), modelName) {
|
||||
runModel(modelName)
|
||||
} else {
|
||||
// No last model or model no longer exists, show picker
|
||||
modelName, err := config.SelectModelWithSelector(cmd.Context(), singleSelector)
|
||||
if errors.Is(err, errSelectionCancelled) {
|
||||
continue // Return to main menu
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error selecting model: %v\n", err)
|
||||
continue
|
||||
}
|
||||
runModel(modelName)
|
||||
}
|
||||
case tui.SelectionChangeRunModel:
|
||||
_ = config.SetLastSelection("run")
|
||||
// Use model from modal if selected, otherwise show picker
|
||||
modelName := result.Model
|
||||
if modelName == "" {
|
||||
var err error
|
||||
modelName, err = config.SelectModelWithSelector(cmd.Context(), singleSelector)
|
||||
if errors.Is(err, errSelectionCancelled) {
|
||||
continue // Return to main menu
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error selecting model: %v\n", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
runModel(modelName)
|
||||
case tui.SelectionIntegration:
|
||||
_ = config.SetLastSelection(result.Integration)
|
||||
if !launchIntegration(result.Integration) {
|
||||
continue // Return to main menu
|
||||
}
|
||||
case tui.SelectionChangeIntegration:
|
||||
_ = config.SetLastSelection(result.Integration)
|
||||
// Use model from modal if selected, otherwise show picker
|
||||
if result.Model != "" {
|
||||
// Model already selected from modal - save and launch
|
||||
if err := config.SaveIntegrationModel(result.Integration, result.Model); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error saving config: %v\n", err)
|
||||
continue
|
||||
}
|
||||
if err := config.LaunchIntegrationWithModel(result.Integration, result.Model); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
|
||||
}
|
||||
} else {
|
||||
err := config.ConfigureIntegrationWithSelectors(cmd.Context(), result.Integration, singleSelector, multiSelector)
|
||||
if errors.Is(err, errSelectionCancelled) {
|
||||
continue // Return to main menu
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", result.Integration, err)
|
||||
continue
|
||||
}
|
||||
if err := config.LaunchIntegration(result.Integration); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func NewCLI() *cobra.Command {
|
||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
cobra.EnableCommandSorting = false
|
||||
@@ -2012,13 +1826,11 @@ func NewCLI() *cobra.Command {
|
||||
return
|
||||
}
|
||||
|
||||
runInteractiveTUI(cmd)
|
||||
cmd.Print(cmd.UsageString())
|
||||
},
|
||||
}
|
||||
|
||||
rootCmd.Flags().BoolP("version", "v", false, "Show version information")
|
||||
rootCmd.Flags().Bool("verbose", false, "Show timings for response")
|
||||
rootCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
|
||||
|
||||
createCmd := &cobra.Command{
|
||||
Use: "create MODEL",
|
||||
@@ -2232,7 +2044,7 @@ func NewCLI() *cobra.Command {
|
||||
copyCmd,
|
||||
deleteCmd,
|
||||
runnerCmd,
|
||||
config.LaunchCmd(checkServerHeartbeat, runInteractiveTUI),
|
||||
config.LaunchCmd(checkServerHeartbeat),
|
||||
)
|
||||
|
||||
return rootCmd
|
||||
|
||||
@@ -1,18 +1,23 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Claude implements Runner for Claude Code integration
|
||||
// Claude implements Runner and AliasConfigurer for Claude Code integration
|
||||
type Claude struct{}
|
||||
|
||||
// Compile-time check that Claude implements AliasConfigurer
|
||||
var _ AliasConfigurer = (*Claude)(nil)
|
||||
|
||||
func (c *Claude) String() string { return "Claude Code" }
|
||||
|
||||
func (c *Claude) args(model string, extra []string) []string {
|
||||
@@ -60,3 +65,104 @@ func (c *Claude) Run(model string, args []string) error {
|
||||
)
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// ConfigureAliases sets up model aliases for Claude Code.
|
||||
// model: the model to use (if empty, user will be prompted to select)
|
||||
// aliases: existing alias configuration to preserve/update
|
||||
// Cloud-only: subagent routing (fast model) is gated to cloud models only until
|
||||
// there is a better strategy for prompt caching on local models.
|
||||
func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAliases map[string]string, force bool) (map[string]string, bool, error) {
|
||||
aliases := make(map[string]string)
|
||||
for k, v := range existingAliases {
|
||||
aliases[k] = v
|
||||
}
|
||||
|
||||
if model != "" {
|
||||
aliases["primary"] = model
|
||||
}
|
||||
|
||||
if !force && aliases["primary"] != "" {
|
||||
client, _ := api.ClientFromEnvironment()
|
||||
if isCloudModel(ctx, client, aliases["primary"]) {
|
||||
if isCloudModel(ctx, client, aliases["fast"]) {
|
||||
return aliases, false, nil
|
||||
}
|
||||
} else {
|
||||
delete(aliases, "fast")
|
||||
return aliases, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
items, existingModels, cloudModels, client, err := listModels(ctx)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n\n", ansiBold, ansiReset)
|
||||
|
||||
if aliases["primary"] == "" || force {
|
||||
primary, err := selectPrompt("Select model:", items)
|
||||
fmt.Fprintf(os.Stderr, "\033[3A\033[J")
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if err := pullIfNeeded(ctx, client, existingModels, primary); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if err := ensureAuth(ctx, client, cloudModels, []string{primary}); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
aliases["primary"] = primary
|
||||
}
|
||||
|
||||
if isCloudModel(ctx, client, aliases["primary"]) {
|
||||
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
}
|
||||
} else {
|
||||
delete(aliases, "fast")
|
||||
}
|
||||
|
||||
return aliases, true, nil
|
||||
}
|
||||
|
||||
// SetAliases syncs the configured aliases to the Ollama server using prefix matching.
|
||||
// Cloud-only: for local models (fast is empty), we delete any existing aliases to
|
||||
// prevent stale routing to a previous cloud model.
|
||||
func (c *Claude) SetAliases(ctx context.Context, aliases map[string]string) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
prefixes := []string{"claude-sonnet-", "claude-haiku-"}
|
||||
|
||||
if aliases["fast"] == "" {
|
||||
for _, prefix := range prefixes {
|
||||
_ = client.DeleteAliasExperimental(ctx, &api.AliasDeleteRequest{Alias: prefix})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
prefixAliases := map[string]string{
|
||||
"claude-sonnet-": aliases["primary"],
|
||||
"claude-haiku-": aliases["fast"],
|
||||
}
|
||||
|
||||
var errs []string
|
||||
for prefix, target := range prefixAliases {
|
||||
req := &api.AliasRequest{
|
||||
Alias: prefix,
|
||||
Target: target,
|
||||
PrefixMatching: true,
|
||||
}
|
||||
if err := client.SetAliasExperimental(ctx, req); err != nil {
|
||||
errs = append(errs, prefix)
|
||||
}
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("failed to set aliases: %v", errs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -11,18 +10,15 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type integration struct {
|
||||
Models []string `json:"models"`
|
||||
Models []string `json:"models"`
|
||||
Aliases map[string]string `json:"aliases,omitempty"`
|
||||
}
|
||||
|
||||
type config struct {
|
||||
Integrations map[string]*integration `json:"integrations"`
|
||||
LastModel string `json:"last_model,omitempty"`
|
||||
LastSelection string `json:"last_selection,omitempty"` // "run" or integration name
|
||||
Integrations map[string]*integration `json:"integrations"`
|
||||
}
|
||||
|
||||
func configPath() (string, error) {
|
||||
@@ -138,81 +134,21 @@ func saveIntegration(appName string, models []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg.Integrations[strings.ToLower(appName)] = &integration{
|
||||
Models: models,
|
||||
key := strings.ToLower(appName)
|
||||
existing := cfg.Integrations[key]
|
||||
var aliases map[string]string
|
||||
if existing != nil && existing.Aliases != nil {
|
||||
aliases = existing.Aliases
|
||||
}
|
||||
|
||||
cfg.Integrations[key] = &integration{
|
||||
Models: models,
|
||||
Aliases: aliases,
|
||||
}
|
||||
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
// IntegrationModel returns the first configured model for an integration, or empty string if not configured.
|
||||
func IntegrationModel(appName string) string {
|
||||
ic, err := loadIntegration(appName)
|
||||
if err != nil || len(ic.Models) == 0 {
|
||||
return ""
|
||||
}
|
||||
return ic.Models[0]
|
||||
}
|
||||
|
||||
// LastModel returns the last model that was run, or empty string if none.
|
||||
func LastModel() string {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return cfg.LastModel
|
||||
}
|
||||
|
||||
// SetLastModel saves the last model that was run.
|
||||
func SetLastModel(model string) error {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cfg.LastModel = model
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
// LastSelection returns the last menu selection ("run" or integration name), or empty string if none.
|
||||
func LastSelection() string {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return cfg.LastSelection
|
||||
}
|
||||
|
||||
// SetLastSelection saves the last menu selection ("run" or integration name).
|
||||
func SetLastSelection(selection string) error {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cfg.LastSelection = selection
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
// ModelExists checks if a model exists on the Ollama server.
|
||||
func ModelExists(ctx context.Context, name string) bool {
|
||||
if name == "" {
|
||||
return false
|
||||
}
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
models, err := client.List(ctx)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for _, m := range models.Models {
|
||||
if m.Name == name || strings.HasPrefix(m.Name, name+":") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func loadIntegration(appName string) (*integration, error) {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
@@ -227,6 +163,29 @@ func loadIntegration(appName string) (*integration, error) {
|
||||
return ic, nil
|
||||
}
|
||||
|
||||
func saveAliases(appName string, aliases map[string]string) error {
|
||||
if appName == "" {
|
||||
return errors.New("app name cannot be empty")
|
||||
}
|
||||
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
key := strings.ToLower(appName)
|
||||
existing := cfg.Integrations[key]
|
||||
if existing == nil {
|
||||
existing = &integration{}
|
||||
}
|
||||
|
||||
// Replace aliases entirely (not merge) so deletions are persisted
|
||||
existing.Aliases = aliases
|
||||
|
||||
cfg.Integrations[key] = existing
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
func listIntegrations() ([]integration, error) {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
|
||||
677
cmd/config/config_cloud_test.go
Normal file
677
cmd/config/config_cloud_test.go
Normal file
@@ -0,0 +1,677 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSetAliases_CloudModel(t *testing.T) {
|
||||
// Test the SetAliases logic by checking the alias map behavior
|
||||
aliases := map[string]string{
|
||||
"primary": "kimi-k2.5:cloud",
|
||||
"fast": "kimi-k2.5:cloud",
|
||||
}
|
||||
|
||||
// Verify fast is set (cloud model behavior)
|
||||
if aliases["fast"] == "" {
|
||||
t.Error("cloud model should have fast alias set")
|
||||
}
|
||||
if aliases["fast"] != aliases["primary"] {
|
||||
t.Errorf("fast should equal primary for auto-set, got fast=%q primary=%q", aliases["fast"], aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetAliases_LocalModel(t *testing.T) {
|
||||
aliases := map[string]string{
|
||||
"primary": "llama3.2:latest",
|
||||
}
|
||||
// Simulate local model behavior: fast should be empty
|
||||
delete(aliases, "fast")
|
||||
|
||||
if aliases["fast"] != "" {
|
||||
t.Error("local model should have empty fast alias")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAliases_ReplacesNotMerges(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// First save with both primary and fast
|
||||
initial := map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}
|
||||
if err := saveAliases("claude", initial); err != nil {
|
||||
t.Fatalf("failed to save initial aliases: %v", err)
|
||||
}
|
||||
|
||||
// Verify both are saved
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["fast"] != "cloud-model" {
|
||||
t.Errorf("expected fast=cloud-model, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
|
||||
// Now save without fast (simulating switch to local model)
|
||||
updated := map[string]string{
|
||||
"primary": "local-model",
|
||||
// fast intentionally missing
|
||||
}
|
||||
if err := saveAliases("claude", updated); err != nil {
|
||||
t.Fatalf("failed to save updated aliases: %v", err)
|
||||
}
|
||||
|
||||
// Verify fast is GONE (not merged/preserved)
|
||||
loaded, err = loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load after update: %v", err)
|
||||
}
|
||||
if loaded.Aliases["fast"] != "" {
|
||||
t.Errorf("fast should be removed after saving without it, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
if loaded.Aliases["primary"] != "local-model" {
|
||||
t.Errorf("primary should be updated to local-model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAliases_PreservesModels(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// First save integration with models
|
||||
if err := saveIntegration("claude", []string{"model1", "model2"}); err != nil {
|
||||
t.Fatalf("failed to save integration: %v", err)
|
||||
}
|
||||
|
||||
// Then update aliases
|
||||
aliases := map[string]string{"primary": "new-model"}
|
||||
if err := saveAliases("claude", aliases); err != nil {
|
||||
t.Fatalf("failed to save aliases: %v", err)
|
||||
}
|
||||
|
||||
// Verify models are preserved
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if len(loaded.Models) != 2 || loaded.Models[0] != "model1" {
|
||||
t.Errorf("models should be preserved, got %v", loaded.Models)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveAliases_EmptyMap clears all aliases
|
||||
func TestSaveAliases_EmptyMap(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save with aliases
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Save empty map
|
||||
if err := saveAliases("claude", map[string]string{}); err != nil {
|
||||
t.Fatalf("failed to save empty: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if len(loaded.Aliases) != 0 {
|
||||
t.Errorf("aliases should be empty, got %v", loaded.Aliases)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveAliases_NilMap handles nil gracefully
|
||||
func TestSaveAliases_NilMap(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save with aliases first
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Save nil map - should clear aliases
|
||||
if err := saveAliases("claude", nil); err != nil {
|
||||
t.Fatalf("failed to save nil: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if len(loaded.Aliases) > 0 {
|
||||
t.Errorf("aliases should be nil or empty, got %v", loaded.Aliases)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveAliases_EmptyAppName returns error
|
||||
func TestSaveAliases_EmptyAppName(t *testing.T) {
|
||||
err := saveAliases("", map[string]string{"primary": "model"})
|
||||
if err == nil {
|
||||
t.Error("expected error for empty app name")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAliases_CaseInsensitive(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
if err := saveAliases("Claude", map[string]string{"primary": "model1"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Load with different case
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != "model1" {
|
||||
t.Errorf("expected primary=model1, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
|
||||
// Update with different case
|
||||
if err := saveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil {
|
||||
t.Fatalf("failed to update: %v", err)
|
||||
}
|
||||
|
||||
loaded, err = loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load after update: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != "model2" {
|
||||
t.Errorf("expected primary=model2, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveAliases_CreatesIntegration creates integration if it doesn't exist
|
||||
func TestSaveAliases_CreatesIntegration(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save aliases for non-existent integration
|
||||
if err := saveAliases("newintegration", map[string]string{"primary": "model"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("newintegration")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != "model" {
|
||||
t.Errorf("expected primary=model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigureAliases_AliasMap(t *testing.T) {
|
||||
t.Run("cloud model auto-sets fast to primary", func(t *testing.T) {
|
||||
aliases := make(map[string]string)
|
||||
aliases["primary"] = "cloud-model"
|
||||
|
||||
// Simulate cloud model behavior
|
||||
isCloud := true
|
||||
if isCloud {
|
||||
if aliases["fast"] == "" {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
}
|
||||
}
|
||||
|
||||
if aliases["fast"] != "cloud-model" {
|
||||
t.Errorf("expected fast=cloud-model, got %q", aliases["fast"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cloud model preserves custom fast", func(t *testing.T) {
|
||||
aliases := map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "custom-fast-model",
|
||||
}
|
||||
|
||||
// Simulate cloud model behavior - should preserve existing fast
|
||||
isCloud := true
|
||||
if isCloud {
|
||||
if aliases["fast"] == "" {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
}
|
||||
}
|
||||
|
||||
if aliases["fast"] != "custom-fast-model" {
|
||||
t.Errorf("expected fast=custom-fast-model (preserved), got %q", aliases["fast"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("local model clears fast", func(t *testing.T) {
|
||||
aliases := map[string]string{
|
||||
"primary": "local-model",
|
||||
"fast": "should-be-cleared",
|
||||
}
|
||||
|
||||
// Simulate local model behavior
|
||||
isCloud := false
|
||||
if !isCloud {
|
||||
delete(aliases, "fast")
|
||||
}
|
||||
|
||||
if aliases["fast"] != "" {
|
||||
t.Errorf("expected fast to be cleared, got %q", aliases["fast"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("switching cloud to local clears fast", func(t *testing.T) {
|
||||
// Start with cloud config
|
||||
aliases := map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}
|
||||
|
||||
// Switch to local
|
||||
aliases["primary"] = "local-model"
|
||||
isCloud := false
|
||||
if !isCloud {
|
||||
delete(aliases, "fast")
|
||||
}
|
||||
|
||||
if aliases["fast"] != "" {
|
||||
t.Errorf("fast should be cleared when switching to local, got %q", aliases["fast"])
|
||||
}
|
||||
if aliases["primary"] != "local-model" {
|
||||
t.Errorf("primary should be updated, got %q", aliases["primary"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("switching local to cloud sets fast", func(t *testing.T) {
|
||||
// Start with local config (no fast)
|
||||
aliases := map[string]string{
|
||||
"primary": "local-model",
|
||||
}
|
||||
|
||||
// Switch to cloud
|
||||
aliases["primary"] = "cloud-model"
|
||||
isCloud := true
|
||||
if isCloud {
|
||||
if aliases["fast"] == "" {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
}
|
||||
}
|
||||
|
||||
if aliases["fast"] != "cloud-model" {
|
||||
t.Errorf("fast should be set when switching to cloud, got %q", aliases["fast"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetAliases_PrefixMapping(t *testing.T) {
|
||||
// This tests the expected mapping without needing a real client
|
||||
aliases := map[string]string{
|
||||
"primary": "my-cloud-model",
|
||||
"fast": "my-fast-model",
|
||||
}
|
||||
|
||||
expectedMappings := map[string]string{
|
||||
"claude-sonnet-": aliases["primary"],
|
||||
"claude-haiku-": aliases["fast"],
|
||||
}
|
||||
|
||||
if expectedMappings["claude-sonnet-"] != "my-cloud-model" {
|
||||
t.Errorf("claude-sonnet- should map to primary")
|
||||
}
|
||||
if expectedMappings["claude-haiku-"] != "my-fast-model" {
|
||||
t.Errorf("claude-haiku- should map to fast")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetAliases_LocalDeletesPrefixes(t *testing.T) {
|
||||
aliases := map[string]string{
|
||||
"primary": "local-model",
|
||||
// fast is empty/missing - indicates local model
|
||||
}
|
||||
|
||||
prefixesToDelete := []string{"claude-sonnet-", "claude-haiku-"}
|
||||
|
||||
// Verify the logic: when fast is empty, we should delete
|
||||
if aliases["fast"] != "" {
|
||||
t.Error("fast should be empty for local model")
|
||||
}
|
||||
|
||||
// Verify we have the right prefixes to delete
|
||||
if len(prefixesToDelete) != 2 {
|
||||
t.Errorf("expected 2 prefixes to delete, got %d", len(prefixesToDelete))
|
||||
}
|
||||
}
|
||||
|
||||
// TestAtomicUpdate_ServerFailsConfigNotSaved simulates atomic update behavior
|
||||
func TestAtomicUpdate_ServerFailsConfigNotSaved(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Simulate: server fails, config should NOT be saved
|
||||
serverErr := errors.New("server unavailable")
|
||||
|
||||
if serverErr == nil {
|
||||
t.Error("config should NOT be saved when server fails")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAtomicUpdate_ServerSucceedsConfigSaved simulates successful atomic update
|
||||
func TestAtomicUpdate_ServerSucceedsConfigSaved(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Simulate: server succeeds, config should be saved
|
||||
var serverErr error
|
||||
if serverErr != nil {
|
||||
t.Fatal("server should succeed")
|
||||
}
|
||||
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||
t.Fatalf("saveAliases failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify it was actually saved
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != "model" {
|
||||
t.Errorf("expected primary=model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigFile_PreservesUnknownFields(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Write config with extra fields
|
||||
configPath := filepath.Join(tmpDir, ".ollama", "config.json")
|
||||
os.MkdirAll(filepath.Dir(configPath), 0o755)
|
||||
|
||||
// Note: Our config struct only has Integrations, so top-level unknown fields
|
||||
// won't be preserved by our current implementation. This test documents that.
|
||||
initialConfig := `{
|
||||
"integrations": {
|
||||
"claude": {
|
||||
"models": ["model1"],
|
||||
"aliases": {"primary": "model1"},
|
||||
"unknownField": "should be lost"
|
||||
}
|
||||
},
|
||||
"topLevelUnknown": "will be lost"
|
||||
}`
|
||||
os.WriteFile(configPath, []byte(initialConfig), 0o644)
|
||||
|
||||
// Update aliases
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model2"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Read raw file to check
|
||||
data, _ := os.ReadFile(configPath)
|
||||
content := string(data)
|
||||
|
||||
// models should be preserved
|
||||
if !contains(content, "model1") {
|
||||
t.Error("models should be preserved")
|
||||
}
|
||||
|
||||
// primary should be updated
|
||||
if !contains(content, "model2") {
|
||||
t.Error("primary should be updated to model2")
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr))
|
||||
}
|
||||
|
||||
func containsHelper(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func TestClaudeImplementsAliasConfigurer(t *testing.T) {
|
||||
c := &Claude{}
|
||||
var _ AliasConfigurer = c // Compile-time check
|
||||
}
|
||||
|
||||
func TestModelNameEdgeCases(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
model string
|
||||
}{
|
||||
{"simple", "llama3.2"},
|
||||
{"with tag", "llama3.2:latest"},
|
||||
{"with cloud tag", "kimi-k2.5:cloud"},
|
||||
{"with namespace", "library/llama3.2"},
|
||||
{"with dots", "glm-4.7-flash"},
|
||||
{"with numbers", "qwen3:8b"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
aliases := map[string]string{"primary": tc.model}
|
||||
if err := saveAliases("claude", aliases); err != nil {
|
||||
t.Fatalf("failed to save model %q: %v", tc.model, err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != tc.model {
|
||||
t.Errorf("expected primary=%q, got %q", tc.model, loaded.Aliases["primary"])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwitchingScenarios(t *testing.T) {
|
||||
t.Run("cloud to local removes fast", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial cloud config
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Switch to local (no fast)
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "local-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
if loaded.Aliases["fast"] != "" {
|
||||
t.Errorf("fast should be removed, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
if loaded.Aliases["primary"] != "local-model" {
|
||||
t.Errorf("primary should be local-model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("local to cloud adds fast", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial local config
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "local-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Switch to cloud (with fast)
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
if loaded.Aliases["fast"] != "cloud-model" {
|
||||
t.Errorf("fast should be cloud-model, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cloud to different cloud updates both", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial cloud config
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model-1",
|
||||
"fast": "cloud-model-1",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Switch to different cloud
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model-2",
|
||||
"fast": "cloud-model-2",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
if loaded.Aliases["primary"] != "cloud-model-2" {
|
||||
t.Errorf("primary should be cloud-model-2, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
if loaded.Aliases["fast"] != "cloud-model-2" {
|
||||
t.Errorf("fast should be cloud-model-2, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestToolCapabilityFiltering(t *testing.T) {
|
||||
t.Run("all models checked for tool capability", func(t *testing.T) {
|
||||
// Both cloud and local models are checked for tool capability via Show API
|
||||
// Only models with "tools" in capabilities are included
|
||||
m := modelInfo{Name: "tool-model", Remote: false, ToolCapable: true}
|
||||
if !m.ToolCapable {
|
||||
t.Error("tool capable model should be marked as such")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("modelInfo includes ToolCapable field", func(t *testing.T) {
|
||||
m := modelInfo{Name: "test", Remote: true, ToolCapable: true}
|
||||
if !m.ToolCapable {
|
||||
t.Error("ToolCapable field should be accessible")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsCloudModel_RequiresClient(t *testing.T) {
|
||||
t.Run("nil client always returns false", func(t *testing.T) {
|
||||
// isCloudModel now only uses Show API, no suffix detection
|
||||
if isCloudModel(context.Background(), nil, "model:cloud") {
|
||||
t.Error("nil client should return false regardless of suffix")
|
||||
}
|
||||
if isCloudModel(context.Background(), nil, "local-model") {
|
||||
t.Error("nil client should return false")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||
t.Run("saveAliases followed by saveIntegration keeps them in sync", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save aliases with one model
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Save integration with same model (this is the pattern we use)
|
||||
if err := saveIntegration("claude", []string{"model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
if loaded.Aliases["primary"] != loaded.Models[0] {
|
||||
t.Errorf("aliases.primary (%q) != models[0] (%q)", loaded.Aliases["primary"], loaded.Models[0])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("out of sync config is detectable", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Simulate out-of-sync state (like manual edit or bug)
|
||||
if err := saveIntegration("claude", []string{"old-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "new-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
|
||||
// They should be different (this is the bug state)
|
||||
if loaded.Models[0] == loaded.Aliases["primary"] {
|
||||
t.Error("expected out-of-sync state for this test")
|
||||
}
|
||||
|
||||
// The fix: when updating aliases, also update models
|
||||
if err := saveIntegration("claude", []string{loaded.Aliases["primary"]}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ = loadIntegration("claude")
|
||||
if loaded.Models[0] != loaded.Aliases["primary"] {
|
||||
t.Errorf("after fix: models[0] (%q) should equal aliases.primary (%q)",
|
||||
loaded.Models[0], loaded.Aliases["primary"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("updating primary alias updates models too", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial state
|
||||
if err := saveIntegration("claude", []string{"initial-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Update aliases AND models together
|
||||
newAliases := map[string]string{"primary": "updated-model"}
|
||||
if err := saveAliases("claude", newAliases); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveIntegration("claude", []string{newAliases["primary"]}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
if loaded.Models[0] != "updated-model" {
|
||||
t.Errorf("models[0] should be updated-model, got %q", loaded.Models[0])
|
||||
}
|
||||
if loaded.Aliases["primary"] != "updated-model" {
|
||||
t.Errorf("aliases.primary should be updated-model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -46,6 +46,53 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("save and load aliases", func(t *testing.T) {
|
||||
models := []string{"llama3.2"}
|
||||
if err := saveIntegration("claude", models); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
aliases := map[string]string{
|
||||
"primary": "llama3.2:70b",
|
||||
"fast": "llama3.2:8b",
|
||||
}
|
||||
if err := saveAliases("claude", aliases); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if config.Aliases == nil {
|
||||
t.Fatal("expected aliases to be saved")
|
||||
}
|
||||
for k, v := range aliases {
|
||||
if config.Aliases[k] != v {
|
||||
t.Errorf("alias %s: expected %s, got %s", k, v, config.Aliases[k])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("saveIntegration preserves aliases", func(t *testing.T) {
|
||||
if err := saveIntegration("claude", []string{"model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := saveIntegration("claude", []string{"model-b"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
config, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if config.Aliases["primary"] != "model-a" {
|
||||
t.Errorf("expected aliases to be preserved, got %v", config.Aliases)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("defaultModel returns first model", func(t *testing.T) {
|
||||
saveIntegration("codex", []string{"model-a", "model-b"})
|
||||
|
||||
|
||||
@@ -4,12 +4,9 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"maps"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
@@ -42,6 +39,15 @@ type Editor interface {
|
||||
Models() []string
|
||||
}
|
||||
|
||||
// AliasConfigurer can configure model aliases (e.g., for subagent routing).
|
||||
// Integrations like Claude and Codex use this to route model requests to local models.
|
||||
type AliasConfigurer interface {
|
||||
// ConfigureAliases prompts the user to configure aliases and returns the updated map.
|
||||
ConfigureAliases(ctx context.Context, primaryModel string, existing map[string]string, force bool) (map[string]string, bool, error)
|
||||
// SetAliases syncs the configured aliases to the server
|
||||
SetAliases(ctx context.Context, aliases map[string]string) error
|
||||
}
|
||||
|
||||
// integrations is the registry of available integrations.
|
||||
var integrations = map[string]Runner{
|
||||
"claude": &Claude{},
|
||||
@@ -55,7 +61,7 @@ var integrations = map[string]Runner{
|
||||
|
||||
// recommendedModels are shown when the user has no models or as suggestions.
|
||||
// Order matters: local models first, then cloud models.
|
||||
var recommendedModels = []ModelItem{
|
||||
var recommendedModels = []selectItem{
|
||||
{Name: "glm-4.7-flash", Description: "Recommended (requires ~25GB VRAM)"},
|
||||
{Name: "qwen3:8b", Description: "Recommended (requires ~11GB VRAM)"},
|
||||
{Name: "glm-4.7:cloud", Description: "Recommended"},
|
||||
@@ -68,256 +74,6 @@ var integrationAliases = map[string]bool{
|
||||
"moltbot": true,
|
||||
}
|
||||
|
||||
// integrationInstallURLs maps integration names to their install script URLs.
|
||||
var integrationInstallURLs = map[string]string{
|
||||
"claude": "https://claude.ai/install.sh",
|
||||
"openclaw": "https://openclaw.ai/install.sh",
|
||||
"droid": "https://app.factory.ai/cli",
|
||||
"opencode": "https://opencode.ai/install",
|
||||
}
|
||||
|
||||
// CanInstallIntegration returns true if we have an install script for this integration.
|
||||
func CanInstallIntegration(name string) bool {
|
||||
_, ok := integrationInstallURLs[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
// IsIntegrationInstalled checks if an integration binary is installed.
|
||||
func IsIntegrationInstalled(name string) bool {
|
||||
switch name {
|
||||
case "claude":
|
||||
c := &Claude{}
|
||||
_, err := c.findPath()
|
||||
return err == nil
|
||||
case "openclaw":
|
||||
if _, err := exec.LookPath("openclaw"); err == nil {
|
||||
return true
|
||||
}
|
||||
if _, err := exec.LookPath("clawdbot"); err == nil {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
case "codex":
|
||||
_, err := exec.LookPath("codex")
|
||||
return err == nil
|
||||
case "droid":
|
||||
_, err := exec.LookPath("droid")
|
||||
return err == nil
|
||||
case "opencode":
|
||||
_, err := exec.LookPath("opencode")
|
||||
return err == nil
|
||||
default:
|
||||
return true // Assume installed for unknown integrations
|
||||
}
|
||||
}
|
||||
|
||||
// InstallIntegration downloads and runs the install script for an integration.
|
||||
func InstallIntegration(name string) error {
|
||||
url, ok := integrationInstallURLs[name]
|
||||
if !ok {
|
||||
return fmt.Errorf("no install script available for %s", name)
|
||||
}
|
||||
|
||||
// Download the install script
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download install script: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("failed to download install script: HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
script, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read install script: %w", err)
|
||||
}
|
||||
|
||||
// Create a temporary file for the script
|
||||
tmpDir := os.TempDir()
|
||||
scriptPath := filepath.Join(tmpDir, fmt.Sprintf("install-%s.sh", name))
|
||||
if err := os.WriteFile(scriptPath, script, 0o700); err != nil {
|
||||
return fmt.Errorf("failed to write install script: %w", err)
|
||||
}
|
||||
defer os.Remove(scriptPath)
|
||||
|
||||
// Execute the script with bash
|
||||
cmd := exec.Command("bash", scriptPath)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return fmt.Errorf("install script failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SelectModel lets the user select a model to run.
|
||||
// ModelItem represents a model for selection.
|
||||
type ModelItem struct {
|
||||
Name string
|
||||
Description string
|
||||
}
|
||||
|
||||
// SingleSelector is a function type for single item selection.
|
||||
type SingleSelector func(title string, items []ModelItem) (string, error)
|
||||
|
||||
// MultiSelector is a function type for multi item selection.
|
||||
type MultiSelector func(title string, items []ModelItem, preChecked []string) ([]string, error)
|
||||
|
||||
// SelectModelWithSelector prompts the user to select a model using the provided selector.
|
||||
func SelectModelWithSelector(ctx context.Context, selector SingleSelector) (string, error) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
models, err := client.List(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var existing []modelInfo
|
||||
for _, m := range models.Models {
|
||||
existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""})
|
||||
}
|
||||
|
||||
lastModel := LastModel()
|
||||
var preChecked []string
|
||||
if lastModel != "" {
|
||||
preChecked = []string{lastModel}
|
||||
}
|
||||
|
||||
items, _, existingModels, cloudModels := buildModelList(existing, preChecked, lastModel)
|
||||
|
||||
if len(items) == 0 {
|
||||
return "", fmt.Errorf("no models available, run 'ollama pull <model>' first")
|
||||
}
|
||||
|
||||
// Sort with last model first, then existing models, then recommendations
|
||||
slices.SortStableFunc(items, func(a, b ModelItem) int {
|
||||
aIsLast := a.Name == lastModel
|
||||
bIsLast := b.Name == lastModel
|
||||
if aIsLast != bIsLast {
|
||||
if aIsLast {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
aExists := existingModels[a.Name]
|
||||
bExists := existingModels[b.Name]
|
||||
if aExists != bExists {
|
||||
if aExists {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
|
||||
})
|
||||
|
||||
selected, err := selector("Select model to run:", items)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// If the selected model isn't installed, pull it first
|
||||
if !existingModels[selected] {
|
||||
msg := fmt.Sprintf("Download %s?", selected)
|
||||
if ok, err := confirmPrompt(msg); err != nil {
|
||||
return "", err
|
||||
} else if !ok {
|
||||
return "", errCancelled
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
if err := pullModel(ctx, client, selected); err != nil {
|
||||
return "", fmt.Errorf("failed to pull %s: %w", selected, err)
|
||||
}
|
||||
}
|
||||
|
||||
// If it's a cloud model, ensure user is signed in
|
||||
if cloudModels[selected] {
|
||||
user, err := client.Whoami(ctx)
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
var aErr api.AuthorizationError
|
||||
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
|
||||
return "", err
|
||||
}
|
||||
|
||||
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", selected))
|
||||
if err != nil || !yes {
|
||||
return "", fmt.Errorf("%s requires sign in", selected)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
|
||||
|
||||
// Auto-open browser (best effort, fail silently)
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
_ = exec.Command("open", aErr.SigninURL).Start()
|
||||
case "linux":
|
||||
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
|
||||
case "windows":
|
||||
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
|
||||
}
|
||||
|
||||
spinnerFrames := []string{"|", "/", "-", "\\"}
|
||||
frame := 0
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
|
||||
|
||||
ticker := time.NewTicker(200 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
return "", ctx.Err()
|
||||
case <-ticker.C:
|
||||
frame++
|
||||
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
|
||||
|
||||
// poll every 10th frame (~2 seconds)
|
||||
if frame%10 == 0 {
|
||||
u, err := client.Whoami(ctx)
|
||||
if err == nil && u != nil && u.Name != "" {
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
|
||||
return selected, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
func SelectModel(ctx context.Context) (string, error) {
|
||||
return SelectModelWithSelector(ctx, defaultSingleSelector)
|
||||
}
|
||||
|
||||
func defaultSingleSelector(title string, items []ModelItem) (string, error) {
|
||||
selectItems := make([]selectItem, len(items))
|
||||
for i, item := range items {
|
||||
selectItems[i] = selectItem(item)
|
||||
}
|
||||
return selectPrompt(title, selectItems)
|
||||
}
|
||||
|
||||
func defaultMultiSelector(title string, items []ModelItem, preChecked []string) ([]string, error) {
|
||||
selectItems := make([]selectItem, len(items))
|
||||
for i, item := range items {
|
||||
selectItems[i] = selectItem(item)
|
||||
}
|
||||
return multiSelectPrompt(title, selectItems, preChecked)
|
||||
}
|
||||
|
||||
func selectIntegration() (string, error) {
|
||||
if len(integrations) == 0 {
|
||||
return "", fmt.Errorf("no integrations available")
|
||||
@@ -340,8 +96,8 @@ func selectIntegration() (string, error) {
|
||||
return selectPrompt("Select integration:", items)
|
||||
}
|
||||
|
||||
// selectModelsWithSelectors lets the user select models for an integration using provided selectors.
|
||||
func selectModelsWithSelectors(ctx context.Context, name, current string, single SingleSelector, multi MultiSelector) ([]string, error) {
|
||||
// selectModels lets the user select models for an integration
|
||||
func selectModels(ctx context.Context, name, current string) ([]string, error) {
|
||||
r, ok := integrations[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown integration: %s", name)
|
||||
@@ -377,12 +133,16 @@ func selectModelsWithSelectors(ctx context.Context, name, current string, single
|
||||
|
||||
var selected []string
|
||||
if _, ok := r.(Editor); ok {
|
||||
selected, err = multi(fmt.Sprintf("Select models for %s:", r), items, preChecked)
|
||||
selected, err = multiSelectPrompt(fmt.Sprintf("Select models for %s:", r), items, preChecked)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
model, err := single(fmt.Sprintf("Select model for %s:", r), items)
|
||||
prompt := fmt.Sprintf("Select model for %s:", r)
|
||||
if _, ok := r.(AliasConfigurer); ok {
|
||||
prompt = fmt.Sprintf("Select Primary model for %s:", r)
|
||||
}
|
||||
model, err := selectPrompt(prompt, items)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -410,78 +170,123 @@ func selectModelsWithSelectors(ctx context.Context, name, current string, single
|
||||
}
|
||||
}
|
||||
|
||||
if err := ensureAuth(ctx, client, cloudModels, selected); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
func pullIfNeeded(ctx context.Context, client *api.Client, existingModels map[string]bool, model string) error {
|
||||
if existingModels[model] {
|
||||
return nil
|
||||
}
|
||||
msg := fmt.Sprintf("Download %s?", model)
|
||||
if ok, err := confirmPrompt(msg); err != nil {
|
||||
return err
|
||||
} else if !ok {
|
||||
return errCancelled
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
if err := pullModel(ctx, client, model); err != nil {
|
||||
return fmt.Errorf("failed to pull %s: %w", model, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func listModels(ctx context.Context) ([]selectItem, map[string]bool, map[string]bool, *api.Client, error) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
|
||||
models, err := client.List(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
|
||||
var existing []modelInfo
|
||||
for _, m := range models.Models {
|
||||
existing = append(existing, modelInfo{
|
||||
Name: m.Name,
|
||||
Remote: m.RemoteModel != "",
|
||||
})
|
||||
}
|
||||
|
||||
items, _, existingModels, cloudModels := buildModelList(existing, nil, "")
|
||||
|
||||
if len(items) == 0 {
|
||||
return nil, nil, nil, nil, fmt.Errorf("no models available, run 'ollama pull <model>' first")
|
||||
}
|
||||
|
||||
return items, existingModels, cloudModels, client, nil
|
||||
}
|
||||
|
||||
func ensureAuth(ctx context.Context, client *api.Client, cloudModels map[string]bool, selected []string) error {
|
||||
var selectedCloudModels []string
|
||||
for _, m := range selected {
|
||||
if cloudModels[m] {
|
||||
selectedCloudModels = append(selectedCloudModels, m)
|
||||
}
|
||||
}
|
||||
if len(selectedCloudModels) > 0 {
|
||||
// ensure user is signed in
|
||||
user, err := client.Whoami(ctx)
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return selected, nil
|
||||
}
|
||||
if len(selectedCloudModels) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var aErr api.AuthorizationError
|
||||
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
|
||||
return nil, err
|
||||
}
|
||||
user, err := client.Whoami(ctx)
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
modelList := strings.Join(selectedCloudModels, ", ")
|
||||
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
|
||||
if err != nil || !yes {
|
||||
return nil, fmt.Errorf("%s requires sign in", modelList)
|
||||
}
|
||||
var aErr api.AuthorizationError
|
||||
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
|
||||
modelList := strings.Join(selectedCloudModels, ", ")
|
||||
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
|
||||
if err != nil || !yes {
|
||||
return fmt.Errorf("%s requires sign in", modelList)
|
||||
}
|
||||
|
||||
// TODO(parthsareen): extract into auth package for cmd
|
||||
// Auto-open browser (best effort, fail silently)
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
_ = exec.Command("open", aErr.SigninURL).Start()
|
||||
case "linux":
|
||||
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
|
||||
case "windows":
|
||||
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
|
||||
|
||||
spinnerFrames := []string{"|", "/", "-", "\\"}
|
||||
frame := 0
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
_ = exec.Command("open", aErr.SigninURL).Start()
|
||||
case "linux":
|
||||
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
|
||||
case "windows":
|
||||
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
|
||||
spinnerFrames := []string{"|", "/", "-", "\\"}
|
||||
frame := 0
|
||||
|
||||
ticker := time.NewTicker(200 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
return nil, ctx.Err()
|
||||
case <-ticker.C:
|
||||
frame++
|
||||
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
|
||||
ticker := time.NewTicker(200 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
// poll every 10th frame (~2 seconds)
|
||||
if frame%10 == 0 {
|
||||
u, err := client.Whoami(ctx)
|
||||
if err == nil && u != nil && u.Name != "" {
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
|
||||
return selected, nil
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
frame++
|
||||
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
|
||||
|
||||
// poll every 10th frame (~2 seconds)
|
||||
if frame%10 == 0 {
|
||||
u, err := client.Whoami(ctx)
|
||||
if err == nil && u != nil && u.Name != "" {
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
// selectModels lets the user select models for an integration using default selectors.
|
||||
func selectModels(ctx context.Context, name, current string) ([]string, error) {
|
||||
return selectModelsWithSelectors(ctx, name, current, defaultSingleSelector, defaultMultiSelector)
|
||||
}
|
||||
|
||||
func runIntegration(name, modelName string, args []string) error {
|
||||
@@ -489,114 +294,42 @@ func runIntegration(name, modelName string, args []string) error {
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", r, modelName)
|
||||
return r.Run(modelName, args)
|
||||
}
|
||||
|
||||
// LaunchIntegration launches the named integration using saved config or prompts for setup.
|
||||
func LaunchIntegration(name string) error {
|
||||
r, ok := integrations[name]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
// syncAliases syncs aliases to server and saves locally for an AliasConfigurer.
|
||||
func syncAliases(ctx context.Context, client *api.Client, ac AliasConfigurer, name, model string, existing map[string]string) error {
|
||||
aliases := make(map[string]string)
|
||||
for k, v := range existing {
|
||||
aliases[k] = v
|
||||
}
|
||||
aliases["primary"] = model
|
||||
|
||||
// Try to use saved config
|
||||
if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 {
|
||||
return runIntegration(name, config.Models[0], nil)
|
||||
}
|
||||
|
||||
// No saved config - prompt user to run setup
|
||||
return fmt.Errorf("%s is not configured. Run 'ollama launch %s' to set it up", r, name)
|
||||
}
|
||||
|
||||
// LaunchIntegrationWithModel launches the named integration with the specified model.
|
||||
func LaunchIntegrationWithModel(name, modelName string) error {
|
||||
return runIntegration(name, modelName, nil)
|
||||
}
|
||||
|
||||
// SaveIntegrationModel saves the model for an integration.
|
||||
func SaveIntegrationModel(name, modelName string) error {
|
||||
// Load existing models and prepend the new one
|
||||
var models []string
|
||||
if existing, err := loadIntegration(name); err == nil && len(existing.Models) > 0 {
|
||||
models = existing.Models
|
||||
// Remove the model if it already exists
|
||||
for i, m := range models {
|
||||
if m == modelName {
|
||||
models = append(models[:i], models[i+1:]...)
|
||||
break
|
||||
}
|
||||
if isCloudModel(ctx, client, model) {
|
||||
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
|
||||
aliases["fast"] = model
|
||||
}
|
||||
}
|
||||
// Prepend the new model
|
||||
models = append([]string{modelName}, models...)
|
||||
return saveIntegration(name, models)
|
||||
}
|
||||
|
||||
// ConfigureIntegrationWithSelectors allows the user to select/change the model for an integration using custom selectors.
|
||||
func ConfigureIntegrationWithSelectors(ctx context.Context, name string, single SingleSelector, multi MultiSelector) error {
|
||||
r, ok := integrations[name]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
} else {
|
||||
delete(aliases, "fast")
|
||||
}
|
||||
|
||||
models, err := selectModelsWithSelectors(ctx, name, "", single, multi)
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
if err := ac.SetAliases(ctx, aliases); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if editor, isEditor := r.(Editor); isEditor {
|
||||
paths := editor.Paths()
|
||||
if len(paths) > 0 {
|
||||
fmt.Fprintf(os.Stderr, "This will modify your %s configuration:\n", r)
|
||||
for _, p := range paths {
|
||||
fmt.Fprintf(os.Stderr, " %s\n", p)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Backups will be saved to %s/\n\n", backupDir())
|
||||
|
||||
if ok, _ := confirmPrompt("Proceed?"); !ok {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := editor.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := saveIntegration(name, models); err != nil {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
|
||||
if len(models) == 1 {
|
||||
fmt.Fprintf(os.Stderr, "Configured %s with %s\n", r, models[0])
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "Configured %s with %d models (default: %s)\n", r, len(models), models[0])
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConfigureIntegration allows the user to select/change the model for an integration.
|
||||
func ConfigureIntegration(ctx context.Context, name string) error {
|
||||
return ConfigureIntegrationWithSelectors(ctx, name, defaultSingleSelector, defaultMultiSelector)
|
||||
return saveAliases(name, aliases)
|
||||
}
|
||||
|
||||
// LaunchCmd returns the cobra command for launching integrations.
|
||||
// The runTUI callback is called when no arguments are provided (alias for main TUI).
|
||||
func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error, runTUI func(cmd *cobra.Command)) *cobra.Command {
|
||||
func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error) *cobra.Command {
|
||||
var modelFlag string
|
||||
var configFlag bool
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "launch [INTEGRATION] [-- [EXTRA_ARGS...]]",
|
||||
Short: "Launch the Ollama menu or an integration",
|
||||
Long: `Launch the Ollama interactive menu, or directly launch a specific integration.
|
||||
|
||||
Without arguments, this is equivalent to running 'ollama' directly.
|
||||
Short: "Launch an integration with Ollama",
|
||||
Long: `Launch an integration configured with Ollama models.
|
||||
|
||||
Supported integrations:
|
||||
claude Claude Code
|
||||
@@ -615,12 +348,6 @@ Examples:
|
||||
Args: cobra.ArbitraryArgs,
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
// No args - run the main TUI (same as 'ollama')
|
||||
if len(args) == 0 && modelFlag == "" && !configFlag {
|
||||
runTUI(cmd)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Extract integration name and args to pass through using -- separator
|
||||
var name string
|
||||
var passArgs []string
|
||||
@@ -661,9 +388,87 @@ Examples:
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
|
||||
if !configFlag && modelFlag == "" {
|
||||
if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 {
|
||||
return runIntegration(name, config.Models[0], passArgs)
|
||||
// Handle AliasConfigurer integrations (claude, codex)
|
||||
if ac, ok := r.(AliasConfigurer); ok {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate --model flag if provided
|
||||
if modelFlag != "" {
|
||||
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: modelFlag}); err != nil {
|
||||
return fmt.Errorf("model %q not found", modelFlag)
|
||||
}
|
||||
}
|
||||
|
||||
var model string
|
||||
var existingAliases map[string]string
|
||||
|
||||
// Load saved config
|
||||
if cfg, err := loadIntegration(name); err == nil {
|
||||
existingAliases = cfg.Aliases
|
||||
if len(cfg.Models) > 0 {
|
||||
model = cfg.Models[0]
|
||||
// AliasConfigurer integrations use single model; sanitize if multiple
|
||||
if len(cfg.Models) > 1 {
|
||||
_ = saveIntegration(name, []string{model})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --model flag overrides saved model
|
||||
if modelFlag != "" {
|
||||
model = modelFlag
|
||||
}
|
||||
|
||||
// Validate saved model still exists
|
||||
if model != "" && modelFlag == "" {
|
||||
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: model}); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%sConfigured model %q not found%s\n\n", ansiGray, model, ansiReset)
|
||||
model = ""
|
||||
}
|
||||
}
|
||||
|
||||
// If no valid model or --config flag, show picker
|
||||
if model == "" || configFlag {
|
||||
aliases, _, err := ac.ConfigureAliases(cmd.Context(), model, existingAliases, configFlag)
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
model = aliases["primary"]
|
||||
existingAliases = aliases
|
||||
}
|
||||
|
||||
// Sync aliases and save
|
||||
if err := syncAliases(cmd.Context(), client, ac, name, model, existingAliases); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%sWarning: Could not sync aliases: %v%s\n", ansiGray, err, ansiReset)
|
||||
}
|
||||
if err := saveIntegration(name, []string{model}); err != nil {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
|
||||
// Launch (unless --config without confirmation)
|
||||
if configFlag {
|
||||
if launch, _ := confirmPrompt(fmt.Sprintf("Launch %s now?", r)); launch {
|
||||
return runIntegration(name, model, passArgs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return runIntegration(name, model, passArgs)
|
||||
}
|
||||
|
||||
// Validate --model flag for non-AliasConfigurer integrations
|
||||
if modelFlag != "" {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: modelFlag}); err != nil {
|
||||
return fmt.Errorf("model %q not found", modelFlag)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -739,13 +544,14 @@ Examples:
|
||||
}
|
||||
|
||||
type modelInfo struct {
|
||||
Name string
|
||||
Remote bool
|
||||
Name string
|
||||
Remote bool
|
||||
ToolCapable bool
|
||||
}
|
||||
|
||||
// buildModelList merges existing models with recommendations, sorts them, and returns
|
||||
// the ordered items along with maps of existing and cloud model names.
|
||||
func buildModelList(existing []modelInfo, preChecked []string, current string) (items []ModelItem, orderedChecked []string, existingModels, cloudModels map[string]bool) {
|
||||
func buildModelList(existing []modelInfo, preChecked []string, current string) (items []selectItem, orderedChecked []string, existingModels, cloudModels map[string]bool) {
|
||||
existingModels = make(map[string]bool)
|
||||
cloudModels = make(map[string]bool)
|
||||
recommended := make(map[string]bool)
|
||||
@@ -765,7 +571,7 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
|
||||
}
|
||||
displayName := strings.TrimSuffix(m.Name, ":latest")
|
||||
existingModels[displayName] = true
|
||||
item := ModelItem{Name: displayName}
|
||||
item := selectItem{Name: displayName}
|
||||
if recommended[displayName] {
|
||||
item.Description = "recommended"
|
||||
}
|
||||
@@ -777,7 +583,7 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
|
||||
continue
|
||||
}
|
||||
items = append(items, rec)
|
||||
if isCloudModel(rec.Name) {
|
||||
if strings.HasSuffix(rec.Name, ":cloud") {
|
||||
cloudModels[rec.Name] = true
|
||||
}
|
||||
}
|
||||
@@ -814,7 +620,7 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
|
||||
}
|
||||
|
||||
if hasLocalModel || hasCloudModel {
|
||||
slices.SortStableFunc(items, func(a, b ModelItem) int {
|
||||
slices.SortStableFunc(items, func(a, b selectItem) int {
|
||||
ac, bc := checked[a.Name], checked[b.Name]
|
||||
aNew, bNew := notInstalled[a.Name], notInstalled[b.Name]
|
||||
|
||||
@@ -837,58 +643,16 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
|
||||
return items, preChecked, existingModels, cloudModels
|
||||
}
|
||||
|
||||
func isCloudModel(name string) bool {
|
||||
return strings.HasSuffix(name, ":cloud")
|
||||
}
|
||||
|
||||
// GetModelItems returns a list of model items including recommendations for the TUI.
|
||||
// It includes all locally available models plus recommended models that aren't installed.
|
||||
func GetModelItems(ctx context.Context) ([]ModelItem, map[string]bool) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
// isCloudModel checks if a model is a cloud model using the Show API.
|
||||
func isCloudModel(ctx context.Context, client *api.Client, name string) bool {
|
||||
if client == nil {
|
||||
return false
|
||||
}
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Name: name})
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
return false
|
||||
}
|
||||
|
||||
models, err := client.List(ctx)
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var existing []modelInfo
|
||||
for _, m := range models.Models {
|
||||
existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""})
|
||||
}
|
||||
|
||||
lastModel := LastModel()
|
||||
var preChecked []string
|
||||
if lastModel != "" {
|
||||
preChecked = []string{lastModel}
|
||||
}
|
||||
|
||||
items, _, existingModels, _ := buildModelList(existing, preChecked, lastModel)
|
||||
|
||||
// Sort with last model first, then existing models, then recommendations
|
||||
slices.SortStableFunc(items, func(a, b ModelItem) int {
|
||||
aIsLast := a.Name == lastModel
|
||||
bIsLast := b.Name == lastModel
|
||||
if aIsLast != bIsLast {
|
||||
if aIsLast {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
aExists := existingModels[a.Name]
|
||||
bExists := existingModels[b.Name]
|
||||
if aExists != bExists {
|
||||
if aExists {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
|
||||
})
|
||||
|
||||
return items, existingModels
|
||||
return resp.RemoteModel != ""
|
||||
}
|
||||
|
||||
func pullModel(ctx context.Context, client *api.Client, model string) error {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
@@ -88,10 +89,8 @@ func TestLaunchCmd(t *testing.T) {
|
||||
mockCheck := func(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
// Mock TUI function (not called in these tests)
|
||||
mockTUI := func(cmd *cobra.Command) {}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd := LaunchCmd(mockCheck)
|
||||
|
||||
t.Run("command structure", func(t *testing.T) {
|
||||
if cmd.Use != "launch [INTEGRATION] [-- [EXTRA_ARGS...]]" {
|
||||
@@ -124,75 +123,6 @@ func TestLaunchCmd(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestLaunchCmd_TUICallback(t *testing.T) {
|
||||
mockCheck := func(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
t.Run("no args calls TUI", func(t *testing.T) {
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{})
|
||||
_ = cmd.Execute()
|
||||
|
||||
if !tuiCalled {
|
||||
t.Error("TUI callback should be called when no args provided")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("integration arg bypasses TUI", func(t *testing.T) {
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{"claude"})
|
||||
// Will error because claude isn't configured, but that's OK
|
||||
_ = cmd.Execute()
|
||||
|
||||
if tuiCalled {
|
||||
t.Error("TUI callback should NOT be called when integration arg provided")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("--model flag bypasses TUI", func(t *testing.T) {
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{"--model", "test-model"})
|
||||
// Will error because no integration specified, but that's OK
|
||||
_ = cmd.Execute()
|
||||
|
||||
if tuiCalled {
|
||||
t.Error("TUI callback should NOT be called when --model flag provided")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("--config flag bypasses TUI", func(t *testing.T) {
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{"--config"})
|
||||
// Will error because no integration specified, but that's OK
|
||||
_ = cmd.Execute()
|
||||
|
||||
if tuiCalled {
|
||||
t.Error("TUI callback should NOT be called when --config flag provided")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRunIntegration_UnknownIntegration(t *testing.T) {
|
||||
err := runIntegration("unknown-integration", "model", nil)
|
||||
if err == nil {
|
||||
@@ -233,7 +163,7 @@ func TestHasLocalModel_DocumentsHeuristic(t *testing.T) {
|
||||
|
||||
func TestLaunchCmd_NilHeartbeat(t *testing.T) {
|
||||
// This should not panic - cmd creation should work even with nil
|
||||
cmd := LaunchCmd(nil, nil)
|
||||
cmd := LaunchCmd(nil)
|
||||
if cmd == nil {
|
||||
t.Fatal("LaunchCmd returned nil")
|
||||
}
|
||||
@@ -368,27 +298,18 @@ func TestParseArgs(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestIsCloudModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
want bool
|
||||
}{
|
||||
{"glm-4.7:cloud", true},
|
||||
{"kimi-k2.5:cloud", true},
|
||||
{"glm-4.7-flash", false},
|
||||
{"glm-4.7-flash:latest", false},
|
||||
{"cloud-model", false},
|
||||
{"model:cloudish", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := isCloudModel(tt.name); got != tt.want {
|
||||
t.Errorf("isCloudModel(%q) = %v, want %v", tt.name, got, tt.want)
|
||||
// isCloudModel now only uses Show API, so nil client always returns false
|
||||
t.Run("nil client returns false", func(t *testing.T) {
|
||||
models := []string{"glm-4.7:cloud", "kimi-k2.5:cloud", "local-model"}
|
||||
for _, model := range models {
|
||||
if isCloudModel(context.Background(), nil, model) {
|
||||
t.Errorf("isCloudModel(%q) with nil client should return false", model)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func names(items []ModelItem) []string {
|
||||
func names(items []selectItem) []string {
|
||||
var out []string
|
||||
for _, item := range items {
|
||||
out = append(out, item.Name)
|
||||
@@ -580,3 +501,19 @@ func TestBuildModelList_ReturnsExistingAndCloudMaps(t *testing.T) {
|
||||
t.Error("llama3.2 should not be in cloudModels")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAliasConfigurerInterface(t *testing.T) {
|
||||
t.Run("claude implements AliasConfigurer", func(t *testing.T) {
|
||||
claude := &Claude{}
|
||||
if _, ok := interface{}(claude).(AliasConfigurer); !ok {
|
||||
t.Error("Claude should implement AliasConfigurer")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("codex does not implement AliasConfigurer", func(t *testing.T) {
|
||||
codex := &Codex{}
|
||||
if _, ok := interface{}(codex).(AliasConfigurer); ok {
|
||||
t.Error("Codex should not implement AliasConfigurer")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -17,8 +17,6 @@ type Openclaw struct{}
|
||||
|
||||
func (c *Openclaw) String() string { return "OpenClaw" }
|
||||
|
||||
const ansiGreen = "\033[32m"
|
||||
|
||||
func (c *Openclaw) Run(model string, args []string) error {
|
||||
bin := "openclaw"
|
||||
if _, err := exec.LookPath(bin); err != nil {
|
||||
|
||||
@@ -17,6 +17,7 @@ const (
|
||||
ansiBold = "\033[1m"
|
||||
ansiReset = "\033[0m"
|
||||
ansiGray = "\033[37m"
|
||||
ansiGreen = "\033[32m"
|
||||
ansiClearDown = "\033[J"
|
||||
)
|
||||
|
||||
|
||||
@@ -96,6 +96,14 @@ func TestSelectState(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_EmptyFilteredList_EmptyFilter_DoesNothing", func(t *testing.T) {
|
||||
s := newSelectState([]selectItem{})
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
if done || result != "" || err != nil {
|
||||
t.Errorf("expected (false, '', nil), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Escape_ReturnsCancelledError", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
done, result, err := s.handleInput(eventEscape, 0)
|
||||
@@ -574,8 +582,19 @@ func TestRenderSelect(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
renderSelect(&buf, "Select:", s)
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "no matches") {
|
||||
t.Errorf("expected 'no matches' message, got: %s", output)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EmptyFilteredList_EmptyFilter_ShowsNoMatches", func(t *testing.T) {
|
||||
s := newSelectState([]selectItem{})
|
||||
var buf bytes.Buffer
|
||||
renderSelect(&buf, "Select:", s)
|
||||
|
||||
if !strings.Contains(buf.String(), "no matches") {
|
||||
t.Error("expected 'no matches' message")
|
||||
t.Error("expected 'no matches' message for empty list with no filter")
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@@ -10,19 +10,21 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
var errNotRunning = errors.New("could not connect to ollama server, run 'ollama serve' to start it")
|
||||
|
||||
func startApp(ctx context.Context, client *api.Client) error {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return err
|
||||
return errNotRunning
|
||||
}
|
||||
link, err := os.Readlink(exe)
|
||||
if err != nil {
|
||||
return err
|
||||
return errNotRunning
|
||||
}
|
||||
r := regexp.MustCompile(`^.*/Ollama\s?\d*.app`)
|
||||
m := r.FindStringSubmatch(link)
|
||||
if len(m) != 1 {
|
||||
return errors.New("could not find ollama app")
|
||||
return errNotRunning
|
||||
}
|
||||
if err := exec.Command("/usr/bin/open", "-j", "-a", m[0], "--args", "--fast-startup").Run(); err != nil {
|
||||
return err
|
||||
|
||||
@@ -1,509 +0,0 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
var (
|
||||
selectorTitleStyle = lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(lipgloss.Color("147"))
|
||||
|
||||
selectorItemStyle = lipgloss.NewStyle().
|
||||
PaddingLeft(4)
|
||||
|
||||
selectorSelectedItemStyle = lipgloss.NewStyle().
|
||||
PaddingLeft(2).
|
||||
Foreground(lipgloss.Color("147")).
|
||||
Bold(true)
|
||||
|
||||
selectorDescStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("241"))
|
||||
|
||||
selectorFilterStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("241")).
|
||||
Italic(true)
|
||||
|
||||
selectorInputStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("252"))
|
||||
|
||||
selectorCheckboxStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("241"))
|
||||
|
||||
selectorCheckboxCheckedStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("147"))
|
||||
|
||||
selectorDefaultTagStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("241")).
|
||||
Italic(true)
|
||||
|
||||
selectorHelpStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("241"))
|
||||
|
||||
selectorMoreStyle = lipgloss.NewStyle().
|
||||
PaddingLeft(4).
|
||||
Foreground(lipgloss.Color("241")).
|
||||
Italic(true)
|
||||
)
|
||||
|
||||
const maxSelectorItems = 10
|
||||
|
||||
// ErrCancelled is returned when the user cancels the selection.
|
||||
var ErrCancelled = errors.New("cancelled")
|
||||
|
||||
// SelectItem represents an item that can be selected.
|
||||
type SelectItem struct {
|
||||
Name string
|
||||
Description string
|
||||
}
|
||||
|
||||
// selectorModel is the bubbletea model for single selection.
|
||||
type selectorModel struct {
|
||||
title string
|
||||
items []SelectItem
|
||||
filter string
|
||||
cursor int
|
||||
scrollOffset int
|
||||
selected string
|
||||
cancelled bool
|
||||
}
|
||||
|
||||
func (m selectorModel) filteredItems() []SelectItem {
|
||||
if m.filter == "" {
|
||||
return m.items
|
||||
}
|
||||
filterLower := strings.ToLower(m.filter)
|
||||
var result []SelectItem
|
||||
for _, item := range m.items {
|
||||
if strings.Contains(strings.ToLower(item.Name), filterLower) {
|
||||
result = append(result, item)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (m selectorModel) Init() tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m selectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
filtered := m.filteredItems()
|
||||
|
||||
switch msg.Type {
|
||||
case tea.KeyCtrlC, tea.KeyEsc:
|
||||
m.cancelled = true
|
||||
return m, tea.Quit
|
||||
|
||||
case tea.KeyEnter:
|
||||
if len(filtered) > 0 && m.cursor < len(filtered) {
|
||||
m.selected = filtered[m.cursor].Name
|
||||
}
|
||||
return m, tea.Quit
|
||||
|
||||
case tea.KeyUp:
|
||||
if m.cursor > 0 {
|
||||
m.cursor--
|
||||
if m.cursor < m.scrollOffset {
|
||||
m.scrollOffset = m.cursor
|
||||
}
|
||||
}
|
||||
|
||||
case tea.KeyDown:
|
||||
if m.cursor < len(filtered)-1 {
|
||||
m.cursor++
|
||||
if m.cursor >= m.scrollOffset+maxSelectorItems {
|
||||
m.scrollOffset = m.cursor - maxSelectorItems + 1
|
||||
}
|
||||
}
|
||||
|
||||
case tea.KeyPgUp:
|
||||
m.cursor -= maxSelectorItems
|
||||
if m.cursor < 0 {
|
||||
m.cursor = 0
|
||||
}
|
||||
m.scrollOffset -= maxSelectorItems
|
||||
if m.scrollOffset < 0 {
|
||||
m.scrollOffset = 0
|
||||
}
|
||||
|
||||
case tea.KeyPgDown:
|
||||
m.cursor += maxSelectorItems
|
||||
if m.cursor >= len(filtered) {
|
||||
m.cursor = len(filtered) - 1
|
||||
}
|
||||
if m.cursor >= m.scrollOffset+maxSelectorItems {
|
||||
m.scrollOffset = m.cursor - maxSelectorItems + 1
|
||||
}
|
||||
|
||||
case tea.KeyBackspace:
|
||||
if len(m.filter) > 0 {
|
||||
m.filter = m.filter[:len(m.filter)-1]
|
||||
m.cursor = 0
|
||||
m.scrollOffset = 0
|
||||
}
|
||||
|
||||
case tea.KeyRunes:
|
||||
m.filter += string(msg.Runes)
|
||||
m.cursor = 0
|
||||
m.scrollOffset = 0
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m selectorModel) View() string {
|
||||
// Clear screen when exiting
|
||||
if m.cancelled || m.selected != "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
var s strings.Builder
|
||||
|
||||
// Title with filter
|
||||
s.WriteString(selectorTitleStyle.Render(m.title))
|
||||
s.WriteString(" ")
|
||||
if m.filter == "" {
|
||||
s.WriteString(selectorFilterStyle.Render("Type to filter..."))
|
||||
} else {
|
||||
s.WriteString(selectorInputStyle.Render(m.filter))
|
||||
}
|
||||
s.WriteString("\n\n")
|
||||
|
||||
filtered := m.filteredItems()
|
||||
|
||||
if len(filtered) == 0 {
|
||||
s.WriteString(selectorItemStyle.Render(selectorDescStyle.Render("(no matches)")))
|
||||
s.WriteString("\n")
|
||||
} else {
|
||||
displayCount := min(len(filtered), maxSelectorItems)
|
||||
|
||||
for i := range displayCount {
|
||||
idx := m.scrollOffset + i
|
||||
if idx >= len(filtered) {
|
||||
break
|
||||
}
|
||||
item := filtered[idx]
|
||||
|
||||
if idx == m.cursor {
|
||||
s.WriteString(selectorSelectedItemStyle.Render("▸ " + item.Name))
|
||||
} else {
|
||||
s.WriteString(selectorItemStyle.Render(item.Name))
|
||||
}
|
||||
|
||||
if item.Description != "" {
|
||||
s.WriteString(" ")
|
||||
s.WriteString(selectorDescStyle.Render("- " + item.Description))
|
||||
}
|
||||
s.WriteString("\n")
|
||||
}
|
||||
|
||||
if remaining := len(filtered) - m.scrollOffset - displayCount; remaining > 0 {
|
||||
s.WriteString(selectorMoreStyle.Render(fmt.Sprintf("... and %d more", remaining)))
|
||||
s.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
s.WriteString("\n")
|
||||
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • enter select • esc cancel"))
|
||||
|
||||
return s.String()
|
||||
}
|
||||
|
||||
// SelectSingle prompts the user to select a single item from a list.
|
||||
func SelectSingle(title string, items []SelectItem) (string, error) {
|
||||
if len(items) == 0 {
|
||||
return "", fmt.Errorf("no items to select from")
|
||||
}
|
||||
|
||||
m := selectorModel{
|
||||
title: title,
|
||||
items: items,
|
||||
}
|
||||
|
||||
p := tea.NewProgram(m)
|
||||
finalModel, err := p.Run()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error running selector: %w", err)
|
||||
}
|
||||
|
||||
fm := finalModel.(selectorModel)
|
||||
if fm.cancelled {
|
||||
return "", ErrCancelled
|
||||
}
|
||||
|
||||
return fm.selected, nil
|
||||
}
|
||||
|
||||
// multiSelectorModel is the bubbletea model for multi selection.
|
||||
type multiSelectorModel struct {
|
||||
title string
|
||||
items []SelectItem
|
||||
itemIndex map[string]int
|
||||
filter string
|
||||
cursor int
|
||||
scrollOffset int
|
||||
checked map[int]bool
|
||||
checkOrder []int
|
||||
cancelled bool
|
||||
confirmed bool
|
||||
}
|
||||
|
||||
func newMultiSelectorModel(title string, items []SelectItem, preChecked []string) multiSelectorModel {
|
||||
m := multiSelectorModel{
|
||||
title: title,
|
||||
items: items,
|
||||
itemIndex: make(map[string]int, len(items)),
|
||||
checked: make(map[int]bool),
|
||||
}
|
||||
|
||||
for i, item := range items {
|
||||
m.itemIndex[item.Name] = i
|
||||
}
|
||||
|
||||
for _, name := range preChecked {
|
||||
if idx, ok := m.itemIndex[name]; ok {
|
||||
m.checked[idx] = true
|
||||
m.checkOrder = append(m.checkOrder, idx)
|
||||
}
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) filteredItems() []SelectItem {
|
||||
if m.filter == "" {
|
||||
return m.items
|
||||
}
|
||||
filterLower := strings.ToLower(m.filter)
|
||||
var result []SelectItem
|
||||
for _, item := range m.items {
|
||||
if strings.Contains(strings.ToLower(item.Name), filterLower) {
|
||||
result = append(result, item)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (m *multiSelectorModel) toggleItem() {
|
||||
filtered := m.filteredItems()
|
||||
if len(filtered) == 0 || m.cursor >= len(filtered) {
|
||||
return
|
||||
}
|
||||
|
||||
item := filtered[m.cursor]
|
||||
origIdx := m.itemIndex[item.Name]
|
||||
|
||||
if m.checked[origIdx] {
|
||||
delete(m.checked, origIdx)
|
||||
for i, idx := range m.checkOrder {
|
||||
if idx == origIdx {
|
||||
m.checkOrder = append(m.checkOrder[:i], m.checkOrder[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
m.checked[origIdx] = true
|
||||
m.checkOrder = append(m.checkOrder, origIdx)
|
||||
}
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) selectedCount() int {
|
||||
return len(m.checkOrder)
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) Init() tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
filtered := m.filteredItems()
|
||||
|
||||
switch msg.Type {
|
||||
case tea.KeyCtrlC, tea.KeyEsc:
|
||||
m.cancelled = true
|
||||
return m, tea.Quit
|
||||
|
||||
case tea.KeyEnter:
|
||||
// Enter confirms if at least one item is selected
|
||||
if len(m.checkOrder) > 0 {
|
||||
m.confirmed = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
|
||||
case tea.KeySpace:
|
||||
// Space always toggles selection
|
||||
m.toggleItem()
|
||||
|
||||
case tea.KeyUp:
|
||||
if m.cursor > 0 {
|
||||
m.cursor--
|
||||
if m.cursor < m.scrollOffset {
|
||||
m.scrollOffset = m.cursor
|
||||
}
|
||||
}
|
||||
|
||||
case tea.KeyDown:
|
||||
if m.cursor < len(filtered)-1 {
|
||||
m.cursor++
|
||||
if m.cursor >= m.scrollOffset+maxSelectorItems {
|
||||
m.scrollOffset = m.cursor - maxSelectorItems + 1
|
||||
}
|
||||
}
|
||||
|
||||
case tea.KeyPgUp:
|
||||
m.cursor -= maxSelectorItems
|
||||
if m.cursor < 0 {
|
||||
m.cursor = 0
|
||||
}
|
||||
m.scrollOffset -= maxSelectorItems
|
||||
if m.scrollOffset < 0 {
|
||||
m.scrollOffset = 0
|
||||
}
|
||||
|
||||
case tea.KeyPgDown:
|
||||
m.cursor += maxSelectorItems
|
||||
if m.cursor >= len(filtered) {
|
||||
m.cursor = len(filtered) - 1
|
||||
}
|
||||
if m.cursor >= m.scrollOffset+maxSelectorItems {
|
||||
m.scrollOffset = m.cursor - maxSelectorItems + 1
|
||||
}
|
||||
|
||||
case tea.KeyBackspace:
|
||||
if len(m.filter) > 0 {
|
||||
m.filter = m.filter[:len(m.filter)-1]
|
||||
m.cursor = 0
|
||||
m.scrollOffset = 0
|
||||
}
|
||||
|
||||
case tea.KeyRunes:
|
||||
m.filter += string(msg.Runes)
|
||||
m.cursor = 0
|
||||
m.scrollOffset = 0
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) View() string {
|
||||
// Clear screen when exiting
|
||||
if m.cancelled || m.confirmed {
|
||||
return ""
|
||||
}
|
||||
|
||||
var s strings.Builder
|
||||
|
||||
// Title with filter
|
||||
s.WriteString(selectorTitleStyle.Render(m.title))
|
||||
s.WriteString(" ")
|
||||
if m.filter == "" {
|
||||
s.WriteString(selectorFilterStyle.Render("Type to filter..."))
|
||||
} else {
|
||||
s.WriteString(selectorInputStyle.Render(m.filter))
|
||||
}
|
||||
s.WriteString("\n\n")
|
||||
|
||||
filtered := m.filteredItems()
|
||||
|
||||
if len(filtered) == 0 {
|
||||
s.WriteString(selectorItemStyle.Render(selectorDescStyle.Render("(no matches)")))
|
||||
s.WriteString("\n")
|
||||
} else {
|
||||
displayCount := min(len(filtered), maxSelectorItems)
|
||||
|
||||
for i := range displayCount {
|
||||
idx := m.scrollOffset + i
|
||||
if idx >= len(filtered) {
|
||||
break
|
||||
}
|
||||
item := filtered[idx]
|
||||
origIdx := m.itemIndex[item.Name]
|
||||
|
||||
// Checkbox
|
||||
var checkbox string
|
||||
if m.checked[origIdx] {
|
||||
checkbox = selectorCheckboxCheckedStyle.Render("[x]")
|
||||
} else {
|
||||
checkbox = selectorCheckboxStyle.Render("[ ]")
|
||||
}
|
||||
|
||||
// Cursor and name
|
||||
var line string
|
||||
if idx == m.cursor {
|
||||
line = selectorSelectedItemStyle.Render("▸ ") + checkbox + " " + selectorSelectedItemStyle.Render(item.Name)
|
||||
} else {
|
||||
line = " " + checkbox + " " + item.Name
|
||||
}
|
||||
|
||||
// Default tag
|
||||
if len(m.checkOrder) > 0 && m.checkOrder[0] == origIdx {
|
||||
line += " " + selectorDefaultTagStyle.Render("(default)")
|
||||
}
|
||||
|
||||
s.WriteString(line)
|
||||
s.WriteString("\n")
|
||||
}
|
||||
|
||||
if remaining := len(filtered) - m.scrollOffset - displayCount; remaining > 0 {
|
||||
s.WriteString(selectorMoreStyle.Render(fmt.Sprintf("... and %d more", remaining)))
|
||||
s.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
s.WriteString("\n")
|
||||
|
||||
// Status line
|
||||
count := m.selectedCount()
|
||||
if count == 0 {
|
||||
s.WriteString(selectorDescStyle.Render(" Select at least one model."))
|
||||
} else {
|
||||
s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count)))
|
||||
}
|
||||
s.WriteString("\n\n")
|
||||
|
||||
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • enter confirm • esc cancel"))
|
||||
|
||||
return s.String()
|
||||
}
|
||||
|
||||
// SelectMultiple prompts the user to select multiple items from a list.
|
||||
func SelectMultiple(title string, items []SelectItem, preChecked []string) ([]string, error) {
|
||||
if len(items) == 0 {
|
||||
return nil, fmt.Errorf("no items to select from")
|
||||
}
|
||||
|
||||
m := newMultiSelectorModel(title, items, preChecked)
|
||||
|
||||
p := tea.NewProgram(m)
|
||||
finalModel, err := p.Run()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error running selector: %w", err)
|
||||
}
|
||||
|
||||
fm := finalModel.(multiSelectorModel)
|
||||
if fm.cancelled {
|
||||
return nil, ErrCancelled
|
||||
}
|
||||
|
||||
if !fm.confirmed {
|
||||
return nil, ErrCancelled
|
||||
}
|
||||
|
||||
var result []string
|
||||
for _, idx := range fm.checkOrder {
|
||||
result = append(result, fm.items[idx].Name)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
808
cmd/tui/tui.go
808
cmd/tui/tui.go
@@ -1,808 +0,0 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
const (
|
||||
logoNormal = ` ▆▁▂▃▂▁▆
|
||||
▟███████▙
|
||||
█▙▛▅ ▅▜▟█
|
||||
▟█▙▀▀▀▟█▙
|
||||
█████████
|
||||
▟███████▙
|
||||
▀▀▀▀▀▀▀▀▀`
|
||||
|
||||
logoBlink = ` ▆▁▂▃▂▁▆
|
||||
▟███████▙
|
||||
██▛▅ ▅▜██
|
||||
▟█▙▀▀▀▟█▙
|
||||
█████████
|
||||
▟███████▙
|
||||
▀▀▀▀▀▀▀▀▀`
|
||||
|
||||
// logoBlank is used for terminals that don't render the logo well
|
||||
logoBlank = `
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
`
|
||||
|
||||
blinkInterval = 15 * time.Second
|
||||
blinkDuration = 250 * time.Millisecond
|
||||
)
|
||||
|
||||
type (
|
||||
blinkMsg struct{}
|
||||
unblinkMsg struct{}
|
||||
)
|
||||
|
||||
var (
|
||||
logoStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("255")).
|
||||
Background(lipgloss.Color("0"))
|
||||
|
||||
titleStyle = lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
MarginBottom(1)
|
||||
|
||||
versionStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("245"))
|
||||
|
||||
itemStyle = lipgloss.NewStyle().
|
||||
PaddingLeft(2)
|
||||
|
||||
selectedStyle = lipgloss.NewStyle().
|
||||
PaddingLeft(2).
|
||||
Foreground(lipgloss.Color("147")).
|
||||
Bold(true)
|
||||
|
||||
greyedStyle = lipgloss.NewStyle().
|
||||
PaddingLeft(2).
|
||||
Foreground(lipgloss.Color("241"))
|
||||
|
||||
greyedSelectedStyle = lipgloss.NewStyle().
|
||||
PaddingLeft(2).
|
||||
Foreground(lipgloss.Color("243"))
|
||||
|
||||
descStyle = lipgloss.NewStyle().
|
||||
PaddingLeft(4).
|
||||
Foreground(lipgloss.Color("241"))
|
||||
|
||||
modelStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("245"))
|
||||
|
||||
notInstalledStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("241")).
|
||||
Italic(true)
|
||||
)
|
||||
|
||||
type menuItem struct {
|
||||
title string
|
||||
description string
|
||||
integration string // integration name for loading model config, empty if not an integration
|
||||
isRunModel bool // true for the "Run a model" option
|
||||
isOthers bool // true for the "Others..." toggle item
|
||||
}
|
||||
|
||||
var mainMenuItems = []menuItem{
|
||||
{
|
||||
title: "Run a model",
|
||||
description: "Start an interactive chat with a local model",
|
||||
isRunModel: true,
|
||||
},
|
||||
{
|
||||
title: "Launch Claude Code",
|
||||
description: "Open Claude Code AI assistant",
|
||||
integration: "claude",
|
||||
},
|
||||
{
|
||||
title: "Launch Open Claw",
|
||||
description: "Open the Open Claw integration",
|
||||
integration: "openclaw",
|
||||
},
|
||||
}
|
||||
|
||||
var othersMenuItem = menuItem{
|
||||
title: "Others...",
|
||||
description: "Show additional integrations",
|
||||
isOthers: true,
|
||||
}
|
||||
|
||||
// getOtherIntegrations returns the list of other integrations, filtering out
|
||||
// Codex if it's not installed (since it requires npm install).
|
||||
func getOtherIntegrations() []menuItem {
|
||||
items := []menuItem{
|
||||
{
|
||||
title: "Launch Droid",
|
||||
description: "Open Droid integration",
|
||||
integration: "droid",
|
||||
},
|
||||
{
|
||||
title: "Launch Open Code",
|
||||
description: "Open Open Code integration",
|
||||
integration: "opencode",
|
||||
},
|
||||
}
|
||||
|
||||
// Only show Codex if it's already installed
|
||||
if config.IsIntegrationInstalled("codex") {
|
||||
items = append([]menuItem{{
|
||||
title: "Launch Codex",
|
||||
description: "Open Codex CLI",
|
||||
integration: "codex",
|
||||
}}, items...)
|
||||
}
|
||||
|
||||
return items
|
||||
}
|
||||
|
||||
type model struct {
|
||||
items []menuItem
|
||||
cursor int
|
||||
quitting bool
|
||||
selected bool // true if user made a selection (enter/space)
|
||||
changeModel bool // true if user pressed right arrow to change model
|
||||
showOthers bool // true if "Others..." is expanded
|
||||
availableModels map[string]bool // cache of available model names
|
||||
blinking bool // true when showing blink logo
|
||||
err error
|
||||
|
||||
// Modal state
|
||||
showingModal bool // true when model picker modal is visible
|
||||
modalSelector selectorModel // the selector model for the modal
|
||||
modalItems []SelectItem // cached items for the modal
|
||||
|
||||
// Sign-in dialog state
|
||||
showingSignIn bool // true when sign-in dialog is visible
|
||||
signInURL string // URL for sign-in
|
||||
signInModel string // model that requires sign-in
|
||||
signInSpinner int // spinner frame index
|
||||
signInFromModal bool // true if sign-in was triggered from modal (not main menu)
|
||||
}
|
||||
|
||||
// signInTickMsg is sent to animate the sign-in spinner
|
||||
type signInTickMsg struct{}
|
||||
|
||||
// signInCheckMsg is sent to check if sign-in is complete
|
||||
type signInCheckMsg struct {
|
||||
signedIn bool
|
||||
userName string
|
||||
}
|
||||
|
||||
// modelExists checks if a model exists in the cached available models.
|
||||
func (m *model) modelExists(name string) bool {
|
||||
if m.availableModels == nil || name == "" {
|
||||
return false
|
||||
}
|
||||
if m.availableModels[name] {
|
||||
return true
|
||||
}
|
||||
// Check for prefix match (e.g., "llama2" matches "llama2:latest")
|
||||
for modelName := range m.availableModels {
|
||||
if strings.HasPrefix(modelName, name+":") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// buildModalItems creates the list of models for the modal selector.
|
||||
func (m *model) buildModalItems() []SelectItem {
|
||||
modelItems, _ := config.GetModelItems(context.Background())
|
||||
var items []SelectItem
|
||||
for _, item := range modelItems {
|
||||
items = append(items, SelectItem{Name: item.Name, Description: item.Description})
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
// openModelModal opens the model picker modal.
|
||||
func (m *model) openModelModal() {
|
||||
m.modalItems = m.buildModalItems()
|
||||
m.modalSelector = selectorModel{
|
||||
title: "Select model:",
|
||||
items: m.modalItems,
|
||||
}
|
||||
m.showingModal = true
|
||||
}
|
||||
|
||||
// isCloudModel returns true if the model name indicates a cloud model.
|
||||
func isCloudModel(name string) bool {
|
||||
return strings.HasSuffix(name, ":cloud")
|
||||
}
|
||||
|
||||
// checkCloudSignIn checks if a cloud model needs sign-in.
|
||||
// Returns a command to start sign-in if needed, or nil if already signed in.
|
||||
func (m *model) checkCloudSignIn(modelName string, fromModal bool) tea.Cmd {
|
||||
if modelName == "" || !isCloudModel(modelName) {
|
||||
return nil
|
||||
}
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
user, err := client.Whoami(context.Background())
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return nil // Already signed in
|
||||
}
|
||||
var aErr api.AuthorizationError
|
||||
if errors.As(err, &aErr) && aErr.SigninURL != "" {
|
||||
return m.startSignIn(modelName, aErr.SigninURL, fromModal)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// startSignIn initiates the sign-in flow for a cloud model.
|
||||
// fromModal indicates if this was triggered from the model picker modal.
|
||||
func (m *model) startSignIn(modelName, signInURL string, fromModal bool) tea.Cmd {
|
||||
m.showingModal = false
|
||||
m.showingSignIn = true
|
||||
m.signInURL = signInURL
|
||||
m.signInModel = modelName
|
||||
m.signInSpinner = 0
|
||||
m.signInFromModal = fromModal
|
||||
|
||||
// Open browser (best effort)
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
_ = exec.Command("open", signInURL).Start()
|
||||
case "linux":
|
||||
_ = exec.Command("xdg-open", signInURL).Start()
|
||||
case "windows":
|
||||
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", signInURL).Start()
|
||||
}
|
||||
|
||||
// Start the spinner tick
|
||||
return tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
||||
return signInTickMsg{}
|
||||
})
|
||||
}
|
||||
|
||||
// checkSignIn checks if the user has completed sign-in.
|
||||
func checkSignIn() tea.Msg {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return signInCheckMsg{signedIn: false}
|
||||
}
|
||||
user, err := client.Whoami(context.Background())
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return signInCheckMsg{signedIn: true, userName: user.Name}
|
||||
}
|
||||
return signInCheckMsg{signedIn: false}
|
||||
}
|
||||
|
||||
// loadAvailableModels fetches and caches the list of available models.
|
||||
func (m *model) loadAvailableModels() {
|
||||
m.availableModels = make(map[string]bool)
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
models, err := client.List(context.Background())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, mdl := range models.Models {
|
||||
m.availableModels[mdl.Name] = true
|
||||
}
|
||||
}
|
||||
|
||||
func (m *model) buildItems() {
|
||||
others := getOtherIntegrations()
|
||||
m.items = make([]menuItem, 0, len(mainMenuItems)+1+len(others))
|
||||
m.items = append(m.items, mainMenuItems...)
|
||||
|
||||
if m.showOthers {
|
||||
// Change "Others..." to "Hide others..."
|
||||
hideItem := menuItem{
|
||||
title: "Hide others...",
|
||||
description: "Hide additional integrations",
|
||||
isOthers: true,
|
||||
}
|
||||
m.items = append(m.items, hideItem)
|
||||
m.items = append(m.items, others...)
|
||||
} else {
|
||||
m.items = append(m.items, othersMenuItem)
|
||||
}
|
||||
}
|
||||
|
||||
// isOthersIntegration returns true if the integration is in the "Others" menu
|
||||
func isOthersIntegration(name string) bool {
|
||||
switch name {
|
||||
case "codex", "droid", "opencode":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func initialModel() model {
|
||||
m := model{
|
||||
cursor: 0,
|
||||
}
|
||||
m.loadAvailableModels()
|
||||
|
||||
// Check last selection to determine if we need to expand "Others"
|
||||
lastSelection := config.LastSelection()
|
||||
if isOthersIntegration(lastSelection) {
|
||||
m.showOthers = true
|
||||
}
|
||||
|
||||
m.buildItems()
|
||||
|
||||
// Position cursor on last selection
|
||||
if lastSelection != "" {
|
||||
for i, item := range m.items {
|
||||
if lastSelection == "run" && item.isRunModel {
|
||||
m.cursor = i
|
||||
break
|
||||
} else if item.integration == lastSelection {
|
||||
m.cursor = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func (m model) Init() tea.Cmd {
|
||||
return tea.Tick(blinkInterval, func(t time.Time) tea.Msg {
|
||||
return blinkMsg{}
|
||||
})
|
||||
}
|
||||
|
||||
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// Handle sign-in dialog
|
||||
if m.showingSignIn {
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
switch msg.Type {
|
||||
case tea.KeyCtrlC, tea.KeyEsc:
|
||||
// Cancel sign-in and go back
|
||||
m.showingSignIn = false
|
||||
if m.signInFromModal {
|
||||
m.showingModal = true
|
||||
}
|
||||
// If from main menu, just return to main menu (default state)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
case signInTickMsg:
|
||||
m.signInSpinner++
|
||||
// Check sign-in status every 5th tick (~1 second)
|
||||
if m.signInSpinner%5 == 0 {
|
||||
return m, tea.Batch(
|
||||
tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
||||
return signInTickMsg{}
|
||||
}),
|
||||
checkSignIn,
|
||||
)
|
||||
}
|
||||
return m, tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
||||
return signInTickMsg{}
|
||||
})
|
||||
|
||||
case signInCheckMsg:
|
||||
if msg.signedIn {
|
||||
// Sign-in complete - proceed with selection
|
||||
if m.signInFromModal {
|
||||
// Came from modal - set changeModel
|
||||
m.modalSelector.selected = m.signInModel
|
||||
m.changeModel = true
|
||||
} else {
|
||||
// Came from main menu - just select
|
||||
m.selected = true
|
||||
}
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Handle modal input if modal is showing
|
||||
if m.showingModal {
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
switch msg.Type {
|
||||
case tea.KeyCtrlC, tea.KeyEsc:
|
||||
// Close modal without selection
|
||||
m.showingModal = false
|
||||
return m, nil
|
||||
|
||||
case tea.KeyEnter:
|
||||
filtered := m.modalSelector.filteredItems()
|
||||
if len(filtered) > 0 && m.modalSelector.cursor < len(filtered) {
|
||||
m.modalSelector.selected = filtered[m.modalSelector.cursor].Name
|
||||
}
|
||||
if m.modalSelector.selected != "" {
|
||||
if cmd := m.checkCloudSignIn(m.modalSelector.selected, true); cmd != nil {
|
||||
return m, cmd
|
||||
}
|
||||
// Selection made - exit with changeModel
|
||||
m.changeModel = true
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case tea.KeyUp:
|
||||
if m.modalSelector.cursor > 0 {
|
||||
m.modalSelector.cursor--
|
||||
if m.modalSelector.cursor < m.modalSelector.scrollOffset {
|
||||
m.modalSelector.scrollOffset = m.modalSelector.cursor
|
||||
}
|
||||
}
|
||||
|
||||
case tea.KeyDown:
|
||||
filtered := m.modalSelector.filteredItems()
|
||||
if m.modalSelector.cursor < len(filtered)-1 {
|
||||
m.modalSelector.cursor++
|
||||
if m.modalSelector.cursor >= m.modalSelector.scrollOffset+maxSelectorItems {
|
||||
m.modalSelector.scrollOffset = m.modalSelector.cursor - maxSelectorItems + 1
|
||||
}
|
||||
}
|
||||
|
||||
case tea.KeyPgUp:
|
||||
filtered := m.modalSelector.filteredItems()
|
||||
m.modalSelector.cursor -= maxSelectorItems
|
||||
if m.modalSelector.cursor < 0 {
|
||||
m.modalSelector.cursor = 0
|
||||
}
|
||||
m.modalSelector.scrollOffset -= maxSelectorItems
|
||||
if m.modalSelector.scrollOffset < 0 {
|
||||
m.modalSelector.scrollOffset = 0
|
||||
}
|
||||
_ = filtered // suppress unused warning
|
||||
|
||||
case tea.KeyPgDown:
|
||||
filtered := m.modalSelector.filteredItems()
|
||||
m.modalSelector.cursor += maxSelectorItems
|
||||
if m.modalSelector.cursor >= len(filtered) {
|
||||
m.modalSelector.cursor = len(filtered) - 1
|
||||
}
|
||||
if m.modalSelector.cursor >= m.modalSelector.scrollOffset+maxSelectorItems {
|
||||
m.modalSelector.scrollOffset = m.modalSelector.cursor - maxSelectorItems + 1
|
||||
}
|
||||
|
||||
case tea.KeyBackspace:
|
||||
if len(m.modalSelector.filter) > 0 {
|
||||
m.modalSelector.filter = m.modalSelector.filter[:len(m.modalSelector.filter)-1]
|
||||
m.modalSelector.cursor = 0
|
||||
m.modalSelector.scrollOffset = 0
|
||||
}
|
||||
|
||||
case tea.KeyRunes:
|
||||
m.modalSelector.filter += string(msg.Runes)
|
||||
m.modalSelector.cursor = 0
|
||||
m.modalSelector.scrollOffset = 0
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case blinkMsg:
|
||||
m.blinking = true
|
||||
return m, tea.Tick(blinkDuration, func(t time.Time) tea.Msg {
|
||||
return unblinkMsg{}
|
||||
})
|
||||
|
||||
case unblinkMsg:
|
||||
m.blinking = false
|
||||
return m, tea.Tick(blinkInterval, func(t time.Time) tea.Msg {
|
||||
return blinkMsg{}
|
||||
})
|
||||
|
||||
case tea.KeyMsg:
|
||||
switch msg.String() {
|
||||
case "ctrl+c", "q", "esc":
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
|
||||
case "up", "k":
|
||||
if m.cursor > 0 {
|
||||
m.cursor--
|
||||
}
|
||||
|
||||
case "down", "j":
|
||||
if m.cursor < len(m.items)-1 {
|
||||
m.cursor++
|
||||
}
|
||||
|
||||
case "enter", " ":
|
||||
item := m.items[m.cursor]
|
||||
|
||||
// Handle "Others..." toggle
|
||||
if item.isOthers {
|
||||
m.showOthers = !m.showOthers
|
||||
m.buildItems()
|
||||
// Keep cursor on the Others/Hide item
|
||||
if m.cursor >= len(m.items) {
|
||||
m.cursor = len(m.items) - 1
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Don't allow selecting uninstalled integrations
|
||||
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Check if a cloud model is configured and needs sign-in
|
||||
var configuredModel string
|
||||
if item.isRunModel {
|
||||
configuredModel = config.LastModel()
|
||||
} else if item.integration != "" {
|
||||
configuredModel = config.IntegrationModel(item.integration)
|
||||
}
|
||||
if cmd := m.checkCloudSignIn(configuredModel, false); cmd != nil {
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
m.selected = true
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
|
||||
case "right", "l":
|
||||
// Allow model change for integrations and run model
|
||||
item := m.items[m.cursor]
|
||||
if item.integration != "" || item.isRunModel {
|
||||
// Don't allow for uninstalled integrations
|
||||
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) {
|
||||
return m, nil
|
||||
}
|
||||
m.openModelModal()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m model) View() string {
|
||||
if m.quitting {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Render sign-in dialog if showing
|
||||
if m.showingSignIn {
|
||||
return m.renderSignInDialog()
|
||||
}
|
||||
|
||||
// Render modal overlay if showing - replaces main view
|
||||
if m.showingModal {
|
||||
return m.renderModal()
|
||||
}
|
||||
|
||||
logo := logoNormal
|
||||
if m.blinking {
|
||||
logo = logoBlink
|
||||
}
|
||||
if os.Getenv("TERM_PROGRAM") == "Apple_Terminal" {
|
||||
logo = logoBlank
|
||||
}
|
||||
|
||||
versionText := "\n\n Ollama " + versionStyle.Render("v"+version.Version)
|
||||
|
||||
logoRendered := logoStyle.Render(logo)
|
||||
logoBlock := lipgloss.NewStyle().Padding(0, 1).MarginLeft(2).Background(lipgloss.Color("0")).Render(logoRendered)
|
||||
versionBlock := titleStyle.Render(versionText)
|
||||
header := lipgloss.JoinHorizontal(lipgloss.Top, logoBlock, versionBlock)
|
||||
|
||||
s := header + "\n\n"
|
||||
|
||||
for i, item := range m.items {
|
||||
cursor := " "
|
||||
style := itemStyle
|
||||
isInstalled := true
|
||||
|
||||
if item.integration != "" {
|
||||
isInstalled = config.IsIntegrationInstalled(item.integration)
|
||||
}
|
||||
|
||||
if m.cursor == i {
|
||||
cursor = "▸ "
|
||||
if isInstalled {
|
||||
style = selectedStyle
|
||||
} else {
|
||||
style = greyedSelectedStyle
|
||||
}
|
||||
} else if !isInstalled && item.integration != "" {
|
||||
style = greyedStyle
|
||||
}
|
||||
|
||||
title := item.title
|
||||
if item.integration != "" {
|
||||
if !isInstalled {
|
||||
title += " " + notInstalledStyle.Render("(not installed)")
|
||||
} else if mdl := config.IntegrationModel(item.integration); mdl != "" && m.modelExists(mdl) {
|
||||
title += " " + modelStyle.Render("("+mdl+")")
|
||||
}
|
||||
} else if item.isRunModel {
|
||||
if mdl := config.LastModel(); mdl != "" && m.modelExists(mdl) {
|
||||
title += " " + modelStyle.Render("("+mdl+")")
|
||||
}
|
||||
}
|
||||
|
||||
s += style.Render(cursor+title) + "\n"
|
||||
s += descStyle.Render(item.description) + "\n\n"
|
||||
}
|
||||
|
||||
s += "\n" + lipgloss.NewStyle().Foreground(lipgloss.Color("241")).Render("↑/↓ navigate • enter select • → change model • esc quit")
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// renderModal renders the model picker modal.
|
||||
func (m model) renderModal() string {
|
||||
modalStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(lipgloss.Color("147")).
|
||||
Padding(1, 2).
|
||||
MarginLeft(2)
|
||||
|
||||
var content strings.Builder
|
||||
|
||||
// Title with filter
|
||||
content.WriteString(selectorTitleStyle.Render(m.modalSelector.title))
|
||||
content.WriteString(" ")
|
||||
if m.modalSelector.filter == "" {
|
||||
content.WriteString(selectorFilterStyle.Render("Type to filter..."))
|
||||
} else {
|
||||
content.WriteString(selectorInputStyle.Render(m.modalSelector.filter))
|
||||
}
|
||||
content.WriteString("\n\n")
|
||||
|
||||
filtered := m.modalSelector.filteredItems()
|
||||
|
||||
if len(filtered) == 0 {
|
||||
content.WriteString(selectorItemStyle.Render(selectorDescStyle.Render("(no matches)")))
|
||||
content.WriteString("\n")
|
||||
} else {
|
||||
displayCount := min(len(filtered), maxSelectorItems)
|
||||
|
||||
for i := range displayCount {
|
||||
idx := m.modalSelector.scrollOffset + i
|
||||
if idx >= len(filtered) {
|
||||
break
|
||||
}
|
||||
item := filtered[idx]
|
||||
|
||||
if idx == m.modalSelector.cursor {
|
||||
content.WriteString(selectorSelectedItemStyle.Render("▸ " + item.Name))
|
||||
} else {
|
||||
content.WriteString(selectorItemStyle.Render(item.Name))
|
||||
}
|
||||
|
||||
if item.Description != "" {
|
||||
content.WriteString(" ")
|
||||
content.WriteString(selectorDescStyle.Render("- " + item.Description))
|
||||
}
|
||||
content.WriteString("\n")
|
||||
}
|
||||
|
||||
if remaining := len(filtered) - m.modalSelector.scrollOffset - displayCount; remaining > 0 {
|
||||
content.WriteString(selectorMoreStyle.Render(fmt.Sprintf("... and %d more", remaining)))
|
||||
content.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
content.WriteString("\n")
|
||||
content.WriteString(selectorHelpStyle.Render("↑/↓ navigate • enter select • esc cancel"))
|
||||
|
||||
return modalStyle.Render(content.String())
|
||||
}
|
||||
|
||||
// renderSignInDialog renders the sign-in dialog.
|
||||
func (m model) renderSignInDialog() string {
|
||||
dialogStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(lipgloss.Color("147")).
|
||||
Padding(1, 2).
|
||||
MarginLeft(2)
|
||||
|
||||
spinnerFrames := []string{"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"}
|
||||
spinner := spinnerFrames[m.signInSpinner%len(spinnerFrames)]
|
||||
|
||||
var content strings.Builder
|
||||
|
||||
content.WriteString(selectorTitleStyle.Render("Sign in required"))
|
||||
content.WriteString("\n\n")
|
||||
|
||||
content.WriteString(fmt.Sprintf("To use %s, please sign in.\n\n", selectedStyle.Render(m.signInModel)))
|
||||
|
||||
content.WriteString("Navigate to:\n")
|
||||
content.WriteString(lipgloss.NewStyle().Foreground(lipgloss.Color("117")).Render(" " + m.signInURL))
|
||||
content.WriteString("\n\n")
|
||||
|
||||
content.WriteString(lipgloss.NewStyle().Foreground(lipgloss.Color("241")).Render(
|
||||
fmt.Sprintf("%s Waiting for sign in to complete...", spinner)))
|
||||
content.WriteString("\n\n")
|
||||
|
||||
content.WriteString(selectorHelpStyle.Render("esc cancel"))
|
||||
|
||||
return dialogStyle.Render(content.String())
|
||||
}
|
||||
|
||||
// Selection represents what the user selected
|
||||
type Selection int
|
||||
|
||||
const (
|
||||
SelectionNone Selection = iota
|
||||
SelectionRunModel
|
||||
SelectionChangeRunModel
|
||||
SelectionIntegration // Generic integration selection
|
||||
SelectionChangeIntegration // Generic change model for integration
|
||||
)
|
||||
|
||||
// Result contains the selection and any associated data
|
||||
type Result struct {
|
||||
Selection Selection
|
||||
Integration string // integration name if applicable
|
||||
Model string // model name if selected from modal
|
||||
}
|
||||
|
||||
// Run starts the TUI and returns the user's selection
|
||||
func Run() (Result, error) {
|
||||
m := initialModel()
|
||||
p := tea.NewProgram(m)
|
||||
|
||||
finalModel, err := p.Run()
|
||||
if err != nil {
|
||||
return Result{Selection: SelectionNone}, fmt.Errorf("error running TUI: %w", err)
|
||||
}
|
||||
|
||||
fm := finalModel.(model)
|
||||
if fm.err != nil {
|
||||
return Result{Selection: SelectionNone}, fm.err
|
||||
}
|
||||
|
||||
// User quit without selecting
|
||||
if !fm.selected && !fm.changeModel {
|
||||
return Result{Selection: SelectionNone}, nil
|
||||
}
|
||||
|
||||
item := fm.items[fm.cursor]
|
||||
|
||||
// Handle model change request
|
||||
if fm.changeModel {
|
||||
if item.isRunModel {
|
||||
return Result{
|
||||
Selection: SelectionChangeRunModel,
|
||||
Model: fm.modalSelector.selected,
|
||||
}, nil
|
||||
}
|
||||
return Result{
|
||||
Selection: SelectionChangeIntegration,
|
||||
Integration: item.integration,
|
||||
Model: fm.modalSelector.selected,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Handle selection
|
||||
if item.isRunModel {
|
||||
return Result{Selection: SelectionRunModel}, nil
|
||||
}
|
||||
|
||||
return Result{
|
||||
Selection: SelectionIntegration,
|
||||
Integration: item.integration,
|
||||
}, nil
|
||||
}
|
||||
18
go.mod
18
go.mod
@@ -21,12 +21,10 @@ require (
|
||||
|
||||
require (
|
||||
github.com/agnivade/levenshtein v1.1.1
|
||||
github.com/charmbracelet/bubbletea v1.3.10
|
||||
github.com/charmbracelet/lipgloss v1.1.0
|
||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
|
||||
github.com/dlclark/regexp2 v1.11.4
|
||||
github.com/emirpasic/gods/v2 v2.0.0-alpha
|
||||
github.com/mattn/go-runewidth v0.0.16
|
||||
github.com/mattn/go-runewidth v0.0.14
|
||||
github.com/nlpodyssey/gopickle v0.3.0
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
|
||||
@@ -40,34 +38,22 @@ require (
|
||||
|
||||
require (
|
||||
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 // indirect
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||
github.com/bahlo/generic-list-go v0.2.0 // indirect
|
||||
github.com/buger/jsonparser v1.1.1 // indirect
|
||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
|
||||
github.com/charmbracelet/x/ansi v0.10.1 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect
|
||||
github.com/charmbracelet/x/term v0.2.1 // indirect
|
||||
github.com/chewxy/hm v1.0.0 // indirect
|
||||
github.com/chewxy/math32 v1.11.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||
github.com/cloudwego/iasm v0.2.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/google/flatbuffers v24.3.25+incompatible // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
|
||||
github.com/mailru/easyjson v0.7.7 // indirect
|
||||
github.com/mattn/go-localereader v0.0.1 // indirect
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/rivo/uniseg v0.2.0 // indirect
|
||||
github.com/tkrajina/go-reflector v0.5.5 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
github.com/xtgo/set v1.0.0 // indirect
|
||||
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
|
||||
|
||||
36
go.sum
36
go.sum
@@ -14,8 +14,6 @@ github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 h1:q4dksr6IC
|
||||
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs=
|
||||
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q=
|
||||
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
|
||||
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
|
||||
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
||||
@@ -26,18 +24,6 @@ github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1
|
||||
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
||||
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
|
||||
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
|
||||
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
|
||||
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs=
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk=
|
||||
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
|
||||
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
|
||||
github.com/charmbracelet/x/ansi v0.10.1 h1:rL3Koar5XvX0pHGfovN03f5cxLbCF2YvLeyz7D2jVDQ=
|
||||
github.com/charmbracelet/x/ansi v0.10.1/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs=
|
||||
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
|
||||
github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
|
||||
github.com/chewxy/hm v1.0.0 h1:zy/TSv3LV2nD3dwUEQL2VhXeoXbb9QkpmdRAVUFiA6k=
|
||||
github.com/chewxy/hm v1.0.0/go.mod h1:qg9YI4q6Fkj/whwHR1D+bOGeF7SniIP40VweVepLjg0=
|
||||
github.com/chewxy/math32 v1.0.0/go.mod h1:Miac6hA1ohdDUTagnvJy/q+aNnEk16qWUdb8ZVhvCN0=
|
||||
@@ -73,8 +59,6 @@ github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.m
|
||||
github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk=
|
||||
github.com/envoyproxy/go-control-plane v0.9.9-0.20210512163311-63b5d3c536b0/go.mod h1:hliV/p42l8fGbc6Y9bQ70uLwIvmJyVE5k4iMKlh8wCQ=
|
||||
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
|
||||
github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
|
||||
github.com/fogleman/gg v1.3.0/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||
@@ -164,17 +148,13 @@ github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728 h1:QwWKgMY28TAXaDl+
|
||||
github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728/go.mod h1:1fEHWurg7pvf5SG6XNE5Q8UZmOwex51Mkx3SLhrW5B4=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
|
||||
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
|
||||
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
|
||||
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
|
||||
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
|
||||
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU=
|
||||
github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
|
||||
github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
@@ -182,12 +162,6 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
|
||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
|
||||
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
|
||||
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||
github.com/nlpodyssey/gopickle v0.3.0 h1:BLUE5gxFLyyNOPzlXxt6GoHEMMxD0qhsE4p0CIQyoLw=
|
||||
github.com/nlpodyssey/gopickle v0.3.0/go.mod h1:f070HJ/yR+eLi5WmM1OXJEGaTpuJEUiib19olXgYha0=
|
||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
||||
@@ -208,9 +182,8 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
|
||||
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
|
||||
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
|
||||
@@ -247,8 +220,6 @@ github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
|
||||
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
||||
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||
github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY=
|
||||
github.com/xtgo/set v1.0.0/go.mod h1:d3NHzGzSa0NmB2NhFyECA+QdRp29oEn2xbT+TpeFoM8=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
@@ -335,7 +306,6 @@ golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
|
||||
@@ -144,3 +144,47 @@ func TestUnicodeModelDir(t *testing.T) {
|
||||
}
|
||||
ChatTestHelper(ctx, t, req, blueSkyExpected)
|
||||
}
|
||||
|
||||
// TestNumPredict verifies that when num_predict is set, the model generates
|
||||
// exactly that many tokens. It uses logprobs to count the actual tokens output.
|
||||
func TestNumPredict(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
if err := PullIfMissing(ctx, client, "qwen3:0.6b"); err != nil {
|
||||
t.Fatalf("failed to pull model: %v", err)
|
||||
}
|
||||
|
||||
req := api.GenerateRequest{
|
||||
Model: "qwen3:0.6b",
|
||||
Prompt: "Write a long story.",
|
||||
Stream: &stream,
|
||||
Logprobs: true,
|
||||
Options: map[string]any{
|
||||
"num_predict": 10,
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
|
||||
logprobCount := 0
|
||||
var finalResponse api.GenerateResponse
|
||||
err := client.Generate(ctx, &req, func(resp api.GenerateResponse) error {
|
||||
logprobCount += len(resp.Logprobs)
|
||||
if resp.Done {
|
||||
finalResponse = resp
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("generate failed: %v", err)
|
||||
}
|
||||
|
||||
if logprobCount != 10 {
|
||||
t.Errorf("expected 10 tokens (logprobs), got %d (EvalCount=%d, DoneReason=%s)",
|
||||
logprobCount, finalResponse.EvalCount, finalResponse.DoneReason)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -131,12 +131,15 @@ func AnthropicMessagesMiddleware() gin.HandlerFunc {
|
||||
|
||||
messageID := anthropic.GenerateMessageID()
|
||||
|
||||
// Estimate input tokens for streaming (actual count not available until generation completes)
|
||||
estimatedTokens := anthropic.EstimateInputTokens(req)
|
||||
|
||||
w := &AnthropicWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: messageID,
|
||||
model: req.Model,
|
||||
converter: anthropic.NewStreamConverter(messageID, req.Model),
|
||||
converter: anthropic.NewStreamConverter(messageID, req.Model, estimatedTokens),
|
||||
}
|
||||
|
||||
if req.Stream {
|
||||
|
||||
@@ -514,13 +514,6 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
||||
continue
|
||||
}
|
||||
|
||||
// if past the num predict limit
|
||||
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
||||
s.removeSequence(seqIdx, llm.DoneReasonLength)
|
||||
nextBatch.seqs[seqIdx] = nil
|
||||
continue
|
||||
}
|
||||
|
||||
if !s.cache.enabled {
|
||||
seq.inputs = append(seq.cache.Inputs, seq.inputs...)
|
||||
seq.cache.Inputs = []*input.Input{}
|
||||
@@ -709,7 +702,6 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
continue
|
||||
}
|
||||
|
||||
seq.numPredicted++
|
||||
nextToken := &input.Input{Token: 0} // placeholder we'll fill in after Compute/Floats
|
||||
seq.inputs = []*input.Input{nextToken}
|
||||
nextBatchTokens[i] = nextToken
|
||||
@@ -745,7 +737,9 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
logutil.Trace("computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i)
|
||||
continue
|
||||
}
|
||||
|
||||
seq.lastUpdatedAt = t
|
||||
seq.numPredicted++
|
||||
if seq.numPredicted == 1 {
|
||||
seq.processingDuration = seq.lastUpdatedAt.Sub(seq.startedAt)
|
||||
seq.startedAt = seq.lastUpdatedAt
|
||||
@@ -791,6 +785,13 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
}
|
||||
|
||||
seq.pendingResponses = append(seq.pendingResponses, piece)
|
||||
|
||||
// if past the num predict limit
|
||||
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
||||
s.removeSequence(i, llm.DoneReasonLength)
|
||||
continue
|
||||
}
|
||||
|
||||
sequence := strings.Join(seq.pendingResponses, "")
|
||||
|
||||
if ok, stop := common.FindStop(sequence, seq.stop); ok {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#!/bin/sh
|
||||
# This script installs Ollama on Linux.
|
||||
# This script installs Ollama on Linux and macOS.
|
||||
# It detects the current operating system architecture and installs the appropriate version of Ollama.
|
||||
|
||||
set -eu
|
||||
@@ -27,8 +27,7 @@ require() {
|
||||
echo $MISSING
|
||||
}
|
||||
|
||||
[ "$(uname -s)" = "Linux" ] || error 'This script is intended to run on Linux only.'
|
||||
|
||||
OS="$(uname -s)"
|
||||
ARCH=$(uname -m)
|
||||
case "$ARCH" in
|
||||
x86_64) ARCH="amd64" ;;
|
||||
@@ -36,6 +35,65 @@ case "$ARCH" in
|
||||
*) error "Unsupported architecture: $ARCH" ;;
|
||||
esac
|
||||
|
||||
###########################################
|
||||
# macOS
|
||||
###########################################
|
||||
|
||||
if [ "$OS" = "Darwin" ]; then
|
||||
NEEDS=$(require curl unzip)
|
||||
if [ -n "$NEEDS" ]; then
|
||||
status "ERROR: The following tools are required but missing:"
|
||||
for NEED in $NEEDS; do
|
||||
echo " - $NEED"
|
||||
done
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -n "${OLLAMA_VERSION:-}" ]; then
|
||||
DOWNLOAD_URL="https://github.com/ollama/ollama/releases/download/${OLLAMA_VERSION}/Ollama-darwin.zip"
|
||||
else
|
||||
DOWNLOAD_URL="https://github.com/ollama/ollama/releases/latest/download/Ollama-darwin.zip"
|
||||
fi
|
||||
|
||||
if pgrep -x Ollama >/dev/null 2>&1; then
|
||||
status "Stopping running Ollama instance..."
|
||||
pkill -x Ollama 2>/dev/null || true
|
||||
sleep 2
|
||||
fi
|
||||
|
||||
if [ -d "/Applications/Ollama.app" ]; then
|
||||
status "Removing existing Ollama installation..."
|
||||
rm -rf "/Applications/Ollama.app"
|
||||
fi
|
||||
|
||||
status "Downloading Ollama for macOS..."
|
||||
curl --fail --show-error --location --progress-bar \
|
||||
-o "$TEMP_DIR/Ollama-darwin.zip" "$DOWNLOAD_URL"
|
||||
|
||||
status "Installing Ollama to /Applications..."
|
||||
unzip -q "$TEMP_DIR/Ollama-darwin.zip" -d "$TEMP_DIR"
|
||||
mv "$TEMP_DIR/Ollama.app" "/Applications/"
|
||||
|
||||
status "Adding 'ollama' command to PATH (may require password)..."
|
||||
mkdir -p "/usr/local/bin" 2>/dev/null || sudo mkdir -p "/usr/local/bin"
|
||||
ln -sf "/Applications/Ollama.app/Contents/Resources/ollama" "/usr/local/bin/ollama" 2>/dev/null || \
|
||||
sudo ln -sf "/Applications/Ollama.app/Contents/Resources/ollama" "/usr/local/bin/ollama"
|
||||
|
||||
if [ -z "${OLLAMA_NO_START:-}" ]; then
|
||||
status "Starting Ollama..."
|
||||
open -a Ollama --args hidden
|
||||
fi
|
||||
|
||||
status "Install complete. You can now run 'ollama'."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
###########################################
|
||||
# Linux
|
||||
###########################################
|
||||
|
||||
[ "$OS" = "Linux" ] || error 'This script is intended to run on Linux and macOS only.'
|
||||
|
||||
IS_WSL2=false
|
||||
|
||||
KERN=$(uname -r)
|
||||
|
||||
422
server/aliases.go
Normal file
422
server/aliases.go
Normal file
@@ -0,0 +1,422 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
const (
|
||||
serverConfigFilename = "server.json"
|
||||
serverConfigVersion = 1
|
||||
)
|
||||
|
||||
var errAliasCycle = errors.New("alias cycle detected")
|
||||
|
||||
type aliasEntry struct {
|
||||
Alias string `json:"alias"`
|
||||
Target string `json:"target"`
|
||||
PrefixMatching bool `json:"prefix_matching,omitempty"`
|
||||
}
|
||||
|
||||
type serverConfig struct {
|
||||
Version int `json:"version"`
|
||||
Aliases []aliasEntry `json:"aliases"`
|
||||
}
|
||||
|
||||
type store struct {
|
||||
mu sync.RWMutex
|
||||
path string
|
||||
entries map[string]aliasEntry // normalized alias -> entry (exact matches)
|
||||
prefixEntries []aliasEntry // prefix matches, sorted longest-first
|
||||
}
|
||||
|
||||
func createStore(path string) (*store, error) {
|
||||
store := &store{
|
||||
path: path,
|
||||
entries: make(map[string]aliasEntry),
|
||||
}
|
||||
if err := store.load(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func (s *store) load() error {
|
||||
data, err := os.ReadFile(s.path)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
var cfg serverConfig
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if cfg.Version != 0 && cfg.Version != serverConfigVersion {
|
||||
return fmt.Errorf("unsupported router config version %d", cfg.Version)
|
||||
}
|
||||
|
||||
for _, entry := range cfg.Aliases {
|
||||
targetName := model.ParseName(entry.Target)
|
||||
if !targetName.IsValid() {
|
||||
slog.Warn("invalid alias target in router config", "target", entry.Target)
|
||||
continue
|
||||
}
|
||||
canonicalTarget := displayAliasName(targetName)
|
||||
|
||||
if entry.PrefixMatching {
|
||||
// Prefix aliases don't need to be valid model names
|
||||
alias := strings.TrimSpace(entry.Alias)
|
||||
if alias == "" {
|
||||
slog.Warn("empty prefix alias in router config")
|
||||
continue
|
||||
}
|
||||
s.prefixEntries = append(s.prefixEntries, aliasEntry{
|
||||
Alias: alias,
|
||||
Target: canonicalTarget,
|
||||
PrefixMatching: true,
|
||||
})
|
||||
} else {
|
||||
aliasName := model.ParseName(entry.Alias)
|
||||
if !aliasName.IsValid() {
|
||||
slog.Warn("invalid alias name in router config", "alias", entry.Alias)
|
||||
continue
|
||||
}
|
||||
canonicalAlias := displayAliasName(aliasName)
|
||||
s.entries[normalizeAliasKey(aliasName)] = aliasEntry{
|
||||
Alias: canonicalAlias,
|
||||
Target: canonicalTarget,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort prefix entries by alias length descending (longest prefix wins)
|
||||
s.sortPrefixEntriesLocked()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *store) saveLocked() error {
|
||||
dir := filepath.Dir(s.path)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Combine exact and prefix entries
|
||||
entries := make([]aliasEntry, 0, len(s.entries)+len(s.prefixEntries))
|
||||
for _, entry := range s.entries {
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
entries = append(entries, s.prefixEntries...)
|
||||
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return strings.Compare(entries[i].Alias, entries[j].Alias) < 0
|
||||
})
|
||||
|
||||
cfg := serverConfig{
|
||||
Version: serverConfigVersion,
|
||||
Aliases: entries,
|
||||
}
|
||||
|
||||
f, err := os.CreateTemp(dir, "router-*.json")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
enc := json.NewEncoder(f)
|
||||
enc.SetIndent("", " ")
|
||||
if err := enc.Encode(cfg); err != nil {
|
||||
_ = f.Close()
|
||||
_ = os.Remove(f.Name())
|
||||
return err
|
||||
}
|
||||
|
||||
if err := f.Close(); err != nil {
|
||||
_ = os.Remove(f.Name())
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.Chmod(f.Name(), 0o644); err != nil {
|
||||
_ = os.Remove(f.Name())
|
||||
return err
|
||||
}
|
||||
|
||||
return os.Rename(f.Name(), s.path)
|
||||
}
|
||||
|
||||
func (s *store) ResolveName(name model.Name) (model.Name, bool, error) {
|
||||
// If a local model exists, do not allow alias shadowing (highest priority).
|
||||
exists, err := localModelExists(name)
|
||||
if err != nil {
|
||||
return name, false, err
|
||||
}
|
||||
if exists {
|
||||
return name, false, nil
|
||||
}
|
||||
|
||||
key := normalizeAliasKey(name)
|
||||
|
||||
s.mu.RLock()
|
||||
entry, exactMatch := s.entries[key]
|
||||
var prefixMatch *aliasEntry
|
||||
if !exactMatch {
|
||||
// Try prefix matching - prefixEntries is sorted longest-first
|
||||
nameStr := strings.ToLower(displayAliasName(name))
|
||||
for i := range s.prefixEntries {
|
||||
prefix := strings.ToLower(s.prefixEntries[i].Alias)
|
||||
if strings.HasPrefix(nameStr, prefix) {
|
||||
prefixMatch = &s.prefixEntries[i]
|
||||
break // First match is longest due to sorting
|
||||
}
|
||||
}
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exactMatch && prefixMatch == nil {
|
||||
return name, false, nil
|
||||
}
|
||||
|
||||
var current string
|
||||
var visited map[string]struct{}
|
||||
|
||||
if exactMatch {
|
||||
visited = map[string]struct{}{key: {}}
|
||||
current = entry.Target
|
||||
} else {
|
||||
// For prefix match, use the target as-is
|
||||
visited = map[string]struct{}{}
|
||||
current = prefixMatch.Target
|
||||
}
|
||||
|
||||
targetKey := normalizeAliasKeyString(current)
|
||||
|
||||
for {
|
||||
targetName := model.ParseName(current)
|
||||
if !targetName.IsValid() {
|
||||
return name, false, fmt.Errorf("alias target %q is invalid", current)
|
||||
}
|
||||
|
||||
if _, seen := visited[targetKey]; seen {
|
||||
return name, false, errAliasCycle
|
||||
}
|
||||
visited[targetKey] = struct{}{}
|
||||
|
||||
s.mu.RLock()
|
||||
next, ok := s.entries[targetKey]
|
||||
s.mu.RUnlock()
|
||||
if !ok {
|
||||
return targetName, true, nil
|
||||
}
|
||||
|
||||
current = next.Target
|
||||
targetKey = normalizeAliasKeyString(current)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *store) Set(alias, target model.Name, prefixMatching bool) error {
|
||||
targetKey := normalizeAliasKey(target)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if prefixMatching {
|
||||
// For prefix aliases, we skip cycle detection since prefix matching
|
||||
// works differently and the target is a specific model
|
||||
aliasStr := displayAliasName(alias)
|
||||
|
||||
// Remove any existing prefix entry with the same alias
|
||||
for i, e := range s.prefixEntries {
|
||||
if strings.EqualFold(e.Alias, aliasStr) {
|
||||
s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
s.prefixEntries = append(s.prefixEntries, aliasEntry{
|
||||
Alias: aliasStr,
|
||||
Target: displayAliasName(target),
|
||||
PrefixMatching: true,
|
||||
})
|
||||
s.sortPrefixEntriesLocked()
|
||||
return s.saveLocked()
|
||||
}
|
||||
|
||||
aliasKey := normalizeAliasKey(alias)
|
||||
|
||||
if aliasKey == targetKey {
|
||||
return fmt.Errorf("alias cannot point to itself")
|
||||
}
|
||||
|
||||
visited := map[string]struct{}{aliasKey: {}}
|
||||
currentKey := targetKey
|
||||
for {
|
||||
if _, seen := visited[currentKey]; seen {
|
||||
return errAliasCycle
|
||||
}
|
||||
visited[currentKey] = struct{}{}
|
||||
|
||||
next, ok := s.entries[currentKey]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
currentKey = normalizeAliasKeyString(next.Target)
|
||||
}
|
||||
|
||||
s.entries[aliasKey] = aliasEntry{
|
||||
Alias: displayAliasName(alias),
|
||||
Target: displayAliasName(target),
|
||||
}
|
||||
|
||||
return s.saveLocked()
|
||||
}
|
||||
|
||||
func (s *store) Delete(alias model.Name) (bool, error) {
|
||||
aliasKey := normalizeAliasKey(alias)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Try exact match first
|
||||
if _, ok := s.entries[aliasKey]; ok {
|
||||
delete(s.entries, aliasKey)
|
||||
return true, s.saveLocked()
|
||||
}
|
||||
|
||||
// Try prefix entries
|
||||
aliasStr := displayAliasName(alias)
|
||||
for i, e := range s.prefixEntries {
|
||||
if strings.EqualFold(e.Alias, aliasStr) {
|
||||
s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...)
|
||||
return true, s.saveLocked()
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// DeleteByString deletes an alias by its raw string value, useful for prefix
|
||||
// aliases that may not be valid model names.
|
||||
func (s *store) DeleteByString(alias string) (bool, error) {
|
||||
alias = strings.TrimSpace(alias)
|
||||
aliasLower := strings.ToLower(alias)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Try prefix entries first (since this is mainly for prefix aliases)
|
||||
for i, e := range s.prefixEntries {
|
||||
if strings.EqualFold(e.Alias, alias) {
|
||||
s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...)
|
||||
return true, s.saveLocked()
|
||||
}
|
||||
}
|
||||
|
||||
// Also check exact entries by normalized key
|
||||
if _, ok := s.entries[aliasLower]; ok {
|
||||
delete(s.entries, aliasLower)
|
||||
return true, s.saveLocked()
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (s *store) List() []aliasEntry {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
entries := make([]aliasEntry, 0, len(s.entries)+len(s.prefixEntries))
|
||||
for _, entry := range s.entries {
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
entries = append(entries, s.prefixEntries...)
|
||||
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return strings.Compare(entries[i].Alias, entries[j].Alias) < 0
|
||||
})
|
||||
return entries
|
||||
}
|
||||
|
||||
func normalizeAliasKey(name model.Name) string {
|
||||
return strings.ToLower(displayAliasName(name))
|
||||
}
|
||||
|
||||
func (s *store) sortPrefixEntriesLocked() {
|
||||
sort.Slice(s.prefixEntries, func(i, j int) bool {
|
||||
// Sort by length descending (longest prefix first)
|
||||
return len(s.prefixEntries[i].Alias) > len(s.prefixEntries[j].Alias)
|
||||
})
|
||||
}
|
||||
|
||||
func normalizeAliasKeyString(value string) string {
|
||||
n := model.ParseName(value)
|
||||
if !n.IsValid() {
|
||||
return strings.ToLower(strings.TrimSpace(value))
|
||||
}
|
||||
return normalizeAliasKey(n)
|
||||
}
|
||||
|
||||
func displayAliasName(n model.Name) string {
|
||||
display := n.DisplayShortest()
|
||||
if strings.EqualFold(n.Tag, "latest") {
|
||||
if idx := strings.LastIndex(display, ":"); idx != -1 {
|
||||
return display[:idx]
|
||||
}
|
||||
}
|
||||
return display
|
||||
}
|
||||
|
||||
func localModelExists(name model.Name) (bool, error) {
|
||||
manifests, err := manifest.Manifests(true)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
needle := name.String()
|
||||
for existing := range manifests {
|
||||
if strings.EqualFold(existing.String(), needle) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func serverConfigPath() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return filepath.Join(".ollama", serverConfigFilename)
|
||||
}
|
||||
return filepath.Join(home, ".ollama", serverConfigFilename)
|
||||
}
|
||||
|
||||
func (s *Server) aliasStore() (*store, error) {
|
||||
s.aliasesOnce.Do(func() {
|
||||
s.aliases, s.aliasesErr = createStore(serverConfigPath())
|
||||
})
|
||||
|
||||
return s.aliases, s.aliasesErr
|
||||
}
|
||||
|
||||
func (s *Server) resolveAlias(name model.Name) (model.Name, bool, error) {
|
||||
store, err := s.aliasStore()
|
||||
if err != nil {
|
||||
return name, false, err
|
||||
}
|
||||
|
||||
if store == nil {
|
||||
return name, false, nil
|
||||
}
|
||||
|
||||
return store.ResolveName(name)
|
||||
}
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"os/signal"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
@@ -81,6 +82,9 @@ type Server struct {
|
||||
addr net.Addr
|
||||
sched *Scheduler
|
||||
defaultNumCtx int
|
||||
aliasesOnce sync.Once
|
||||
aliases *store
|
||||
aliasesErr error
|
||||
}
|
||||
|
||||
func init() {
|
||||
@@ -191,9 +195,16 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
resolvedName, _, err := s.resolveAlias(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
name = resolvedName
|
||||
|
||||
// We cannot currently consolidate this into GetModel because all we'll
|
||||
// induce infinite recursion given the current code structure.
|
||||
name, err := getExistingName(name)
|
||||
name, err = getExistingName(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
return
|
||||
@@ -1580,6 +1591,9 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
|
||||
r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
|
||||
r.POST("/api/copy", s.CopyHandler)
|
||||
r.GET("/api/experimental/aliases", s.ListAliasesHandler)
|
||||
r.POST("/api/experimental/aliases", s.CreateAliasHandler)
|
||||
r.DELETE("/api/experimental/aliases", s.DeleteAliasHandler)
|
||||
|
||||
// Inference
|
||||
r.GET("/api/ps", s.PsHandler)
|
||||
@@ -1950,13 +1964,20 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
name, err := getExistingName(name)
|
||||
resolvedName, _, err := s.resolveAlias(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
name = resolvedName
|
||||
|
||||
name, err = getExistingName(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
return
|
||||
}
|
||||
|
||||
m, err := GetModel(req.Model)
|
||||
m, err := GetModel(name.String())
|
||||
if err != nil {
|
||||
switch {
|
||||
case os.IsNotExist(err):
|
||||
|
||||
159
server/routes_aliases.go
Normal file
159
server/routes_aliases.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
type aliasListResponse struct {
|
||||
Aliases []aliasEntry `json:"aliases"`
|
||||
}
|
||||
|
||||
type aliasDeleteRequest struct {
|
||||
Alias string `json:"alias"`
|
||||
}
|
||||
|
||||
func (s *Server) ListAliasesHandler(c *gin.Context) {
|
||||
store, err := s.aliasStore()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
var aliases []aliasEntry
|
||||
if store != nil {
|
||||
aliases = store.List()
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, aliasListResponse{Aliases: aliases})
|
||||
}
|
||||
|
||||
func (s *Server) CreateAliasHandler(c *gin.Context) {
|
||||
var req aliasEntry
|
||||
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
} else if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
req.Alias = strings.TrimSpace(req.Alias)
|
||||
req.Target = strings.TrimSpace(req.Target)
|
||||
if req.Alias == "" || req.Target == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias and target are required"})
|
||||
return
|
||||
}
|
||||
|
||||
// Target must always be a valid model name
|
||||
targetName := model.ParseName(req.Target)
|
||||
if !targetName.IsValid() {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("target %q is invalid", req.Target)})
|
||||
return
|
||||
}
|
||||
|
||||
var aliasName model.Name
|
||||
if req.PrefixMatching {
|
||||
// For prefix aliases, we still parse the alias to normalize it,
|
||||
// but we allow any non-empty string since prefix patterns may not be valid model names
|
||||
aliasName = model.ParseName(req.Alias)
|
||||
// Even if not valid as a model name, we accept it for prefix matching
|
||||
} else {
|
||||
aliasName = model.ParseName(req.Alias)
|
||||
if !aliasName.IsValid() {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("alias %q is invalid", req.Alias)})
|
||||
return
|
||||
}
|
||||
|
||||
if normalizeAliasKey(aliasName) == normalizeAliasKey(targetName) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias cannot point to itself"})
|
||||
return
|
||||
}
|
||||
|
||||
exists, err := localModelExists(aliasName)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if exists {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("alias %q conflicts with existing model", req.Alias)})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
store, err := s.aliasStore()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := store.Set(aliasName, targetName, req.PrefixMatching); err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
if errors.Is(err, errAliasCycle) {
|
||||
status = http.StatusBadRequest
|
||||
}
|
||||
c.AbortWithStatusJSON(status, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
resp := aliasEntry{
|
||||
Alias: displayAliasName(aliasName),
|
||||
Target: displayAliasName(targetName),
|
||||
PrefixMatching: req.PrefixMatching,
|
||||
}
|
||||
if req.PrefixMatching && !aliasName.IsValid() {
|
||||
// For prefix aliases that aren't valid model names, use the raw alias
|
||||
resp.Alias = req.Alias
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (s *Server) DeleteAliasHandler(c *gin.Context) {
|
||||
var req aliasDeleteRequest
|
||||
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
} else if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
req.Alias = strings.TrimSpace(req.Alias)
|
||||
if req.Alias == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias is required"})
|
||||
return
|
||||
}
|
||||
|
||||
store, err := s.aliasStore()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
aliasName := model.ParseName(req.Alias)
|
||||
var deleted bool
|
||||
if aliasName.IsValid() {
|
||||
deleted, err = store.Delete(aliasName)
|
||||
} else {
|
||||
// For invalid model names (like prefix aliases), try deleting by raw string
|
||||
deleted, err = store.DeleteByString(req.Alias)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if !deleted {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("alias %q not found", req.Alias)})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"deleted": true})
|
||||
}
|
||||
426
server/routes_aliases_test.go
Normal file
426
server/routes_aliases_test.go
Normal file
@@ -0,0 +1,426 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func TestAliasShadowingRejected(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Setenv("HOME", t.TempDir())
|
||||
|
||||
s := Server{}
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "shadowed-model",
|
||||
RemoteHost: "example.com",
|
||||
From: "test",
|
||||
Info: map[string]any{
|
||||
"capabilities": []string{"completion"},
|
||||
},
|
||||
Stream: &stream,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
w = createRequest(t, s.CreateAliasHandler, aliasEntry{Alias: "shadowed-model", Target: "other-model"})
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected status 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAliasResolvesForChatRemote(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Setenv("HOME", t.TempDir())
|
||||
|
||||
var remoteModel string
|
||||
rs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var req api.ChatRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
remoteModel = req.Model
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
resp := api.ChatResponse{
|
||||
Model: req.Model,
|
||||
Done: true,
|
||||
DoneReason: "load",
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}))
|
||||
defer rs.Close()
|
||||
|
||||
p, err := url.Parse(rs.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_REMOTES", p.Hostname())
|
||||
|
||||
s := Server{}
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "target-model",
|
||||
RemoteHost: rs.URL,
|
||||
From: "test",
|
||||
Info: map[string]any{
|
||||
"capabilities": []string{"completion"},
|
||||
},
|
||||
Stream: &stream,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
w = createRequest(t, s.CreateAliasHandler, aliasEntry{Alias: "alias-model", Target: "target-model"})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
w = createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "alias-model",
|
||||
Messages: []api.Message{{Role: "user", Content: "hi"}},
|
||||
Stream: &stream,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp api.ChatResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if resp.Model != "alias-model" {
|
||||
t.Fatalf("expected response model to be alias-model, got %q", resp.Model)
|
||||
}
|
||||
|
||||
if remoteModel != "test" {
|
||||
t.Fatalf("expected remote model to be 'test', got %q", remoteModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrefixAliasBasicMatching(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
store, err := createStore(filepath.Join(tmpDir, "server.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a prefix alias: "myprefix-" -> "targetmodel"
|
||||
targetName := model.ParseName("targetmodel")
|
||||
|
||||
// Set a prefix alias (using "myprefix-" as the pattern)
|
||||
store.mu.Lock()
|
||||
store.prefixEntries = append(store.prefixEntries, aliasEntry{
|
||||
Alias: "myprefix-",
|
||||
Target: "targetmodel",
|
||||
PrefixMatching: true,
|
||||
})
|
||||
store.mu.Unlock()
|
||||
|
||||
// Test that "myprefix-foo" resolves to "targetmodel"
|
||||
testName := model.ParseName("myprefix-foo")
|
||||
resolved, wasResolved, err := store.ResolveName(testName)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected name to be resolved")
|
||||
}
|
||||
if resolved.DisplayShortest() != targetName.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be %q, got %q", targetName.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
|
||||
// Test that "otherprefix-foo" does not resolve
|
||||
otherName := model.ParseName("otherprefix-foo")
|
||||
_, wasResolved, err = store.ResolveName(otherName)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if wasResolved {
|
||||
t.Fatal("expected name not to be resolved")
|
||||
}
|
||||
|
||||
// Test that exact alias takes precedence
|
||||
exactAlias := model.ParseName("myprefix-exact")
|
||||
exactTarget := model.ParseName("exacttarget")
|
||||
if err := store.Set(exactAlias, exactTarget, false); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
resolved, wasResolved, err = store.ResolveName(exactAlias)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected name to be resolved")
|
||||
}
|
||||
if resolved.DisplayShortest() != exactTarget.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be %q (exact match), got %q", exactTarget.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrefixAliasLongestMatchWins(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
store, err := createStore(filepath.Join(tmpDir, "server.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Add two prefix aliases with overlapping patterns
|
||||
store.mu.Lock()
|
||||
store.prefixEntries = []aliasEntry{
|
||||
{Alias: "abc-", Target: "short-target", PrefixMatching: true},
|
||||
{Alias: "abc-def-", Target: "long-target", PrefixMatching: true},
|
||||
}
|
||||
store.sortPrefixEntriesLocked()
|
||||
store.mu.Unlock()
|
||||
|
||||
// "abc-def-ghi" should match the longer prefix "abc-def-"
|
||||
testName := model.ParseName("abc-def-ghi")
|
||||
resolved, wasResolved, err := store.ResolveName(testName)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected name to be resolved")
|
||||
}
|
||||
expectedLongTarget := model.ParseName("long-target")
|
||||
if resolved.DisplayShortest() != expectedLongTarget.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be %q (longest prefix match), got %q", expectedLongTarget.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
|
||||
// "abc-xyz" should match the shorter prefix "abc-"
|
||||
testName2 := model.ParseName("abc-xyz")
|
||||
resolved, wasResolved, err = store.ResolveName(testName2)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected name to be resolved")
|
||||
}
|
||||
expectedShortTarget := model.ParseName("short-target")
|
||||
if resolved.DisplayShortest() != expectedShortTarget.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be %q, got %q", expectedShortTarget.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrefixAliasChain(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
store, err := createStore(filepath.Join(tmpDir, "server.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a chain: prefix "test-" -> "intermediate" -> "final"
|
||||
intermediate := model.ParseName("intermediate")
|
||||
final := model.ParseName("final")
|
||||
|
||||
// Add prefix alias
|
||||
store.mu.Lock()
|
||||
store.prefixEntries = []aliasEntry{
|
||||
{Alias: "test-", Target: "intermediate", PrefixMatching: true},
|
||||
}
|
||||
store.mu.Unlock()
|
||||
|
||||
// Add exact alias for the intermediate step
|
||||
if err := store.Set(intermediate, final, false); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// "test-foo" should resolve through the chain to "final"
|
||||
testName := model.ParseName("test-foo")
|
||||
resolved, wasResolved, err := store.ResolveName(testName)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected name to be resolved")
|
||||
}
|
||||
if resolved.DisplayShortest() != final.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be %q, got %q", final.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrefixAliasCRUD(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Setenv("HOME", t.TempDir())
|
||||
|
||||
s := Server{}
|
||||
|
||||
// Create a prefix alias via API
|
||||
w := createRequest(t, s.CreateAliasHandler, aliasEntry{
|
||||
Alias: "myprefix-",
|
||||
Target: "llama2",
|
||||
PrefixMatching: true,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var createResp aliasEntry
|
||||
if err := json.NewDecoder(w.Body).Decode(&createResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !createResp.PrefixMatching {
|
||||
t.Fatal("expected prefix_matching to be true in response")
|
||||
}
|
||||
|
||||
// List aliases and verify the prefix alias is included
|
||||
w = createRequest(t, s.ListAliasesHandler, nil)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var listResp aliasListResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&listResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, a := range listResp.Aliases {
|
||||
if a.PrefixMatching && a.Target == "llama2" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatal("expected to find prefix alias in list")
|
||||
}
|
||||
|
||||
// Delete the prefix alias
|
||||
w = createRequest(t, s.DeleteAliasHandler, aliasDeleteRequest{Alias: "myprefix-"})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Verify it's deleted
|
||||
w = createRequest(t, s.ListAliasesHandler, nil)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(w.Body).Decode(&listResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, a := range listResp.Aliases {
|
||||
if a.PrefixMatching {
|
||||
t.Fatal("expected prefix alias to be deleted")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrefixAliasCaseInsensitive(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
store, err := createStore(filepath.Join(tmpDir, "server.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Add a prefix alias with mixed case
|
||||
store.mu.Lock()
|
||||
store.prefixEntries = []aliasEntry{
|
||||
{Alias: "MyPrefix-", Target: "targetmodel", PrefixMatching: true},
|
||||
}
|
||||
store.mu.Unlock()
|
||||
|
||||
// Test that matching is case-insensitive
|
||||
testName := model.ParseName("myprefix-foo")
|
||||
resolved, wasResolved, err := store.ResolveName(testName)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected name to be resolved (case-insensitive)")
|
||||
}
|
||||
expectedTarget := model.ParseName("targetmodel")
|
||||
if resolved.DisplayShortest() != expectedTarget.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be %q, got %q", expectedTarget.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
|
||||
// Test uppercase request
|
||||
testName2 := model.ParseName("MYPREFIX-BAR")
|
||||
_, wasResolved, err = store.ResolveName(testName2)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected name to be resolved (uppercase)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrefixAliasLocalModelPrecedence(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Setenv("HOME", t.TempDir())
|
||||
|
||||
s := Server{}
|
||||
|
||||
// Create a local model that would match a prefix alias
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "myprefix-localmodel",
|
||||
RemoteHost: "example.com",
|
||||
From: "test",
|
||||
Info: map[string]any{
|
||||
"capabilities": []string{"completion"},
|
||||
},
|
||||
Stream: &stream,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Create a prefix alias that would match the local model name
|
||||
w = createRequest(t, s.CreateAliasHandler, aliasEntry{
|
||||
Alias: "myprefix-",
|
||||
Target: "someothermodel",
|
||||
PrefixMatching: true,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Verify that resolving "myprefix-localmodel" returns the local model, not the alias target
|
||||
store, err := s.aliasStore()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
localModelName := model.ParseName("myprefix-localmodel")
|
||||
resolved, wasResolved, err := store.ResolveName(localModelName)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if wasResolved {
|
||||
t.Fatalf("expected local model to take precedence (wasResolved should be false), but got resolved to %q", resolved.DisplayShortest())
|
||||
}
|
||||
if resolved.DisplayShortest() != localModelName.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be local model %q, got %q", localModelName.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
|
||||
// Also verify that a non-local model matching the prefix DOES resolve to the alias target
|
||||
nonLocalName := model.ParseName("myprefix-nonexistent")
|
||||
resolved, wasResolved, err = store.ResolveName(nonLocalName)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected non-local model to resolve via prefix alias")
|
||||
}
|
||||
expectedTarget := model.ParseName("someothermodel")
|
||||
if resolved.DisplayShortest() != expectedTarget.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be %q, got %q", expectedTarget.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
}
|
||||
@@ -417,9 +417,9 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo
|
||||
numParallel = 1
|
||||
}
|
||||
|
||||
// `mllama`, `qwen3vl`, and `qwen3vlmoe` are snowflakes and uses an encoder cache which cannot be used with num_parallel > 1
|
||||
// Some architectures are not safe with num_parallel > 1.
|
||||
// ref: https://github.com/ollama/ollama/issues/4165
|
||||
if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe"}, req.model.Config.ModelFamily) && numParallel != 1 {
|
||||
if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe", "qwen3next", "lfm2", "lfm2moe"}, req.model.Config.ModelFamily) && numParallel != 1 {
|
||||
numParallel = 1
|
||||
slog.Warn("model architecture does not currently support parallel requests", "architecture", req.model.Config.ModelFamily)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user