mirror of
https://github.com/ollama/ollama.git
synced 2026-02-04 12:42:58 -05:00
Compare commits
2 Commits
v0.15.5-rc
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cefabd79a8 | ||
|
|
df70249520 |
@@ -518,26 +518,24 @@ func mapStopReason(reason string, hasToolCalls bool) string {
|
||||
|
||||
// StreamConverter manages state for converting Ollama streaming responses to Anthropic format
|
||||
type StreamConverter struct {
|
||||
ID string
|
||||
Model string
|
||||
firstWrite bool
|
||||
contentIndex int
|
||||
inputTokens int
|
||||
outputTokens int
|
||||
estimatedInputTokens int // Estimated tokens from request (used when actual metrics are 0)
|
||||
thinkingStarted bool
|
||||
thinkingDone bool
|
||||
textStarted bool
|
||||
toolCallsSent map[string]bool
|
||||
ID string
|
||||
Model string
|
||||
firstWrite bool
|
||||
contentIndex int
|
||||
inputTokens int
|
||||
outputTokens int
|
||||
thinkingStarted bool
|
||||
thinkingDone bool
|
||||
textStarted bool
|
||||
toolCallsSent map[string]bool
|
||||
}
|
||||
|
||||
func NewStreamConverter(id, model string, estimatedInputTokens int) *StreamConverter {
|
||||
func NewStreamConverter(id, model string) *StreamConverter {
|
||||
return &StreamConverter{
|
||||
ID: id,
|
||||
Model: model,
|
||||
firstWrite: true,
|
||||
estimatedInputTokens: estimatedInputTokens,
|
||||
toolCallsSent: make(map[string]bool),
|
||||
ID: id,
|
||||
Model: model,
|
||||
firstWrite: true,
|
||||
toolCallsSent: make(map[string]bool),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -553,11 +551,7 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
||||
|
||||
if c.firstWrite {
|
||||
c.firstWrite = false
|
||||
// Use actual metrics if available, otherwise use estimate
|
||||
c.inputTokens = r.Metrics.PromptEvalCount
|
||||
if c.inputTokens == 0 && c.estimatedInputTokens > 0 {
|
||||
c.inputTokens = c.estimatedInputTokens
|
||||
}
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "message_start",
|
||||
@@ -785,121 +779,3 @@ func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
// CountTokensRequest represents an Anthropic count_tokens request
|
||||
type CountTokensRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []MessageParam `json:"messages"`
|
||||
System any `json:"system,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
// EstimateInputTokens estimates input tokens from a MessagesRequest (reuses CountTokensRequest logic)
|
||||
func EstimateInputTokens(req MessagesRequest) int {
|
||||
return estimateTokens(CountTokensRequest{
|
||||
Model: req.Model,
|
||||
Messages: req.Messages,
|
||||
System: req.System,
|
||||
Tools: req.Tools,
|
||||
Thinking: req.Thinking,
|
||||
})
|
||||
}
|
||||
|
||||
// CountTokensResponse represents an Anthropic count_tokens response
|
||||
type CountTokensResponse struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
}
|
||||
|
||||
// estimateTokens returns a rough estimate of tokens (len/4)
|
||||
func estimateTokens(req CountTokensRequest) int {
|
||||
var totalLen int
|
||||
|
||||
// Count system prompt
|
||||
if req.System != nil {
|
||||
totalLen += countAnyContent(req.System)
|
||||
}
|
||||
|
||||
// Count messages
|
||||
for _, msg := range req.Messages {
|
||||
// Count role (always present)
|
||||
totalLen += len(msg.Role)
|
||||
// Count content
|
||||
contentLen := countAnyContent(msg.Content)
|
||||
totalLen += contentLen
|
||||
}
|
||||
|
||||
for _, tool := range req.Tools {
|
||||
totalLen += len(tool.Name) + len(tool.Description) + len(tool.InputSchema)
|
||||
}
|
||||
|
||||
// Return len/4 as rough token estimate, minimum 1 if there's any content
|
||||
tokens := totalLen / 4
|
||||
if tokens == 0 && (len(req.Messages) > 0 || req.System != nil) {
|
||||
tokens = 1
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
func countAnyContent(content any) int {
|
||||
if content == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
switch c := content.(type) {
|
||||
case string:
|
||||
return len(c)
|
||||
case []any:
|
||||
total := 0
|
||||
for _, block := range c {
|
||||
total += countContentBlock(block)
|
||||
}
|
||||
return total
|
||||
default:
|
||||
if data, err := json.Marshal(content); err == nil {
|
||||
return len(data)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func countContentBlock(block any) int {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
if s, ok := block.(string); ok {
|
||||
return len(s)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
total := 0
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
|
||||
if text, ok := blockMap["text"].(string); ok {
|
||||
total += len(text)
|
||||
}
|
||||
|
||||
if thinking, ok := blockMap["thinking"].(string); ok {
|
||||
total += len(thinking)
|
||||
}
|
||||
|
||||
if blockType == "tool_use" {
|
||||
if data, err := json.Marshal(blockMap); err == nil {
|
||||
total += len(data)
|
||||
}
|
||||
}
|
||||
|
||||
if blockType == "tool_result" {
|
||||
if data, err := json.Marshal(blockMap); err == nil {
|
||||
total += len(data)
|
||||
}
|
||||
}
|
||||
|
||||
if source, ok := blockMap["source"].(map[string]any); ok {
|
||||
if data, ok := source["data"].(string); ok {
|
||||
total += len(data)
|
||||
}
|
||||
}
|
||||
|
||||
return total
|
||||
}
|
||||
|
||||
@@ -605,7 +605,7 @@ func TestGenerateMessageID(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStreamConverter_Basic(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
// First chunk
|
||||
resp1 := api.ChatResponse{
|
||||
@@ -678,7 +678,7 @@ func TestStreamConverter_Basic(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStreamConverter_WithToolCalls(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
@@ -731,7 +731,7 @@ func TestStreamConverter_WithToolCalls(t *testing.T) {
|
||||
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
|
||||
// Test that unmarshalable arguments (like channels) are handled gracefully
|
||||
// and don't cause a panic or corrupt stream
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
// Create a channel which cannot be JSON marshaled
|
||||
unmarshalable := make(chan int)
|
||||
@@ -778,7 +778,7 @@ func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
|
||||
|
||||
func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
|
||||
// Test that valid tool calls still work when mixed with invalid ones
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
unmarshalable := make(chan int)
|
||||
badArgs := api.NewToolCallFunctionArguments()
|
||||
@@ -903,7 +903,7 @@ func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
|
||||
// events include the required empty fields for SDK compatibility.
|
||||
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
||||
t.Run("text block start includes empty text", func(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
@@ -937,7 +937,7 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("thinking block start includes empty thinking", func(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
|
||||
@@ -466,15 +466,3 @@ func (c *Client) Whoami(ctx context.Context) (*UserResponse, error) {
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// AliasRequest is the request body for creating or updating a model alias.
|
||||
type AliasRequest struct {
|
||||
Alias string `json:"alias"`
|
||||
Target string `json:"target"`
|
||||
PrefixMatching bool `json:"prefix_matching,omitempty"`
|
||||
}
|
||||
|
||||
// SetAliasExperimental creates or updates a model alias via the experimental aliases API.
|
||||
func (c *Client) SetAliasExperimental(ctx context.Context, req *AliasRequest) error {
|
||||
return c.do(ctx, http.MethodPost, "/api/experimental/aliases", req, nil)
|
||||
}
|
||||
|
||||
@@ -1,23 +1,18 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Claude implements Runner and AliasConfigurer for Claude Code integration
|
||||
// Claude implements Runner for Claude Code integration
|
||||
type Claude struct{}
|
||||
|
||||
// Compile-time check that Claude implements AliasConfigurer
|
||||
var _ AliasConfigurer = (*Claude)(nil)
|
||||
|
||||
func (c *Claude) String() string { return "Claude Code" }
|
||||
|
||||
func (c *Claude) args(model string, extra []string) []string {
|
||||
@@ -65,96 +60,3 @@ func (c *Claude) Run(model string, args []string) error {
|
||||
)
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// ConfigureAliases sets up Primary and Fast model aliases for Claude Code.
|
||||
func (c *Claude) ConfigureAliases(ctx context.Context, primaryModel string, existing map[string]string, force bool) (map[string]string, bool, error) {
|
||||
aliases := make(map[string]string)
|
||||
for k, v := range existing {
|
||||
aliases[k] = v
|
||||
}
|
||||
|
||||
if primaryModel != "" {
|
||||
aliases["primary"] = primaryModel
|
||||
}
|
||||
|
||||
if !force && aliases["primary"] != "" && aliases["fast"] != "" {
|
||||
return aliases, false, nil
|
||||
}
|
||||
|
||||
items, existingModels, cloudModels, client, err := listModels(ctx)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n", ansiBold, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "%sClaude Code uses multiple models for various tasks%s\n\n", ansiGray, ansiReset)
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%sPrimary%s\n", ansiBold, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "%sHandles complex reasoning: planning, code generation, debugging.%s\n\n", ansiGray, ansiReset)
|
||||
|
||||
if aliases["primary"] == "" || force {
|
||||
primary, err := selectPrompt("Select Primary model:", items)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if err := pullIfNeeded(ctx, client, existingModels, primary); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if err := ensureAuth(ctx, client, cloudModels, []string{primary}); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
aliases["primary"] = primary
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, " %s\n\n", aliases["primary"])
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%sFast%s\n", ansiBold, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "%sHandles quick operations: file searches, simple edits, status checks.%s\n", ansiGray, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "%sSmaller models work well and respond faster.%s\n\n", ansiGray, ansiReset)
|
||||
|
||||
if aliases["fast"] == "" || force {
|
||||
fast, err := selectPrompt("Select Fast model:", items)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if err := pullIfNeeded(ctx, client, existingModels, fast); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if err := ensureAuth(ctx, client, cloudModels, []string{fast}); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
aliases["fast"] = fast
|
||||
}
|
||||
|
||||
return aliases, true, nil
|
||||
}
|
||||
|
||||
// SetAliases syncs the configured aliases to the Ollama server using prefix matching.
|
||||
func (c *Claude) SetAliases(ctx context.Context, aliases map[string]string) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
prefixAliases := map[string]string{
|
||||
"claude-sonnet-": aliases["primary"],
|
||||
"claude-haiku-": aliases["fast"],
|
||||
}
|
||||
|
||||
var errs []string
|
||||
for prefix, target := range prefixAliases {
|
||||
req := &api.AliasRequest{
|
||||
Alias: prefix,
|
||||
Target: target,
|
||||
PrefixMatching: true,
|
||||
}
|
||||
if err := client.SetAliasExperimental(ctx, req); err != nil {
|
||||
errs = append(errs, prefix)
|
||||
}
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("failed to set aliases: %v", errs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -13,8 +13,7 @@ import (
|
||||
)
|
||||
|
||||
type integration struct {
|
||||
Models []string `json:"models"`
|
||||
Aliases map[string]string `json:"aliases,omitempty"`
|
||||
Models []string `json:"models"`
|
||||
}
|
||||
|
||||
type config struct {
|
||||
@@ -134,16 +133,8 @@ func saveIntegration(appName string, models []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
key := strings.ToLower(appName)
|
||||
existing := cfg.Integrations[key]
|
||||
var aliases map[string]string
|
||||
if existing != nil && existing.Aliases != nil {
|
||||
aliases = existing.Aliases
|
||||
}
|
||||
|
||||
cfg.Integrations[key] = &integration{
|
||||
Models: models,
|
||||
Aliases: aliases,
|
||||
cfg.Integrations[strings.ToLower(appName)] = &integration{
|
||||
Models: models,
|
||||
}
|
||||
|
||||
return save(cfg)
|
||||
@@ -163,33 +154,6 @@ func loadIntegration(appName string) (*integration, error) {
|
||||
return ic, nil
|
||||
}
|
||||
|
||||
func saveAliases(appName string, aliases map[string]string) error {
|
||||
if appName == "" {
|
||||
return errors.New("app name cannot be empty")
|
||||
}
|
||||
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
key := strings.ToLower(appName)
|
||||
existing := cfg.Integrations[key]
|
||||
if existing == nil {
|
||||
existing = &integration{}
|
||||
}
|
||||
|
||||
if existing.Aliases == nil {
|
||||
existing.Aliases = make(map[string]string)
|
||||
}
|
||||
for k, v := range aliases {
|
||||
existing.Aliases[k] = v
|
||||
}
|
||||
|
||||
cfg.Integrations[key] = existing
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
func listIntegrations() ([]integration, error) {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
|
||||
@@ -46,53 +46,6 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("save and load aliases", func(t *testing.T) {
|
||||
models := []string{"llama3.2"}
|
||||
if err := saveIntegration("claude", models); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
aliases := map[string]string{
|
||||
"primary": "llama3.2:70b",
|
||||
"fast": "llama3.2:8b",
|
||||
}
|
||||
if err := saveAliases("claude", aliases); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if config.Aliases == nil {
|
||||
t.Fatal("expected aliases to be saved")
|
||||
}
|
||||
for k, v := range aliases {
|
||||
if config.Aliases[k] != v {
|
||||
t.Errorf("alias %s: expected %s, got %s", k, v, config.Aliases[k])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("saveIntegration preserves aliases", func(t *testing.T) {
|
||||
if err := saveIntegration("claude", []string{"model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := saveIntegration("claude", []string{"model-b"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
config, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if config.Aliases["primary"] != "model-a" {
|
||||
t.Errorf("expected aliases to be preserved, got %v", config.Aliases)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("defaultModel returns first model", func(t *testing.T) {
|
||||
saveIntegration("codex", []string{"model-a", "model-b"})
|
||||
|
||||
|
||||
@@ -39,15 +39,6 @@ type Editor interface {
|
||||
Models() []string
|
||||
}
|
||||
|
||||
// AliasConfigurer can configure model aliases (e.g., for subagent routing).
|
||||
// Integrations like Claude and Codex use this to route model requests to local models.
|
||||
type AliasConfigurer interface {
|
||||
// ConfigureAliases prompts the user to configure aliases and returns the updated map.
|
||||
ConfigureAliases(ctx context.Context, primaryModel string, existing map[string]string, force bool) (map[string]string, bool, error)
|
||||
// SetAliases syncs the configured aliases to the server
|
||||
SetAliases(ctx context.Context, aliases map[string]string) error
|
||||
}
|
||||
|
||||
// integrations is the registry of available integrations.
|
||||
var integrations = map[string]Runner{
|
||||
"claude": &Claude{},
|
||||
@@ -138,11 +129,7 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
prompt := fmt.Sprintf("Select model for %s:", r)
|
||||
if _, ok := r.(AliasConfigurer); ok {
|
||||
prompt = fmt.Sprintf("Select Primary model for %s:", r)
|
||||
}
|
||||
model, err := selectPrompt(prompt, items)
|
||||
model, err := selectPrompt(fmt.Sprintf("Select model for %s:", r), items)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -170,146 +157,73 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if err := ensureAuth(ctx, client, cloudModels, selected); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
func pullIfNeeded(ctx context.Context, client *api.Client, existingModels map[string]bool, model string) error {
|
||||
if existingModels[model] {
|
||||
return nil
|
||||
}
|
||||
msg := fmt.Sprintf("Download %s?", model)
|
||||
if ok, err := confirmPrompt(msg); err != nil {
|
||||
return err
|
||||
} else if !ok {
|
||||
return errCancelled
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
if err := pullModel(ctx, client, model); err != nil {
|
||||
return fmt.Errorf("failed to pull %s: %w", model, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func listModels(ctx context.Context) ([]selectItem, map[string]bool, map[string]bool, *api.Client, error) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
|
||||
models, err := client.List(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
|
||||
var existing []modelInfo
|
||||
for _, m := range models.Models {
|
||||
existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""})
|
||||
}
|
||||
|
||||
items, _, existingModels, cloudModels := buildModelList(existing, nil, "")
|
||||
|
||||
if len(items) == 0 {
|
||||
return nil, nil, nil, nil, fmt.Errorf("no models available, run 'ollama pull <model>' first")
|
||||
}
|
||||
|
||||
return items, existingModels, cloudModels, client, nil
|
||||
}
|
||||
|
||||
func ensureAuth(ctx context.Context, client *api.Client, cloudModels map[string]bool, selected []string) error {
|
||||
var selectedCloudModels []string
|
||||
for _, m := range selected {
|
||||
if cloudModels[m] {
|
||||
selectedCloudModels = append(selectedCloudModels, m)
|
||||
}
|
||||
}
|
||||
if len(selectedCloudModels) == 0 {
|
||||
return nil
|
||||
}
|
||||
if len(selectedCloudModels) > 0 {
|
||||
// ensure user is signed in
|
||||
user, err := client.Whoami(ctx)
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
user, err := client.Whoami(ctx)
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return nil
|
||||
}
|
||||
var aErr api.AuthorizationError
|
||||
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var aErr api.AuthorizationError
|
||||
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
|
||||
return err
|
||||
}
|
||||
modelList := strings.Join(selectedCloudModels, ", ")
|
||||
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
|
||||
if err != nil || !yes {
|
||||
return nil, fmt.Errorf("%s requires sign in", modelList)
|
||||
}
|
||||
|
||||
modelList := strings.Join(selectedCloudModels, ", ")
|
||||
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
|
||||
if err != nil || !yes {
|
||||
return fmt.Errorf("%s requires sign in", modelList)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
|
||||
// TODO(parthsareen): extract into auth package for cmd
|
||||
// Auto-open browser (best effort, fail silently)
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
_ = exec.Command("open", aErr.SigninURL).Start()
|
||||
case "linux":
|
||||
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
|
||||
case "windows":
|
||||
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
_ = exec.Command("open", aErr.SigninURL).Start()
|
||||
case "linux":
|
||||
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
|
||||
case "windows":
|
||||
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
|
||||
}
|
||||
spinnerFrames := []string{"|", "/", "-", "\\"}
|
||||
frame := 0
|
||||
|
||||
spinnerFrames := []string{"|", "/", "-", "\\"}
|
||||
frame := 0
|
||||
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
|
||||
ticker := time.NewTicker(200 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
ticker := time.NewTicker(200 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
return nil, ctx.Err()
|
||||
case <-ticker.C:
|
||||
frame++
|
||||
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
frame++
|
||||
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
|
||||
|
||||
// poll every 10th frame (~2 seconds)
|
||||
if frame%10 == 0 {
|
||||
u, err := client.Whoami(ctx)
|
||||
if err == nil && u != nil && u.Name != "" {
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
|
||||
return nil
|
||||
// poll every 10th frame (~2 seconds)
|
||||
if frame%10 == 0 {
|
||||
u, err := client.Whoami(ctx)
|
||||
if err == nil && u != nil && u.Name != "" {
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
|
||||
return selected, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ensureAliases(ctx context.Context, r Runner, name string, primaryModel string, existing map[string]string, force bool) (bool, error) {
|
||||
ac, ok := r.(AliasConfigurer)
|
||||
if !ok {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
aliases, updated, err := ac.ConfigureAliases(ctx, primaryModel, existing, force)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if !updated {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if err := saveAliases(name, aliases); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if err := ac.SetAliases(ctx, aliases); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%sWarning: Could not sync aliases to server: %v%s\n", ansiGray, err, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "%sAliases saved locally. Server sync will retry on next launch.%s\n\n", ansiGray, ansiReset)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
func runIntegration(name, modelName string, args []string) error {
|
||||
@@ -317,17 +231,6 @@ func runIntegration(name, modelName string, args []string) error {
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
|
||||
if _, ok := r.(AliasConfigurer); ok {
|
||||
if config, err := loadIntegration(name); err == nil && config.Aliases != nil {
|
||||
primary, fast := config.Aliases["primary"], config.Aliases["fast"]
|
||||
if primary != "" && fast != "" {
|
||||
fmt.Fprintf(os.Stderr, "\nLaunching %s with Primary: %s, Fast: %s...\n", r, primary, fast)
|
||||
return r.Run(modelName, args)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", r, modelName)
|
||||
return r.Run(modelName, args)
|
||||
}
|
||||
@@ -401,50 +304,10 @@ Examples:
|
||||
|
||||
if !configFlag && modelFlag == "" {
|
||||
if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 {
|
||||
if _, err := ensureAliases(cmd.Context(), r, name, config.Models[0], config.Aliases, false); errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
return runIntegration(name, config.Models[0], passArgs)
|
||||
}
|
||||
}
|
||||
|
||||
if ac, ok := r.(AliasConfigurer); ok {
|
||||
var existingAliases map[string]string
|
||||
if existing, err := loadIntegration(name); err == nil {
|
||||
existingAliases = existing.Aliases
|
||||
}
|
||||
aliases, updated, err := ac.ConfigureAliases(cmd.Context(), "", existingAliases, configFlag)
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if updated {
|
||||
if err := saveAliases(name, aliases); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ac.SetAliases(cmd.Context(), aliases); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%sWarning: Could not sync aliases to server: %v%s\n", ansiGray, err, ansiReset)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\n%sConfiguration Complete%s\n", ansiBold, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "Primary: %s\n", aliases["primary"])
|
||||
fmt.Fprintf(os.Stderr, "Fast: %s\n\n", aliases["fast"])
|
||||
}
|
||||
if err := saveIntegration(name, []string{aliases["primary"]}); err != nil {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
if configFlag {
|
||||
if launch, _ := confirmPrompt(fmt.Sprintf("Launch %s now?", r)); launch {
|
||||
return runIntegration(name, aliases["primary"], passArgs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return runIntegration(name, aliases["primary"], passArgs)
|
||||
}
|
||||
|
||||
var models []string
|
||||
if modelFlag != "" {
|
||||
models = []string{modelFlag}
|
||||
|
||||
@@ -509,19 +509,3 @@ func TestBuildModelList_ReturnsExistingAndCloudMaps(t *testing.T) {
|
||||
t.Error("llama3.2 should not be in cloudModels")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAliasConfigurerInterface(t *testing.T) {
|
||||
t.Run("claude implements AliasConfigurer", func(t *testing.T) {
|
||||
claude := &Claude{}
|
||||
if _, ok := interface{}(claude).(AliasConfigurer); !ok {
|
||||
t.Error("Claude should implement AliasConfigurer")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("codex does not implement AliasConfigurer", func(t *testing.T) {
|
||||
codex := &Codex{}
|
||||
if _, ok := interface{}(codex).(AliasConfigurer); ok {
|
||||
t.Error("Codex should not implement AliasConfigurer")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -65,10 +65,6 @@ func (s *selectState) handleInput(event inputEvent, char byte) (done bool, resul
|
||||
if len(filtered) > 0 && s.selected < len(filtered) {
|
||||
return true, filtered[s.selected].Name, nil
|
||||
}
|
||||
// No matches but user typed something - return filter for pull prompt
|
||||
if len(filtered) == 0 && s.filter != "" {
|
||||
return true, s.filter, nil
|
||||
}
|
||||
case eventEscape:
|
||||
return true, "", errCancelled
|
||||
case eventBackspace:
|
||||
@@ -287,11 +283,7 @@ func renderSelect(w io.Writer, prompt string, s *selectState) int {
|
||||
lineCount := 1
|
||||
|
||||
if len(filtered) == 0 {
|
||||
if s.filter != "" {
|
||||
fmt.Fprintf(w, " %s→ Download model: '%s'? Press Enter%s\r\n", ansiGray, s.filter, ansiReset)
|
||||
} else {
|
||||
fmt.Fprintf(w, " %s(no matches)%s\r\n", ansiGray, ansiReset)
|
||||
}
|
||||
fmt.Fprintf(w, " %s(no matches)%s\r\n", ansiGray, ansiReset)
|
||||
lineCount++
|
||||
} else {
|
||||
displayCount := min(len(filtered), maxDisplayedItems)
|
||||
|
||||
@@ -87,18 +87,10 @@ func TestSelectState(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_EmptyFilteredList_ReturnsFilter", func(t *testing.T) {
|
||||
t.Run("Enter_EmptyFilteredList_DoesNothing", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.filter = "nonexistent"
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
if !done || result != "nonexistent" || err != nil {
|
||||
t.Errorf("expected (true, 'nonexistent', nil), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_EmptyFilteredList_EmptyFilter_DoesNothing", func(t *testing.T) {
|
||||
s := newSelectState([]selectItem{})
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
if done || result != "" || err != nil {
|
||||
t.Errorf("expected (false, '', nil), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
@@ -576,25 +568,14 @@ func TestRenderSelect(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EmptyFilteredList_ShowsPullPrompt", func(t *testing.T) {
|
||||
t.Run("EmptyFilteredList_ShowsNoMatches", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.filter = "xyz"
|
||||
var buf bytes.Buffer
|
||||
renderSelect(&buf, "Select:", s)
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "Download model: 'xyz'?") {
|
||||
t.Errorf("expected 'Download model: xyz?' message, got: %s", output)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EmptyFilteredList_EmptyFilter_ShowsNoMatches", func(t *testing.T) {
|
||||
s := newSelectState([]selectItem{})
|
||||
var buf bytes.Buffer
|
||||
renderSelect(&buf, "Select:", s)
|
||||
|
||||
if !strings.Contains(buf.String(), "no matches") {
|
||||
t.Error("expected 'no matches' message for empty list with no filter")
|
||||
t.Error("expected 'no matches' message")
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@@ -131,15 +131,12 @@ func AnthropicMessagesMiddleware() gin.HandlerFunc {
|
||||
|
||||
messageID := anthropic.GenerateMessageID()
|
||||
|
||||
// Estimate input tokens for streaming (actual count not available until generation completes)
|
||||
estimatedTokens := anthropic.EstimateInputTokens(req)
|
||||
|
||||
w := &AnthropicWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: messageID,
|
||||
model: req.Model,
|
||||
converter: anthropic.NewStreamConverter(messageID, req.Model, estimatedTokens),
|
||||
converter: anthropic.NewStreamConverter(messageID, req.Model),
|
||||
}
|
||||
|
||||
if req.Stream {
|
||||
|
||||
@@ -1,422 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
const (
|
||||
routerConfigFilename = "server.json"
|
||||
routerConfigVersion = 1
|
||||
)
|
||||
|
||||
var errAliasCycle = errors.New("alias cycle detected")
|
||||
|
||||
type aliasEntry struct {
|
||||
Alias string `json:"alias"`
|
||||
Target string `json:"target"`
|
||||
PrefixMatching bool `json:"prefix_matching,omitempty"`
|
||||
}
|
||||
|
||||
type routerConfig struct {
|
||||
Version int `json:"version"`
|
||||
Aliases []aliasEntry `json:"aliases"`
|
||||
}
|
||||
|
||||
type aliasStore struct {
|
||||
mu sync.RWMutex
|
||||
path string
|
||||
entries map[string]aliasEntry // normalized alias -> entry (exact matches)
|
||||
prefixEntries []aliasEntry // prefix matches, sorted longest-first
|
||||
}
|
||||
|
||||
func newAliasStore(path string) (*aliasStore, error) {
|
||||
store := &aliasStore{
|
||||
path: path,
|
||||
entries: make(map[string]aliasEntry),
|
||||
}
|
||||
if err := store.load(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func (s *aliasStore) load() error {
|
||||
data, err := os.ReadFile(s.path)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
var cfg routerConfig
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if cfg.Version != 0 && cfg.Version != routerConfigVersion {
|
||||
return fmt.Errorf("unsupported router config version %d", cfg.Version)
|
||||
}
|
||||
|
||||
for _, entry := range cfg.Aliases {
|
||||
targetName := model.ParseName(entry.Target)
|
||||
if !targetName.IsValid() {
|
||||
slog.Warn("invalid alias target in router config", "target", entry.Target)
|
||||
continue
|
||||
}
|
||||
canonicalTarget := displayAliasName(targetName)
|
||||
|
||||
if entry.PrefixMatching {
|
||||
// Prefix aliases don't need to be valid model names
|
||||
alias := strings.TrimSpace(entry.Alias)
|
||||
if alias == "" {
|
||||
slog.Warn("empty prefix alias in router config")
|
||||
continue
|
||||
}
|
||||
s.prefixEntries = append(s.prefixEntries, aliasEntry{
|
||||
Alias: alias,
|
||||
Target: canonicalTarget,
|
||||
PrefixMatching: true,
|
||||
})
|
||||
} else {
|
||||
aliasName := model.ParseName(entry.Alias)
|
||||
if !aliasName.IsValid() {
|
||||
slog.Warn("invalid alias name in router config", "alias", entry.Alias)
|
||||
continue
|
||||
}
|
||||
canonicalAlias := displayAliasName(aliasName)
|
||||
s.entries[normalizeAliasKey(aliasName)] = aliasEntry{
|
||||
Alias: canonicalAlias,
|
||||
Target: canonicalTarget,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort prefix entries by alias length descending (longest prefix wins)
|
||||
s.sortPrefixEntriesLocked()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *aliasStore) saveLocked() error {
|
||||
dir := filepath.Dir(s.path)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Combine exact and prefix entries
|
||||
entries := make([]aliasEntry, 0, len(s.entries)+len(s.prefixEntries))
|
||||
for _, entry := range s.entries {
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
entries = append(entries, s.prefixEntries...)
|
||||
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return strings.Compare(entries[i].Alias, entries[j].Alias) < 0
|
||||
})
|
||||
|
||||
cfg := routerConfig{
|
||||
Version: routerConfigVersion,
|
||||
Aliases: entries,
|
||||
}
|
||||
|
||||
f, err := os.CreateTemp(dir, "router-*.json")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
enc := json.NewEncoder(f)
|
||||
enc.SetIndent("", " ")
|
||||
if err := enc.Encode(cfg); err != nil {
|
||||
_ = f.Close()
|
||||
_ = os.Remove(f.Name())
|
||||
return err
|
||||
}
|
||||
|
||||
if err := f.Close(); err != nil {
|
||||
_ = os.Remove(f.Name())
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.Chmod(f.Name(), 0o644); err != nil {
|
||||
_ = os.Remove(f.Name())
|
||||
return err
|
||||
}
|
||||
|
||||
return os.Rename(f.Name(), s.path)
|
||||
}
|
||||
|
||||
func (s *aliasStore) ResolveName(name model.Name) (model.Name, bool, error) {
|
||||
// If a local model exists, do not allow alias shadowing (highest priority).
|
||||
exists, err := localModelExists(name)
|
||||
if err != nil {
|
||||
return name, false, err
|
||||
}
|
||||
if exists {
|
||||
return name, false, nil
|
||||
}
|
||||
|
||||
key := normalizeAliasKey(name)
|
||||
|
||||
s.mu.RLock()
|
||||
entry, exactMatch := s.entries[key]
|
||||
var prefixMatch *aliasEntry
|
||||
if !exactMatch {
|
||||
// Try prefix matching - prefixEntries is sorted longest-first
|
||||
nameStr := strings.ToLower(displayAliasName(name))
|
||||
for i := range s.prefixEntries {
|
||||
prefix := strings.ToLower(s.prefixEntries[i].Alias)
|
||||
if strings.HasPrefix(nameStr, prefix) {
|
||||
prefixMatch = &s.prefixEntries[i]
|
||||
break // First match is longest due to sorting
|
||||
}
|
||||
}
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exactMatch && prefixMatch == nil {
|
||||
return name, false, nil
|
||||
}
|
||||
|
||||
var current string
|
||||
var visited map[string]struct{}
|
||||
|
||||
if exactMatch {
|
||||
visited = map[string]struct{}{key: {}}
|
||||
current = entry.Target
|
||||
} else {
|
||||
// For prefix match, use the target as-is
|
||||
visited = map[string]struct{}{}
|
||||
current = prefixMatch.Target
|
||||
}
|
||||
|
||||
targetKey := normalizeAliasKeyString(current)
|
||||
|
||||
for {
|
||||
targetName := model.ParseName(current)
|
||||
if !targetName.IsValid() {
|
||||
return name, false, fmt.Errorf("alias target %q is invalid", current)
|
||||
}
|
||||
|
||||
if _, seen := visited[targetKey]; seen {
|
||||
return name, false, errAliasCycle
|
||||
}
|
||||
visited[targetKey] = struct{}{}
|
||||
|
||||
s.mu.RLock()
|
||||
next, ok := s.entries[targetKey]
|
||||
s.mu.RUnlock()
|
||||
if !ok {
|
||||
return targetName, true, nil
|
||||
}
|
||||
|
||||
current = next.Target
|
||||
targetKey = normalizeAliasKeyString(current)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *aliasStore) Set(alias, target model.Name, prefixMatching bool) error {
|
||||
targetKey := normalizeAliasKey(target)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if prefixMatching {
|
||||
// For prefix aliases, we skip cycle detection since prefix matching
|
||||
// works differently and the target is a specific model
|
||||
aliasStr := displayAliasName(alias)
|
||||
|
||||
// Remove any existing prefix entry with the same alias
|
||||
for i, e := range s.prefixEntries {
|
||||
if strings.EqualFold(e.Alias, aliasStr) {
|
||||
s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
s.prefixEntries = append(s.prefixEntries, aliasEntry{
|
||||
Alias: aliasStr,
|
||||
Target: displayAliasName(target),
|
||||
PrefixMatching: true,
|
||||
})
|
||||
s.sortPrefixEntriesLocked()
|
||||
return s.saveLocked()
|
||||
}
|
||||
|
||||
aliasKey := normalizeAliasKey(alias)
|
||||
|
||||
if aliasKey == targetKey {
|
||||
return fmt.Errorf("alias cannot point to itself")
|
||||
}
|
||||
|
||||
visited := map[string]struct{}{aliasKey: {}}
|
||||
currentKey := targetKey
|
||||
for {
|
||||
if _, seen := visited[currentKey]; seen {
|
||||
return errAliasCycle
|
||||
}
|
||||
visited[currentKey] = struct{}{}
|
||||
|
||||
next, ok := s.entries[currentKey]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
currentKey = normalizeAliasKeyString(next.Target)
|
||||
}
|
||||
|
||||
s.entries[aliasKey] = aliasEntry{
|
||||
Alias: displayAliasName(alias),
|
||||
Target: displayAliasName(target),
|
||||
}
|
||||
|
||||
return s.saveLocked()
|
||||
}
|
||||
|
||||
func (s *aliasStore) Delete(alias model.Name) (bool, error) {
|
||||
aliasKey := normalizeAliasKey(alias)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Try exact match first
|
||||
if _, ok := s.entries[aliasKey]; ok {
|
||||
delete(s.entries, aliasKey)
|
||||
return true, s.saveLocked()
|
||||
}
|
||||
|
||||
// Try prefix entries
|
||||
aliasStr := displayAliasName(alias)
|
||||
for i, e := range s.prefixEntries {
|
||||
if strings.EqualFold(e.Alias, aliasStr) {
|
||||
s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...)
|
||||
return true, s.saveLocked()
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// DeleteByString deletes an alias by its raw string value, useful for prefix
|
||||
// aliases that may not be valid model names.
|
||||
func (s *aliasStore) DeleteByString(alias string) (bool, error) {
|
||||
alias = strings.TrimSpace(alias)
|
||||
aliasLower := strings.ToLower(alias)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Try prefix entries first (since this is mainly for prefix aliases)
|
||||
for i, e := range s.prefixEntries {
|
||||
if strings.EqualFold(e.Alias, alias) {
|
||||
s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...)
|
||||
return true, s.saveLocked()
|
||||
}
|
||||
}
|
||||
|
||||
// Also check exact entries by normalized key
|
||||
if _, ok := s.entries[aliasLower]; ok {
|
||||
delete(s.entries, aliasLower)
|
||||
return true, s.saveLocked()
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (s *aliasStore) List() []aliasEntry {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
entries := make([]aliasEntry, 0, len(s.entries)+len(s.prefixEntries))
|
||||
for _, entry := range s.entries {
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
entries = append(entries, s.prefixEntries...)
|
||||
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return strings.Compare(entries[i].Alias, entries[j].Alias) < 0
|
||||
})
|
||||
return entries
|
||||
}
|
||||
|
||||
func normalizeAliasKey(name model.Name) string {
|
||||
return strings.ToLower(displayAliasName(name))
|
||||
}
|
||||
|
||||
func (s *aliasStore) sortPrefixEntriesLocked() {
|
||||
sort.Slice(s.prefixEntries, func(i, j int) bool {
|
||||
// Sort by length descending (longest prefix first)
|
||||
return len(s.prefixEntries[i].Alias) > len(s.prefixEntries[j].Alias)
|
||||
})
|
||||
}
|
||||
|
||||
func normalizeAliasKeyString(value string) string {
|
||||
n := model.ParseName(value)
|
||||
if !n.IsValid() {
|
||||
return strings.ToLower(strings.TrimSpace(value))
|
||||
}
|
||||
return normalizeAliasKey(n)
|
||||
}
|
||||
|
||||
func displayAliasName(n model.Name) string {
|
||||
display := n.DisplayShortest()
|
||||
if strings.EqualFold(n.Tag, "latest") {
|
||||
if idx := strings.LastIndex(display, ":"); idx != -1 {
|
||||
return display[:idx]
|
||||
}
|
||||
}
|
||||
return display
|
||||
}
|
||||
|
||||
func localModelExists(name model.Name) (bool, error) {
|
||||
manifests, err := manifest.Manifests(true)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
needle := name.String()
|
||||
for existing := range manifests {
|
||||
if strings.EqualFold(existing.String(), needle) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func routerConfigPath() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return filepath.Join(".ollama", routerConfigFilename)
|
||||
}
|
||||
return filepath.Join(home, ".ollama", routerConfigFilename)
|
||||
}
|
||||
|
||||
func (s *Server) aliasStore() (*aliasStore, error) {
|
||||
s.aliasesOnce.Do(func() {
|
||||
s.aliases, s.aliasesErr = newAliasStore(routerConfigPath())
|
||||
})
|
||||
|
||||
return s.aliases, s.aliasesErr
|
||||
}
|
||||
|
||||
func (s *Server) resolveModelAliasName(name model.Name) (model.Name, bool, error) {
|
||||
store, err := s.aliasStore()
|
||||
if err != nil {
|
||||
return name, false, err
|
||||
}
|
||||
|
||||
if store == nil {
|
||||
return name, false, nil
|
||||
}
|
||||
|
||||
return store.ResolveName(name)
|
||||
}
|
||||
@@ -27,14 +27,12 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
// Clip images are represented as 768 tokens, each an embedding
|
||||
imageNumTokens := 768
|
||||
|
||||
n := len(msgs) - 1
|
||||
// in reverse, find all messages that fit into context window
|
||||
for i := n; i >= 0; i-- {
|
||||
// always include the last message
|
||||
if i == n {
|
||||
continue
|
||||
}
|
||||
lastMsgIdx := len(msgs) - 1
|
||||
currMsgIdx := 0
|
||||
|
||||
// Start with all messages and remove from the front until it fits in context
|
||||
for i := 0; i <= lastMsgIdx; i++ {
|
||||
// Collect system messages from the portion we're about to skip
|
||||
system = make([]api.Message, 0)
|
||||
for j := range i {
|
||||
if msgs[j].Role == "system" {
|
||||
@@ -54,20 +52,26 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
|
||||
ctxLen := len(s)
|
||||
if m.ProjectorPaths != nil {
|
||||
for _, m := range msgs[i:] {
|
||||
ctxLen += imageNumTokens * len(m.Images)
|
||||
for _, msg := range msgs[i:] {
|
||||
ctxLen += imageNumTokens * len(msg.Images)
|
||||
}
|
||||
}
|
||||
|
||||
if truncate && ctxLen > opts.NumCtx {
|
||||
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
|
||||
if !truncate || ctxLen <= opts.NumCtx {
|
||||
currMsgIdx = i
|
||||
break
|
||||
}
|
||||
|
||||
// Must always include at least the last message
|
||||
if i == lastMsgIdx {
|
||||
currMsgIdx = lastMsgIdx
|
||||
break
|
||||
} else {
|
||||
n = i
|
||||
}
|
||||
}
|
||||
|
||||
currMsgIdx := n
|
||||
if currMsgIdx > 0 {
|
||||
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[currMsgIdx:]))
|
||||
}
|
||||
|
||||
for cnt, msg := range msgs[currMsgIdx:] {
|
||||
if slices.Contains(m.Config.ModelFamilies, "mllama") && len(msg.Images) > 1 {
|
||||
|
||||
@@ -2,6 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
@@ -264,3 +265,68 @@ func TestChatPrompt(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatPromptTokenizeCalls(t *testing.T) {
|
||||
tmpl, err := template.Parse(`
|
||||
{{- if .System }}{{ .System }} {{ end }}
|
||||
{{- if .Prompt }}{{ .Prompt }} {{ end }}
|
||||
{{- if .Response }}{{ .Response }} {{ end }}`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
model := Model{Template: tmpl}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
limit int
|
||||
msgs []api.Message
|
||||
maxTokenizes int
|
||||
}{
|
||||
{
|
||||
name: "all messages fit",
|
||||
limit: 2048,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "message 1"},
|
||||
{Role: "assistant", Content: "response 1"},
|
||||
{Role: "user", Content: "message 2"},
|
||||
{Role: "assistant", Content: "response 2"},
|
||||
{Role: "user", Content: "message 3"},
|
||||
},
|
||||
maxTokenizes: 1,
|
||||
},
|
||||
{
|
||||
name: "truncate to last message",
|
||||
limit: 5,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "message 1"},
|
||||
{Role: "assistant", Content: "response 1"},
|
||||
{Role: "user", Content: "message 2"},
|
||||
{Role: "assistant", Content: "response 2"},
|
||||
{Role: "user", Content: "message 3"},
|
||||
},
|
||||
maxTokenizes: 5,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tokenizeCount := 0
|
||||
countingTokenize := func(ctx context.Context, s string) ([]int, error) {
|
||||
tokenizeCount++
|
||||
tokens, err := mockRunner{}.Tokenize(ctx, s)
|
||||
return tokens, err
|
||||
}
|
||||
|
||||
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
||||
think := false
|
||||
_, _, err := chatPrompt(t.Context(), &model, countingTokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think}, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if tokenizeCount > tt.maxTokenizes {
|
||||
t.Errorf("tokenize called %d times, expected at most %d", tokenizeCount, tt.maxTokenizes)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,7 +22,6 @@ import (
|
||||
"os/signal"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
@@ -82,9 +81,6 @@ type Server struct {
|
||||
addr net.Addr
|
||||
sched *Scheduler
|
||||
defaultNumCtx int
|
||||
aliasesOnce sync.Once
|
||||
aliases *aliasStore
|
||||
aliasesErr error
|
||||
}
|
||||
|
||||
func init() {
|
||||
@@ -195,16 +191,9 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
resolvedName, _, err := s.resolveModelAliasName(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
name = resolvedName
|
||||
|
||||
// We cannot currently consolidate this into GetModel because all we'll
|
||||
// induce infinite recursion given the current code structure.
|
||||
name, err = getExistingName(name)
|
||||
name, err := getExistingName(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
return
|
||||
@@ -1591,9 +1580,6 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
|
||||
r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
|
||||
r.POST("/api/copy", s.CopyHandler)
|
||||
r.GET("/api/experimental/aliases", s.ListAliasesHandler)
|
||||
r.POST("/api/experimental/aliases", s.CreateAliasHandler)
|
||||
r.DELETE("/api/experimental/aliases", s.DeleteAliasHandler)
|
||||
|
||||
// Inference
|
||||
r.GET("/api/ps", s.PsHandler)
|
||||
@@ -1964,20 +1950,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
resolvedName, _, err := s.resolveModelAliasName(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
name = resolvedName
|
||||
|
||||
name, err = getExistingName(name)
|
||||
name, err := getExistingName(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
return
|
||||
}
|
||||
|
||||
m, err := GetModel(name.String())
|
||||
m, err := GetModel(req.Model)
|
||||
if err != nil {
|
||||
switch {
|
||||
case os.IsNotExist(err):
|
||||
|
||||
@@ -1,159 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
type aliasListResponse struct {
|
||||
Aliases []aliasEntry `json:"aliases"`
|
||||
}
|
||||
|
||||
type aliasDeleteRequest struct {
|
||||
Alias string `json:"alias"`
|
||||
}
|
||||
|
||||
func (s *Server) ListAliasesHandler(c *gin.Context) {
|
||||
store, err := s.aliasStore()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
var aliases []aliasEntry
|
||||
if store != nil {
|
||||
aliases = store.List()
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, aliasListResponse{Aliases: aliases})
|
||||
}
|
||||
|
||||
func (s *Server) CreateAliasHandler(c *gin.Context) {
|
||||
var req aliasEntry
|
||||
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
} else if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
req.Alias = strings.TrimSpace(req.Alias)
|
||||
req.Target = strings.TrimSpace(req.Target)
|
||||
if req.Alias == "" || req.Target == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias and target are required"})
|
||||
return
|
||||
}
|
||||
|
||||
// Target must always be a valid model name
|
||||
targetName := model.ParseName(req.Target)
|
||||
if !targetName.IsValid() {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("target %q is invalid", req.Target)})
|
||||
return
|
||||
}
|
||||
|
||||
var aliasName model.Name
|
||||
if req.PrefixMatching {
|
||||
// For prefix aliases, we still parse the alias to normalize it,
|
||||
// but we allow any non-empty string since prefix patterns may not be valid model names
|
||||
aliasName = model.ParseName(req.Alias)
|
||||
// Even if not valid as a model name, we accept it for prefix matching
|
||||
} else {
|
||||
aliasName = model.ParseName(req.Alias)
|
||||
if !aliasName.IsValid() {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("alias %q is invalid", req.Alias)})
|
||||
return
|
||||
}
|
||||
|
||||
if normalizeAliasKey(aliasName) == normalizeAliasKey(targetName) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias cannot point to itself"})
|
||||
return
|
||||
}
|
||||
|
||||
exists, err := localModelExists(aliasName)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if exists {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("alias %q conflicts with existing model", req.Alias)})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
store, err := s.aliasStore()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := store.Set(aliasName, targetName, req.PrefixMatching); err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
if errors.Is(err, errAliasCycle) {
|
||||
status = http.StatusBadRequest
|
||||
}
|
||||
c.AbortWithStatusJSON(status, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
resp := aliasEntry{
|
||||
Alias: displayAliasName(aliasName),
|
||||
Target: displayAliasName(targetName),
|
||||
PrefixMatching: req.PrefixMatching,
|
||||
}
|
||||
if req.PrefixMatching && !aliasName.IsValid() {
|
||||
// For prefix aliases that aren't valid model names, use the raw alias
|
||||
resp.Alias = req.Alias
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (s *Server) DeleteAliasHandler(c *gin.Context) {
|
||||
var req aliasDeleteRequest
|
||||
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
} else if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
req.Alias = strings.TrimSpace(req.Alias)
|
||||
if req.Alias == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias is required"})
|
||||
return
|
||||
}
|
||||
|
||||
store, err := s.aliasStore()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
aliasName := model.ParseName(req.Alias)
|
||||
var deleted bool
|
||||
if aliasName.IsValid() {
|
||||
deleted, err = store.Delete(aliasName)
|
||||
} else {
|
||||
// For invalid model names (like prefix aliases), try deleting by raw string
|
||||
deleted, err = store.DeleteByString(req.Alias)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if !deleted {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("alias %q not found", req.Alias)})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"deleted": true})
|
||||
}
|
||||
@@ -1,426 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func TestAliasShadowingRejected(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Setenv("HOME", t.TempDir())
|
||||
|
||||
s := Server{}
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "shadowed-model",
|
||||
RemoteHost: "example.com",
|
||||
From: "test",
|
||||
Info: map[string]any{
|
||||
"capabilities": []string{"completion"},
|
||||
},
|
||||
Stream: &stream,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
w = createRequest(t, s.CreateAliasHandler, aliasEntry{Alias: "shadowed-model", Target: "other-model"})
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected status 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAliasResolvesForChatRemote(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Setenv("HOME", t.TempDir())
|
||||
|
||||
var remoteModel string
|
||||
rs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var req api.ChatRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
remoteModel = req.Model
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
resp := api.ChatResponse{
|
||||
Model: req.Model,
|
||||
Done: true,
|
||||
DoneReason: "load",
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}))
|
||||
defer rs.Close()
|
||||
|
||||
p, err := url.Parse(rs.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_REMOTES", p.Hostname())
|
||||
|
||||
s := Server{}
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "target-model",
|
||||
RemoteHost: rs.URL,
|
||||
From: "test",
|
||||
Info: map[string]any{
|
||||
"capabilities": []string{"completion"},
|
||||
},
|
||||
Stream: &stream,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
w = createRequest(t, s.CreateAliasHandler, aliasEntry{Alias: "alias-model", Target: "target-model"})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
w = createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "alias-model",
|
||||
Messages: []api.Message{{Role: "user", Content: "hi"}},
|
||||
Stream: &stream,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp api.ChatResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if resp.Model != "alias-model" {
|
||||
t.Fatalf("expected response model to be alias-model, got %q", resp.Model)
|
||||
}
|
||||
|
||||
if remoteModel != "test" {
|
||||
t.Fatalf("expected remote model to be 'test', got %q", remoteModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrefixAliasBasicMatching(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
store, err := newAliasStore(filepath.Join(tmpDir, "server.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a prefix alias: "myprefix-" -> "targetmodel"
|
||||
targetName := model.ParseName("targetmodel")
|
||||
|
||||
// Set a prefix alias (using "myprefix-" as the pattern)
|
||||
store.mu.Lock()
|
||||
store.prefixEntries = append(store.prefixEntries, aliasEntry{
|
||||
Alias: "myprefix-",
|
||||
Target: "targetmodel",
|
||||
PrefixMatching: true,
|
||||
})
|
||||
store.mu.Unlock()
|
||||
|
||||
// Test that "myprefix-foo" resolves to "targetmodel"
|
||||
testName := model.ParseName("myprefix-foo")
|
||||
resolved, wasResolved, err := store.ResolveName(testName)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected name to be resolved")
|
||||
}
|
||||
if resolved.DisplayShortest() != targetName.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be %q, got %q", targetName.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
|
||||
// Test that "otherprefix-foo" does not resolve
|
||||
otherName := model.ParseName("otherprefix-foo")
|
||||
_, wasResolved, err = store.ResolveName(otherName)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if wasResolved {
|
||||
t.Fatal("expected name not to be resolved")
|
||||
}
|
||||
|
||||
// Test that exact alias takes precedence
|
||||
exactAlias := model.ParseName("myprefix-exact")
|
||||
exactTarget := model.ParseName("exacttarget")
|
||||
if err := store.Set(exactAlias, exactTarget, false); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
resolved, wasResolved, err = store.ResolveName(exactAlias)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected name to be resolved")
|
||||
}
|
||||
if resolved.DisplayShortest() != exactTarget.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be %q (exact match), got %q", exactTarget.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrefixAliasLongestMatchWins(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
store, err := newAliasStore(filepath.Join(tmpDir, "server.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Add two prefix aliases with overlapping patterns
|
||||
store.mu.Lock()
|
||||
store.prefixEntries = []aliasEntry{
|
||||
{Alias: "abc-", Target: "short-target", PrefixMatching: true},
|
||||
{Alias: "abc-def-", Target: "long-target", PrefixMatching: true},
|
||||
}
|
||||
store.sortPrefixEntriesLocked()
|
||||
store.mu.Unlock()
|
||||
|
||||
// "abc-def-ghi" should match the longer prefix "abc-def-"
|
||||
testName := model.ParseName("abc-def-ghi")
|
||||
resolved, wasResolved, err := store.ResolveName(testName)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected name to be resolved")
|
||||
}
|
||||
expectedLongTarget := model.ParseName("long-target")
|
||||
if resolved.DisplayShortest() != expectedLongTarget.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be %q (longest prefix match), got %q", expectedLongTarget.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
|
||||
// "abc-xyz" should match the shorter prefix "abc-"
|
||||
testName2 := model.ParseName("abc-xyz")
|
||||
resolved, wasResolved, err = store.ResolveName(testName2)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected name to be resolved")
|
||||
}
|
||||
expectedShortTarget := model.ParseName("short-target")
|
||||
if resolved.DisplayShortest() != expectedShortTarget.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be %q, got %q", expectedShortTarget.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrefixAliasChain(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
store, err := newAliasStore(filepath.Join(tmpDir, "server.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a chain: prefix "test-" -> "intermediate" -> "final"
|
||||
intermediate := model.ParseName("intermediate")
|
||||
final := model.ParseName("final")
|
||||
|
||||
// Add prefix alias
|
||||
store.mu.Lock()
|
||||
store.prefixEntries = []aliasEntry{
|
||||
{Alias: "test-", Target: "intermediate", PrefixMatching: true},
|
||||
}
|
||||
store.mu.Unlock()
|
||||
|
||||
// Add exact alias for the intermediate step
|
||||
if err := store.Set(intermediate, final, false); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// "test-foo" should resolve through the chain to "final"
|
||||
testName := model.ParseName("test-foo")
|
||||
resolved, wasResolved, err := store.ResolveName(testName)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected name to be resolved")
|
||||
}
|
||||
if resolved.DisplayShortest() != final.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be %q, got %q", final.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrefixAliasCRUD(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Setenv("HOME", t.TempDir())
|
||||
|
||||
s := Server{}
|
||||
|
||||
// Create a prefix alias via API
|
||||
w := createRequest(t, s.CreateAliasHandler, aliasEntry{
|
||||
Alias: "myprefix-",
|
||||
Target: "llama2",
|
||||
PrefixMatching: true,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var createResp aliasEntry
|
||||
if err := json.NewDecoder(w.Body).Decode(&createResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !createResp.PrefixMatching {
|
||||
t.Fatal("expected prefix_matching to be true in response")
|
||||
}
|
||||
|
||||
// List aliases and verify the prefix alias is included
|
||||
w = createRequest(t, s.ListAliasesHandler, nil)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var listResp aliasListResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&listResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, a := range listResp.Aliases {
|
||||
if a.PrefixMatching && a.Target == "llama2" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatal("expected to find prefix alias in list")
|
||||
}
|
||||
|
||||
// Delete the prefix alias
|
||||
w = createRequest(t, s.DeleteAliasHandler, aliasDeleteRequest{Alias: "myprefix-"})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Verify it's deleted
|
||||
w = createRequest(t, s.ListAliasesHandler, nil)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(w.Body).Decode(&listResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, a := range listResp.Aliases {
|
||||
if a.PrefixMatching {
|
||||
t.Fatal("expected prefix alias to be deleted")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrefixAliasCaseInsensitive(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
store, err := newAliasStore(filepath.Join(tmpDir, "server.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Add a prefix alias with mixed case
|
||||
store.mu.Lock()
|
||||
store.prefixEntries = []aliasEntry{
|
||||
{Alias: "MyPrefix-", Target: "targetmodel", PrefixMatching: true},
|
||||
}
|
||||
store.mu.Unlock()
|
||||
|
||||
// Test that matching is case-insensitive
|
||||
testName := model.ParseName("myprefix-foo")
|
||||
resolved, wasResolved, err := store.ResolveName(testName)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected name to be resolved (case-insensitive)")
|
||||
}
|
||||
expectedTarget := model.ParseName("targetmodel")
|
||||
if resolved.DisplayShortest() != expectedTarget.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be %q, got %q", expectedTarget.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
|
||||
// Test uppercase request
|
||||
testName2 := model.ParseName("MYPREFIX-BAR")
|
||||
_, wasResolved, err = store.ResolveName(testName2)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected name to be resolved (uppercase)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrefixAliasLocalModelPrecedence(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Setenv("HOME", t.TempDir())
|
||||
|
||||
s := Server{}
|
||||
|
||||
// Create a local model that would match a prefix alias
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "myprefix-localmodel",
|
||||
RemoteHost: "example.com",
|
||||
From: "test",
|
||||
Info: map[string]any{
|
||||
"capabilities": []string{"completion"},
|
||||
},
|
||||
Stream: &stream,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Create a prefix alias that would match the local model name
|
||||
w = createRequest(t, s.CreateAliasHandler, aliasEntry{
|
||||
Alias: "myprefix-",
|
||||
Target: "someothermodel",
|
||||
PrefixMatching: true,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Verify that resolving "myprefix-localmodel" returns the local model, not the alias target
|
||||
store, err := s.aliasStore()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
localModelName := model.ParseName("myprefix-localmodel")
|
||||
resolved, wasResolved, err := store.ResolveName(localModelName)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if wasResolved {
|
||||
t.Fatalf("expected local model to take precedence (wasResolved should be false), but got resolved to %q", resolved.DisplayShortest())
|
||||
}
|
||||
if resolved.DisplayShortest() != localModelName.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be local model %q, got %q", localModelName.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
|
||||
// Also verify that a non-local model matching the prefix DOES resolve to the alias target
|
||||
nonLocalName := model.ParseName("myprefix-nonexistent")
|
||||
resolved, wasResolved, err = store.ResolveName(nonLocalName)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected non-local model to resolve via prefix alias")
|
||||
}
|
||||
expectedTarget := model.ParseName("someothermodel")
|
||||
if resolved.DisplayShortest() != expectedTarget.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be %q, got %q", expectedTarget.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user