diff --git a/anthropic/anthropic.go b/anthropic/anthropic.go index f5799fc1e..a5ff995f5 100755 --- a/anthropic/anthropic.go +++ b/anthropic/anthropic.go @@ -518,24 +518,26 @@ func mapStopReason(reason string, hasToolCalls bool) string { // StreamConverter manages state for converting Ollama streaming responses to Anthropic format type StreamConverter struct { - ID string - Model string - firstWrite bool - contentIndex int - inputTokens int - outputTokens int - thinkingStarted bool - thinkingDone bool - textStarted bool - toolCallsSent map[string]bool + ID string + Model string + firstWrite bool + contentIndex int + inputTokens int + outputTokens int + estimatedInputTokens int // Estimated tokens from request (used when actual metrics are 0) + thinkingStarted bool + thinkingDone bool + textStarted bool + toolCallsSent map[string]bool } -func NewStreamConverter(id, model string) *StreamConverter { +func NewStreamConverter(id, model string, estimatedInputTokens int) *StreamConverter { return &StreamConverter{ - ID: id, - Model: model, - firstWrite: true, - toolCallsSent: make(map[string]bool), + ID: id, + Model: model, + firstWrite: true, + estimatedInputTokens: estimatedInputTokens, + toolCallsSent: make(map[string]bool), } } @@ -551,7 +553,11 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent { if c.firstWrite { c.firstWrite = false + // Use actual metrics if available, otherwise use estimate c.inputTokens = r.Metrics.PromptEvalCount + if c.inputTokens == 0 && c.estimatedInputTokens > 0 { + c.inputTokens = c.estimatedInputTokens + } events = append(events, StreamEvent{ Event: "message_start", @@ -779,3 +785,121 @@ func mapToArgs(m map[string]any) api.ToolCallFunctionArguments { } return args } + +// CountTokensRequest represents an Anthropic count_tokens request +type CountTokensRequest struct { + Model string `json:"model"` + Messages []MessageParam `json:"messages"` + System any `json:"system,omitempty"` + Tools []Tool `json:"tools,omitempty"` + Thinking *ThinkingConfig `json:"thinking,omitempty"` +} + +// EstimateInputTokens estimates input tokens from a MessagesRequest (reuses CountTokensRequest logic) +func EstimateInputTokens(req MessagesRequest) int { + return estimateTokens(CountTokensRequest{ + Model: req.Model, + Messages: req.Messages, + System: req.System, + Tools: req.Tools, + Thinking: req.Thinking, + }) +} + +// CountTokensResponse represents an Anthropic count_tokens response +type CountTokensResponse struct { + InputTokens int `json:"input_tokens"` +} + +// estimateTokens returns a rough estimate of tokens (len/4) +func estimateTokens(req CountTokensRequest) int { + var totalLen int + + // Count system prompt + if req.System != nil { + totalLen += countAnyContent(req.System) + } + + // Count messages + for _, msg := range req.Messages { + // Count role (always present) + totalLen += len(msg.Role) + // Count content + contentLen := countAnyContent(msg.Content) + totalLen += contentLen + } + + for _, tool := range req.Tools { + totalLen += len(tool.Name) + len(tool.Description) + len(tool.InputSchema) + } + + // Return len/4 as rough token estimate, minimum 1 if there's any content + tokens := totalLen / 4 + if tokens == 0 && (len(req.Messages) > 0 || req.System != nil) { + tokens = 1 + } + return tokens +} + +func countAnyContent(content any) int { + if content == nil { + return 0 + } + + switch c := content.(type) { + case string: + return len(c) + case []any: + total := 0 + for _, block := range c { + total += countContentBlock(block) + } + return total + default: + if data, err := json.Marshal(content); err == nil { + return len(data) + } + return 0 + } +} + +func countContentBlock(block any) int { + blockMap, ok := block.(map[string]any) + if !ok { + if s, ok := block.(string); ok { + return len(s) + } + return 0 + } + + total := 0 + blockType, _ := blockMap["type"].(string) + + if text, ok := blockMap["text"].(string); ok { + total += len(text) + } + + if thinking, ok := blockMap["thinking"].(string); ok { + total += len(thinking) + } + + if blockType == "tool_use" { + if data, err := json.Marshal(blockMap); err == nil { + total += len(data) + } + } + + if blockType == "tool_result" { + if data, err := json.Marshal(blockMap); err == nil { + total += len(data) + } + } + + if source, ok := blockMap["source"].(map[string]any); ok { + if data, ok := source["data"].(string); ok { + total += len(data) + } + } + + return total +} diff --git a/anthropic/anthropic_test.go b/anthropic/anthropic_test.go index 1c2a4a868..a60327e6a 100755 --- a/anthropic/anthropic_test.go +++ b/anthropic/anthropic_test.go @@ -605,7 +605,7 @@ func TestGenerateMessageID(t *testing.T) { } func TestStreamConverter_Basic(t *testing.T) { - conv := NewStreamConverter("msg_123", "test-model") + conv := NewStreamConverter("msg_123", "test-model", 0) // First chunk resp1 := api.ChatResponse{ @@ -678,7 +678,7 @@ func TestStreamConverter_Basic(t *testing.T) { } func TestStreamConverter_WithToolCalls(t *testing.T) { - conv := NewStreamConverter("msg_123", "test-model") + conv := NewStreamConverter("msg_123", "test-model", 0) resp := api.ChatResponse{ Model: "test-model", @@ -731,7 +731,7 @@ func TestStreamConverter_WithToolCalls(t *testing.T) { func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) { // Test that unmarshalable arguments (like channels) are handled gracefully // and don't cause a panic or corrupt stream - conv := NewStreamConverter("msg_123", "test-model") + conv := NewStreamConverter("msg_123", "test-model", 0) // Create a channel which cannot be JSON marshaled unmarshalable := make(chan int) @@ -778,7 +778,7 @@ func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) { func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) { // Test that valid tool calls still work when mixed with invalid ones - conv := NewStreamConverter("msg_123", "test-model") + conv := NewStreamConverter("msg_123", "test-model", 0) unmarshalable := make(chan int) badArgs := api.NewToolCallFunctionArguments() @@ -903,7 +903,7 @@ func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) { // events include the required empty fields for SDK compatibility. func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) { t.Run("text block start includes empty text", func(t *testing.T) { - conv := NewStreamConverter("msg_123", "test-model") + conv := NewStreamConverter("msg_123", "test-model", 0) resp := api.ChatResponse{ Model: "test-model", @@ -937,7 +937,7 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) { }) t.Run("thinking block start includes empty thinking", func(t *testing.T) { - conv := NewStreamConverter("msg_123", "test-model") + conv := NewStreamConverter("msg_123", "test-model", 0) resp := api.ChatResponse{ Model: "test-model", diff --git a/api/client.go b/api/client.go index d70672a6b..eec720b93 100644 --- a/api/client.go +++ b/api/client.go @@ -466,3 +466,15 @@ func (c *Client) Whoami(ctx context.Context) (*UserResponse, error) { } return &resp, nil } + +// AliasRequest is the request body for creating or updating a model alias. +type AliasRequest struct { + Alias string `json:"alias"` + Target string `json:"target"` + PrefixMatching bool `json:"prefix_matching,omitempty"` +} + +// SetAliasExperimental creates or updates a model alias via the experimental aliases API. +func (c *Client) SetAliasExperimental(ctx context.Context, req *AliasRequest) error { + return c.do(ctx, http.MethodPost, "/api/experimental/aliases", req, nil) +} diff --git a/cmd/config/claude.go b/cmd/config/claude.go index 80a72f564..c26bd6ccb 100644 --- a/cmd/config/claude.go +++ b/cmd/config/claude.go @@ -1,18 +1,23 @@ package config import ( + "context" "fmt" "os" "os/exec" "path/filepath" "runtime" + "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" ) -// Claude implements Runner for Claude Code integration +// Claude implements Runner and AliasConfigurer for Claude Code integration type Claude struct{} +// Compile-time check that Claude implements AliasConfigurer +var _ AliasConfigurer = (*Claude)(nil) + func (c *Claude) String() string { return "Claude Code" } func (c *Claude) args(model string, extra []string) []string { @@ -60,3 +65,96 @@ func (c *Claude) Run(model string, args []string) error { ) return cmd.Run() } + +// ConfigureAliases sets up Primary and Fast model aliases for Claude Code. +func (c *Claude) ConfigureAliases(ctx context.Context, primaryModel string, existing map[string]string, force bool) (map[string]string, bool, error) { + aliases := make(map[string]string) + for k, v := range existing { + aliases[k] = v + } + + if primaryModel != "" { + aliases["primary"] = primaryModel + } + + if !force && aliases["primary"] != "" && aliases["fast"] != "" { + return aliases, false, nil + } + + items, existingModels, cloudModels, client, err := listModels(ctx) + if err != nil { + return nil, false, err + } + + fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n", ansiBold, ansiReset) + fmt.Fprintf(os.Stderr, "%sClaude Code uses multiple models for various tasks%s\n\n", ansiGray, ansiReset) + + fmt.Fprintf(os.Stderr, "%sPrimary%s\n", ansiBold, ansiReset) + fmt.Fprintf(os.Stderr, "%sHandles complex reasoning: planning, code generation, debugging.%s\n\n", ansiGray, ansiReset) + + if aliases["primary"] == "" || force { + primary, err := selectPrompt("Select Primary model:", items) + if err != nil { + return nil, false, err + } + if err := pullIfNeeded(ctx, client, existingModels, primary); err != nil { + return nil, false, err + } + if err := ensureAuth(ctx, client, cloudModels, []string{primary}); err != nil { + return nil, false, err + } + aliases["primary"] = primary + } else { + fmt.Fprintf(os.Stderr, " %s\n\n", aliases["primary"]) + } + + fmt.Fprintf(os.Stderr, "%sFast%s\n", ansiBold, ansiReset) + fmt.Fprintf(os.Stderr, "%sHandles quick operations: file searches, simple edits, status checks.%s\n", ansiGray, ansiReset) + fmt.Fprintf(os.Stderr, "%sSmaller models work well and respond faster.%s\n\n", ansiGray, ansiReset) + + if aliases["fast"] == "" || force { + fast, err := selectPrompt("Select Fast model:", items) + if err != nil { + return nil, false, err + } + if err := pullIfNeeded(ctx, client, existingModels, fast); err != nil { + return nil, false, err + } + if err := ensureAuth(ctx, client, cloudModels, []string{fast}); err != nil { + return nil, false, err + } + aliases["fast"] = fast + } + + return aliases, true, nil +} + +// SetAliases syncs the configured aliases to the Ollama server using prefix matching. +func (c *Claude) SetAliases(ctx context.Context, aliases map[string]string) error { + client, err := api.ClientFromEnvironment() + if err != nil { + return err + } + + prefixAliases := map[string]string{ + "claude-sonnet-": aliases["primary"], + "claude-haiku-": aliases["fast"], + } + + var errs []string + for prefix, target := range prefixAliases { + req := &api.AliasRequest{ + Alias: prefix, + Target: target, + PrefixMatching: true, + } + if err := client.SetAliasExperimental(ctx, req); err != nil { + errs = append(errs, prefix) + } + } + + if len(errs) > 0 { + return fmt.Errorf("failed to set aliases: %v", errs) + } + return nil +} diff --git a/cmd/config/config.go b/cmd/config/config.go index 5f98bd5ed..1c183e7b9 100644 --- a/cmd/config/config.go +++ b/cmd/config/config.go @@ -13,7 +13,8 @@ import ( ) type integration struct { - Models []string `json:"models"` + Models []string `json:"models"` + Aliases map[string]string `json:"aliases,omitempty"` } type config struct { @@ -133,8 +134,16 @@ func saveIntegration(appName string, models []string) error { return err } - cfg.Integrations[strings.ToLower(appName)] = &integration{ - Models: models, + key := strings.ToLower(appName) + existing := cfg.Integrations[key] + var aliases map[string]string + if existing != nil && existing.Aliases != nil { + aliases = existing.Aliases + } + + cfg.Integrations[key] = &integration{ + Models: models, + Aliases: aliases, } return save(cfg) @@ -154,6 +163,33 @@ func loadIntegration(appName string) (*integration, error) { return ic, nil } +func saveAliases(appName string, aliases map[string]string) error { + if appName == "" { + return errors.New("app name cannot be empty") + } + + cfg, err := load() + if err != nil { + return err + } + + key := strings.ToLower(appName) + existing := cfg.Integrations[key] + if existing == nil { + existing = &integration{} + } + + if existing.Aliases == nil { + existing.Aliases = make(map[string]string) + } + for k, v := range aliases { + existing.Aliases[k] = v + } + + cfg.Integrations[key] = existing + return save(cfg) +} + func listIntegrations() ([]integration, error) { cfg, err := load() if err != nil { diff --git a/cmd/config/config_test.go b/cmd/config/config_test.go index ae87c6a40..a491a276f 100644 --- a/cmd/config/config_test.go +++ b/cmd/config/config_test.go @@ -46,6 +46,53 @@ func TestIntegrationConfig(t *testing.T) { } }) + t.Run("save and load aliases", func(t *testing.T) { + models := []string{"llama3.2"} + if err := saveIntegration("claude", models); err != nil { + t.Fatal(err) + } + aliases := map[string]string{ + "primary": "llama3.2:70b", + "fast": "llama3.2:8b", + } + if err := saveAliases("claude", aliases); err != nil { + t.Fatal(err) + } + + config, err := loadIntegration("claude") + if err != nil { + t.Fatal(err) + } + if config.Aliases == nil { + t.Fatal("expected aliases to be saved") + } + for k, v := range aliases { + if config.Aliases[k] != v { + t.Errorf("alias %s: expected %s, got %s", k, v, config.Aliases[k]) + } + } + }) + + t.Run("saveIntegration preserves aliases", func(t *testing.T) { + if err := saveIntegration("claude", []string{"model-a"}); err != nil { + t.Fatal(err) + } + if err := saveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil { + t.Fatal(err) + } + + if err := saveIntegration("claude", []string{"model-b"}); err != nil { + t.Fatal(err) + } + config, err := loadIntegration("claude") + if err != nil { + t.Fatal(err) + } + if config.Aliases["primary"] != "model-a" { + t.Errorf("expected aliases to be preserved, got %v", config.Aliases) + } + }) + t.Run("defaultModel returns first model", func(t *testing.T) { saveIntegration("codex", []string{"model-a", "model-b"}) diff --git a/cmd/config/integrations.go b/cmd/config/integrations.go index 714eae625..6991609e3 100644 --- a/cmd/config/integrations.go +++ b/cmd/config/integrations.go @@ -39,6 +39,15 @@ type Editor interface { Models() []string } +// AliasConfigurer can configure model aliases (e.g., for subagent routing). +// Integrations like Claude and Codex use this to route model requests to local models. +type AliasConfigurer interface { + // ConfigureAliases prompts the user to configure aliases and returns the updated map. + ConfigureAliases(ctx context.Context, primaryModel string, existing map[string]string, force bool) (map[string]string, bool, error) + // SetAliases syncs the configured aliases to the server + SetAliases(ctx context.Context, aliases map[string]string) error +} + // integrations is the registry of available integrations. var integrations = map[string]Runner{ "claude": &Claude{}, @@ -129,7 +138,11 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) { return nil, err } } else { - model, err := selectPrompt(fmt.Sprintf("Select model for %s:", r), items) + prompt := fmt.Sprintf("Select model for %s:", r) + if _, ok := r.(AliasConfigurer); ok { + prompt = fmt.Sprintf("Select Primary model for %s:", r) + } + model, err := selectPrompt(prompt, items) if err != nil { return nil, err } @@ -157,73 +170,146 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) { } } + if err := ensureAuth(ctx, client, cloudModels, selected); err != nil { + return nil, err + } + + return selected, nil +} + +func pullIfNeeded(ctx context.Context, client *api.Client, existingModels map[string]bool, model string) error { + if existingModels[model] { + return nil + } + msg := fmt.Sprintf("Download %s?", model) + if ok, err := confirmPrompt(msg); err != nil { + return err + } else if !ok { + return errCancelled + } + fmt.Fprintf(os.Stderr, "\n") + if err := pullModel(ctx, client, model); err != nil { + return fmt.Errorf("failed to pull %s: %w", model, err) + } + return nil +} + +func listModels(ctx context.Context) ([]selectItem, map[string]bool, map[string]bool, *api.Client, error) { + client, err := api.ClientFromEnvironment() + if err != nil { + return nil, nil, nil, nil, err + } + + models, err := client.List(ctx) + if err != nil { + return nil, nil, nil, nil, err + } + + var existing []modelInfo + for _, m := range models.Models { + existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""}) + } + + items, _, existingModels, cloudModels := buildModelList(existing, nil, "") + + if len(items) == 0 { + return nil, nil, nil, nil, fmt.Errorf("no models available, run 'ollama pull ' first") + } + + return items, existingModels, cloudModels, client, nil +} + +func ensureAuth(ctx context.Context, client *api.Client, cloudModels map[string]bool, selected []string) error { var selectedCloudModels []string for _, m := range selected { if cloudModels[m] { selectedCloudModels = append(selectedCloudModels, m) } } - if len(selectedCloudModels) > 0 { - // ensure user is signed in - user, err := client.Whoami(ctx) - if err == nil && user != nil && user.Name != "" { - return selected, nil - } + if len(selectedCloudModels) == 0 { + return nil + } - var aErr api.AuthorizationError - if !errors.As(err, &aErr) || aErr.SigninURL == "" { - return nil, err - } + user, err := client.Whoami(ctx) + if err == nil && user != nil && user.Name != "" { + return nil + } - modelList := strings.Join(selectedCloudModels, ", ") - yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList)) - if err != nil || !yes { - return nil, fmt.Errorf("%s requires sign in", modelList) - } + var aErr api.AuthorizationError + if !errors.As(err, &aErr) || aErr.SigninURL == "" { + return err + } - fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL) + modelList := strings.Join(selectedCloudModels, ", ") + yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList)) + if err != nil || !yes { + return fmt.Errorf("%s requires sign in", modelList) + } - // TODO(parthsareen): extract into auth package for cmd - // Auto-open browser (best effort, fail silently) - switch runtime.GOOS { - case "darwin": - _ = exec.Command("open", aErr.SigninURL).Start() - case "linux": - _ = exec.Command("xdg-open", aErr.SigninURL).Start() - case "windows": - _ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start() - } + fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL) - spinnerFrames := []string{"|", "/", "-", "\\"} - frame := 0 + switch runtime.GOOS { + case "darwin": + _ = exec.Command("open", aErr.SigninURL).Start() + case "linux": + _ = exec.Command("xdg-open", aErr.SigninURL).Start() + case "windows": + _ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start() + } - fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0]) + spinnerFrames := []string{"|", "/", "-", "\\"} + frame := 0 - ticker := time.NewTicker(200 * time.Millisecond) - defer ticker.Stop() + fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0]) - for { - select { - case <-ctx.Done(): - fmt.Fprintf(os.Stderr, "\r\033[K") - return nil, ctx.Err() - case <-ticker.C: - frame++ - fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)]) + ticker := time.NewTicker(200 * time.Millisecond) + defer ticker.Stop() - // poll every 10th frame (~2 seconds) - if frame%10 == 0 { - u, err := client.Whoami(ctx) - if err == nil && u != nil && u.Name != "" { - fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name) - return selected, nil - } + for { + select { + case <-ctx.Done(): + fmt.Fprintf(os.Stderr, "\r\033[K") + return ctx.Err() + case <-ticker.C: + frame++ + fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)]) + + // poll every 10th frame (~2 seconds) + if frame%10 == 0 { + u, err := client.Whoami(ctx) + if err == nil && u != nil && u.Name != "" { + fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name) + return nil } } } } +} - return selected, nil +func ensureAliases(ctx context.Context, r Runner, name string, primaryModel string, existing map[string]string, force bool) (bool, error) { + ac, ok := r.(AliasConfigurer) + if !ok { + return false, nil + } + + aliases, updated, err := ac.ConfigureAliases(ctx, primaryModel, existing, force) + if err != nil { + return false, err + } + if !updated { + return false, nil + } + + if err := saveAliases(name, aliases); err != nil { + return false, err + } + + if err := ac.SetAliases(ctx, aliases); err != nil { + fmt.Fprintf(os.Stderr, "%sWarning: Could not sync aliases to server: %v%s\n", ansiGray, err, ansiReset) + fmt.Fprintf(os.Stderr, "%sAliases saved locally. Server sync will retry on next launch.%s\n\n", ansiGray, ansiReset) + } + + return true, nil } func runIntegration(name, modelName string, args []string) error { @@ -231,6 +317,17 @@ func runIntegration(name, modelName string, args []string) error { if !ok { return fmt.Errorf("unknown integration: %s", name) } + + if _, ok := r.(AliasConfigurer); ok { + if config, err := loadIntegration(name); err == nil && config.Aliases != nil { + primary, fast := config.Aliases["primary"], config.Aliases["fast"] + if primary != "" && fast != "" { + fmt.Fprintf(os.Stderr, "\nLaunching %s with Primary: %s, Fast: %s...\n", r, primary, fast) + return r.Run(modelName, args) + } + } + } + fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", r, modelName) return r.Run(modelName, args) } @@ -304,10 +401,50 @@ Examples: if !configFlag && modelFlag == "" { if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 { + if _, err := ensureAliases(cmd.Context(), r, name, config.Models[0], config.Aliases, false); errors.Is(err, errCancelled) { + return nil + } else if err != nil { + return err + } return runIntegration(name, config.Models[0], passArgs) } } + if ac, ok := r.(AliasConfigurer); ok { + var existingAliases map[string]string + if existing, err := loadIntegration(name); err == nil { + existingAliases = existing.Aliases + } + aliases, updated, err := ac.ConfigureAliases(cmd.Context(), "", existingAliases, configFlag) + if errors.Is(err, errCancelled) { + return nil + } + if err != nil { + return err + } + if updated { + if err := saveAliases(name, aliases); err != nil { + return err + } + if err := ac.SetAliases(cmd.Context(), aliases); err != nil { + fmt.Fprintf(os.Stderr, "%sWarning: Could not sync aliases to server: %v%s\n", ansiGray, err, ansiReset) + } + fmt.Fprintf(os.Stderr, "\n%sConfiguration Complete%s\n", ansiBold, ansiReset) + fmt.Fprintf(os.Stderr, "Primary: %s\n", aliases["primary"]) + fmt.Fprintf(os.Stderr, "Fast: %s\n\n", aliases["fast"]) + } + if err := saveIntegration(name, []string{aliases["primary"]}); err != nil { + return fmt.Errorf("failed to save: %w", err) + } + if configFlag { + if launch, _ := confirmPrompt(fmt.Sprintf("Launch %s now?", r)); launch { + return runIntegration(name, aliases["primary"], passArgs) + } + return nil + } + return runIntegration(name, aliases["primary"], passArgs) + } + var models []string if modelFlag != "" { models = []string{modelFlag} diff --git a/cmd/config/integrations_test.go b/cmd/config/integrations_test.go index dd2056e98..e4b213e64 100644 --- a/cmd/config/integrations_test.go +++ b/cmd/config/integrations_test.go @@ -509,3 +509,19 @@ func TestBuildModelList_ReturnsExistingAndCloudMaps(t *testing.T) { t.Error("llama3.2 should not be in cloudModels") } } + +func TestAliasConfigurerInterface(t *testing.T) { + t.Run("claude implements AliasConfigurer", func(t *testing.T) { + claude := &Claude{} + if _, ok := interface{}(claude).(AliasConfigurer); !ok { + t.Error("Claude should implement AliasConfigurer") + } + }) + + t.Run("codex does not implement AliasConfigurer", func(t *testing.T) { + codex := &Codex{} + if _, ok := interface{}(codex).(AliasConfigurer); ok { + t.Error("Codex should not implement AliasConfigurer") + } + }) +} diff --git a/cmd/config/selector.go b/cmd/config/selector.go index 956e1f1ea..f617a7c1d 100644 --- a/cmd/config/selector.go +++ b/cmd/config/selector.go @@ -65,6 +65,10 @@ func (s *selectState) handleInput(event inputEvent, char byte) (done bool, resul if len(filtered) > 0 && s.selected < len(filtered) { return true, filtered[s.selected].Name, nil } + // No matches but user typed something - return filter for pull prompt + if len(filtered) == 0 && s.filter != "" { + return true, s.filter, nil + } case eventEscape: return true, "", errCancelled case eventBackspace: @@ -283,7 +287,11 @@ func renderSelect(w io.Writer, prompt string, s *selectState) int { lineCount := 1 if len(filtered) == 0 { - fmt.Fprintf(w, " %s(no matches)%s\r\n", ansiGray, ansiReset) + if s.filter != "" { + fmt.Fprintf(w, " %s→ Download model: '%s'? Press Enter%s\r\n", ansiGray, s.filter, ansiReset) + } else { + fmt.Fprintf(w, " %s(no matches)%s\r\n", ansiGray, ansiReset) + } lineCount++ } else { displayCount := min(len(filtered), maxDisplayedItems) diff --git a/cmd/config/selector_test.go b/cmd/config/selector_test.go index 74e8796ee..a6bd64465 100644 --- a/cmd/config/selector_test.go +++ b/cmd/config/selector_test.go @@ -87,10 +87,18 @@ func TestSelectState(t *testing.T) { } }) - t.Run("Enter_EmptyFilteredList_DoesNothing", func(t *testing.T) { + t.Run("Enter_EmptyFilteredList_ReturnsFilter", func(t *testing.T) { s := newSelectState(items) s.filter = "nonexistent" done, result, err := s.handleInput(eventEnter, 0) + if !done || result != "nonexistent" || err != nil { + t.Errorf("expected (true, 'nonexistent', nil), got (%v, %v, %v)", done, result, err) + } + }) + + t.Run("Enter_EmptyFilteredList_EmptyFilter_DoesNothing", func(t *testing.T) { + s := newSelectState([]selectItem{}) + done, result, err := s.handleInput(eventEnter, 0) if done || result != "" || err != nil { t.Errorf("expected (false, '', nil), got (%v, %v, %v)", done, result, err) } @@ -568,14 +576,25 @@ func TestRenderSelect(t *testing.T) { } }) - t.Run("EmptyFilteredList_ShowsNoMatches", func(t *testing.T) { + t.Run("EmptyFilteredList_ShowsPullPrompt", func(t *testing.T) { s := newSelectState(items) s.filter = "xyz" var buf bytes.Buffer renderSelect(&buf, "Select:", s) + output := buf.String() + if !strings.Contains(output, "Download model: 'xyz'?") { + t.Errorf("expected 'Download model: xyz?' message, got: %s", output) + } + }) + + t.Run("EmptyFilteredList_EmptyFilter_ShowsNoMatches", func(t *testing.T) { + s := newSelectState([]selectItem{}) + var buf bytes.Buffer + renderSelect(&buf, "Select:", s) + if !strings.Contains(buf.String(), "no matches") { - t.Error("expected 'no matches' message") + t.Error("expected 'no matches' message for empty list with no filter") } }) diff --git a/middleware/anthropic.go b/middleware/anthropic.go index ff55b6ebf..5df87a84a 100644 --- a/middleware/anthropic.go +++ b/middleware/anthropic.go @@ -131,12 +131,15 @@ func AnthropicMessagesMiddleware() gin.HandlerFunc { messageID := anthropic.GenerateMessageID() + // Estimate input tokens for streaming (actual count not available until generation completes) + estimatedTokens := anthropic.EstimateInputTokens(req) + w := &AnthropicWriter{ BaseWriter: BaseWriter{ResponseWriter: c.Writer}, stream: req.Stream, id: messageID, model: req.Model, - converter: anthropic.NewStreamConverter(messageID, req.Model), + converter: anthropic.NewStreamConverter(messageID, req.Model, estimatedTokens), } if req.Stream { diff --git a/server/aliases.go b/server/aliases.go new file mode 100644 index 000000000..9757a33fe --- /dev/null +++ b/server/aliases.go @@ -0,0 +1,422 @@ +package server + +import ( + "encoding/json" + "errors" + "fmt" + "log/slog" + "os" + "path/filepath" + "sort" + "strings" + "sync" + + "github.com/ollama/ollama/manifest" + "github.com/ollama/ollama/types/model" +) + +const ( + routerConfigFilename = "server.json" + routerConfigVersion = 1 +) + +var errAliasCycle = errors.New("alias cycle detected") + +type aliasEntry struct { + Alias string `json:"alias"` + Target string `json:"target"` + PrefixMatching bool `json:"prefix_matching,omitempty"` +} + +type routerConfig struct { + Version int `json:"version"` + Aliases []aliasEntry `json:"aliases"` +} + +type aliasStore struct { + mu sync.RWMutex + path string + entries map[string]aliasEntry // normalized alias -> entry (exact matches) + prefixEntries []aliasEntry // prefix matches, sorted longest-first +} + +func newAliasStore(path string) (*aliasStore, error) { + store := &aliasStore{ + path: path, + entries: make(map[string]aliasEntry), + } + if err := store.load(); err != nil { + return nil, err + } + return store, nil +} + +func (s *aliasStore) load() error { + data, err := os.ReadFile(s.path) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil + } + return err + } + + var cfg routerConfig + if err := json.Unmarshal(data, &cfg); err != nil { + return err + } + + if cfg.Version != 0 && cfg.Version != routerConfigVersion { + return fmt.Errorf("unsupported router config version %d", cfg.Version) + } + + for _, entry := range cfg.Aliases { + targetName := model.ParseName(entry.Target) + if !targetName.IsValid() { + slog.Warn("invalid alias target in router config", "target", entry.Target) + continue + } + canonicalTarget := displayAliasName(targetName) + + if entry.PrefixMatching { + // Prefix aliases don't need to be valid model names + alias := strings.TrimSpace(entry.Alias) + if alias == "" { + slog.Warn("empty prefix alias in router config") + continue + } + s.prefixEntries = append(s.prefixEntries, aliasEntry{ + Alias: alias, + Target: canonicalTarget, + PrefixMatching: true, + }) + } else { + aliasName := model.ParseName(entry.Alias) + if !aliasName.IsValid() { + slog.Warn("invalid alias name in router config", "alias", entry.Alias) + continue + } + canonicalAlias := displayAliasName(aliasName) + s.entries[normalizeAliasKey(aliasName)] = aliasEntry{ + Alias: canonicalAlias, + Target: canonicalTarget, + } + } + } + + // Sort prefix entries by alias length descending (longest prefix wins) + s.sortPrefixEntriesLocked() + + return nil +} + +func (s *aliasStore) saveLocked() error { + dir := filepath.Dir(s.path) + if err := os.MkdirAll(dir, 0o755); err != nil { + return err + } + + // Combine exact and prefix entries + entries := make([]aliasEntry, 0, len(s.entries)+len(s.prefixEntries)) + for _, entry := range s.entries { + entries = append(entries, entry) + } + entries = append(entries, s.prefixEntries...) + + sort.Slice(entries, func(i, j int) bool { + return strings.Compare(entries[i].Alias, entries[j].Alias) < 0 + }) + + cfg := routerConfig{ + Version: routerConfigVersion, + Aliases: entries, + } + + f, err := os.CreateTemp(dir, "router-*.json") + if err != nil { + return err + } + + enc := json.NewEncoder(f) + enc.SetIndent("", " ") + if err := enc.Encode(cfg); err != nil { + _ = f.Close() + _ = os.Remove(f.Name()) + return err + } + + if err := f.Close(); err != nil { + _ = os.Remove(f.Name()) + return err + } + + if err := os.Chmod(f.Name(), 0o644); err != nil { + _ = os.Remove(f.Name()) + return err + } + + return os.Rename(f.Name(), s.path) +} + +func (s *aliasStore) ResolveName(name model.Name) (model.Name, bool, error) { + // If a local model exists, do not allow alias shadowing (highest priority). + exists, err := localModelExists(name) + if err != nil { + return name, false, err + } + if exists { + return name, false, nil + } + + key := normalizeAliasKey(name) + + s.mu.RLock() + entry, exactMatch := s.entries[key] + var prefixMatch *aliasEntry + if !exactMatch { + // Try prefix matching - prefixEntries is sorted longest-first + nameStr := strings.ToLower(displayAliasName(name)) + for i := range s.prefixEntries { + prefix := strings.ToLower(s.prefixEntries[i].Alias) + if strings.HasPrefix(nameStr, prefix) { + prefixMatch = &s.prefixEntries[i] + break // First match is longest due to sorting + } + } + } + s.mu.RUnlock() + + if !exactMatch && prefixMatch == nil { + return name, false, nil + } + + var current string + var visited map[string]struct{} + + if exactMatch { + visited = map[string]struct{}{key: {}} + current = entry.Target + } else { + // For prefix match, use the target as-is + visited = map[string]struct{}{} + current = prefixMatch.Target + } + + targetKey := normalizeAliasKeyString(current) + + for { + targetName := model.ParseName(current) + if !targetName.IsValid() { + return name, false, fmt.Errorf("alias target %q is invalid", current) + } + + if _, seen := visited[targetKey]; seen { + return name, false, errAliasCycle + } + visited[targetKey] = struct{}{} + + s.mu.RLock() + next, ok := s.entries[targetKey] + s.mu.RUnlock() + if !ok { + return targetName, true, nil + } + + current = next.Target + targetKey = normalizeAliasKeyString(current) + } +} + +func (s *aliasStore) Set(alias, target model.Name, prefixMatching bool) error { + targetKey := normalizeAliasKey(target) + + s.mu.Lock() + defer s.mu.Unlock() + + if prefixMatching { + // For prefix aliases, we skip cycle detection since prefix matching + // works differently and the target is a specific model + aliasStr := displayAliasName(alias) + + // Remove any existing prefix entry with the same alias + for i, e := range s.prefixEntries { + if strings.EqualFold(e.Alias, aliasStr) { + s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...) + break + } + } + + s.prefixEntries = append(s.prefixEntries, aliasEntry{ + Alias: aliasStr, + Target: displayAliasName(target), + PrefixMatching: true, + }) + s.sortPrefixEntriesLocked() + return s.saveLocked() + } + + aliasKey := normalizeAliasKey(alias) + + if aliasKey == targetKey { + return fmt.Errorf("alias cannot point to itself") + } + + visited := map[string]struct{}{aliasKey: {}} + currentKey := targetKey + for { + if _, seen := visited[currentKey]; seen { + return errAliasCycle + } + visited[currentKey] = struct{}{} + + next, ok := s.entries[currentKey] + if !ok { + break + } + currentKey = normalizeAliasKeyString(next.Target) + } + + s.entries[aliasKey] = aliasEntry{ + Alias: displayAliasName(alias), + Target: displayAliasName(target), + } + + return s.saveLocked() +} + +func (s *aliasStore) Delete(alias model.Name) (bool, error) { + aliasKey := normalizeAliasKey(alias) + + s.mu.Lock() + defer s.mu.Unlock() + + // Try exact match first + if _, ok := s.entries[aliasKey]; ok { + delete(s.entries, aliasKey) + return true, s.saveLocked() + } + + // Try prefix entries + aliasStr := displayAliasName(alias) + for i, e := range s.prefixEntries { + if strings.EqualFold(e.Alias, aliasStr) { + s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...) + return true, s.saveLocked() + } + } + + return false, nil +} + +// DeleteByString deletes an alias by its raw string value, useful for prefix +// aliases that may not be valid model names. +func (s *aliasStore) DeleteByString(alias string) (bool, error) { + alias = strings.TrimSpace(alias) + aliasLower := strings.ToLower(alias) + + s.mu.Lock() + defer s.mu.Unlock() + + // Try prefix entries first (since this is mainly for prefix aliases) + for i, e := range s.prefixEntries { + if strings.EqualFold(e.Alias, alias) { + s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...) + return true, s.saveLocked() + } + } + + // Also check exact entries by normalized key + if _, ok := s.entries[aliasLower]; ok { + delete(s.entries, aliasLower) + return true, s.saveLocked() + } + + return false, nil +} + +func (s *aliasStore) List() []aliasEntry { + s.mu.RLock() + defer s.mu.RUnlock() + + entries := make([]aliasEntry, 0, len(s.entries)+len(s.prefixEntries)) + for _, entry := range s.entries { + entries = append(entries, entry) + } + entries = append(entries, s.prefixEntries...) + + sort.Slice(entries, func(i, j int) bool { + return strings.Compare(entries[i].Alias, entries[j].Alias) < 0 + }) + return entries +} + +func normalizeAliasKey(name model.Name) string { + return strings.ToLower(displayAliasName(name)) +} + +func (s *aliasStore) sortPrefixEntriesLocked() { + sort.Slice(s.prefixEntries, func(i, j int) bool { + // Sort by length descending (longest prefix first) + return len(s.prefixEntries[i].Alias) > len(s.prefixEntries[j].Alias) + }) +} + +func normalizeAliasKeyString(value string) string { + n := model.ParseName(value) + if !n.IsValid() { + return strings.ToLower(strings.TrimSpace(value)) + } + return normalizeAliasKey(n) +} + +func displayAliasName(n model.Name) string { + display := n.DisplayShortest() + if strings.EqualFold(n.Tag, "latest") { + if idx := strings.LastIndex(display, ":"); idx != -1 { + return display[:idx] + } + } + return display +} + +func localModelExists(name model.Name) (bool, error) { + manifests, err := manifest.Manifests(true) + if err != nil { + return false, err + } + needle := name.String() + for existing := range manifests { + if strings.EqualFold(existing.String(), needle) { + return true, nil + } + } + return false, nil +} + +func routerConfigPath() string { + home, err := os.UserHomeDir() + if err != nil { + return filepath.Join(".ollama", routerConfigFilename) + } + return filepath.Join(home, ".ollama", routerConfigFilename) +} + +func (s *Server) aliasStore() (*aliasStore, error) { + s.aliasesOnce.Do(func() { + s.aliases, s.aliasesErr = newAliasStore(routerConfigPath()) + }) + + return s.aliases, s.aliasesErr +} + +func (s *Server) resolveModelAliasName(name model.Name) (model.Name, bool, error) { + store, err := s.aliasStore() + if err != nil { + return name, false, err + } + + if store == nil { + return name, false, nil + } + + return store.ResolveName(name) +} diff --git a/server/routes.go b/server/routes.go index 910b8e954..34c1350a7 100644 --- a/server/routes.go +++ b/server/routes.go @@ -22,6 +22,7 @@ import ( "os/signal" "slices" "strings" + "sync" "sync/atomic" "syscall" "time" @@ -81,6 +82,9 @@ type Server struct { addr net.Addr sched *Scheduler defaultNumCtx int + aliasesOnce sync.Once + aliases *aliasStore + aliasesErr error } func init() { @@ -191,9 +195,16 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } + resolvedName, _, err := s.resolveModelAliasName(name) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + name = resolvedName + // We cannot currently consolidate this into GetModel because all we'll // induce infinite recursion given the current code structure. - name, err := getExistingName(name) + name, err = getExistingName(name) if err != nil { c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) return @@ -1580,6 +1591,9 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) { r.POST("/api/blobs/:digest", s.CreateBlobHandler) r.HEAD("/api/blobs/:digest", s.HeadBlobHandler) r.POST("/api/copy", s.CopyHandler) + r.GET("/api/experimental/aliases", s.ListAliasesHandler) + r.POST("/api/experimental/aliases", s.CreateAliasHandler) + r.DELETE("/api/experimental/aliases", s.DeleteAliasHandler) // Inference r.GET("/api/ps", s.PsHandler) @@ -1950,13 +1964,20 @@ func (s *Server) ChatHandler(c *gin.Context) { return } - name, err := getExistingName(name) + resolvedName, _, err := s.resolveModelAliasName(name) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + name = resolvedName + + name, err = getExistingName(name) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"}) return } - m, err := GetModel(req.Model) + m, err := GetModel(name.String()) if err != nil { switch { case os.IsNotExist(err): diff --git a/server/routes_aliases.go b/server/routes_aliases.go new file mode 100644 index 000000000..d68514e9c --- /dev/null +++ b/server/routes_aliases.go @@ -0,0 +1,159 @@ +package server + +import ( + "errors" + "fmt" + "io" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + + "github.com/ollama/ollama/types/model" +) + +type aliasListResponse struct { + Aliases []aliasEntry `json:"aliases"` +} + +type aliasDeleteRequest struct { + Alias string `json:"alias"` +} + +func (s *Server) ListAliasesHandler(c *gin.Context) { + store, err := s.aliasStore() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + var aliases []aliasEntry + if store != nil { + aliases = store.List() + } + + c.JSON(http.StatusOK, aliasListResponse{Aliases: aliases}) +} + +func (s *Server) CreateAliasHandler(c *gin.Context) { + var req aliasEntry + if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) + return + } else if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + req.Alias = strings.TrimSpace(req.Alias) + req.Target = strings.TrimSpace(req.Target) + if req.Alias == "" || req.Target == "" { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias and target are required"}) + return + } + + // Target must always be a valid model name + targetName := model.ParseName(req.Target) + if !targetName.IsValid() { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("target %q is invalid", req.Target)}) + return + } + + var aliasName model.Name + if req.PrefixMatching { + // For prefix aliases, we still parse the alias to normalize it, + // but we allow any non-empty string since prefix patterns may not be valid model names + aliasName = model.ParseName(req.Alias) + // Even if not valid as a model name, we accept it for prefix matching + } else { + aliasName = model.ParseName(req.Alias) + if !aliasName.IsValid() { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("alias %q is invalid", req.Alias)}) + return + } + + if normalizeAliasKey(aliasName) == normalizeAliasKey(targetName) { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias cannot point to itself"}) + return + } + + exists, err := localModelExists(aliasName) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if exists { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("alias %q conflicts with existing model", req.Alias)}) + return + } + } + + store, err := s.aliasStore() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if err := store.Set(aliasName, targetName, req.PrefixMatching); err != nil { + status := http.StatusInternalServerError + if errors.Is(err, errAliasCycle) { + status = http.StatusBadRequest + } + c.AbortWithStatusJSON(status, gin.H{"error": err.Error()}) + return + } + + resp := aliasEntry{ + Alias: displayAliasName(aliasName), + Target: displayAliasName(targetName), + PrefixMatching: req.PrefixMatching, + } + if req.PrefixMatching && !aliasName.IsValid() { + // For prefix aliases that aren't valid model names, use the raw alias + resp.Alias = req.Alias + } + c.JSON(http.StatusOK, resp) +} + +func (s *Server) DeleteAliasHandler(c *gin.Context) { + var req aliasDeleteRequest + if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) + return + } else if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + req.Alias = strings.TrimSpace(req.Alias) + if req.Alias == "" { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias is required"}) + return + } + + store, err := s.aliasStore() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + aliasName := model.ParseName(req.Alias) + var deleted bool + if aliasName.IsValid() { + deleted, err = store.Delete(aliasName) + } else { + // For invalid model names (like prefix aliases), try deleting by raw string + deleted, err = store.DeleteByString(req.Alias) + } + + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if !deleted { + c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("alias %q not found", req.Alias)}) + return + } + + c.JSON(http.StatusOK, gin.H{"deleted": true}) +} diff --git a/server/routes_aliases_test.go b/server/routes_aliases_test.go new file mode 100644 index 000000000..f4cfb4be7 --- /dev/null +++ b/server/routes_aliases_test.go @@ -0,0 +1,426 @@ +package server + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "path/filepath" + "testing" + + "github.com/gin-gonic/gin" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/types/model" +) + +func TestAliasShadowingRejected(t *testing.T) { + gin.SetMode(gin.TestMode) + t.Setenv("HOME", t.TempDir()) + + s := Server{} + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Model: "shadowed-model", + RemoteHost: "example.com", + From: "test", + Info: map[string]any{ + "capabilities": []string{"completion"}, + }, + Stream: &stream, + }) + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + w = createRequest(t, s.CreateAliasHandler, aliasEntry{Alias: "shadowed-model", Target: "other-model"}) + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status 400, got %d", w.Code) + } +} + +func TestAliasResolvesForChatRemote(t *testing.T) { + gin.SetMode(gin.TestMode) + t.Setenv("HOME", t.TempDir()) + + var remoteModel string + rs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req api.ChatRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Fatal(err) + } + remoteModel = req.Model + + w.Header().Set("Content-Type", "application/json") + resp := api.ChatResponse{ + Model: req.Model, + Done: true, + DoneReason: "load", + } + if err := json.NewEncoder(w).Encode(&resp); err != nil { + t.Fatal(err) + } + })) + defer rs.Close() + + p, err := url.Parse(rs.URL) + if err != nil { + t.Fatal(err) + } + + t.Setenv("OLLAMA_REMOTES", p.Hostname()) + + s := Server{} + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Model: "target-model", + RemoteHost: rs.URL, + From: "test", + Info: map[string]any{ + "capabilities": []string{"completion"}, + }, + Stream: &stream, + }) + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + w = createRequest(t, s.CreateAliasHandler, aliasEntry{Alias: "alias-model", Target: "target-model"}) + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + w = createRequest(t, s.ChatHandler, api.ChatRequest{ + Model: "alias-model", + Messages: []api.Message{{Role: "user", Content: "hi"}}, + Stream: &stream, + }) + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + var resp api.ChatResponse + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatal(err) + } + + if resp.Model != "alias-model" { + t.Fatalf("expected response model to be alias-model, got %q", resp.Model) + } + + if remoteModel != "test" { + t.Fatalf("expected remote model to be 'test', got %q", remoteModel) + } +} + +func TestPrefixAliasBasicMatching(t *testing.T) { + tmpDir := t.TempDir() + store, err := newAliasStore(filepath.Join(tmpDir, "server.json")) + if err != nil { + t.Fatal(err) + } + + // Create a prefix alias: "myprefix-" -> "targetmodel" + targetName := model.ParseName("targetmodel") + + // Set a prefix alias (using "myprefix-" as the pattern) + store.mu.Lock() + store.prefixEntries = append(store.prefixEntries, aliasEntry{ + Alias: "myprefix-", + Target: "targetmodel", + PrefixMatching: true, + }) + store.mu.Unlock() + + // Test that "myprefix-foo" resolves to "targetmodel" + testName := model.ParseName("myprefix-foo") + resolved, wasResolved, err := store.ResolveName(testName) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !wasResolved { + t.Fatal("expected name to be resolved") + } + if resolved.DisplayShortest() != targetName.DisplayShortest() { + t.Fatalf("expected resolved name to be %q, got %q", targetName.DisplayShortest(), resolved.DisplayShortest()) + } + + // Test that "otherprefix-foo" does not resolve + otherName := model.ParseName("otherprefix-foo") + _, wasResolved, err = store.ResolveName(otherName) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if wasResolved { + t.Fatal("expected name not to be resolved") + } + + // Test that exact alias takes precedence + exactAlias := model.ParseName("myprefix-exact") + exactTarget := model.ParseName("exacttarget") + if err := store.Set(exactAlias, exactTarget, false); err != nil { + t.Fatal(err) + } + + resolved, wasResolved, err = store.ResolveName(exactAlias) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !wasResolved { + t.Fatal("expected name to be resolved") + } + if resolved.DisplayShortest() != exactTarget.DisplayShortest() { + t.Fatalf("expected resolved name to be %q (exact match), got %q", exactTarget.DisplayShortest(), resolved.DisplayShortest()) + } +} + +func TestPrefixAliasLongestMatchWins(t *testing.T) { + tmpDir := t.TempDir() + store, err := newAliasStore(filepath.Join(tmpDir, "server.json")) + if err != nil { + t.Fatal(err) + } + + // Add two prefix aliases with overlapping patterns + store.mu.Lock() + store.prefixEntries = []aliasEntry{ + {Alias: "abc-", Target: "short-target", PrefixMatching: true}, + {Alias: "abc-def-", Target: "long-target", PrefixMatching: true}, + } + store.sortPrefixEntriesLocked() + store.mu.Unlock() + + // "abc-def-ghi" should match the longer prefix "abc-def-" + testName := model.ParseName("abc-def-ghi") + resolved, wasResolved, err := store.ResolveName(testName) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !wasResolved { + t.Fatal("expected name to be resolved") + } + expectedLongTarget := model.ParseName("long-target") + if resolved.DisplayShortest() != expectedLongTarget.DisplayShortest() { + t.Fatalf("expected resolved name to be %q (longest prefix match), got %q", expectedLongTarget.DisplayShortest(), resolved.DisplayShortest()) + } + + // "abc-xyz" should match the shorter prefix "abc-" + testName2 := model.ParseName("abc-xyz") + resolved, wasResolved, err = store.ResolveName(testName2) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !wasResolved { + t.Fatal("expected name to be resolved") + } + expectedShortTarget := model.ParseName("short-target") + if resolved.DisplayShortest() != expectedShortTarget.DisplayShortest() { + t.Fatalf("expected resolved name to be %q, got %q", expectedShortTarget.DisplayShortest(), resolved.DisplayShortest()) + } +} + +func TestPrefixAliasChain(t *testing.T) { + tmpDir := t.TempDir() + store, err := newAliasStore(filepath.Join(tmpDir, "server.json")) + if err != nil { + t.Fatal(err) + } + + // Create a chain: prefix "test-" -> "intermediate" -> "final" + intermediate := model.ParseName("intermediate") + final := model.ParseName("final") + + // Add prefix alias + store.mu.Lock() + store.prefixEntries = []aliasEntry{ + {Alias: "test-", Target: "intermediate", PrefixMatching: true}, + } + store.mu.Unlock() + + // Add exact alias for the intermediate step + if err := store.Set(intermediate, final, false); err != nil { + t.Fatal(err) + } + + // "test-foo" should resolve through the chain to "final" + testName := model.ParseName("test-foo") + resolved, wasResolved, err := store.ResolveName(testName) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !wasResolved { + t.Fatal("expected name to be resolved") + } + if resolved.DisplayShortest() != final.DisplayShortest() { + t.Fatalf("expected resolved name to be %q, got %q", final.DisplayShortest(), resolved.DisplayShortest()) + } +} + +func TestPrefixAliasCRUD(t *testing.T) { + gin.SetMode(gin.TestMode) + t.Setenv("HOME", t.TempDir()) + + s := Server{} + + // Create a prefix alias via API + w := createRequest(t, s.CreateAliasHandler, aliasEntry{ + Alias: "myprefix-", + Target: "llama2", + PrefixMatching: true, + }) + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + var createResp aliasEntry + if err := json.NewDecoder(w.Body).Decode(&createResp); err != nil { + t.Fatal(err) + } + if !createResp.PrefixMatching { + t.Fatal("expected prefix_matching to be true in response") + } + + // List aliases and verify the prefix alias is included + w = createRequest(t, s.ListAliasesHandler, nil) + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + var listResp aliasListResponse + if err := json.NewDecoder(w.Body).Decode(&listResp); err != nil { + t.Fatal(err) + } + + found := false + for _, a := range listResp.Aliases { + if a.PrefixMatching && a.Target == "llama2" { + found = true + break + } + } + if !found { + t.Fatal("expected to find prefix alias in list") + } + + // Delete the prefix alias + w = createRequest(t, s.DeleteAliasHandler, aliasDeleteRequest{Alias: "myprefix-"}) + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + // Verify it's deleted + w = createRequest(t, s.ListAliasesHandler, nil) + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + if err := json.NewDecoder(w.Body).Decode(&listResp); err != nil { + t.Fatal(err) + } + + for _, a := range listResp.Aliases { + if a.PrefixMatching { + t.Fatal("expected prefix alias to be deleted") + } + } +} + +func TestPrefixAliasCaseInsensitive(t *testing.T) { + tmpDir := t.TempDir() + store, err := newAliasStore(filepath.Join(tmpDir, "server.json")) + if err != nil { + t.Fatal(err) + } + + // Add a prefix alias with mixed case + store.mu.Lock() + store.prefixEntries = []aliasEntry{ + {Alias: "MyPrefix-", Target: "targetmodel", PrefixMatching: true}, + } + store.mu.Unlock() + + // Test that matching is case-insensitive + testName := model.ParseName("myprefix-foo") + resolved, wasResolved, err := store.ResolveName(testName) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !wasResolved { + t.Fatal("expected name to be resolved (case-insensitive)") + } + expectedTarget := model.ParseName("targetmodel") + if resolved.DisplayShortest() != expectedTarget.DisplayShortest() { + t.Fatalf("expected resolved name to be %q, got %q", expectedTarget.DisplayShortest(), resolved.DisplayShortest()) + } + + // Test uppercase request + testName2 := model.ParseName("MYPREFIX-BAR") + _, wasResolved, err = store.ResolveName(testName2) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !wasResolved { + t.Fatal("expected name to be resolved (uppercase)") + } +} + +func TestPrefixAliasLocalModelPrecedence(t *testing.T) { + gin.SetMode(gin.TestMode) + t.Setenv("HOME", t.TempDir()) + + s := Server{} + + // Create a local model that would match a prefix alias + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Model: "myprefix-localmodel", + RemoteHost: "example.com", + From: "test", + Info: map[string]any{ + "capabilities": []string{"completion"}, + }, + Stream: &stream, + }) + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + // Create a prefix alias that would match the local model name + w = createRequest(t, s.CreateAliasHandler, aliasEntry{ + Alias: "myprefix-", + Target: "someothermodel", + PrefixMatching: true, + }) + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + // Verify that resolving "myprefix-localmodel" returns the local model, not the alias target + store, err := s.aliasStore() + if err != nil { + t.Fatal(err) + } + + localModelName := model.ParseName("myprefix-localmodel") + resolved, wasResolved, err := store.ResolveName(localModelName) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if wasResolved { + t.Fatalf("expected local model to take precedence (wasResolved should be false), but got resolved to %q", resolved.DisplayShortest()) + } + if resolved.DisplayShortest() != localModelName.DisplayShortest() { + t.Fatalf("expected resolved name to be local model %q, got %q", localModelName.DisplayShortest(), resolved.DisplayShortest()) + } + + // Also verify that a non-local model matching the prefix DOES resolve to the alias target + nonLocalName := model.ParseName("myprefix-nonexistent") + resolved, wasResolved, err = store.ResolveName(nonLocalName) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !wasResolved { + t.Fatal("expected non-local model to resolve via prefix alias") + } + expectedTarget := model.ParseName("someothermodel") + if resolved.DisplayShortest() != expectedTarget.DisplayShortest() { + t.Fatalf("expected resolved name to be %q, got %q", expectedTarget.DisplayShortest(), resolved.DisplayShortest()) + } +}