mirror of
https://github.com/ollama/ollama.git
synced 2026-03-06 16:08:21 -05:00
Compare commits
25 Commits
v0.17.2
...
parth-refa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d978006378 | ||
|
|
875089e964 | ||
|
|
1cc7150733 | ||
|
|
5d3e317232 | ||
|
|
ace441429e | ||
|
|
fc2a3825d7 | ||
|
|
799e51d419 | ||
|
|
e8fcb29586 | ||
|
|
97d2f05a6d | ||
|
|
8207e55ec7 | ||
|
|
ad16bffc7d | ||
|
|
c1e3ef4bcc | ||
|
|
a3093cd5e5 | ||
|
|
23d4cad1a2 | ||
|
|
86513cb697 | ||
|
|
3490e9590b | ||
|
|
8da09b1e7e | ||
|
|
a60b9adcce | ||
|
|
a16f96658b | ||
|
|
18ab09b431 | ||
|
|
638faeac54 | ||
|
|
dd5eb6337d | ||
|
|
79917cf80b | ||
|
|
cc90a035a0 | ||
|
|
d98dda4676 |
@@ -1063,7 +1063,7 @@ func DefaultOptions() Options {
|
||||
TopP: 0.9,
|
||||
TypicalP: 1.0,
|
||||
RepeatLastN: 64,
|
||||
RepeatPenalty: 1.1,
|
||||
RepeatPenalty: 1.0,
|
||||
PresencePenalty: 0.0,
|
||||
FrequencyPenalty: 0.0,
|
||||
Seed: -1,
|
||||
|
||||
324
cmd/cmd.go
324
cmd/cmd.go
@@ -38,9 +38,11 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/cmd/launch"
|
||||
"github.com/ollama/ollama/cmd/tui"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/internal/modelref"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/readline"
|
||||
@@ -57,36 +59,36 @@ import (
|
||||
|
||||
func init() {
|
||||
// Override default selectors to use Bubbletea TUI instead of raw terminal I/O.
|
||||
config.DefaultSingleSelector = func(title string, items []config.ModelItem, current string) (string, error) {
|
||||
launch.DefaultSingleSelector = func(title string, items []launch.ModelItem, current string) (string, error) {
|
||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
||||
result, err := tui.SelectSingle(title, tuiItems, current)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return "", config.ErrCancelled
|
||||
return "", launch.ErrCancelled
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
config.DefaultMultiSelector = func(title string, items []config.ModelItem, preChecked []string) ([]string, error) {
|
||||
launch.DefaultMultiSelector = func(title string, items []launch.ModelItem, preChecked []string) ([]string, error) {
|
||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
||||
result, err := tui.SelectMultiple(title, tuiItems, preChecked)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return nil, config.ErrCancelled
|
||||
return nil, launch.ErrCancelled
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
config.DefaultSignIn = func(modelName, signInURL string) (string, error) {
|
||||
launch.DefaultSignIn = func(modelName, signInURL string) (string, error) {
|
||||
userName, err := tui.RunSignIn(modelName, signInURL)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return "", config.ErrCancelled
|
||||
return "", launch.ErrCancelled
|
||||
}
|
||||
return userName, err
|
||||
}
|
||||
|
||||
config.DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
launch.DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
ok, err := tui.RunConfirm(prompt)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return false, config.ErrCancelled
|
||||
return false, launch.ErrCancelled
|
||||
}
|
||||
return ok, err
|
||||
}
|
||||
@@ -406,12 +408,14 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
|
||||
return err
|
||||
}
|
||||
|
||||
requestedCloud := modelref.HasExplicitCloudSource(opts.Model)
|
||||
|
||||
if info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: opts.Model}); err != nil {
|
||||
return err
|
||||
} else if info.RemoteHost != "" {
|
||||
} else if info.RemoteHost != "" || requestedCloud {
|
||||
// Cloud model, no need to load/unload
|
||||
|
||||
isCloud := strings.HasPrefix(info.RemoteHost, "https://ollama.com")
|
||||
isCloud := requestedCloud || strings.HasPrefix(info.RemoteHost, "https://ollama.com")
|
||||
|
||||
// Check if user is signed in for ollama.com cloud models
|
||||
if isCloud {
|
||||
@@ -422,10 +426,14 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
|
||||
|
||||
if opts.ShowConnect {
|
||||
p.StopAndClear()
|
||||
remoteModel := info.RemoteModel
|
||||
if remoteModel == "" {
|
||||
remoteModel = opts.Model
|
||||
}
|
||||
if isCloud {
|
||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", info.RemoteModel)
|
||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", remoteModel)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", info.RemoteModel, info.RemoteHost)
|
||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", remoteModel, info.RemoteHost)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -497,6 +505,20 @@ func generateEmbedding(cmd *cobra.Command, modelName, input string, keepAlive *a
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(parthsareen): consolidate with TUI signin flow
|
||||
func handleCloudAuthorizationError(err error) bool {
|
||||
var authErr api.AuthorizationError
|
||||
if errors.As(err, &authErr) && authErr.StatusCode == http.StatusUnauthorized {
|
||||
fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n")
|
||||
if authErr.SigninURL != "" {
|
||||
fmt.Printf(ConnectInstructions, authErr.SigninURL)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
interactive := true
|
||||
|
||||
@@ -604,12 +626,16 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
name := args[0]
|
||||
requestedCloud := modelref.HasExplicitCloudSource(name)
|
||||
|
||||
info, err := func() (*api.ShowResponse, error) {
|
||||
showReq := &api.ShowRequest{Name: name}
|
||||
info, err := client.Show(cmd.Context(), showReq)
|
||||
var se api.StatusError
|
||||
if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
|
||||
if requestedCloud {
|
||||
return nil, err
|
||||
}
|
||||
if err := PullHandler(cmd, []string{name}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -618,6 +644,9 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return info, err
|
||||
}()
|
||||
if err != nil {
|
||||
if handleCloudAuthorizationError(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -712,7 +741,13 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
return generateInteractive(cmd, opts)
|
||||
}
|
||||
return generate(cmd, opts)
|
||||
if err := generate(cmd, opts); err != nil {
|
||||
if handleCloudAuthorizationError(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func SigninHandler(cmd *cobra.Command, args []string) error {
|
||||
@@ -1892,183 +1927,124 @@ func ensureServerRunning(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
func launchInteractiveModel(cmd *cobra.Command, modelName string) error {
|
||||
opts := runOptions{
|
||||
Model: modelName,
|
||||
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||
Options: map[string]any{},
|
||||
ShowConnect: true,
|
||||
}
|
||||
// loadOrUnloadModel is cloud-safe here: remote/cloud models skip local preload
|
||||
// and only validate auth/connectivity before interactive chat starts.
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
return fmt.Errorf("error loading model: %w", err)
|
||||
}
|
||||
if err := generateInteractive(cmd, opts); err != nil {
|
||||
return fmt.Errorf("error running model: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// runInteractiveTUI runs the main interactive TUI menu.
|
||||
func runInteractiveTUI(cmd *cobra.Command) {
|
||||
func runInteractiveTUI(cmd *cobra.Command, invocation launch.LauncherInvocation) {
|
||||
// Ensure the server is running before showing the TUI
|
||||
if err := ensureServerRunning(cmd.Context()); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error starting server: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Selector adapters for tui
|
||||
singleSelector := func(title string, items []config.ModelItem, current string) (string, error) {
|
||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
||||
result, err := tui.SelectSingle(title, tuiItems, current)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return "", config.ErrCancelled
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
multiSelector := func(title string, items []config.ModelItem, preChecked []string) ([]string, error) {
|
||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
||||
result, err := tui.SelectMultiple(title, tuiItems, preChecked)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return nil, config.ErrCancelled
|
||||
}
|
||||
return result, err
|
||||
deps := launcherDeps{
|
||||
buildState: launch.BuildLauncherState,
|
||||
runMenu: tui.RunMenu,
|
||||
resolveRunModel: launch.ResolveRunModel,
|
||||
resolveRequestedRunModel: launch.ResolveRequestedRunModel,
|
||||
launchIntegration: launch.LaunchIntegration,
|
||||
runModel: launchInteractiveModel,
|
||||
}
|
||||
|
||||
currentInvocation := invocation
|
||||
for {
|
||||
result, err := tui.Run()
|
||||
continueLoop, err := runInteractiveTUIStep(cmd, currentInvocation, deps)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
}
|
||||
if !continueLoop {
|
||||
return
|
||||
}
|
||||
currentInvocation = launch.LauncherInvocation{}
|
||||
}
|
||||
}
|
||||
|
||||
runModel := func(modelName string) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
return
|
||||
}
|
||||
if err := config.ShowOrPull(cmd.Context(), client, modelName); err != nil {
|
||||
if errors.Is(err, config.ErrCancelled) {
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
return
|
||||
}
|
||||
_ = config.SetLastModel(modelName)
|
||||
opts := runOptions{
|
||||
Model: modelName,
|
||||
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||
Options: map[string]any{},
|
||||
ShowConnect: true,
|
||||
}
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error loading model: %v\n", err)
|
||||
return
|
||||
}
|
||||
if err := generateInteractive(cmd, opts); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error running model: %v\n", err)
|
||||
}
|
||||
}
|
||||
type launcherDeps struct {
|
||||
buildState func(context.Context) (*launch.LauncherState, error)
|
||||
runMenu func(*launch.LauncherState) (tui.TUIAction, error)
|
||||
resolveRunModel func(context.Context, launch.RunModelRequest) (string, error)
|
||||
resolveRequestedRunModel func(context.Context, string) (string, error)
|
||||
launchIntegration func(context.Context, launch.IntegrationLaunchRequest) error
|
||||
runModel func(*cobra.Command, string) error
|
||||
}
|
||||
|
||||
launchIntegration := func(name string) bool {
|
||||
if err := config.EnsureInstalled(name); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
return true
|
||||
}
|
||||
// If not configured or model no longer exists, prompt for model selection
|
||||
configuredModel := config.IntegrationModel(name)
|
||||
if configuredModel == "" || !config.ModelExists(cmd.Context(), configuredModel) || config.IsCloudModelDisabled(cmd.Context(), configuredModel) {
|
||||
err := config.ConfigureIntegrationWithSelectors(cmd.Context(), name, singleSelector, multiSelector)
|
||||
if errors.Is(err, config.ErrCancelled) {
|
||||
return false // Return to main menu
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", name, err)
|
||||
return true
|
||||
}
|
||||
}
|
||||
if err := config.LaunchIntegration(name); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", name, err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
func runInteractiveTUIStep(cmd *cobra.Command, invocation launch.LauncherInvocation, deps launcherDeps) (bool, error) {
|
||||
state, err := deps.buildState(cmd.Context())
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("build launcher state: %w", err)
|
||||
}
|
||||
|
||||
switch result.Selection {
|
||||
case tui.SelectionNone:
|
||||
// User quit
|
||||
return
|
||||
case tui.SelectionRunModel:
|
||||
_ = config.SetLastSelection("run")
|
||||
if modelName := config.LastModel(); modelName != "" && !config.IsCloudModelDisabled(cmd.Context(), modelName) {
|
||||
runModel(modelName)
|
||||
} else {
|
||||
modelName, err := config.SelectModelWithSelector(cmd.Context(), singleSelector)
|
||||
if errors.Is(err, config.ErrCancelled) {
|
||||
continue // Return to main menu
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error selecting model: %v\n", err)
|
||||
continue
|
||||
}
|
||||
runModel(modelName)
|
||||
}
|
||||
case tui.SelectionChangeRunModel:
|
||||
_ = config.SetLastSelection("run")
|
||||
// Use model from modal if selected, otherwise show picker
|
||||
modelName := result.Model
|
||||
if modelName == "" {
|
||||
var err error
|
||||
modelName, err = config.SelectModelWithSelector(cmd.Context(), singleSelector)
|
||||
if errors.Is(err, config.ErrCancelled) {
|
||||
continue // Return to main menu
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error selecting model: %v\n", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
if config.IsCloudModelDisabled(cmd.Context(), modelName) {
|
||||
continue // Return to main menu
|
||||
}
|
||||
runModel(modelName)
|
||||
case tui.SelectionIntegration:
|
||||
_ = config.SetLastSelection(result.Integration)
|
||||
if !launchIntegration(result.Integration) {
|
||||
continue // Return to main menu
|
||||
}
|
||||
case tui.SelectionChangeIntegration:
|
||||
_ = config.SetLastSelection(result.Integration)
|
||||
if len(result.Models) > 0 {
|
||||
// Filter out cloud-disabled models
|
||||
var filtered []string
|
||||
for _, m := range result.Models {
|
||||
if !config.IsCloudModelDisabled(cmd.Context(), m) {
|
||||
filtered = append(filtered, m)
|
||||
}
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
continue
|
||||
}
|
||||
result.Models = filtered
|
||||
// Multi-select from modal (Editor integrations)
|
||||
if err := config.SaveAndEditIntegration(result.Integration, result.Models); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", result.Integration, err)
|
||||
continue
|
||||
}
|
||||
if err := config.LaunchIntegrationWithModel(result.Integration, result.Models[0]); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
|
||||
}
|
||||
} else if result.Model != "" {
|
||||
if config.IsCloudModelDisabled(cmd.Context(), result.Model) {
|
||||
continue
|
||||
}
|
||||
// Single-select from modal - save and launch
|
||||
if err := config.SaveIntegration(result.Integration, []string{result.Model}); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error saving config: %v\n", err)
|
||||
continue
|
||||
}
|
||||
if err := config.LaunchIntegrationWithModel(result.Integration, result.Model); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
|
||||
}
|
||||
} else {
|
||||
err := config.ConfigureIntegrationWithSelectors(cmd.Context(), result.Integration, singleSelector, multiSelector)
|
||||
if errors.Is(err, config.ErrCancelled) {
|
||||
continue // Return to main menu
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", result.Integration, err)
|
||||
continue
|
||||
}
|
||||
if err := config.LaunchIntegration(result.Integration); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
|
||||
}
|
||||
}
|
||||
action, err := deps.runMenu(state)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("run launcher menu: %w", err)
|
||||
}
|
||||
|
||||
return runLauncherAction(cmd, invocation, action, deps)
|
||||
}
|
||||
|
||||
func saveLauncherSelection(action tui.TUIAction) {
|
||||
// Best effort only: this affects menu recall, not launch correctness.
|
||||
_ = config.SetLastSelection(action.LastSelection())
|
||||
}
|
||||
|
||||
func runLauncherAction(cmd *cobra.Command, invocation launch.LauncherInvocation, action tui.TUIAction, deps launcherDeps) (bool, error) {
|
||||
switch action.Kind {
|
||||
case tui.TUIActionNone:
|
||||
return false, nil
|
||||
case tui.TUIActionRunModel:
|
||||
saveLauncherSelection(action)
|
||||
var (
|
||||
modelName string
|
||||
err error
|
||||
)
|
||||
if !action.ForceConfigure && invocation.ModelOverride != "" {
|
||||
modelName, err = deps.resolveRequestedRunModel(cmd.Context(), invocation.ModelOverride)
|
||||
} else {
|
||||
modelName, err = deps.resolveRunModel(cmd.Context(), action.RunModelRequest())
|
||||
}
|
||||
if errors.Is(err, launch.ErrCancelled) {
|
||||
return true, nil
|
||||
}
|
||||
if err != nil {
|
||||
return true, fmt.Errorf("selecting model: %w", err)
|
||||
}
|
||||
if err := deps.runModel(cmd, modelName); err != nil {
|
||||
return true, err
|
||||
}
|
||||
return true, nil
|
||||
case tui.TUIActionLaunchIntegration:
|
||||
saveLauncherSelection(action)
|
||||
req := action.IntegrationLaunchRequest()
|
||||
if !action.ForceConfigure {
|
||||
req.ModelOverride = invocation.ModelOverride
|
||||
req.ExtraArgs = append([]string(nil), invocation.ExtraArgs...)
|
||||
}
|
||||
err := deps.launchIntegration(cmd.Context(), req)
|
||||
if errors.Is(err, launch.ErrCancelled) {
|
||||
return true, nil
|
||||
}
|
||||
if err != nil {
|
||||
return true, fmt.Errorf("launching %s: %w", action.Integration, err)
|
||||
}
|
||||
return true, nil
|
||||
default:
|
||||
return false, fmt.Errorf("unknown launcher action: %d", action.Kind)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2094,7 +2070,7 @@ func NewCLI() *cobra.Command {
|
||||
return
|
||||
}
|
||||
|
||||
runInteractiveTUI(cmd)
|
||||
runInteractiveTUI(cmd, launch.LauncherInvocation{})
|
||||
},
|
||||
}
|
||||
|
||||
@@ -2338,7 +2314,7 @@ func NewCLI() *cobra.Command {
|
||||
copyCmd,
|
||||
deleteCmd,
|
||||
runnerCmd,
|
||||
config.LaunchCmd(checkServerHeartbeat, runInteractiveTUI),
|
||||
launch.LaunchCmd(checkServerHeartbeat, runInteractiveTUI),
|
||||
)
|
||||
|
||||
return rootCmd
|
||||
|
||||
397
cmd/cmd_launcher_test.go
Normal file
397
cmd/cmd_launcher_test.go
Normal file
@@ -0,0 +1,397 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/cmd/launch"
|
||||
"github.com/ollama/ollama/cmd/tui"
|
||||
)
|
||||
|
||||
func setCmdTestHome(t *testing.T, dir string) {
|
||||
t.Helper()
|
||||
t.Setenv("HOME", dir)
|
||||
t.Setenv("USERPROFILE", dir)
|
||||
}
|
||||
|
||||
func unexpectedRunModelResolution(t *testing.T) func(context.Context, launch.RunModelRequest) (string, error) {
|
||||
t.Helper()
|
||||
return func(ctx context.Context, req launch.RunModelRequest) (string, error) {
|
||||
t.Fatalf("did not expect run-model resolution: %+v", req)
|
||||
return "", nil
|
||||
}
|
||||
}
|
||||
|
||||
func unexpectedRequestedRunModelResolution(t *testing.T) func(context.Context, string) (string, error) {
|
||||
t.Helper()
|
||||
return func(ctx context.Context, model string) (string, error) {
|
||||
t.Fatalf("did not expect requested run-model resolution: %s", model)
|
||||
return "", nil
|
||||
}
|
||||
}
|
||||
|
||||
func unexpectedIntegrationLaunch(t *testing.T) func(context.Context, launch.IntegrationLaunchRequest) error {
|
||||
t.Helper()
|
||||
return func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
|
||||
t.Fatalf("did not expect integration launch: %+v", req)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func unexpectedModelLaunch(t *testing.T) func(*cobra.Command, string) error {
|
||||
t.Helper()
|
||||
return func(cmd *cobra.Command, model string) error {
|
||||
t.Fatalf("did not expect chat launch: %s", model)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunInteractiveTUI_RunModelActionsUseResolveRunModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
action tui.TUIAction
|
||||
wantForce bool
|
||||
wantModel string
|
||||
}{
|
||||
{
|
||||
name: "enter uses saved model flow",
|
||||
action: tui.TUIAction{Kind: tui.TUIActionRunModel},
|
||||
wantModel: "qwen3:8b",
|
||||
},
|
||||
{
|
||||
name: "right forces picker",
|
||||
action: tui.TUIAction{Kind: tui.TUIActionRunModel, ForceConfigure: true},
|
||||
wantForce: true,
|
||||
wantModel: "glm-5:cloud",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
setCmdTestHome(t, t.TempDir())
|
||||
|
||||
var menuCalls int
|
||||
runMenu := func(state *launch.LauncherState) (tui.TUIAction, error) {
|
||||
menuCalls++
|
||||
if menuCalls == 1 {
|
||||
return tt.action, nil
|
||||
}
|
||||
return tui.TUIAction{Kind: tui.TUIActionNone}, nil
|
||||
}
|
||||
|
||||
var gotReq launch.RunModelRequest
|
||||
var launched string
|
||||
deps := launcherDeps{
|
||||
buildState: func(ctx context.Context) (*launch.LauncherState, error) {
|
||||
return &launch.LauncherState{}, nil
|
||||
},
|
||||
runMenu: runMenu,
|
||||
resolveRunModel: func(ctx context.Context, req launch.RunModelRequest) (string, error) {
|
||||
gotReq = req
|
||||
return tt.wantModel, nil
|
||||
},
|
||||
resolveRequestedRunModel: unexpectedRequestedRunModelResolution(t),
|
||||
launchIntegration: unexpectedIntegrationLaunch(t),
|
||||
runModel: func(cmd *cobra.Command, model string) error {
|
||||
launched = model
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(context.Background())
|
||||
for {
|
||||
continueLoop, err := runInteractiveTUIStep(cmd, launch.LauncherInvocation{}, deps)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected step error: %v", err)
|
||||
}
|
||||
if !continueLoop {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if gotReq.ForcePicker != tt.wantForce {
|
||||
t.Fatalf("expected ForcePicker=%v, got %v", tt.wantForce, gotReq.ForcePicker)
|
||||
}
|
||||
if launched != tt.wantModel {
|
||||
t.Fatalf("expected interactive launcher to run %q, got %q", tt.wantModel, launched)
|
||||
}
|
||||
if got := config.LastSelection(); got != "run" {
|
||||
t.Fatalf("expected last selection to be run, got %q", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunInteractiveTUI_IntegrationActionsUseLaunchIntegration(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
action tui.TUIAction
|
||||
wantForce bool
|
||||
}{
|
||||
{
|
||||
name: "enter launches integration",
|
||||
action: tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude"},
|
||||
},
|
||||
{
|
||||
name: "right forces configure",
|
||||
action: tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude", ForceConfigure: true},
|
||||
wantForce: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
setCmdTestHome(t, t.TempDir())
|
||||
|
||||
var menuCalls int
|
||||
runMenu := func(state *launch.LauncherState) (tui.TUIAction, error) {
|
||||
menuCalls++
|
||||
if menuCalls == 1 {
|
||||
return tt.action, nil
|
||||
}
|
||||
return tui.TUIAction{Kind: tui.TUIActionNone}, nil
|
||||
}
|
||||
|
||||
var gotReq launch.IntegrationLaunchRequest
|
||||
deps := launcherDeps{
|
||||
buildState: func(ctx context.Context) (*launch.LauncherState, error) {
|
||||
return &launch.LauncherState{}, nil
|
||||
},
|
||||
runMenu: runMenu,
|
||||
resolveRunModel: unexpectedRunModelResolution(t),
|
||||
resolveRequestedRunModel: unexpectedRequestedRunModelResolution(t),
|
||||
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
|
||||
gotReq = req
|
||||
return nil
|
||||
},
|
||||
runModel: unexpectedModelLaunch(t),
|
||||
}
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(context.Background())
|
||||
for {
|
||||
continueLoop, err := runInteractiveTUIStep(cmd, launch.LauncherInvocation{}, deps)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected step error: %v", err)
|
||||
}
|
||||
if !continueLoop {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if gotReq.Name != "claude" {
|
||||
t.Fatalf("expected integration name to be passed through, got %q", gotReq.Name)
|
||||
}
|
||||
if gotReq.ForceConfigure != tt.wantForce {
|
||||
t.Fatalf("expected ForceConfigure=%v, got %v", tt.wantForce, gotReq.ForceConfigure)
|
||||
}
|
||||
if got := config.LastSelection(); got != "claude" {
|
||||
t.Fatalf("expected last selection to be claude, got %q", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunLauncherAction_RunModelContinuesAfterCancellation(t *testing.T) {
|
||||
setCmdTestHome(t, t.TempDir())
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(context.Background())
|
||||
|
||||
continueLoop, err := runLauncherAction(cmd, launch.LauncherInvocation{}, tui.TUIAction{Kind: tui.TUIActionRunModel}, launcherDeps{
|
||||
buildState: nil,
|
||||
runMenu: nil,
|
||||
resolveRunModel: func(ctx context.Context, req launch.RunModelRequest) (string, error) {
|
||||
return "", launch.ErrCancelled
|
||||
},
|
||||
resolveRequestedRunModel: unexpectedRequestedRunModelResolution(t),
|
||||
launchIntegration: unexpectedIntegrationLaunch(t),
|
||||
runModel: unexpectedModelLaunch(t),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error on cancellation, got %v", err)
|
||||
}
|
||||
if !continueLoop {
|
||||
t.Fatal("expected cancellation to continue the menu loop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunLauncherAction_IntegrationContinuesAfterCancellation(t *testing.T) {
|
||||
setCmdTestHome(t, t.TempDir())
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(context.Background())
|
||||
|
||||
continueLoop, err := runLauncherAction(cmd, launch.LauncherInvocation{}, tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude"}, launcherDeps{
|
||||
buildState: nil,
|
||||
runMenu: nil,
|
||||
resolveRunModel: unexpectedRunModelResolution(t),
|
||||
resolveRequestedRunModel: unexpectedRequestedRunModelResolution(t),
|
||||
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
|
||||
return launch.ErrCancelled
|
||||
},
|
||||
runModel: unexpectedModelLaunch(t),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error on cancellation, got %v", err)
|
||||
}
|
||||
if !continueLoop {
|
||||
t.Fatal("expected cancellation to continue the menu loop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunLauncherAction_RunModelUsesInvocationOverrideOnEnter(t *testing.T) {
|
||||
setCmdTestHome(t, t.TempDir())
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(context.Background())
|
||||
|
||||
var gotModel string
|
||||
var launched string
|
||||
continueLoop, err := runLauncherAction(cmd, launch.LauncherInvocation{ModelOverride: "qwen3.5:cloud"}, tui.TUIAction{Kind: tui.TUIActionRunModel}, launcherDeps{
|
||||
resolveRunModel: unexpectedRunModelResolution(t),
|
||||
resolveRequestedRunModel: func(ctx context.Context, model string) (string, error) {
|
||||
gotModel = model
|
||||
return model, nil
|
||||
},
|
||||
launchIntegration: unexpectedIntegrationLaunch(t),
|
||||
runModel: func(cmd *cobra.Command, model string) error {
|
||||
launched = model
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !continueLoop {
|
||||
t.Fatal("expected menu loop to continue after launch")
|
||||
}
|
||||
if gotModel != "qwen3.5:cloud" {
|
||||
t.Fatalf("expected requested model override to be used, got %q", gotModel)
|
||||
}
|
||||
if launched != "qwen3.5:cloud" {
|
||||
t.Fatalf("expected launched model to use override, got %q", launched)
|
||||
}
|
||||
if got := config.LastSelection(); got != "run" {
|
||||
t.Fatalf("expected last selection to be run, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunLauncherAction_RunModelIgnoresInvocationOverrideOnChange(t *testing.T) {
|
||||
setCmdTestHome(t, t.TempDir())
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(context.Background())
|
||||
|
||||
var gotReq launch.RunModelRequest
|
||||
var launched string
|
||||
continueLoop, err := runLauncherAction(cmd, launch.LauncherInvocation{ModelOverride: "qwen3.5:cloud"}, tui.TUIAction{Kind: tui.TUIActionRunModel, ForceConfigure: true}, launcherDeps{
|
||||
resolveRunModel: func(ctx context.Context, req launch.RunModelRequest) (string, error) {
|
||||
gotReq = req
|
||||
return "llama3.2", nil
|
||||
},
|
||||
resolveRequestedRunModel: unexpectedRequestedRunModelResolution(t),
|
||||
launchIntegration: unexpectedIntegrationLaunch(t),
|
||||
runModel: func(cmd *cobra.Command, model string) error {
|
||||
launched = model
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !continueLoop {
|
||||
t.Fatal("expected menu loop to continue after launch")
|
||||
}
|
||||
if !gotReq.ForcePicker {
|
||||
t.Fatal("expected change action to force the picker")
|
||||
}
|
||||
if launched != "llama3.2" {
|
||||
t.Fatalf("expected launched model to come from picker flow, got %q", launched)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunLauncherAction_IntegrationUsesInvocationOverrideOnEnter(t *testing.T) {
|
||||
setCmdTestHome(t, t.TempDir())
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(context.Background())
|
||||
|
||||
var gotReq launch.IntegrationLaunchRequest
|
||||
continueLoop, err := runLauncherAction(cmd, launch.LauncherInvocation{
|
||||
ModelOverride: "qwen3.5:cloud",
|
||||
ExtraArgs: []string{"--sandbox", "workspace-write"},
|
||||
}, tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude"}, launcherDeps{
|
||||
resolveRunModel: unexpectedRunModelResolution(t),
|
||||
resolveRequestedRunModel: unexpectedRequestedRunModelResolution(t),
|
||||
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
|
||||
gotReq = req
|
||||
return nil
|
||||
},
|
||||
runModel: unexpectedModelLaunch(t),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !continueLoop {
|
||||
t.Fatal("expected menu loop to continue after launch")
|
||||
}
|
||||
if gotReq.Name != "claude" {
|
||||
t.Fatalf("expected integration name to be passed through, got %q", gotReq.Name)
|
||||
}
|
||||
if gotReq.ModelOverride != "qwen3.5:cloud" {
|
||||
t.Fatalf("expected model override to be forwarded, got %q", gotReq.ModelOverride)
|
||||
}
|
||||
if gotReq.ForceConfigure {
|
||||
t.Fatal("expected enter action not to force configure")
|
||||
}
|
||||
if len(gotReq.ExtraArgs) != 2 || gotReq.ExtraArgs[0] != "--sandbox" || gotReq.ExtraArgs[1] != "workspace-write" {
|
||||
t.Fatalf("unexpected extra args: %v", gotReq.ExtraArgs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunLauncherAction_IntegrationIgnoresInvocationOverrideOnChange(t *testing.T) {
|
||||
setCmdTestHome(t, t.TempDir())
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(context.Background())
|
||||
|
||||
var gotReq launch.IntegrationLaunchRequest
|
||||
continueLoop, err := runLauncherAction(cmd, launch.LauncherInvocation{
|
||||
ModelOverride: "qwen3.5:cloud",
|
||||
ExtraArgs: []string{"--sandbox", "workspace-write"},
|
||||
}, tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude", ForceConfigure: true}, launcherDeps{
|
||||
resolveRunModel: unexpectedRunModelResolution(t),
|
||||
resolveRequestedRunModel: unexpectedRequestedRunModelResolution(t),
|
||||
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
|
||||
gotReq = req
|
||||
return nil
|
||||
},
|
||||
runModel: unexpectedModelLaunch(t),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !continueLoop {
|
||||
t.Fatal("expected menu loop to continue after configure")
|
||||
}
|
||||
if gotReq.ModelOverride != "" {
|
||||
t.Fatalf("expected change action to ignore model override, got %q", gotReq.ModelOverride)
|
||||
}
|
||||
if len(gotReq.ExtraArgs) != 0 {
|
||||
t.Fatalf("expected change action to ignore extra args, got %v", gotReq.ExtraArgs)
|
||||
}
|
||||
if !gotReq.ForceConfigure {
|
||||
t.Fatal("expected change action to force configure")
|
||||
}
|
||||
}
|
||||
227
cmd/cmd_test.go
227
cmd/cmd_test.go
@@ -705,6 +705,139 @@ func TestRunEmbeddingModelNoInput(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunHandler_CloudAuthErrorOnShow_PrintsSigninMessage(t *testing.T) {
|
||||
var generateCalled bool
|
||||
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
if err := json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "unauthorized",
|
||||
"signin_url": "https://ollama.com/signin",
|
||||
}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
|
||||
generateCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
|
||||
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||
t.Cleanup(mockServer.Close)
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(t.Context())
|
||||
cmd.Flags().String("keepalive", "", "")
|
||||
cmd.Flags().Bool("truncate", false, "")
|
||||
cmd.Flags().Int("dimensions", 0, "")
|
||||
cmd.Flags().Bool("verbose", false, "")
|
||||
cmd.Flags().Bool("insecure", false, "")
|
||||
cmd.Flags().Bool("nowordwrap", false, "")
|
||||
cmd.Flags().String("format", "", "")
|
||||
cmd.Flags().String("think", "", "")
|
||||
cmd.Flags().Bool("hidethinking", false, "")
|
||||
|
||||
oldStdout := os.Stdout
|
||||
readOut, writeOut, _ := os.Pipe()
|
||||
os.Stdout = writeOut
|
||||
t.Cleanup(func() { os.Stdout = oldStdout })
|
||||
|
||||
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
|
||||
|
||||
_ = writeOut.Close()
|
||||
var out bytes.Buffer
|
||||
_, _ = io.Copy(&out, readOut)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("RunHandler returned error: %v", err)
|
||||
}
|
||||
|
||||
if generateCalled {
|
||||
t.Fatal("expected run to stop before /api/generate after unauthorized /api/show")
|
||||
}
|
||||
|
||||
if !strings.Contains(out.String(), "You need to be signed in to Ollama to run Cloud models.") {
|
||||
t.Fatalf("expected sign-in guidance message, got %q", out.String())
|
||||
}
|
||||
|
||||
if !strings.Contains(out.String(), "https://ollama.com/signin") {
|
||||
t.Fatalf("expected signin_url in output, got %q", out.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunHandler_CloudAuthErrorOnGenerate_PrintsSigninMessage(t *testing.T) {
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||
Capabilities: []model.Capability{model.CapabilityCompletion},
|
||||
}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
if err := json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "unauthorized",
|
||||
"signin_url": "https://ollama.com/signin",
|
||||
}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
|
||||
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||
t.Cleanup(mockServer.Close)
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(t.Context())
|
||||
cmd.Flags().String("keepalive", "", "")
|
||||
cmd.Flags().Bool("truncate", false, "")
|
||||
cmd.Flags().Int("dimensions", 0, "")
|
||||
cmd.Flags().Bool("verbose", false, "")
|
||||
cmd.Flags().Bool("insecure", false, "")
|
||||
cmd.Flags().Bool("nowordwrap", false, "")
|
||||
cmd.Flags().String("format", "", "")
|
||||
cmd.Flags().String("think", "", "")
|
||||
cmd.Flags().Bool("hidethinking", false, "")
|
||||
|
||||
oldStdout := os.Stdout
|
||||
readOut, writeOut, _ := os.Pipe()
|
||||
os.Stdout = writeOut
|
||||
t.Cleanup(func() { os.Stdout = oldStdout })
|
||||
|
||||
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
|
||||
|
||||
_ = writeOut.Close()
|
||||
var out bytes.Buffer
|
||||
_, _ = io.Copy(&out, readOut)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("RunHandler returned error: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(out.String(), "You need to be signed in to Ollama to run Cloud models.") {
|
||||
t.Fatalf("expected sign-in guidance message, got %q", out.String())
|
||||
}
|
||||
|
||||
if !strings.Contains(out.String(), "https://ollama.com/signin") {
|
||||
t.Fatalf("expected signin_url in output, got %q", out.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelfileName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -1663,31 +1796,81 @@ func TestRunOptions_Copy_Independence(t *testing.T) {
|
||||
|
||||
func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteHost string
|
||||
whoamiStatus int
|
||||
whoamiResp any
|
||||
expectedError string
|
||||
name string
|
||||
model string
|
||||
showStatus int
|
||||
remoteHost string
|
||||
remoteModel string
|
||||
whoamiStatus int
|
||||
whoamiResp any
|
||||
expectWhoami bool
|
||||
expectedError string
|
||||
expectAuthError bool
|
||||
}{
|
||||
{
|
||||
name: "ollama.com cloud model - user signed in",
|
||||
model: "test-cloud-model",
|
||||
remoteHost: "https://ollama.com",
|
||||
remoteModel: "test-model",
|
||||
whoamiStatus: http.StatusOK,
|
||||
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||
expectWhoami: true,
|
||||
},
|
||||
{
|
||||
name: "ollama.com cloud model - user not signed in",
|
||||
model: "test-cloud-model",
|
||||
remoteHost: "https://ollama.com",
|
||||
remoteModel: "test-model",
|
||||
whoamiStatus: http.StatusUnauthorized,
|
||||
whoamiResp: map[string]string{
|
||||
"error": "unauthorized",
|
||||
"signin_url": "https://ollama.com/signin",
|
||||
},
|
||||
expectedError: "unauthorized",
|
||||
expectWhoami: true,
|
||||
expectedError: "unauthorized",
|
||||
expectAuthError: true,
|
||||
},
|
||||
{
|
||||
name: "non-ollama.com remote - no auth check",
|
||||
model: "test-cloud-model",
|
||||
remoteHost: "https://other-remote.com",
|
||||
remoteModel: "test-model",
|
||||
whoamiStatus: http.StatusUnauthorized, // should not be called
|
||||
whoamiResp: nil,
|
||||
},
|
||||
{
|
||||
name: "explicit :cloud model - auth check without remote metadata",
|
||||
model: "kimi-k2.5:cloud",
|
||||
remoteHost: "",
|
||||
remoteModel: "",
|
||||
whoamiStatus: http.StatusOK,
|
||||
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||
expectWhoami: true,
|
||||
},
|
||||
{
|
||||
name: "explicit :cloud model without local stub returns not found by default",
|
||||
model: "minimax-m2.5:cloud",
|
||||
showStatus: http.StatusNotFound,
|
||||
whoamiStatus: http.StatusOK,
|
||||
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||
expectedError: "not found",
|
||||
expectWhoami: false,
|
||||
expectAuthError: false,
|
||||
},
|
||||
{
|
||||
name: "explicit -cloud model - auth check without remote metadata",
|
||||
model: "kimi-k2.5:latest-cloud",
|
||||
remoteHost: "",
|
||||
remoteModel: "",
|
||||
whoamiStatus: http.StatusOK,
|
||||
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||
expectWhoami: true,
|
||||
},
|
||||
{
|
||||
name: "dash cloud-like name without explicit source does not require auth",
|
||||
model: "test-cloud-model",
|
||||
remoteHost: "",
|
||||
remoteModel: "",
|
||||
whoamiStatus: http.StatusUnauthorized, // should not be called
|
||||
whoamiResp: nil,
|
||||
},
|
||||
@@ -1699,10 +1882,15 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/show":
|
||||
if tt.showStatus != 0 && tt.showStatus != http.StatusOK {
|
||||
w.WriteHeader(tt.showStatus)
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"error": "not found"})
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||
RemoteHost: tt.remoteHost,
|
||||
RemoteModel: "test-model",
|
||||
RemoteModel: tt.remoteModel,
|
||||
}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
@@ -1715,6 +1903,8 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
case "/api/generate":
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
@@ -1727,29 +1917,28 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
||||
cmd.SetContext(t.Context())
|
||||
|
||||
opts := &runOptions{
|
||||
Model: "test-cloud-model",
|
||||
Model: tt.model,
|
||||
ShowConnect: false,
|
||||
}
|
||||
|
||||
err := loadOrUnloadModel(cmd, opts)
|
||||
|
||||
if strings.HasPrefix(tt.remoteHost, "https://ollama.com") {
|
||||
if !whoamiCalled {
|
||||
t.Error("expected whoami to be called for ollama.com cloud model")
|
||||
}
|
||||
} else {
|
||||
if whoamiCalled {
|
||||
t.Error("whoami should not be called for non-ollama.com remote")
|
||||
}
|
||||
if whoamiCalled != tt.expectWhoami {
|
||||
t.Errorf("whoami called = %v, want %v", whoamiCalled, tt.expectWhoami)
|
||||
}
|
||||
|
||||
if tt.expectedError != "" {
|
||||
if err == nil {
|
||||
t.Errorf("expected error containing %q, got nil", tt.expectedError)
|
||||
} else {
|
||||
var authErr api.AuthorizationError
|
||||
if !errors.As(err, &authErr) {
|
||||
t.Errorf("expected AuthorizationError, got %T: %v", err, err)
|
||||
if !tt.expectAuthError && !strings.Contains(strings.ToLower(err.Error()), strings.ToLower(tt.expectedError)) {
|
||||
t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
|
||||
}
|
||||
if tt.expectAuthError {
|
||||
var authErr api.AuthorizationError
|
||||
if !errors.As(err, &authErr) {
|
||||
t.Errorf("expected AuthorizationError, got %T: %v", err, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
||||
@@ -3,15 +3,12 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type integration struct {
|
||||
@@ -20,6 +17,9 @@ type integration struct {
|
||||
Onboarded bool `json:"onboarded,omitempty"`
|
||||
}
|
||||
|
||||
// IntegrationConfig is the persisted config for one integration.
|
||||
type IntegrationConfig = integration
|
||||
|
||||
type config struct {
|
||||
Integrations map[string]*integration `json:"integrations"`
|
||||
LastModel string `json:"last_model,omitempty"`
|
||||
@@ -155,8 +155,8 @@ func SaveIntegration(appName string, models []string) error {
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
// integrationOnboarded marks an integration as onboarded in ollama's config.
|
||||
func integrationOnboarded(appName string) error {
|
||||
// MarkIntegrationOnboarded marks an integration as onboarded in Ollama's config.
|
||||
func MarkIntegrationOnboarded(appName string) error {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -174,7 +174,7 @@ func integrationOnboarded(appName string) error {
|
||||
|
||||
// IntegrationModel returns the first configured model for an integration, or empty string if not configured.
|
||||
func IntegrationModel(appName string) string {
|
||||
integrationConfig, err := loadIntegration(appName)
|
||||
integrationConfig, err := LoadIntegration(appName)
|
||||
if err != nil || len(integrationConfig.Models) == 0 {
|
||||
return ""
|
||||
}
|
||||
@@ -183,7 +183,7 @@ func IntegrationModel(appName string) string {
|
||||
|
||||
// IntegrationModels returns all configured models for an integration, or nil.
|
||||
func IntegrationModels(appName string) []string {
|
||||
integrationConfig, err := loadIntegration(appName)
|
||||
integrationConfig, err := LoadIntegration(appName)
|
||||
if err != nil || len(integrationConfig.Models) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -228,28 +228,8 @@ func SetLastSelection(selection string) error {
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
// ModelExists checks if a model exists on the Ollama server.
|
||||
func ModelExists(ctx context.Context, name string) bool {
|
||||
if name == "" {
|
||||
return false
|
||||
}
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
models, err := client.List(ctx)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for _, m := range models.Models {
|
||||
if m.Name == name || strings.HasPrefix(m.Name, name+":") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func loadIntegration(appName string) (*integration, error) {
|
||||
// LoadIntegration returns the saved config for one integration.
|
||||
func LoadIntegration(appName string) (*integration, error) {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -263,7 +243,8 @@ func loadIntegration(appName string) (*integration, error) {
|
||||
return integrationConfig, nil
|
||||
}
|
||||
|
||||
func saveAliases(appName string, aliases map[string]string) error {
|
||||
// SaveAliases replaces the saved aliases for one integration.
|
||||
func SaveAliases(appName string, aliases map[string]string) error {
|
||||
if appName == "" {
|
||||
return errors.New("app name cannot be empty")
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -45,12 +44,12 @@ func TestSaveAliases_ReplacesNotMerges(t *testing.T) {
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}
|
||||
if err := saveAliases("claude", initial); err != nil {
|
||||
if err := SaveAliases("claude", initial); err != nil {
|
||||
t.Fatalf("failed to save initial aliases: %v", err)
|
||||
}
|
||||
|
||||
// Verify both are saved
|
||||
loaded, err := loadIntegration("claude")
|
||||
loaded, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
@@ -63,12 +62,12 @@ func TestSaveAliases_ReplacesNotMerges(t *testing.T) {
|
||||
"primary": "local-model",
|
||||
// fast intentionally missing
|
||||
}
|
||||
if err := saveAliases("claude", updated); err != nil {
|
||||
if err := SaveAliases("claude", updated); err != nil {
|
||||
t.Fatalf("failed to save updated aliases: %v", err)
|
||||
}
|
||||
|
||||
// Verify fast is GONE (not merged/preserved)
|
||||
loaded, err = loadIntegration("claude")
|
||||
loaded, err = LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load after update: %v", err)
|
||||
}
|
||||
@@ -91,12 +90,12 @@ func TestSaveAliases_PreservesModels(t *testing.T) {
|
||||
|
||||
// Then update aliases
|
||||
aliases := map[string]string{"primary": "new-model"}
|
||||
if err := saveAliases("claude", aliases); err != nil {
|
||||
if err := SaveAliases("claude", aliases); err != nil {
|
||||
t.Fatalf("failed to save aliases: %v", err)
|
||||
}
|
||||
|
||||
// Verify models are preserved
|
||||
loaded, err := loadIntegration("claude")
|
||||
loaded, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
@@ -111,16 +110,16 @@ func TestSaveAliases_EmptyMap(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save with aliases
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil {
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Save empty map
|
||||
if err := saveAliases("claude", map[string]string{}); err != nil {
|
||||
if err := SaveAliases("claude", map[string]string{}); err != nil {
|
||||
t.Fatalf("failed to save empty: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("claude")
|
||||
loaded, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
@@ -135,16 +134,16 @@ func TestSaveAliases_NilMap(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save with aliases first
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Save nil map - should clear aliases
|
||||
if err := saveAliases("claude", nil); err != nil {
|
||||
if err := SaveAliases("claude", nil); err != nil {
|
||||
t.Fatalf("failed to save nil: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("claude")
|
||||
loaded, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
@@ -155,7 +154,7 @@ func TestSaveAliases_NilMap(t *testing.T) {
|
||||
|
||||
// TestSaveAliases_EmptyAppName returns error
|
||||
func TestSaveAliases_EmptyAppName(t *testing.T) {
|
||||
err := saveAliases("", map[string]string{"primary": "model"})
|
||||
err := SaveAliases("", map[string]string{"primary": "model"})
|
||||
if err == nil {
|
||||
t.Error("expected error for empty app name")
|
||||
}
|
||||
@@ -165,12 +164,12 @@ func TestSaveAliases_CaseInsensitive(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
if err := saveAliases("Claude", map[string]string{"primary": "model1"}); err != nil {
|
||||
if err := SaveAliases("Claude", map[string]string{"primary": "model1"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Load with different case
|
||||
loaded, err := loadIntegration("claude")
|
||||
loaded, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
@@ -179,11 +178,11 @@ func TestSaveAliases_CaseInsensitive(t *testing.T) {
|
||||
}
|
||||
|
||||
// Update with different case
|
||||
if err := saveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil {
|
||||
if err := SaveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil {
|
||||
t.Fatalf("failed to update: %v", err)
|
||||
}
|
||||
|
||||
loaded, err = loadIntegration("claude")
|
||||
loaded, err = LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load after update: %v", err)
|
||||
}
|
||||
@@ -198,11 +197,11 @@ func TestSaveAliases_CreatesIntegration(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save aliases for non-existent integration
|
||||
if err := saveAliases("newintegration", map[string]string{"primary": "model"}); err != nil {
|
||||
if err := SaveAliases("newintegration", map[string]string{"primary": "model"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("newintegration")
|
||||
loaded, err := LoadIntegration("newintegration")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
@@ -371,12 +370,12 @@ func TestAtomicUpdate_ServerSucceedsConfigSaved(t *testing.T) {
|
||||
t.Fatal("server should succeed")
|
||||
}
|
||||
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||
t.Fatalf("saveAliases failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify it was actually saved
|
||||
loaded, err := loadIntegration("claude")
|
||||
loaded, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
@@ -408,7 +407,7 @@ func TestConfigFile_PreservesUnknownFields(t *testing.T) {
|
||||
os.WriteFile(configPath, []byte(initialConfig), 0o644)
|
||||
|
||||
// Update aliases
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model2"}); err != nil {
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "model2"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
@@ -440,11 +439,6 @@ func containsHelper(s, substr string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func TestClaudeImplementsAliasConfigurer(t *testing.T) {
|
||||
c := &Claude{}
|
||||
var _ AliasConfigurer = c // Compile-time check
|
||||
}
|
||||
|
||||
func TestModelNameEdgeCases(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -464,11 +458,11 @@ func TestModelNameEdgeCases(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
aliases := map[string]string{"primary": tc.model}
|
||||
if err := saveAliases("claude", aliases); err != nil {
|
||||
if err := SaveAliases("claude", aliases); err != nil {
|
||||
t.Fatalf("failed to save model %q: %v", tc.model, err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("claude")
|
||||
loaded, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
@@ -485,7 +479,7 @@ func TestSwitchingScenarios(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial cloud config
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
if err := SaveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}); err != nil {
|
||||
@@ -493,13 +487,13 @@ func TestSwitchingScenarios(t *testing.T) {
|
||||
}
|
||||
|
||||
// Switch to local (no fast)
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
if err := SaveAliases("claude", map[string]string{
|
||||
"primary": "local-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
loaded, _ := LoadIntegration("claude")
|
||||
if loaded.Aliases["fast"] != "" {
|
||||
t.Errorf("fast should be removed, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
@@ -513,21 +507,21 @@ func TestSwitchingScenarios(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial local config
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
if err := SaveAliases("claude", map[string]string{
|
||||
"primary": "local-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Switch to cloud (with fast)
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
if err := SaveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
loaded, _ := LoadIntegration("claude")
|
||||
if loaded.Aliases["fast"] != "cloud-model" {
|
||||
t.Errorf("fast should be cloud-model, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
@@ -538,7 +532,7 @@ func TestSwitchingScenarios(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial cloud config
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
if err := SaveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model-1",
|
||||
"fast": "cloud-model-1",
|
||||
}); err != nil {
|
||||
@@ -546,14 +540,14 @@ func TestSwitchingScenarios(t *testing.T) {
|
||||
}
|
||||
|
||||
// Switch to different cloud
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
if err := SaveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model-2",
|
||||
"fast": "cloud-model-2",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
loaded, _ := LoadIntegration("claude")
|
||||
if loaded.Aliases["primary"] != "cloud-model-2" {
|
||||
t.Errorf("primary should be cloud-model-2, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
@@ -563,43 +557,13 @@ func TestSwitchingScenarios(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestToolCapabilityFiltering(t *testing.T) {
|
||||
t.Run("all models checked for tool capability", func(t *testing.T) {
|
||||
// Both cloud and local models are checked for tool capability via Show API
|
||||
// Only models with "tools" in capabilities are included
|
||||
m := modelInfo{Name: "tool-model", Remote: false, ToolCapable: true}
|
||||
if !m.ToolCapable {
|
||||
t.Error("tool capable model should be marked as such")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("modelInfo includes ToolCapable field", func(t *testing.T) {
|
||||
m := modelInfo{Name: "test", Remote: true, ToolCapable: true}
|
||||
if !m.ToolCapable {
|
||||
t.Error("ToolCapable field should be accessible")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsCloudModel_RequiresClient(t *testing.T) {
|
||||
t.Run("nil client always returns false", func(t *testing.T) {
|
||||
// isCloudModel now only uses Show API, no suffix detection
|
||||
if isCloudModel(context.Background(), nil, "model:cloud") {
|
||||
t.Error("nil client should return false regardless of suffix")
|
||||
}
|
||||
if isCloudModel(context.Background(), nil, "local-model") {
|
||||
t.Error("nil client should return false")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||
t.Run("saveAliases followed by saveIntegration keeps them in sync", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save aliases with one model
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model-a"}); err != nil {
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -608,7 +572,7 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
loaded, _ := LoadIntegration("claude")
|
||||
if loaded.Aliases["primary"] != loaded.Models[0] {
|
||||
t.Errorf("aliases.primary (%q) != models[0] (%q)", loaded.Aliases["primary"], loaded.Models[0])
|
||||
}
|
||||
@@ -622,11 +586,11 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||
if err := SaveIntegration("claude", []string{"old-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "new-model"}); err != nil {
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "new-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
loaded, _ := LoadIntegration("claude")
|
||||
|
||||
// They should be different (this is the bug state)
|
||||
if loaded.Models[0] == loaded.Aliases["primary"] {
|
||||
@@ -638,7 +602,7 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ = loadIntegration("claude")
|
||||
loaded, _ = LoadIntegration("claude")
|
||||
if loaded.Models[0] != loaded.Aliases["primary"] {
|
||||
t.Errorf("after fix: models[0] (%q) should equal aliases.primary (%q)",
|
||||
loaded.Models[0], loaded.Aliases["primary"])
|
||||
@@ -653,20 +617,20 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||
if err := SaveIntegration("claude", []string{"initial-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil {
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Update aliases AND models together
|
||||
newAliases := map[string]string{"primary": "updated-model"}
|
||||
if err := saveAliases("claude", newAliases); err != nil {
|
||||
if err := SaveAliases("claude", newAliases); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := SaveIntegration("claude", []string{newAliases["primary"]}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
loaded, _ := LoadIntegration("claude")
|
||||
if loaded.Models[0] != "updated-model" {
|
||||
t.Errorf("models[0] should be updated-model, got %q", loaded.Models[0])
|
||||
}
|
||||
|
||||
@@ -13,14 +13,6 @@ func setTestHome(t *testing.T, dir string) {
|
||||
t.Setenv("USERPROFILE", dir)
|
||||
}
|
||||
|
||||
// editorPaths is a test helper that safely calls Paths if the runner implements Editor
|
||||
func editorPaths(r Runner) []string {
|
||||
if editor, ok := r.(Editor); ok {
|
||||
return editor.Paths()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestIntegrationConfig(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
@@ -31,7 +23,7 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config, err := loadIntegration("claude")
|
||||
config, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -55,11 +47,11 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
"primary": "llama3.2:70b",
|
||||
"fast": "llama3.2:8b",
|
||||
}
|
||||
if err := saveAliases("claude", aliases); err != nil {
|
||||
if err := SaveAliases("claude", aliases); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config, err := loadIntegration("claude")
|
||||
config, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -77,14 +69,14 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
if err := SaveIntegration("claude", []string{"model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := SaveIntegration("claude", []string{"model-b"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
config, err := loadIntegration("claude")
|
||||
config, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -96,7 +88,7 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
t.Run("defaultModel returns first model", func(t *testing.T) {
|
||||
SaveIntegration("codex", []string{"model-a", "model-b"})
|
||||
|
||||
config, _ := loadIntegration("codex")
|
||||
config, _ := LoadIntegration("codex")
|
||||
defaultModel := ""
|
||||
if len(config.Models) > 0 {
|
||||
defaultModel = config.Models[0]
|
||||
@@ -120,7 +112,7 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
t.Run("app name is case-insensitive", func(t *testing.T) {
|
||||
SaveIntegration("Claude", []string{"model-x"})
|
||||
|
||||
config, err := loadIntegration("claude")
|
||||
config, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -137,8 +129,8 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
SaveIntegration("app1", []string{"model-1"})
|
||||
SaveIntegration("app2", []string{"model-2"})
|
||||
|
||||
config1, _ := loadIntegration("app1")
|
||||
config2, _ := loadIntegration("app2")
|
||||
config1, _ := LoadIntegration("app1")
|
||||
config2, _ := LoadIntegration("app2")
|
||||
|
||||
defaultModel1 := ""
|
||||
if len(config1.Models) > 0 {
|
||||
@@ -185,64 +177,6 @@ func TestListIntegrations(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestEditorPaths(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("returns empty for claude (no Editor)", func(t *testing.T) {
|
||||
r := integrations["claude"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 0 {
|
||||
t.Errorf("expected no paths for claude, got %v", paths)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns empty for codex (no Editor)", func(t *testing.T) {
|
||||
r := integrations["codex"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 0 {
|
||||
t.Errorf("expected no paths for codex, got %v", paths)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns empty for droid when no config exists", func(t *testing.T) {
|
||||
r := integrations["droid"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 0 {
|
||||
t.Errorf("expected no paths, got %v", paths)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns path for droid when config exists", func(t *testing.T) {
|
||||
settingsDir, _ := os.UserHomeDir()
|
||||
settingsDir = filepath.Join(settingsDir, ".factory")
|
||||
os.MkdirAll(settingsDir, 0o755)
|
||||
os.WriteFile(filepath.Join(settingsDir, "settings.json"), []byte(`{}`), 0o644)
|
||||
|
||||
r := integrations["droid"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 1 {
|
||||
t.Errorf("expected 1 path, got %d", len(paths))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns paths for opencode when configs exist", func(t *testing.T) {
|
||||
home, _ := os.UserHomeDir()
|
||||
configDir := filepath.Join(home, ".config", "opencode")
|
||||
stateDir := filepath.Join(home, ".local", "state", "opencode")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "opencode.json"), []byte(`{}`), 0o644)
|
||||
os.WriteFile(filepath.Join(stateDir, "model.json"), []byte(`{}`), 0o644)
|
||||
|
||||
r := integrations["opencode"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 2 {
|
||||
t.Errorf("expected 2 paths, got %d: %v", len(paths), paths)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoadIntegration_CorruptedJSON(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
@@ -251,7 +185,7 @@ func TestLoadIntegration_CorruptedJSON(t *testing.T) {
|
||||
os.MkdirAll(dir, 0o755)
|
||||
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644)
|
||||
|
||||
_, err := loadIntegration("test")
|
||||
_, err := LoadIntegration("test")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent integration in corrupted file")
|
||||
}
|
||||
@@ -265,7 +199,7 @@ func TestSaveIntegration_NilModels(t *testing.T) {
|
||||
t.Fatalf("saveIntegration with nil models failed: %v", err)
|
||||
}
|
||||
|
||||
config, err := loadIntegration("test")
|
||||
config, err := LoadIntegration("test")
|
||||
if err != nil {
|
||||
t.Fatalf("loadIntegration failed: %v", err)
|
||||
}
|
||||
@@ -294,7 +228,7 @@ func TestLoadIntegration_NonexistentIntegration(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
_, err := loadIntegration("nonexistent")
|
||||
_, err := LoadIntegration("nonexistent")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent integration, got nil")
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,59 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// ANSI escape sequences for terminal formatting.
|
||||
const (
|
||||
ansiBold = "\033[1m"
|
||||
ansiReset = "\033[0m"
|
||||
ansiGray = "\033[37m"
|
||||
ansiGreen = "\033[32m"
|
||||
ansiYellow = "\033[33m"
|
||||
)
|
||||
|
||||
// ErrCancelled is returned when the user cancels a selection.
|
||||
var ErrCancelled = errors.New("cancelled")
|
||||
|
||||
// errCancelled is kept as an alias for backward compatibility within the package.
|
||||
var errCancelled = ErrCancelled
|
||||
|
||||
// DefaultConfirmPrompt provides a TUI-based confirmation prompt.
|
||||
// When set, confirmPrompt delegates to it instead of using raw terminal I/O.
|
||||
var DefaultConfirmPrompt func(prompt string) (bool, error)
|
||||
|
||||
func confirmPrompt(prompt string) (bool, error) {
|
||||
if DefaultConfirmPrompt != nil {
|
||||
return DefaultConfirmPrompt(prompt)
|
||||
}
|
||||
|
||||
fd := int(os.Stdin.Fd())
|
||||
oldState, err := term.MakeRaw(fd)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer term.Restore(fd, oldState)
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%s (\033[1my\033[0m/n) ", prompt)
|
||||
|
||||
buf := make([]byte, 1)
|
||||
for {
|
||||
if _, err := os.Stdin.Read(buf); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
switch buf[0] {
|
||||
case 'Y', 'y', 13:
|
||||
fmt.Fprintf(os.Stderr, "yes\r\n")
|
||||
return true, nil
|
||||
case 'N', 'n', 27, 3:
|
||||
fmt.Fprintf(os.Stderr, "no\r\n")
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
15
cmd/config/test_helpers_test.go
Normal file
15
cmd/config/test_helpers_test.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func writeFakeBinary(t *testing.T, dir, name string) {
|
||||
t.Helper()
|
||||
path := filepath.Join(dir, name)
|
||||
if err := os.WriteFile(path, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
|
||||
t.Fatalf("failed to write fake binary: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"runtime"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
@@ -75,7 +76,7 @@ func (c *Claude) Run(model string, args []string) error {
|
||||
func (c *Claude) modelEnvVars(model string) []string {
|
||||
primary := model
|
||||
fast := model
|
||||
if cfg, err := loadIntegration("claude"); err == nil && cfg.Aliases != nil {
|
||||
if cfg, err := config.LoadIntegration("claude"); err == nil && cfg.Aliases != nil {
|
||||
if p := cfg.Aliases["primary"]; p != "" {
|
||||
primary = p
|
||||
}
|
||||
@@ -107,15 +108,12 @@ func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAli
|
||||
}
|
||||
|
||||
if !force && aliases["primary"] != "" {
|
||||
client, _ := api.ClientFromEnvironment()
|
||||
if isCloudModel(ctx, client, aliases["primary"]) {
|
||||
if isCloudModel(ctx, client, aliases["fast"]) {
|
||||
return aliases, false, nil
|
||||
}
|
||||
} else {
|
||||
delete(aliases, "fast")
|
||||
if IsCloudModelName(aliases["primary"]) {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
return aliases, false, nil
|
||||
}
|
||||
delete(aliases, "fast")
|
||||
return aliases, false, nil
|
||||
}
|
||||
|
||||
items, existingModels, cloudModels, client, err := listModels(ctx)
|
||||
@@ -126,6 +124,9 @@ func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAli
|
||||
fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n\n", ansiBold, ansiReset)
|
||||
|
||||
if aliases["primary"] == "" || force {
|
||||
if DefaultSingleSelector == nil {
|
||||
return nil, false, fmt.Errorf("no selector configured")
|
||||
}
|
||||
primary, err := DefaultSingleSelector("Select model:", items, aliases["primary"])
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
@@ -133,16 +134,14 @@ func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAli
|
||||
if err := pullIfNeeded(ctx, client, existingModels, primary); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if err := ensureAuth(ctx, client, cloudModels, []string{primary}); err != nil {
|
||||
if err := EnsureAuth(ctx, client, cloudModels, []string{primary}); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
aliases["primary"] = primary
|
||||
}
|
||||
|
||||
if isCloudModel(ctx, client, aliases["primary"]) {
|
||||
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
}
|
||||
if IsCloudModelName(aliases["primary"]) {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
} else {
|
||||
delete(aliases, "fast")
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"os"
|
||||
@@ -141,7 +141,7 @@ func TestClaudeModelEnvVars(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
SaveIntegration("claude", []string{"qwen3:8b"})
|
||||
saveAliases("claude", map[string]string{"primary": "qwen3:8b"})
|
||||
SaveAliases("claude", map[string]string{"primary": "qwen3:8b"})
|
||||
|
||||
got := envMap(c.modelEnvVars("qwen3:8b"))
|
||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "qwen3:8b" {
|
||||
@@ -163,7 +163,7 @@ func TestClaudeModelEnvVars(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
SaveIntegration("claude", []string{"llama3.2:70b"})
|
||||
saveAliases("claude", map[string]string{
|
||||
SaveAliases("claude", map[string]string{
|
||||
"primary": "llama3.2:70b",
|
||||
"fast": "llama3.2:8b",
|
||||
})
|
||||
@@ -188,7 +188,7 @@ func TestClaudeModelEnvVars(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
SaveIntegration("claude", []string{"saved-model"})
|
||||
saveAliases("claude", map[string]string{"primary": "saved-model"})
|
||||
SaveAliases("claude", map[string]string{"primary": "saved-model"})
|
||||
|
||||
got := envMap(c.modelEnvVars("different-model"))
|
||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "saved-model" {
|
||||
@@ -1,9 +1,7 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
@@ -22,24 +20,6 @@ func (c *Cline) Run(model string, args []string) error {
|
||||
return fmt.Errorf("cline is not installed, install with: npm install -g cline")
|
||||
}
|
||||
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("cline"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
var err error
|
||||
models, err = resolveEditorModels("cline", models, func() ([]string, error) {
|
||||
return selectModels(context.Background(), "cline", "")
|
||||
})
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("cline", args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
@@ -1,4 +1,4 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -1,4 +1,4 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -1,4 +1,4 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"slices"
|
||||
298
cmd/launch/command_test.go
Normal file
298
cmd/launch/command_test.go
Normal file
@@ -0,0 +1,298 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func TestLaunchCmd(t *testing.T) {
|
||||
mockCheck := func(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
mockTUI := func(cmd *cobra.Command, inv LauncherInvocation) {}
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
|
||||
t.Run("command structure", func(t *testing.T) {
|
||||
if cmd.Use != "launch [INTEGRATION] [-- [EXTRA_ARGS...]]" {
|
||||
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION] [-- [EXTRA_ARGS...]]")
|
||||
}
|
||||
if cmd.Short == "" {
|
||||
t.Error("Short description should not be empty")
|
||||
}
|
||||
if cmd.Long == "" {
|
||||
t.Error("Long description should not be empty")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("flags exist", func(t *testing.T) {
|
||||
if cmd.Flags().Lookup("model") == nil {
|
||||
t.Error("--model flag should exist")
|
||||
}
|
||||
if cmd.Flags().Lookup("config") == nil {
|
||||
t.Error("--config flag should exist")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PreRunE is set", func(t *testing.T) {
|
||||
if cmd.PreRunE == nil {
|
||||
t.Error("PreRunE should be set to checkServerHeartbeat")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLaunchCmdTUICallback(t *testing.T) {
|
||||
mockCheck := func(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
t.Run("no args calls TUI", func(t *testing.T) {
|
||||
tuiCalled := false
|
||||
var gotInv LauncherInvocation
|
||||
mockTUI := func(cmd *cobra.Command, inv LauncherInvocation) {
|
||||
tuiCalled = true
|
||||
gotInv = inv
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{})
|
||||
_ = cmd.Execute()
|
||||
|
||||
if !tuiCalled {
|
||||
t.Error("TUI callback should be called when no args provided")
|
||||
}
|
||||
if diff := cmp.Diff(LauncherInvocation{}, gotInv); diff != "" {
|
||||
t.Fatalf("launcher invocation mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("integration arg bypasses TUI", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.NotFoundHandler())
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command, inv LauncherInvocation) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{"claude"})
|
||||
_ = cmd.Execute()
|
||||
|
||||
if tuiCalled {
|
||||
t.Error("TUI callback should NOT be called when integration arg provided")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("--model flag opens TUI with invocation", func(t *testing.T) {
|
||||
tuiCalled := false
|
||||
var gotInv LauncherInvocation
|
||||
mockTUI := func(cmd *cobra.Command, inv LauncherInvocation) {
|
||||
tuiCalled = true
|
||||
gotInv = inv
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{"--model", "test-model"})
|
||||
_ = cmd.Execute()
|
||||
|
||||
if !tuiCalled {
|
||||
t.Error("TUI callback should be called when --model flag provided without an integration")
|
||||
}
|
||||
want := LauncherInvocation{ModelOverride: "test-model"}
|
||||
if diff := cmp.Diff(want, gotInv); diff != "" {
|
||||
t.Fatalf("launcher invocation mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("--config flag bypasses TUI", func(t *testing.T) {
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command, inv LauncherInvocation) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{"--config"})
|
||||
_ = cmd.Execute()
|
||||
|
||||
if tuiCalled {
|
||||
t.Error("TUI callback should NOT be called when --config flag provided")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("--model flag forwards extra args through TUI invocation", func(t *testing.T) {
|
||||
tuiCalled := false
|
||||
var gotInv LauncherInvocation
|
||||
mockTUI := func(cmd *cobra.Command, inv LauncherInvocation) {
|
||||
tuiCalled = true
|
||||
gotInv = inv
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{"--model", "test-model", "--", "--sandbox", "workspace-write"})
|
||||
_ = cmd.Execute()
|
||||
|
||||
if !tuiCalled {
|
||||
t.Error("TUI callback should be called when --model flag is provided without an integration")
|
||||
}
|
||||
want := LauncherInvocation{
|
||||
ModelOverride: "test-model",
|
||||
ExtraArgs: []string{"--sandbox", "workspace-write"},
|
||||
}
|
||||
if diff := cmp.Diff(want, gotInv); diff != "" {
|
||||
t.Fatalf("launcher invocation mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLaunchCmdNilHeartbeat(t *testing.T) {
|
||||
cmd := LaunchCmd(nil, nil)
|
||||
if cmd == nil {
|
||||
t.Fatal("LaunchCmd returned nil")
|
||||
}
|
||||
if cmd.PreRunE != nil {
|
||||
t.Log("Note: PreRunE is set even when nil is passed (acceptable)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchIntegrationByNameUnknownIntegration(t *testing.T) {
|
||||
err := LaunchIntegrationByName("nonexistent-integration")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown integration")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "unknown integration") {
|
||||
t.Errorf("error should mention 'unknown integration', got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchIntegrationByNameNotConfigured(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
|
||||
err := LaunchIntegrationByName("claude")
|
||||
if err == nil {
|
||||
t.Fatal("expected error when integration is not configured")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "no selector configured") {
|
||||
t.Errorf("error should mention missing selector, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAndEditIntegrationUnknownIntegration(t *testing.T) {
|
||||
err := SaveAndEditIntegration("nonexistent", []string{"model"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown integration")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "unknown integration") {
|
||||
t.Errorf("error should mention 'unknown integration', got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchCmdModelFlagFiltersDisabledCloudFromSavedConfig(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
|
||||
if err := config.SaveIntegration("stubeditor", []string{"glm-5:cloud"}); err != nil {
|
||||
t.Fatalf("failed to seed saved config: %v", err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/status":
|
||||
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
|
||||
case "/api/show":
|
||||
fmt.Fprintf(w, `{"model":"llama3.2"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
stub := &launcherEditorRunner{}
|
||||
restore := OverrideIntegration("stubeditor", stub)
|
||||
defer restore()
|
||||
|
||||
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command, inv LauncherInvocation) {})
|
||||
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2"})
|
||||
if err := cmd.Execute(); err != nil {
|
||||
t.Fatalf("launch command failed: %v", err)
|
||||
}
|
||||
|
||||
saved, err := config.LoadIntegration("stubeditor")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to reload integration config: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
|
||||
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff([][]string{{"llama3.2"}}, stub.edited); diff != "" {
|
||||
t.Fatalf("editor models mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
if stub.ranModel != "llama3.2" {
|
||||
t.Fatalf("expected launch to run with llama3.2, got %q", stub.ranModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchCmdIntegrationArgPromptsForModelWithSavedSelection(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
|
||||
if err := config.SaveIntegration("stubapp", []string{"llama3.2"}); err != nil {
|
||||
t.Fatalf("failed to seed saved config: %v", err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"llama3.2"},{"name":"qwen3:8b"}]}`)
|
||||
case "/api/show":
|
||||
fmt.Fprint(w, `{"model":"qwen3:8b"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
stub := &launcherSingleRunner{}
|
||||
restore := OverrideIntegration("stubapp", stub)
|
||||
defer restore()
|
||||
|
||||
oldSelector := DefaultSingleSelector
|
||||
defer func() { DefaultSingleSelector = oldSelector }()
|
||||
|
||||
var gotCurrent string
|
||||
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||
gotCurrent = current
|
||||
return "qwen3:8b", nil
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command, inv LauncherInvocation) {})
|
||||
cmd.SetArgs([]string{"stubapp"})
|
||||
if err := cmd.Execute(); err != nil {
|
||||
t.Fatalf("launch command failed: %v", err)
|
||||
}
|
||||
|
||||
if gotCurrent != "llama3.2" {
|
||||
t.Fatalf("expected selector current model to be saved model llama3.2, got %q", gotCurrent)
|
||||
}
|
||||
if stub.ranModel != "qwen3:8b" {
|
||||
t.Fatalf("expected launch to run selected model qwen3:8b, got %q", stub.ranModel)
|
||||
}
|
||||
|
||||
saved, err := config.LoadIntegration("stubapp")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to reload integration config: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff([]string{"qwen3:8b"}, saved.Models); diff != "" {
|
||||
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
@@ -1,16 +1,13 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
@@ -47,25 +44,6 @@ func (d *Droid) Run(model string, args []string) error {
|
||||
return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart")
|
||||
}
|
||||
|
||||
// Call Edit() to ensure config is up-to-date before launch
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("droid"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
var err error
|
||||
models, err = resolveEditorModels("droid", models, func() ([]string, error) {
|
||||
return selectModels(context.Background(), "droid", "")
|
||||
})
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := d.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("droid", args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
@@ -125,13 +103,12 @@ func (d *Droid) Edit(models []string) error {
|
||||
}
|
||||
|
||||
// Build new Ollama model entries with sequential indices (0, 1, 2, ...)
|
||||
client, _ := api.ClientFromEnvironment()
|
||||
|
||||
var newModels []any
|
||||
var defaultModelID string
|
||||
for i, model := range models {
|
||||
maxOutput := 64000
|
||||
if isCloudModel(context.Background(), client, model) {
|
||||
if IsCloudModelName(model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
maxOutput = l.Output
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
98
cmd/launch/files.go
Normal file
98
cmd/launch/files.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
func readJSONFile(path string) (map[string]any, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func copyFile(src, dst string) error {
|
||||
info, err := os.Stat(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := os.ReadFile(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(dst, data, info.Mode().Perm())
|
||||
}
|
||||
|
||||
func backupDir() string {
|
||||
return filepath.Join(os.TempDir(), "ollama-backups")
|
||||
}
|
||||
|
||||
func backupToTmp(srcPath string) (string, error) {
|
||||
dir := backupDir()
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
backupPath := filepath.Join(dir, fmt.Sprintf("%s.%d", filepath.Base(srcPath), time.Now().Unix()))
|
||||
if err := copyFile(srcPath, backupPath); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return backupPath, nil
|
||||
}
|
||||
|
||||
// writeWithBackup writes data to path via temp file + rename, backing up any existing file first.
|
||||
func writeWithBackup(path string, data []byte) error {
|
||||
var backupPath string
|
||||
if existingContent, err := os.ReadFile(path); err == nil {
|
||||
if !bytes.Equal(existingContent, data) {
|
||||
backupPath, err = backupToTmp(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("backup failed: %w", err)
|
||||
}
|
||||
}
|
||||
} else if !os.IsNotExist(err) {
|
||||
return fmt.Errorf("read existing file: %w", err)
|
||||
}
|
||||
|
||||
dir := filepath.Dir(path)
|
||||
tmp, err := os.CreateTemp(dir, ".tmp-*")
|
||||
if err != nil {
|
||||
return fmt.Errorf("create temp failed: %w", err)
|
||||
}
|
||||
tmpPath := tmp.Name()
|
||||
|
||||
if _, err := tmp.Write(data); err != nil {
|
||||
_ = tmp.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
return fmt.Errorf("write failed: %w", err)
|
||||
}
|
||||
if err := tmp.Sync(); err != nil {
|
||||
_ = tmp.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
return fmt.Errorf("sync failed: %w", err)
|
||||
}
|
||||
if err := tmp.Close(); err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
return fmt.Errorf("close failed: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Rename(tmpPath, path); err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
if backupPath != "" {
|
||||
_ = copyFile(backupPath, path)
|
||||
}
|
||||
return fmt.Errorf("rename failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,8 +1,9 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -13,14 +14,25 @@ import (
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
type stubEditorRunner struct {
|
||||
edited [][]string
|
||||
ranModel string
|
||||
editErr error
|
||||
}
|
||||
|
||||
type stubRunner struct {
|
||||
ranModel string
|
||||
}
|
||||
|
||||
func (s *stubRunner) Run(model string, args []string) error {
|
||||
s.ranModel = model
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubRunner) String() string { return "StubRunner" }
|
||||
|
||||
func (s *stubEditorRunner) Run(model string, args []string) error {
|
||||
s.ranModel = model
|
||||
return nil
|
||||
@@ -31,6 +43,9 @@ func (s *stubEditorRunner) String() string { return "StubEditor" }
|
||||
func (s *stubEditorRunner) Paths() []string { return nil }
|
||||
|
||||
func (s *stubEditorRunner) Edit(models []string) error {
|
||||
if s.editErr != nil {
|
||||
return s.editErr
|
||||
}
|
||||
cloned := append([]string(nil), models...)
|
||||
s.edited = append(s.edited, cloned)
|
||||
return nil
|
||||
@@ -111,120 +126,8 @@ func TestHasLocalModel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchCmd(t *testing.T) {
|
||||
// Mock checkServerHeartbeat that always succeeds
|
||||
mockCheck := func(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
mockTUI := func(cmd *cobra.Command) {}
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
|
||||
t.Run("command structure", func(t *testing.T) {
|
||||
if cmd.Use != "launch [INTEGRATION] [-- [EXTRA_ARGS...]]" {
|
||||
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION] [-- [EXTRA_ARGS...]]")
|
||||
}
|
||||
if cmd.Short == "" {
|
||||
t.Error("Short description should not be empty")
|
||||
}
|
||||
if cmd.Long == "" {
|
||||
t.Error("Long description should not be empty")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("flags exist", func(t *testing.T) {
|
||||
modelFlag := cmd.Flags().Lookup("model")
|
||||
if modelFlag == nil {
|
||||
t.Error("--model flag should exist")
|
||||
}
|
||||
|
||||
configFlag := cmd.Flags().Lookup("config")
|
||||
if configFlag == nil {
|
||||
t.Error("--config flag should exist")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PreRunE is set", func(t *testing.T) {
|
||||
if cmd.PreRunE == nil {
|
||||
t.Error("PreRunE should be set to checkServerHeartbeat")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLaunchCmd_TUICallback(t *testing.T) {
|
||||
mockCheck := func(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
t.Run("no args calls TUI", func(t *testing.T) {
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{})
|
||||
_ = cmd.Execute()
|
||||
|
||||
if !tuiCalled {
|
||||
t.Error("TUI callback should be called when no args provided")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("integration arg bypasses TUI", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.NotFoundHandler())
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{"claude"})
|
||||
// Will error because claude isn't configured, but that's OK
|
||||
_ = cmd.Execute()
|
||||
|
||||
if tuiCalled {
|
||||
t.Error("TUI callback should NOT be called when integration arg provided")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("--model flag bypasses TUI", func(t *testing.T) {
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{"--model", "test-model"})
|
||||
// Will error because no integration specified, but that's OK
|
||||
_ = cmd.Execute()
|
||||
|
||||
if tuiCalled {
|
||||
t.Error("TUI callback should NOT be called when --model flag provided")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("--config flag bypasses TUI", func(t *testing.T) {
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{"--config"})
|
||||
// Will error because no integration specified, but that's OK
|
||||
_ = cmd.Execute()
|
||||
|
||||
if tuiCalled {
|
||||
t.Error("TUI callback should NOT be called when --config flag provided")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRunIntegration_UnknownIntegration(t *testing.T) {
|
||||
err := runIntegration("unknown-integration", "model", nil)
|
||||
err := RunIntegration("unknown-integration", "model", nil)
|
||||
if err == nil {
|
||||
t.Error("expected error for unknown integration, got nil")
|
||||
}
|
||||
@@ -261,19 +164,6 @@ func TestHasLocalModel_DocumentsHeuristic(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchCmd_NilHeartbeat(t *testing.T) {
|
||||
// This should not panic - cmd creation should work even with nil
|
||||
cmd := LaunchCmd(nil, nil)
|
||||
if cmd == nil {
|
||||
t.Fatal("LaunchCmd returned nil")
|
||||
}
|
||||
|
||||
// PreRunE should be nil when passed nil
|
||||
if cmd.PreRunE != nil {
|
||||
t.Log("Note: PreRunE is set even when nil is passed (acceptable)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllIntegrations_HaveRequiredMethods(t *testing.T) {
|
||||
for name, r := range integrations {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
@@ -418,7 +308,7 @@ func names(items []ModelItem) []string {
|
||||
}
|
||||
|
||||
func TestBuildModelList_NoExistingModels(t *testing.T) {
|
||||
items, _, _, _ := buildModelList(nil, nil, "")
|
||||
items, _, _, _ := BuildModelList(nil, nil, "")
|
||||
|
||||
want := []string{"minimax-m2.5:cloud", "glm-5:cloud", "kimi-k2.5:cloud", "glm-4.7-flash", "qwen3:8b"}
|
||||
if diff := cmp.Diff(want, names(items)); diff != "" {
|
||||
@@ -426,8 +316,14 @@ func TestBuildModelList_NoExistingModels(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, item := range items {
|
||||
if !strings.HasSuffix(item.Description, "(not downloaded)") {
|
||||
t.Errorf("item %q should have description ending with '(not downloaded)', got %q", item.Name, item.Description)
|
||||
if strings.HasSuffix(item.Name, ":cloud") {
|
||||
if strings.HasSuffix(item.Description, "(not downloaded)") {
|
||||
t.Errorf("cloud model %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description)
|
||||
}
|
||||
} else {
|
||||
if !strings.HasSuffix(item.Description, "(not downloaded)") {
|
||||
t.Errorf("item %q should have description ending with '(not downloaded)', got %q", item.Name, item.Description)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -438,7 +334,7 @@ func TestBuildModelList_OnlyLocalModels_CloudRecsAtBottom(t *testing.T) {
|
||||
{Name: "qwen2.5:latest", Remote: false},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
items, _, _, _ := BuildModelList(existing, nil, "")
|
||||
got := names(items)
|
||||
|
||||
// Recommended pinned at top (local recs first, then cloud recs when only-local), then installed non-recs
|
||||
@@ -454,7 +350,7 @@ func TestBuildModelList_BothCloudAndLocal_RegularSort(t *testing.T) {
|
||||
{Name: "glm-5:cloud", Remote: true},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
items, _, _, _ := BuildModelList(existing, nil, "")
|
||||
got := names(items)
|
||||
|
||||
// All recs pinned at top (cloud before local in mixed case), then non-recs
|
||||
@@ -470,7 +366,7 @@ func TestBuildModelList_PreCheckedFirst(t *testing.T) {
|
||||
{Name: "glm-5:cloud", Remote: true},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, []string{"llama3.2"}, "")
|
||||
items, _, _, _ := BuildModelList(existing, []string{"llama3.2"}, "")
|
||||
got := names(items)
|
||||
|
||||
if got[0] != "llama3.2" {
|
||||
@@ -484,7 +380,7 @@ func TestBuildModelList_ExistingRecommendedMarked(t *testing.T) {
|
||||
{Name: "glm-5:cloud", Remote: true},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
items, _, _, _ := BuildModelList(existing, nil, "")
|
||||
|
||||
for _, item := range items {
|
||||
switch item.Name {
|
||||
@@ -492,10 +388,14 @@ func TestBuildModelList_ExistingRecommendedMarked(t *testing.T) {
|
||||
if strings.HasSuffix(item.Description, "(not downloaded)") {
|
||||
t.Errorf("installed recommended %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description)
|
||||
}
|
||||
case "minimax-m2.5:cloud", "kimi-k2.5:cloud", "qwen3:8b":
|
||||
case "qwen3:8b":
|
||||
if !strings.HasSuffix(item.Description, "(not downloaded)") {
|
||||
t.Errorf("non-installed recommended %q should have '(not downloaded)' suffix, got %q", item.Name, item.Description)
|
||||
}
|
||||
case "minimax-m2.5:cloud", "kimi-k2.5:cloud":
|
||||
if strings.HasSuffix(item.Description, "(not downloaded)") {
|
||||
t.Errorf("cloud model %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -506,7 +406,7 @@ func TestBuildModelList_ExistingCloudModelsNotPushedToBottom(t *testing.T) {
|
||||
{Name: "glm-5:cloud", Remote: true},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
items, _, _, _ := BuildModelList(existing, nil, "")
|
||||
got := names(items)
|
||||
|
||||
// glm-4.7-flash and glm-5:cloud are installed so they sort normally;
|
||||
@@ -524,7 +424,7 @@ func TestBuildModelList_HasRecommendedCloudModel_OnlyNonInstalledAtBottom(t *tes
|
||||
{Name: "kimi-k2.5:cloud", Remote: true},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
items, _, _, _ := BuildModelList(existing, nil, "")
|
||||
got := names(items)
|
||||
|
||||
// kimi-k2.5:cloud is installed so it sorts normally;
|
||||
@@ -536,7 +436,13 @@ func TestBuildModelList_HasRecommendedCloudModel_OnlyNonInstalledAtBottom(t *tes
|
||||
}
|
||||
|
||||
for _, item := range items {
|
||||
if !slices.Contains([]string{"kimi-k2.5:cloud", "llama3.2"}, item.Name) {
|
||||
isCloud := strings.HasSuffix(item.Name, ":cloud")
|
||||
isInstalled := slices.Contains([]string{"kimi-k2.5:cloud", "llama3.2"}, item.Name)
|
||||
if isInstalled || isCloud {
|
||||
if strings.HasSuffix(item.Description, "(not downloaded)") {
|
||||
t.Errorf("installed or cloud model %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description)
|
||||
}
|
||||
} else {
|
||||
if !strings.HasSuffix(item.Description, "(not downloaded)") {
|
||||
t.Errorf("non-installed %q should have '(not downloaded)' suffix, got %q", item.Name, item.Description)
|
||||
}
|
||||
@@ -550,7 +456,7 @@ func TestBuildModelList_LatestTagStripped(t *testing.T) {
|
||||
{Name: "llama3.2:latest", Remote: false},
|
||||
}
|
||||
|
||||
items, _, existingModels, _ := buildModelList(existing, nil, "")
|
||||
items, _, existingModels, _ := BuildModelList(existing, nil, "")
|
||||
got := names(items)
|
||||
|
||||
// :latest should be stripped from display names
|
||||
@@ -583,7 +489,7 @@ func TestBuildModelList_ReturnsExistingAndCloudMaps(t *testing.T) {
|
||||
{Name: "glm-5:cloud", Remote: true},
|
||||
}
|
||||
|
||||
_, _, existingModels, cloudModels := buildModelList(existing, nil, "")
|
||||
_, _, existingModels, cloudModels := BuildModelList(existing, nil, "")
|
||||
|
||||
if !existingModels["llama3.2"] {
|
||||
t.Error("llama3.2 should be in existingModels")
|
||||
@@ -612,7 +518,7 @@ func TestBuildModelList_RecommendedFieldSet(t *testing.T) {
|
||||
{Name: "llama3.2:latest", Remote: false},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
items, _, _, _ := BuildModelList(existing, nil, "")
|
||||
|
||||
for _, item := range items {
|
||||
switch item.Name {
|
||||
@@ -634,7 +540,7 @@ func TestBuildModelList_MixedCase_CloudRecsFirst(t *testing.T) {
|
||||
{Name: "glm-5:cloud", Remote: true},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
items, _, _, _ := BuildModelList(existing, nil, "")
|
||||
got := names(items)
|
||||
|
||||
// Cloud recs should sort before local recs in mixed case
|
||||
@@ -650,7 +556,7 @@ func TestBuildModelList_OnlyLocal_LocalRecsFirst(t *testing.T) {
|
||||
{Name: "llama3.2:latest", Remote: false},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
items, _, _, _ := BuildModelList(existing, nil, "")
|
||||
got := names(items)
|
||||
|
||||
// Local recs should sort before cloud recs in only-local case
|
||||
@@ -667,7 +573,7 @@ func TestBuildModelList_RecsAboveNonRecs(t *testing.T) {
|
||||
{Name: "custom-model", Remote: false},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
items, _, _, _ := BuildModelList(existing, nil, "")
|
||||
got := names(items)
|
||||
|
||||
// All recommended models should appear before non-recommended installed models
|
||||
@@ -693,7 +599,7 @@ func TestBuildModelList_CheckedBeforeRecs(t *testing.T) {
|
||||
{Name: "glm-5:cloud", Remote: true},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, []string{"llama3.2"}, "")
|
||||
items, _, _, _ := BuildModelList(existing, []string{"llama3.2"}, "")
|
||||
got := names(items)
|
||||
|
||||
if got[0] != "llama3.2" {
|
||||
@@ -711,7 +617,7 @@ func TestEditorIntegration_SavedConfigSkipsSelection(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify loadIntegration returns the saved models
|
||||
saved, err := loadIntegration("opencode")
|
||||
saved, err := LoadIntegration("opencode")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -753,7 +659,7 @@ func TestResolveEditorLaunchModels_PicksWhenAllFiltered(t *testing.T) {
|
||||
t.Fatalf("resolved models mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
saved, err := loadIntegration("opencode")
|
||||
saved, err := LoadIntegration("opencode")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to reload integration config: %v", err)
|
||||
}
|
||||
@@ -792,7 +698,7 @@ func TestResolveEditorLaunchModels_FiltersAndSkipsPickerWhenLocalRemains(t *test
|
||||
t.Fatalf("resolved models mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
saved, err := loadIntegration("droid")
|
||||
saved, err := LoadIntegration("droid")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to reload integration config: %v", err)
|
||||
}
|
||||
@@ -801,57 +707,27 @@ func TestResolveEditorLaunchModels_FiltersAndSkipsPickerWhenLocalRemains(t *test
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchCmd_ModelFlagFiltersDisabledCloudFromSavedConfig(t *testing.T) {
|
||||
func TestPrepareEditorIntegration_SavesOnlyAfterSuccessfulEdit(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
if err := SaveIntegration("stubeditor", []string{"glm-5:cloud"}); err != nil {
|
||||
t.Fatalf("failed to seed saved config: %v", err)
|
||||
if err := SaveIntegration("droid", []string{"existing-model"}); err != nil {
|
||||
t.Fatalf("failed to seed config: %v", err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/status":
|
||||
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
|
||||
case "/api/show":
|
||||
fmt.Fprintf(w, `{"model":"llama3.2"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
stub := &stubEditorRunner{}
|
||||
old, existed := integrations["stubeditor"]
|
||||
integrations["stubeditor"] = stub
|
||||
defer func() {
|
||||
if existed {
|
||||
integrations["stubeditor"] = old
|
||||
} else {
|
||||
delete(integrations, "stubeditor")
|
||||
}
|
||||
}()
|
||||
|
||||
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2"})
|
||||
if err := cmd.Execute(); err != nil {
|
||||
t.Fatalf("launch command failed: %v", err)
|
||||
editor := &stubEditorRunner{editErr: errors.New("boom")}
|
||||
err := PrepareEditorIntegration("droid", editor, editor, []string{"new-model"})
|
||||
if err == nil || !strings.Contains(err.Error(), "setup failed") {
|
||||
t.Fatalf("expected setup failure, got %v", err)
|
||||
}
|
||||
|
||||
saved, err := loadIntegration("stubeditor")
|
||||
saved, err := LoadIntegration("droid")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to reload integration config: %v", err)
|
||||
t.Fatalf("failed to reload saved config: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
|
||||
if diff := cmp.Diff([]string{"existing-model"}, saved.Models); diff != "" {
|
||||
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff([][]string{{"llama3.2"}}, stub.edited); diff != "" {
|
||||
t.Fatalf("editor models mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
if stub.ranModel != "llama3.2" {
|
||||
t.Fatalf("expected launch to run with llama3.2, got %q", stub.ranModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAliasConfigurerInterface(t *testing.T) {
|
||||
@@ -890,6 +766,143 @@ func TestShowOrPull_ModelExists(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowOrPullWithPolicy_ModelExists(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"model":"test-model"}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
err := ShowOrPullWithPolicy(context.Background(), client, "test-model", MissingModelFail)
|
||||
if err != nil {
|
||||
t.Errorf("ShowOrPullWithPolicy should return nil when model exists, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowOrPullWithPolicy_ModelNotFound_FailDoesNotPromptOrPull(t *testing.T) {
|
||||
oldHook := DefaultConfirmPrompt
|
||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
t.Fatal("confirm prompt should not be called with fail policy")
|
||||
return false, nil
|
||||
}
|
||||
defer func() { DefaultConfirmPrompt = oldHook }()
|
||||
|
||||
var pullCalled bool
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/show":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||
case "/api/pull":
|
||||
pullCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"status":"success"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
err := ShowOrPullWithPolicy(context.Background(), client, "missing-model", MissingModelFail)
|
||||
if err == nil {
|
||||
t.Fatal("expected fail policy to return an error for missing model")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "ollama pull missing-model") {
|
||||
t.Fatalf("expected actionable pull guidance, got: %v", err)
|
||||
}
|
||||
if pullCalled {
|
||||
t.Fatal("expected pull not to be called with fail policy")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowOrPullWithPolicy_ModelNotFound_PromptPolicyPulls(t *testing.T) {
|
||||
oldHook := DefaultConfirmPrompt
|
||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
if !strings.Contains(prompt, "missing-model") {
|
||||
t.Fatalf("expected prompt to mention missing model, got %q", prompt)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
defer func() { DefaultConfirmPrompt = oldHook }()
|
||||
|
||||
var pullCalled bool
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/show":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||
case "/api/pull":
|
||||
pullCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"status":"success"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
err := ShowOrPullWithPolicy(context.Background(), client, "missing-model", MissingModelPromptPull)
|
||||
if err != nil {
|
||||
t.Fatalf("expected prompt policy to pull and succeed, got %v", err)
|
||||
}
|
||||
if !pullCalled {
|
||||
t.Fatal("expected pull to be called with prompt policy")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowOrPullWithPolicy_CloudModelSkipsPullForAllPolicies(t *testing.T) {
|
||||
oldHook := DefaultConfirmPrompt
|
||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
t.Fatal("confirm prompt should not be called for explicit cloud models")
|
||||
return false, nil
|
||||
}
|
||||
defer func() { DefaultConfirmPrompt = oldHook }()
|
||||
|
||||
for _, policy := range []MissingModelPolicy{MissingModelPromptPull, MissingModelFail} {
|
||||
t.Run(fmt.Sprintf("policy=%d", policy), func(t *testing.T) {
|
||||
var pullCalled bool
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/show":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||
case "/api/pull":
|
||||
pullCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"status":"success"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
err := ShowOrPullWithPolicy(context.Background(), client, "glm-5:cloud", policy)
|
||||
if err != nil {
|
||||
t.Fatalf("expected cloud model to bypass pull for policy %d, got %v", policy, err)
|
||||
}
|
||||
if pullCalled {
|
||||
t.Fatalf("expected pull not to be called for cloud model with policy %d", policy)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowOrPull_ModelNotFound_NoTerminal(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
@@ -1000,8 +1013,8 @@ func TestShowOrPull_ModelNotFound_ConfirmNo_Cancelled(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowOrPull_CloudModel_SkipsConfirmation(t *testing.T) {
|
||||
// Confirm prompt should NOT be called for cloud models
|
||||
func TestShowOrPull_CloudModel_DoesNotPull(t *testing.T) {
|
||||
// Confirm prompt should NOT be called for explicit cloud models
|
||||
oldHook := DefaultConfirmPrompt
|
||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
t.Error("confirm prompt should not be called for cloud models")
|
||||
@@ -1032,8 +1045,115 @@ func TestShowOrPull_CloudModel_SkipsConfirmation(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Errorf("ShowOrPull should succeed for cloud model, got: %v", err)
|
||||
}
|
||||
if !pullCalled {
|
||||
t.Error("expected pull to be called for cloud model without confirmation")
|
||||
if pullCalled {
|
||||
t.Error("expected pull not to be called for cloud model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowOrPull_CloudLegacySuffix_DoesNotPull(t *testing.T) {
|
||||
// Confirm prompt should NOT be called for explicit cloud models
|
||||
oldHook := DefaultConfirmPrompt
|
||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
t.Error("confirm prompt should not be called for cloud models")
|
||||
return false, nil
|
||||
}
|
||||
defer func() { DefaultConfirmPrompt = oldHook }()
|
||||
|
||||
var pullCalled bool
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/show":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||
case "/api/pull":
|
||||
pullCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"status":"success"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
err := ShowOrPull(context.Background(), client, "gpt-oss:20b-cloud")
|
||||
if err != nil {
|
||||
t.Errorf("ShowOrPull should succeed for cloud model, got: %v", err)
|
||||
}
|
||||
if pullCalled {
|
||||
t.Error("expected pull not to be called for cloud model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPullIfNeeded_CloudModel_DoesNotPull(t *testing.T) {
|
||||
oldHook := DefaultConfirmPrompt
|
||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
t.Error("confirm prompt should not be called for cloud models")
|
||||
return false, nil
|
||||
}
|
||||
defer func() { DefaultConfirmPrompt = oldHook }()
|
||||
|
||||
err := pullIfNeeded(context.Background(), nil, map[string]bool{}, "glm-5:cloud")
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error for cloud model, got %v", err)
|
||||
}
|
||||
|
||||
err = pullIfNeeded(context.Background(), nil, map[string]bool{}, "gpt-oss:20b-cloud")
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error for cloud model with legacy suffix, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectModelsWithSelectors_CloudSelection_DoesNotPull(t *testing.T) {
|
||||
var pullCalled bool
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/status":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"not found"}`)
|
||||
case "/api/tags":
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"models":[]}`)
|
||||
case "/api/me":
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"name":"test-user"}`)
|
||||
case "/api/pull":
|
||||
pullCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"status":"success"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"not found"}`)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
single := func(title string, items []ModelItem, current string) (string, error) {
|
||||
for _, item := range items {
|
||||
if item.Name == "glm-5:cloud" {
|
||||
return item.Name, nil
|
||||
}
|
||||
}
|
||||
t.Fatalf("expected glm-5:cloud in selector items, got %v", items)
|
||||
return "", nil
|
||||
}
|
||||
|
||||
multi := func(title string, items []ModelItem, preChecked []string) ([]string, error) {
|
||||
return nil, fmt.Errorf("multi selector should not be called")
|
||||
}
|
||||
|
||||
selected, err := selectModelsWithSelectors(context.Background(), "codex", "", single, multi)
|
||||
if err != nil {
|
||||
t.Fatalf("selectModelsWithSelectors returned error: %v", err)
|
||||
}
|
||||
if !slices.Equal(selected, []string{"glm-5:cloud"}) {
|
||||
t.Fatalf("unexpected selected models: %v", selected)
|
||||
}
|
||||
if pullCalled {
|
||||
t.Fatal("expected cloud selection to skip pull")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1049,7 +1169,7 @@ func TestConfirmPrompt_DelegatesToHook(t *testing.T) {
|
||||
}
|
||||
defer func() { DefaultConfirmPrompt = oldHook }()
|
||||
|
||||
ok, err := confirmPrompt("test prompt?")
|
||||
ok, err := ConfirmPrompt("test prompt?")
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -1071,7 +1191,7 @@ func TestEnsureAuth_NoCloudModels(t *testing.T) {
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
err := ensureAuth(context.Background(), client, map[string]bool{}, []string{"local-model"})
|
||||
err := EnsureAuth(context.Background(), client, map[string]bool{}, []string{"local-model"})
|
||||
if err != nil {
|
||||
t.Errorf("ensureAuth should return nil for non-cloud models, got: %v", err)
|
||||
}
|
||||
@@ -1097,7 +1217,7 @@ func TestEnsureAuth_CloudModelFilteredCorrectly(t *testing.T) {
|
||||
cloudModels := map[string]bool{"cloud-model:cloud": true}
|
||||
selected := []string{"cloud-model:cloud", "local-model"}
|
||||
|
||||
err := ensureAuth(context.Background(), client, cloudModels, selected)
|
||||
err := EnsureAuth(context.Background(), client, cloudModels, selected)
|
||||
if err != nil {
|
||||
t.Errorf("ensureAuth should succeed when user is authenticated, got: %v", err)
|
||||
}
|
||||
@@ -1123,7 +1243,7 @@ func TestEnsureAuth_SkipsWhenNoCloudSelected(t *testing.T) {
|
||||
cloudModels := map[string]bool{"cloud-model:cloud": true}
|
||||
selected := []string{"local-model"}
|
||||
|
||||
err := ensureAuth(context.Background(), client, cloudModels, selected)
|
||||
err := EnsureAuth(context.Background(), client, cloudModels, selected)
|
||||
if err != nil {
|
||||
t.Errorf("expected nil error, got: %v", err)
|
||||
}
|
||||
@@ -1132,6 +1252,66 @@ func TestEnsureAuth_SkipsWhenNoCloudSelected(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureAuth_PreservesCancelledSignInHook(t *testing.T) {
|
||||
oldSignIn := DefaultSignIn
|
||||
DefaultSignIn = func(modelName, signInURL string) (string, error) {
|
||||
return "", ErrCancelled
|
||||
}
|
||||
defer func() { DefaultSignIn = oldSignIn }()
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/status":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"not found"}`)
|
||||
case "/api/me":
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
fmt.Fprintf(w, `{"error":"unauthorized","signin_url":"https://example.com/signin"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
err := EnsureAuth(context.Background(), client, map[string]bool{"cloud-model:cloud": true}, []string{"cloud-model:cloud"})
|
||||
if !errors.Is(err, ErrCancelled) {
|
||||
t.Fatalf("expected ErrCancelled, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureAuth_DeclinedFallbackReturnsCancelled(t *testing.T) {
|
||||
oldConfirm := DefaultConfirmPrompt
|
||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
defer func() { DefaultConfirmPrompt = oldConfirm }()
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/status":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"not found"}`)
|
||||
case "/api/me":
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
fmt.Fprintf(w, `{"error":"unauthorized","signin_url":"https://example.com/signin"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
err := EnsureAuth(context.Background(), client, map[string]bool{"cloud-model:cloud": true}, []string{"cloud-model:cloud"})
|
||||
if !errors.Is(err, ErrCancelled) {
|
||||
t.Fatalf("expected ErrCancelled, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHyperlink(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -1303,7 +1483,7 @@ func TestBuildModelList_Descriptions(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "qwen3:8b", Remote: false},
|
||||
}
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
items, _, _, _ := BuildModelList(existing, nil, "")
|
||||
|
||||
for _, item := range items {
|
||||
if item.Name == "qwen3:8b" {
|
||||
@@ -1320,7 +1500,7 @@ func TestBuildModelList_Descriptions(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("not-installed local rec has VRAM in description", func(t *testing.T) {
|
||||
items, _, _, _ := buildModelList(nil, nil, "")
|
||||
items, _, _, _ := BuildModelList(nil, nil, "")
|
||||
|
||||
for _, item := range items {
|
||||
if item.Name == "qwen3:8b" {
|
||||
@@ -1337,7 +1517,7 @@ func TestBuildModelList_Descriptions(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "qwen3:8b", Remote: false},
|
||||
}
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
items, _, _, _ := BuildModelList(existing, nil, "")
|
||||
|
||||
for _, item := range items {
|
||||
if item.Name == "qwen3:8b" {
|
||||
@@ -1351,30 +1531,6 @@ func TestBuildModelList_Descriptions(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestLaunchIntegration_UnknownIntegration(t *testing.T) {
|
||||
err := LaunchIntegration("nonexistent-integration")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown integration")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "unknown integration") {
|
||||
t.Errorf("error should mention 'unknown integration', got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchIntegration_NotConfigured(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Claude is a known integration but not configured in temp dir
|
||||
err := LaunchIntegration("claude")
|
||||
if err == nil {
|
||||
t.Fatal("expected error when integration is not configured")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not configured") {
|
||||
t.Errorf("error should mention 'not configured', got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsEditorIntegration(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -1417,13 +1573,3 @@ func TestIntegrationModels(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSaveAndEditIntegration_UnknownIntegration(t *testing.T) {
|
||||
err := SaveAndEditIntegration("nonexistent", []string{"model"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown integration")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "unknown integration") {
|
||||
t.Errorf("error should mention 'unknown integration', got: %v", err)
|
||||
}
|
||||
}
|
||||
857
cmd/launch/launch.go
Normal file
857
cmd/launch/launch.go
Normal file
@@ -0,0 +1,857 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// LauncherState is the launch-owned snapshot used to render the root launcher menu.
|
||||
type LauncherState struct {
|
||||
LastSelection string
|
||||
RunModel string
|
||||
RunModelUsable bool
|
||||
Integrations map[string]LauncherIntegrationState
|
||||
}
|
||||
|
||||
// LauncherIntegrationState is the launch-owned status for one launcher integration.
|
||||
type LauncherIntegrationState struct {
|
||||
Name string
|
||||
DisplayName string
|
||||
Description string
|
||||
Installed bool
|
||||
AutoInstallable bool
|
||||
Selectable bool
|
||||
Changeable bool
|
||||
CurrentModel string
|
||||
ModelUsable bool
|
||||
InstallHint string
|
||||
Editor bool
|
||||
}
|
||||
|
||||
// RunModelRequest controls how the root launcher resolves the chat model.
|
||||
type RunModelRequest struct {
|
||||
ForcePicker bool
|
||||
}
|
||||
|
||||
// IntegrationLaunchRequest controls the canonical integration launcher flow.
|
||||
type IntegrationLaunchRequest struct {
|
||||
Name string
|
||||
ModelOverride string
|
||||
ForceConfigure bool
|
||||
ConfigureOnly bool
|
||||
ExtraArgs []string
|
||||
}
|
||||
|
||||
// LauncherInvocation carries one-shot root launcher overrides derived from CLI flags.
|
||||
type LauncherInvocation struct {
|
||||
ModelOverride string
|
||||
ExtraArgs []string
|
||||
}
|
||||
|
||||
var isInteractiveSession = func() bool {
|
||||
return term.IsTerminal(int(os.Stdin.Fd())) && term.IsTerminal(int(os.Stdout.Fd()))
|
||||
}
|
||||
|
||||
// Runner executes a model with an integration.
|
||||
type Runner interface {
|
||||
Run(model string, args []string) error
|
||||
String() string
|
||||
}
|
||||
|
||||
// Editor can edit config files for integrations that support model configuration.
|
||||
type Editor interface {
|
||||
Paths() []string
|
||||
Edit(models []string) error
|
||||
Models() []string
|
||||
}
|
||||
|
||||
// AliasConfigurer can configure model aliases for integrations like Claude and Codex.
|
||||
type AliasConfigurer interface {
|
||||
ConfigureAliases(ctx context.Context, primaryModel string, existing map[string]string, force bool) (map[string]string, bool, error)
|
||||
SetAliases(ctx context.Context, aliases map[string]string) error
|
||||
}
|
||||
|
||||
type modelInfo struct {
|
||||
Name string
|
||||
Remote bool
|
||||
ToolCapable bool
|
||||
}
|
||||
|
||||
// ModelInfo re-exports launcher model inventory details for callers.
|
||||
type ModelInfo = modelInfo
|
||||
|
||||
// ModelItem represents a model for selection UIs.
|
||||
type ModelItem struct {
|
||||
Name string
|
||||
Description string
|
||||
Recommended bool
|
||||
}
|
||||
|
||||
// LaunchIntegrationByName launches the named integration using saved config or prompts for setup.
|
||||
func LaunchIntegrationByName(name string) error {
|
||||
return LaunchIntegration(context.Background(), IntegrationLaunchRequest{Name: name})
|
||||
}
|
||||
|
||||
// LaunchIntegrationWithModel launches the named integration with the specified model.
|
||||
func LaunchIntegrationWithModel(name, modelName string) error {
|
||||
return LaunchIntegration(context.Background(), IntegrationLaunchRequest{
|
||||
Name: name,
|
||||
ModelOverride: modelName,
|
||||
})
|
||||
}
|
||||
|
||||
// SaveAndEditIntegration saves the models for an integration and, when supported,
|
||||
// runs its Edit method to write any integration-managed config files.
|
||||
func SaveAndEditIntegration(name string, models []string) error {
|
||||
key, runner, err := LookupIntegration(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
editor, ok := runner.(Editor)
|
||||
if !ok {
|
||||
return config.SaveIntegration(key, models)
|
||||
}
|
||||
|
||||
return PrepareEditorIntegration(key, runner, editor, models)
|
||||
}
|
||||
|
||||
// ConfigureIntegrationWithSelectors allows the user to select/change the model for an integration using custom selectors.
|
||||
func ConfigureIntegrationWithSelectors(ctx context.Context, name string, single SingleSelector, multi MultiSelector) error {
|
||||
oldSingle := DefaultSingleSelector
|
||||
oldMulti := DefaultMultiSelector
|
||||
if single != nil {
|
||||
DefaultSingleSelector = single
|
||||
}
|
||||
if multi != nil {
|
||||
DefaultMultiSelector = multi
|
||||
}
|
||||
defer func() {
|
||||
DefaultSingleSelector = oldSingle
|
||||
DefaultMultiSelector = oldMulti
|
||||
}()
|
||||
|
||||
return LaunchIntegration(ctx, IntegrationLaunchRequest{
|
||||
Name: name,
|
||||
ForceConfigure: true,
|
||||
ConfigureOnly: true,
|
||||
})
|
||||
}
|
||||
|
||||
// ConfigureIntegration allows the user to select/change the model for an integration.
|
||||
func ConfigureIntegration(ctx context.Context, name string) error {
|
||||
return LaunchIntegration(ctx, IntegrationLaunchRequest{
|
||||
Name: name,
|
||||
ForceConfigure: true,
|
||||
ConfigureOnly: true,
|
||||
})
|
||||
}
|
||||
|
||||
// LaunchCmd returns the cobra command for launching integrations.
|
||||
// The runTUI callback is called when the root launcher UI should be shown.
|
||||
func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error, runTUI func(cmd *cobra.Command, inv LauncherInvocation)) *cobra.Command {
|
||||
var modelFlag string
|
||||
var configFlag bool
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "launch [INTEGRATION] [-- [EXTRA_ARGS...]]",
|
||||
Short: "Launch the Ollama menu or an integration",
|
||||
Long: `Launch the Ollama interactive menu, or directly launch a specific integration.
|
||||
|
||||
Without arguments, this is equivalent to running 'ollama' directly.
|
||||
|
||||
Supported integrations:
|
||||
claude Claude Code
|
||||
cline Cline
|
||||
codex Codex
|
||||
droid Droid
|
||||
opencode OpenCode
|
||||
openclaw OpenClaw (aliases: clawdbot, moltbot)
|
||||
pi Pi
|
||||
|
||||
Examples:
|
||||
ollama launch
|
||||
ollama launch claude
|
||||
ollama launch claude --model <model>
|
||||
ollama launch droid --config (does not auto-launch)
|
||||
ollama launch codex -- -p myprofile (pass extra args to integration)
|
||||
ollama launch codex -- --sandbox workspace-write`,
|
||||
Args: cobra.ArbitraryArgs,
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if len(args) == 0 && modelFlag == "" && !configFlag {
|
||||
runTUI(cmd, LauncherInvocation{})
|
||||
return nil
|
||||
}
|
||||
|
||||
var name string
|
||||
var passArgs []string
|
||||
dashIdx := cmd.ArgsLenAtDash()
|
||||
|
||||
if dashIdx == -1 {
|
||||
if len(args) > 1 {
|
||||
return fmt.Errorf("unexpected arguments: %v\nUse '--' to pass extra arguments to the integration", args[1:])
|
||||
}
|
||||
if len(args) == 1 {
|
||||
name = args[0]
|
||||
}
|
||||
} else {
|
||||
if dashIdx > 1 {
|
||||
return fmt.Errorf("expected at most 1 integration name before '--', got %d", dashIdx)
|
||||
}
|
||||
if dashIdx == 1 {
|
||||
name = args[0]
|
||||
}
|
||||
passArgs = args[dashIdx:]
|
||||
}
|
||||
|
||||
if name == "" && modelFlag != "" && !configFlag {
|
||||
runTUI(cmd, LauncherInvocation{
|
||||
ModelOverride: modelFlag,
|
||||
ExtraArgs: append([]string(nil), passArgs...),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
if name == "" {
|
||||
var err error
|
||||
name, err = SelectIntegration()
|
||||
if errors.Is(err, ErrCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
err := LaunchIntegration(cmd.Context(), IntegrationLaunchRequest{
|
||||
Name: name,
|
||||
ModelOverride: modelFlag,
|
||||
ForceConfigure: configFlag || modelFlag == "",
|
||||
ConfigureOnly: configFlag,
|
||||
ExtraArgs: passArgs,
|
||||
})
|
||||
if errors.Is(err, ErrCancelled) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVar(&modelFlag, "model", "", "Model to use")
|
||||
cmd.Flags().BoolVar(&configFlag, "config", false, "Configure without launching")
|
||||
return cmd
|
||||
}
|
||||
|
||||
type launcherClient struct {
|
||||
apiClient *api.Client
|
||||
modelInventory []ModelInfo
|
||||
cloudDisabled bool
|
||||
cloudStatusLoaded bool
|
||||
inventoryLoaded bool
|
||||
}
|
||||
|
||||
func newLauncherClient() (*launcherClient, error) {
|
||||
apiClient, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &launcherClient{
|
||||
apiClient: apiClient,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// BuildLauncherState returns the launch-owned root launcher menu snapshot.
|
||||
func BuildLauncherState(ctx context.Context) (*LauncherState, error) {
|
||||
launchClient, err := newLauncherClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return launchClient.buildLauncherState(ctx)
|
||||
}
|
||||
|
||||
// ResolveRunModel returns the model that should be used for interactive chat.
|
||||
func ResolveRunModel(ctx context.Context, req RunModelRequest) (string, error) {
|
||||
launchClient, err := newLauncherClient()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return launchClient.resolveRunModel(ctx, req)
|
||||
}
|
||||
|
||||
// ResolveRequestedRunModel validates and persists an explicitly requested chat model.
|
||||
func ResolveRequestedRunModel(ctx context.Context, model string) (string, error) {
|
||||
launchClient, err := newLauncherClient()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return launchClient.resolveRequestedRunModel(ctx, model)
|
||||
}
|
||||
|
||||
// LaunchIntegration runs the canonical launcher flow for one integration.
|
||||
func LaunchIntegration(ctx context.Context, req IntegrationLaunchRequest) error {
|
||||
name, runner, err := LookupIntegration(req.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
launchClient, err := newLauncherClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
saved, _ := config.LoadIntegration(name)
|
||||
|
||||
if aliasConfigurer, ok := runner.(AliasConfigurer); ok {
|
||||
return launchClient.launchAliasConfiguredIntegration(ctx, name, runner, aliasConfigurer, saved, req)
|
||||
}
|
||||
if editor, ok := runner.(Editor); ok {
|
||||
return launchClient.launchEditorIntegration(ctx, name, runner, editor, saved, req)
|
||||
}
|
||||
return launchClient.launchSingleIntegration(ctx, name, runner, saved, req)
|
||||
}
|
||||
|
||||
// SelectIntegration lets the user choose which integration to launch.
|
||||
func SelectIntegration() (string, error) {
|
||||
if DefaultSingleSelector == nil {
|
||||
return "", fmt.Errorf("no selector configured")
|
||||
}
|
||||
|
||||
items, err := IntegrationSelectionItems()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return DefaultSingleSelector("Select integration:", items, "")
|
||||
}
|
||||
|
||||
func (c *launcherClient) buildLauncherState(ctx context.Context) (*LauncherState, error) {
|
||||
if err := c.loadModelInventoryOnce(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
state := &LauncherState{
|
||||
LastSelection: config.LastSelection(),
|
||||
RunModel: config.LastModel(),
|
||||
Integrations: make(map[string]LauncherIntegrationState),
|
||||
}
|
||||
runModelUsable, err := c.savedModelUsable(ctx, state.RunModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
state.RunModelUsable = runModelUsable
|
||||
|
||||
for _, info := range ListIntegrationInfos() {
|
||||
integrationState, err := c.buildLauncherIntegrationState(ctx, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
state.Integrations[info.Name] = integrationState
|
||||
}
|
||||
|
||||
return state, nil
|
||||
}
|
||||
|
||||
func (c *launcherClient) buildLauncherIntegrationState(ctx context.Context, info IntegrationInfo) (LauncherIntegrationState, error) {
|
||||
installed := IsIntegrationInstalled(info.Name)
|
||||
autoInstallable := AutoInstallable(info.Name)
|
||||
isEditor := IsEditorIntegration(info.Name)
|
||||
currentModel, usable, err := c.launcherModelState(ctx, info.Name, isEditor)
|
||||
if err != nil {
|
||||
return LauncherIntegrationState{}, err
|
||||
}
|
||||
|
||||
return LauncherIntegrationState{
|
||||
Name: info.Name,
|
||||
DisplayName: info.DisplayName,
|
||||
Description: info.Description,
|
||||
Installed: installed,
|
||||
AutoInstallable: autoInstallable,
|
||||
Selectable: installed || autoInstallable,
|
||||
Changeable: installed || autoInstallable,
|
||||
CurrentModel: currentModel,
|
||||
ModelUsable: usable,
|
||||
InstallHint: IntegrationInstallHint(info.Name),
|
||||
Editor: isEditor,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *launcherClient) launcherModelState(ctx context.Context, name string, isEditor bool) (string, bool, error) {
|
||||
cfg, err := config.LoadIntegration(name)
|
||||
if err != nil || len(cfg.Models) == 0 {
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
if isEditor {
|
||||
filtered := c.filterDisabledCloudModels(ctx, cfg.Models)
|
||||
if len(filtered) > 0 {
|
||||
return filtered[0], true, nil
|
||||
}
|
||||
return cfg.Models[0], false, nil
|
||||
}
|
||||
|
||||
model := cfg.Models[0]
|
||||
usable, err := c.savedModelUsable(ctx, model)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
return model, usable, nil
|
||||
}
|
||||
|
||||
func (c *launcherClient) resolveRunModel(ctx context.Context, req RunModelRequest) (string, error) {
|
||||
current := config.LastModel()
|
||||
if !req.ForcePicker {
|
||||
usable, err := c.savedModelUsable(ctx, current)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if usable {
|
||||
if err := c.ensureModelsReady(ctx, []string{current}); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := config.SetLastModel(current); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return current, nil
|
||||
}
|
||||
}
|
||||
|
||||
model, err := c.selectSingleModelWithSelector(ctx, "Select model to run:", current, DefaultSingleSelector)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := config.SetLastModel(model); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return model, nil
|
||||
}
|
||||
|
||||
func (c *launcherClient) resolveRequestedRunModel(ctx context.Context, model string) (string, error) {
|
||||
if err := c.ensureModelsReady(ctx, []string{model}); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := config.SetLastModel(model); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return model, nil
|
||||
}
|
||||
|
||||
func (c *launcherClient) launchSingleIntegration(ctx context.Context, name string, runner Runner, saved *config.IntegrationConfig, req IntegrationLaunchRequest) error {
|
||||
current := primaryModelFromConfig(saved)
|
||||
target := req.ModelOverride
|
||||
needsConfigure := req.ForceConfigure
|
||||
|
||||
if target == "" {
|
||||
target = current
|
||||
usable, err := c.savedModelUsable(ctx, target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !usable {
|
||||
needsConfigure = true
|
||||
}
|
||||
}
|
||||
|
||||
if needsConfigure {
|
||||
selected, err := c.selectSingleModelWithSelector(ctx, fmt.Sprintf("Select model for %s:", runner), target, DefaultSingleSelector)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
target = selected
|
||||
} else if err := c.ensureModelsReady(ctx, []string{target}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if target == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := config.SaveIntegration(name, []string{target}); err != nil {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
|
||||
return launchAfterConfiguration(name, runner, target, req)
|
||||
}
|
||||
|
||||
func (c *launcherClient) launchEditorIntegration(ctx context.Context, name string, runner Runner, editor Editor, saved *config.IntegrationConfig, req IntegrationLaunchRequest) error {
|
||||
models, needsConfigure := c.resolveEditorLaunchModels(ctx, saved, req)
|
||||
|
||||
if needsConfigure {
|
||||
selected, err := c.selectMultiModelsForIntegration(ctx, runner, models)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
models = selected
|
||||
} else if err := c.ensureModelsReady(ctx, models); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if needsConfigure || req.ModelOverride != "" {
|
||||
if err := PrepareEditorIntegration(name, runner, editor, models); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return launchAfterConfiguration(name, runner, models[0], req)
|
||||
}
|
||||
|
||||
func (c *launcherClient) launchAliasConfiguredIntegration(ctx context.Context, name string, runner Runner, aliases AliasConfigurer, saved *config.IntegrationConfig, req IntegrationLaunchRequest) error {
|
||||
primary := req.ModelOverride
|
||||
var existingAliases map[string]string
|
||||
if saved != nil {
|
||||
existingAliases = saved.Aliases
|
||||
if primary == "" {
|
||||
primary = primaryModelFromConfig(saved)
|
||||
}
|
||||
}
|
||||
|
||||
forceConfigure := req.ForceConfigure
|
||||
if primary == "" {
|
||||
forceConfigure = true
|
||||
} else if req.ModelOverride == "" {
|
||||
// Only auto-force reconfiguration for saved/default models.
|
||||
// Explicit --model overrides should be respected and validated via ensureModelsReady.
|
||||
usable, err := c.savedModelUsable(ctx, primary)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !usable {
|
||||
forceConfigure = true
|
||||
}
|
||||
}
|
||||
|
||||
resolvedAliases := cloneAliases(existingAliases)
|
||||
if forceConfigure || primary != "" {
|
||||
var changed bool
|
||||
var err error
|
||||
resolvedAliases, changed, err = aliases.ConfigureAliases(ctx, primary, existingAliases, forceConfigure)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if changed || primary == "" {
|
||||
primary = resolvedAliases["primary"]
|
||||
}
|
||||
}
|
||||
|
||||
if primary == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := c.ensureModelsReady(ctx, []string{primary}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := syncAliases(ctx, c.apiClient, aliases, name, primary, resolvedAliases); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Warning: could not sync aliases: %v\n", err)
|
||||
}
|
||||
if err := config.SaveAliases(name, normalizedAliases(primary, resolvedAliases)); err != nil {
|
||||
return fmt.Errorf("failed to save aliases: %w", err)
|
||||
}
|
||||
if err := config.SaveIntegration(name, []string{primary}); err != nil {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
|
||||
return launchAfterConfiguration(name, runner, primary, req)
|
||||
}
|
||||
|
||||
func (c *launcherClient) selectSingleModelWithSelector(ctx context.Context, title, current string, selector SingleSelector) (string, error) {
|
||||
if selector == nil {
|
||||
return "", fmt.Errorf("no selector configured")
|
||||
}
|
||||
|
||||
items, _, err := c.loadSelectableModels(ctx, singleModelPrechecked(current), current, "no models available, run 'ollama pull <model>' first")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
selected, err := selector(title, items, current)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := c.ensureModelsReady(ctx, []string{selected}); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
func (c *launcherClient) selectMultiModelsForIntegration(ctx context.Context, runner Runner, preChecked []string) ([]string, error) {
|
||||
if DefaultMultiSelector == nil {
|
||||
return nil, fmt.Errorf("no selector configured")
|
||||
}
|
||||
|
||||
items, orderedChecked, err := c.loadSelectableModels(ctx, preChecked, firstModel(preChecked), "no models available")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
selected, err := DefaultMultiSelector(fmt.Sprintf("Select models for %s:", runner), items, orderedChecked)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := c.ensureModelsReady(ctx, selected); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
func (c *launcherClient) loadSelectableModels(ctx context.Context, preChecked []string, current, emptyMessage string) ([]ModelItem, []string, error) {
|
||||
if err := c.loadModelInventoryOnce(ctx); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
items, orderedChecked, _, _ := BuildModelList(c.modelInventory, preChecked, current)
|
||||
if c.cloudDisabled {
|
||||
items = FilterCloudItems(items)
|
||||
orderedChecked = c.filterDisabledCloudModels(ctx, orderedChecked)
|
||||
}
|
||||
if len(items) == 0 {
|
||||
return nil, nil, errors.New(emptyMessage)
|
||||
}
|
||||
return items, orderedChecked, nil
|
||||
}
|
||||
|
||||
func (c *launcherClient) ensureModelsReady(ctx context.Context, models []string) error {
|
||||
var deduped []string
|
||||
seen := make(map[string]bool, len(models))
|
||||
for _, model := range models {
|
||||
if model == "" || seen[model] {
|
||||
continue
|
||||
}
|
||||
seen[model] = true
|
||||
deduped = append(deduped, model)
|
||||
}
|
||||
models = deduped
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
missingModelPolicy := MissingModelPromptPull
|
||||
if !isInteractiveSession() {
|
||||
missingModelPolicy = MissingModelFail
|
||||
}
|
||||
|
||||
cloudModels := make(map[string]bool, len(models))
|
||||
for _, model := range models {
|
||||
if err := ShowOrPullWithPolicy(ctx, c.apiClient, model, missingModelPolicy); err != nil {
|
||||
return err
|
||||
}
|
||||
if IsCloudModelName(model) {
|
||||
cloudModels[model] = true
|
||||
}
|
||||
}
|
||||
|
||||
return EnsureAuth(ctx, c.apiClient, cloudModels, models)
|
||||
}
|
||||
|
||||
func (c *launcherClient) resolveEditorLaunchModels(ctx context.Context, saved *config.IntegrationConfig, req IntegrationLaunchRequest) ([]string, bool) {
|
||||
if req.ForceConfigure {
|
||||
return editorPreCheckedModels(saved, req.ModelOverride), true
|
||||
}
|
||||
|
||||
if req.ModelOverride != "" {
|
||||
models := append([]string{req.ModelOverride}, additionalSavedModels(saved, req.ModelOverride)...)
|
||||
models = c.filterDisabledCloudModels(ctx, models)
|
||||
return models, len(models) == 0
|
||||
}
|
||||
|
||||
if saved == nil || len(saved.Models) == 0 {
|
||||
return nil, true
|
||||
}
|
||||
|
||||
models := c.filterDisabledCloudModels(ctx, saved.Models)
|
||||
return models, len(models) == 0
|
||||
}
|
||||
|
||||
func (c *launcherClient) filterDisabledCloudModels(ctx context.Context, models []string) []string {
|
||||
c.ensureCloudStatus(ctx)
|
||||
if !c.cloudDisabled {
|
||||
return append([]string(nil), models...)
|
||||
}
|
||||
|
||||
filtered := make([]string, 0, len(models))
|
||||
for _, model := range models {
|
||||
if !IsCloudModelName(model) {
|
||||
filtered = append(filtered, model)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func (c *launcherClient) savedModelUsable(ctx context.Context, name string) (bool, error) {
|
||||
if err := c.loadModelInventoryOnce(ctx); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return c.singleModelUsable(name), nil
|
||||
}
|
||||
|
||||
func (c *launcherClient) singleModelUsable(name string) bool {
|
||||
if name == "" {
|
||||
return false
|
||||
}
|
||||
if IsCloudModelName(name) {
|
||||
return !c.cloudDisabled
|
||||
}
|
||||
return c.hasLocalModel(name)
|
||||
}
|
||||
|
||||
func (c *launcherClient) hasLocalModel(name string) bool {
|
||||
for _, model := range c.modelInventory {
|
||||
if model.Remote {
|
||||
continue
|
||||
}
|
||||
if model.Name == name || strings.HasPrefix(model.Name, name+":") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *launcherClient) ensureCloudStatus(ctx context.Context) {
|
||||
if c.cloudStatusLoaded {
|
||||
return
|
||||
}
|
||||
c.cloudDisabled, _ = CloudStatusDisabled(ctx, c.apiClient)
|
||||
c.cloudStatusLoaded = true
|
||||
}
|
||||
|
||||
func (c *launcherClient) loadModelInventoryOnce(ctx context.Context) error {
|
||||
if c.inventoryLoaded {
|
||||
return nil
|
||||
}
|
||||
|
||||
resp, err := c.apiClient.List(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.ensureCloudStatus(ctx)
|
||||
c.modelInventory = c.modelInventory[:0]
|
||||
for _, model := range resp.Models {
|
||||
c.modelInventory = append(c.modelInventory, ModelInfo{
|
||||
Name: model.Name,
|
||||
Remote: model.RemoteModel != "",
|
||||
})
|
||||
}
|
||||
if c.cloudDisabled {
|
||||
c.modelInventory = FilterCloudModels(c.modelInventory)
|
||||
}
|
||||
c.inventoryLoaded = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func runIntegration(runner Runner, modelName string, args []string) error {
|
||||
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", runner, modelName)
|
||||
return runner.Run(modelName, args)
|
||||
}
|
||||
|
||||
func syncAliases(ctx context.Context, client *api.Client, aliasConfigurer AliasConfigurer, name, model string, existing map[string]string) error {
|
||||
aliases := cloneAliases(existing)
|
||||
aliases["primary"] = model
|
||||
|
||||
if IsCloudModelName(model) {
|
||||
aliases["fast"] = model
|
||||
} else {
|
||||
delete(aliases, "fast")
|
||||
}
|
||||
|
||||
if err := aliasConfigurer.SetAliases(ctx, aliases); err != nil {
|
||||
return err
|
||||
}
|
||||
return config.SaveAliases(name, aliases)
|
||||
}
|
||||
|
||||
func launchAfterConfiguration(name string, runner Runner, model string, req IntegrationLaunchRequest) error {
|
||||
if req.ConfigureOnly {
|
||||
launch, err := ConfirmPrompt(fmt.Sprintf("Launch %s now?", runner))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !launch {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if err := EnsureIntegrationInstalled(name, runner); err != nil {
|
||||
return err
|
||||
}
|
||||
return runIntegration(runner, model, req.ExtraArgs)
|
||||
}
|
||||
|
||||
func primaryModelFromConfig(cfg *config.IntegrationConfig) string {
|
||||
if cfg == nil || len(cfg.Models) == 0 {
|
||||
return ""
|
||||
}
|
||||
return cfg.Models[0]
|
||||
}
|
||||
|
||||
func cloneAliases(aliases map[string]string) map[string]string {
|
||||
if len(aliases) == 0 {
|
||||
return make(map[string]string)
|
||||
}
|
||||
|
||||
cloned := make(map[string]string, len(aliases))
|
||||
for key, value := range aliases {
|
||||
cloned[key] = value
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func normalizedAliases(primary string, aliases map[string]string) map[string]string {
|
||||
normalized := cloneAliases(aliases)
|
||||
normalized["primary"] = primary
|
||||
if IsCloudModelName(primary) {
|
||||
normalized["fast"] = primary
|
||||
} else {
|
||||
delete(normalized, "fast")
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func singleModelPrechecked(current string) []string {
|
||||
if current == "" {
|
||||
return nil
|
||||
}
|
||||
return []string{current}
|
||||
}
|
||||
|
||||
func firstModel(models []string) string {
|
||||
if len(models) == 0 {
|
||||
return ""
|
||||
}
|
||||
return models[0]
|
||||
}
|
||||
|
||||
func editorPreCheckedModels(saved *config.IntegrationConfig, override string) []string {
|
||||
if override == "" {
|
||||
if saved == nil {
|
||||
return nil
|
||||
}
|
||||
return append([]string(nil), saved.Models...)
|
||||
}
|
||||
return append([]string{override}, additionalSavedModels(saved, override)...)
|
||||
}
|
||||
|
||||
func additionalSavedModels(saved *config.IntegrationConfig, exclude string) []string {
|
||||
if saved == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var models []string
|
||||
for _, model := range saved.Models {
|
||||
if model != exclude {
|
||||
models = append(models, model)
|
||||
}
|
||||
}
|
||||
return models
|
||||
}
|
||||
1110
cmd/launch/launch_test.go
Normal file
1110
cmd/launch/launch_test.go
Normal file
File diff suppressed because it is too large
Load Diff
726
cmd/launch/models.go
Normal file
726
cmd/launch/models.go
Normal file
@@ -0,0 +1,726 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||
"github.com/ollama/ollama/internal/modelref"
|
||||
"github.com/ollama/ollama/progress"
|
||||
)
|
||||
|
||||
var recommendedModels = []ModelItem{
|
||||
{Name: "minimax-m2.5:cloud", Description: "Fast, efficient coding and real-world productivity", Recommended: true},
|
||||
{Name: "glm-5:cloud", Description: "Reasoning and code generation", Recommended: true},
|
||||
{Name: "kimi-k2.5:cloud", Description: "Multimodal reasoning with subagents", Recommended: true},
|
||||
{Name: "glm-4.7-flash", Description: "Reasoning and code generation locally", Recommended: true},
|
||||
{Name: "qwen3:8b", Description: "Efficient all-purpose assistant", Recommended: true},
|
||||
}
|
||||
|
||||
var recommendedVRAM = map[string]string{
|
||||
"glm-4.7-flash": "~25GB",
|
||||
"qwen3:8b": "~11GB",
|
||||
}
|
||||
|
||||
// cloudModelLimit holds context and output token limits for a cloud model.
|
||||
type cloudModelLimit struct {
|
||||
Context int
|
||||
Output int
|
||||
}
|
||||
|
||||
// cloudModelLimits maps cloud model base names to their token limits.
|
||||
// TODO(parthsareen): grab context/output limits from model info instead of hardcoding
|
||||
var cloudModelLimits = map[string]cloudModelLimit{
|
||||
"minimax-m2.5": {Context: 204_800, Output: 128_000},
|
||||
"cogito-2.1:671b": {Context: 163_840, Output: 65_536},
|
||||
"deepseek-v3.1:671b": {Context: 163_840, Output: 163_840},
|
||||
"deepseek-v3.2": {Context: 163_840, Output: 65_536},
|
||||
"glm-4.6": {Context: 202_752, Output: 131_072},
|
||||
"glm-4.7": {Context: 202_752, Output: 131_072},
|
||||
"glm-5": {Context: 202_752, Output: 131_072},
|
||||
"gpt-oss:120b": {Context: 131_072, Output: 131_072},
|
||||
"gpt-oss:20b": {Context: 131_072, Output: 131_072},
|
||||
"kimi-k2:1t": {Context: 262_144, Output: 262_144},
|
||||
"kimi-k2.5": {Context: 262_144, Output: 262_144},
|
||||
"kimi-k2-thinking": {Context: 262_144, Output: 262_144},
|
||||
"nemotron-3-nano:30b": {Context: 1_048_576, Output: 131_072},
|
||||
"qwen3-coder:480b": {Context: 262_144, Output: 65_536},
|
||||
"qwen3-coder-next": {Context: 262_144, Output: 32_768},
|
||||
"qwen3-next:80b": {Context: 262_144, Output: 32_768},
|
||||
}
|
||||
|
||||
// lookupCloudModelLimit returns the token limits for a cloud model.
|
||||
// It tries the exact name first, then strips explicit cloud suffixes.
|
||||
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
|
||||
if l, ok := cloudModelLimits[name]; ok {
|
||||
return l, true
|
||||
}
|
||||
base, stripped := modelref.StripCloudSourceTag(name)
|
||||
if stripped {
|
||||
if l, ok := cloudModelLimits[base]; ok {
|
||||
return l, true
|
||||
}
|
||||
}
|
||||
return cloudModelLimit{}, false
|
||||
}
|
||||
|
||||
// MissingModelPolicy controls how model-not-found errors should be handled.
|
||||
type MissingModelPolicy int
|
||||
|
||||
const (
|
||||
// MissingModelPromptPull prompts the user to download missing local models.
|
||||
MissingModelPromptPull MissingModelPolicy = iota
|
||||
// MissingModelFail returns an error for missing local models without prompting.
|
||||
MissingModelFail
|
||||
)
|
||||
|
||||
// SelectModelWithSelector prompts the user to select a model using the provided selector.
|
||||
func SelectModelWithSelector(ctx context.Context, selector SingleSelector) (string, error) {
|
||||
if selector == nil {
|
||||
return "", fmt.Errorf("no selector configured")
|
||||
}
|
||||
|
||||
items, existingModels, cloudModels, client, err := listModels(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
current := config.LastModel()
|
||||
selected, err := selector("Select model to run:", items, current)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := pullIfNeeded(ctx, client, existingModels, selected); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := EnsureAuth(ctx, client, cloudModels, []string{selected}); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
// SelectModel lets the user select a model to run.
|
||||
func SelectModel(ctx context.Context) (string, error) {
|
||||
return SelectModelWithSelector(ctx, DefaultSingleSelector)
|
||||
}
|
||||
|
||||
// OpenBrowser opens the URL in the user's browser.
|
||||
func OpenBrowser(url string) {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
_ = exec.Command("open", url).Start()
|
||||
case "linux":
|
||||
_ = exec.Command("xdg-open", url).Start()
|
||||
case "windows":
|
||||
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
|
||||
}
|
||||
}
|
||||
|
||||
// EnsureAuth ensures the user is signed in before cloud-backed models run.
|
||||
func EnsureAuth(ctx context.Context, client *api.Client, cloudModels map[string]bool, selected []string) error {
|
||||
var selectedCloudModels []string
|
||||
for _, m := range selected {
|
||||
if cloudModels[m] {
|
||||
selectedCloudModels = append(selectedCloudModels, m)
|
||||
}
|
||||
}
|
||||
if len(selectedCloudModels) == 0 {
|
||||
return nil
|
||||
}
|
||||
if disabled, known := CloudStatusDisabled(ctx, client); known && disabled {
|
||||
return errors.New(internalcloud.DisabledError("remote inference is unavailable"))
|
||||
}
|
||||
|
||||
user, err := client.Whoami(ctx)
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var aErr api.AuthorizationError
|
||||
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
|
||||
return err
|
||||
}
|
||||
|
||||
modelList := strings.Join(selectedCloudModels, ", ")
|
||||
|
||||
if DefaultSignIn != nil {
|
||||
_, err := DefaultSignIn(modelList, aErr.SigninURL)
|
||||
if errors.Is(err, ErrCancelled) {
|
||||
return ErrCancelled
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s requires sign in", modelList)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
yes, err := ConfirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
|
||||
if errors.Is(err, ErrCancelled) {
|
||||
return ErrCancelled
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !yes {
|
||||
return ErrCancelled
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
|
||||
OpenBrowser(aErr.SigninURL)
|
||||
|
||||
spinnerFrames := []string{"|", "/", "-", "\\"}
|
||||
frame := 0
|
||||
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
|
||||
|
||||
ticker := time.NewTicker(200 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
frame++
|
||||
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
|
||||
|
||||
if frame%10 == 0 {
|
||||
u, err := client.Whoami(ctx)
|
||||
if err == nil && u != nil && u.Name != "" {
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IntegrationSelectionItems returns the sorted integration items shown by launcher selection UIs.
|
||||
func SelectIntegrationItems() ([]ModelItem, error) {
|
||||
return IntegrationSelectionItems()
|
||||
}
|
||||
|
||||
func selectModelsWithSelectors(ctx context.Context, name, current string, single SingleSelector, multi MultiSelector) ([]string, error) {
|
||||
key, runner, err := LookupIntegration(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
models, err := client.List(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var existing []modelInfo
|
||||
for _, m := range models.Models {
|
||||
existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""})
|
||||
}
|
||||
|
||||
cloudDisabled, _ := CloudStatusDisabled(ctx, client)
|
||||
if cloudDisabled {
|
||||
existing = FilterCloudModels(existing)
|
||||
}
|
||||
|
||||
var preChecked []string
|
||||
if saved, err := config.LoadIntegration(key); err == nil {
|
||||
preChecked = saved.Models
|
||||
} else if editor, ok := runner.(Editor); ok {
|
||||
preChecked = editor.Models()
|
||||
}
|
||||
|
||||
items, preChecked, existingModels, cloudModels := BuildModelList(existing, preChecked, current)
|
||||
if cloudDisabled {
|
||||
items = FilterCloudItems(items)
|
||||
}
|
||||
if len(items) == 0 {
|
||||
return nil, fmt.Errorf("no models available")
|
||||
}
|
||||
|
||||
var selected []string
|
||||
if _, ok := runner.(Editor); ok {
|
||||
selected, err = multi(fmt.Sprintf("Select models for %s:", runner), items, preChecked)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
prompt := fmt.Sprintf("Select model for %s:", runner)
|
||||
if _, ok := runner.(AliasConfigurer); ok {
|
||||
prompt = fmt.Sprintf("Select Primary model for %s:", runner)
|
||||
}
|
||||
model, err := single(prompt, items, current)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selected = []string{model}
|
||||
}
|
||||
|
||||
var toPull []string
|
||||
for _, m := range selected {
|
||||
if !existingModels[m] && !IsCloudModelName(m) {
|
||||
toPull = append(toPull, m)
|
||||
}
|
||||
}
|
||||
if len(toPull) > 0 {
|
||||
msg := fmt.Sprintf("Download %s?", strings.Join(toPull, ", "))
|
||||
if ok, err := ConfirmPrompt(msg); err != nil {
|
||||
return nil, err
|
||||
} else if !ok {
|
||||
return nil, errCancelled
|
||||
}
|
||||
for _, m := range toPull {
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
if err := pullModel(ctx, client, m, false); err != nil {
|
||||
return nil, fmt.Errorf("failed to pull %s: %w", m, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := EnsureAuth(ctx, client, cloudModels, selected); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
func selectModels(ctx context.Context, name, current string) ([]string, error) {
|
||||
return selectModelsWithSelectors(ctx, name, current, DefaultSingleSelector, DefaultMultiSelector)
|
||||
}
|
||||
|
||||
func pullIfNeeded(ctx context.Context, client *api.Client, existingModels map[string]bool, model string) error {
|
||||
if IsCloudModelName(model) || existingModels[model] {
|
||||
return nil
|
||||
}
|
||||
return confirmAndPull(ctx, client, model)
|
||||
}
|
||||
|
||||
// ShowOrPull checks if a model exists via client.Show and offers to pull it if not found.
|
||||
func ShowOrPull(ctx context.Context, client *api.Client, model string) error {
|
||||
return ShowOrPullWithPolicy(ctx, client, model, MissingModelPromptPull)
|
||||
}
|
||||
|
||||
// ShowOrPullWithPolicy checks if a model exists and applies the provided missing-model policy.
|
||||
func ShowOrPullWithPolicy(ctx context.Context, client *api.Client, model string, policy MissingModelPolicy) error {
|
||||
if _, err := client.Show(ctx, &api.ShowRequest{Model: model}); err == nil {
|
||||
return nil
|
||||
} else {
|
||||
var statusErr api.StatusError
|
||||
if !errors.As(err, &statusErr) || statusErr.StatusCode != http.StatusNotFound {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if IsCloudModelName(model) {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch policy {
|
||||
case MissingModelFail:
|
||||
return fmt.Errorf("model %q not found; run 'ollama pull %s' first", model, model)
|
||||
default:
|
||||
return confirmAndPull(ctx, client, model)
|
||||
}
|
||||
}
|
||||
|
||||
func confirmAndPull(ctx context.Context, client *api.Client, model string) error {
|
||||
if ok, err := ConfirmPrompt(fmt.Sprintf("Download %s?", model)); err != nil {
|
||||
return err
|
||||
} else if !ok {
|
||||
return errCancelled
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
if err := pullModel(ctx, client, model, false); err != nil {
|
||||
return fmt.Errorf("failed to pull %s: %w", model, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func listModels(ctx context.Context) ([]ModelItem, map[string]bool, map[string]bool, *api.Client, error) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
|
||||
models, err := client.List(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
|
||||
var existing []modelInfo
|
||||
for _, m := range models.Models {
|
||||
existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""})
|
||||
}
|
||||
|
||||
cloudDisabled, _ := CloudStatusDisabled(ctx, client)
|
||||
if cloudDisabled {
|
||||
existing = FilterCloudModels(existing)
|
||||
}
|
||||
|
||||
items, _, existingModels, cloudModels := BuildModelList(existing, nil, "")
|
||||
if cloudDisabled {
|
||||
items = FilterCloudItems(items)
|
||||
}
|
||||
if len(items) == 0 {
|
||||
return nil, nil, nil, nil, fmt.Errorf("no models available, run 'ollama pull <model>' first")
|
||||
}
|
||||
|
||||
return items, existingModels, cloudModels, client, nil
|
||||
}
|
||||
|
||||
func resolveEditorModels(name string, models []string, picker func() ([]string, error)) ([]string, error) {
|
||||
filtered := filterDisabledCloudModels(models)
|
||||
if len(filtered) != len(models) {
|
||||
if err := config.SaveIntegration(name, filtered); err != nil {
|
||||
return nil, fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
}
|
||||
if len(filtered) > 0 {
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
selected, err := picker()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := config.SaveIntegration(name, selected); err != nil {
|
||||
return nil, fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
// PrepareEditorIntegration persists models and applies editor-managed config files.
|
||||
func PrepareEditorIntegration(name string, runner Runner, editor Editor, models []string) error {
|
||||
if ok, err := confirmEditorEdit(runner, editor); err != nil {
|
||||
return err
|
||||
} else if !ok {
|
||||
return errCancelled
|
||||
}
|
||||
if err := editor.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
if err := config.SaveIntegration(name, models); err != nil {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RunIntegration executes a configured integration with the selected model.
|
||||
func RunIntegration(name, modelName string, args []string) error {
|
||||
_, runner, err := LookupIntegration(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", runner, modelName)
|
||||
return runner.Run(modelName, args)
|
||||
}
|
||||
|
||||
func confirmEditorEdit(runner Runner, editor Editor) (bool, error) {
|
||||
paths := editor.Paths()
|
||||
if len(paths) == 0 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "This will modify your %s configuration:\n", runner)
|
||||
for _, path := range paths {
|
||||
fmt.Fprintf(os.Stderr, " %s\n", path)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Backups will be saved to %s/\n\n", backupDir())
|
||||
|
||||
return ConfirmPrompt("Proceed?")
|
||||
}
|
||||
|
||||
// BuildModelList merges existing models with recommendations for selection UIs.
|
||||
func BuildModelList(existing []modelInfo, preChecked []string, current string) (items []ModelItem, orderedChecked []string, existingModels, cloudModels map[string]bool) {
|
||||
existingModels = make(map[string]bool)
|
||||
cloudModels = make(map[string]bool)
|
||||
recommended := make(map[string]bool)
|
||||
var hasLocalModel, hasCloudModel bool
|
||||
|
||||
recDesc := make(map[string]string)
|
||||
for _, rec := range recommendedModels {
|
||||
recommended[rec.Name] = true
|
||||
recDesc[rec.Name] = rec.Description
|
||||
}
|
||||
|
||||
for _, m := range existing {
|
||||
existingModels[m.Name] = true
|
||||
if m.Remote {
|
||||
cloudModels[m.Name] = true
|
||||
hasCloudModel = true
|
||||
} else {
|
||||
hasLocalModel = true
|
||||
}
|
||||
displayName := strings.TrimSuffix(m.Name, ":latest")
|
||||
existingModels[displayName] = true
|
||||
item := ModelItem{Name: displayName, Recommended: recommended[displayName], Description: recDesc[displayName]}
|
||||
items = append(items, item)
|
||||
}
|
||||
|
||||
for _, rec := range recommendedModels {
|
||||
if existingModels[rec.Name] || existingModels[rec.Name+":latest"] {
|
||||
continue
|
||||
}
|
||||
items = append(items, rec)
|
||||
if IsCloudModelName(rec.Name) {
|
||||
cloudModels[rec.Name] = true
|
||||
}
|
||||
}
|
||||
|
||||
checked := make(map[string]bool, len(preChecked))
|
||||
for _, n := range preChecked {
|
||||
checked[n] = true
|
||||
}
|
||||
|
||||
for _, item := range items {
|
||||
if item.Name == current || strings.HasPrefix(item.Name, current+":") {
|
||||
current = item.Name
|
||||
break
|
||||
}
|
||||
}
|
||||
if checked[current] {
|
||||
preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...)
|
||||
}
|
||||
|
||||
notInstalled := make(map[string]bool)
|
||||
for i := range items {
|
||||
if !existingModels[items[i].Name] && !cloudModels[items[i].Name] {
|
||||
notInstalled[items[i].Name] = true
|
||||
var parts []string
|
||||
if items[i].Description != "" {
|
||||
parts = append(parts, items[i].Description)
|
||||
}
|
||||
if vram := recommendedVRAM[items[i].Name]; vram != "" {
|
||||
parts = append(parts, vram)
|
||||
}
|
||||
parts = append(parts, "(not downloaded)")
|
||||
items[i].Description = strings.Join(parts, ", ")
|
||||
}
|
||||
}
|
||||
|
||||
recRank := make(map[string]int)
|
||||
for i, rec := range recommendedModels {
|
||||
recRank[rec.Name] = i + 1
|
||||
}
|
||||
|
||||
onlyLocal := hasLocalModel && !hasCloudModel
|
||||
|
||||
if hasLocalModel || hasCloudModel {
|
||||
slices.SortStableFunc(items, func(a, b ModelItem) int {
|
||||
ac, bc := checked[a.Name], checked[b.Name]
|
||||
aNew, bNew := notInstalled[a.Name], notInstalled[b.Name]
|
||||
aRec, bRec := recRank[a.Name] > 0, recRank[b.Name] > 0
|
||||
aCloud, bCloud := cloudModels[a.Name], cloudModels[b.Name]
|
||||
|
||||
if ac != bc {
|
||||
if ac {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
if aRec != bRec {
|
||||
if aRec {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
if aRec && bRec {
|
||||
if aCloud != bCloud {
|
||||
if onlyLocal {
|
||||
if aCloud {
|
||||
return 1
|
||||
}
|
||||
return -1
|
||||
}
|
||||
if aCloud {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
return recRank[a.Name] - recRank[b.Name]
|
||||
}
|
||||
if aNew != bNew {
|
||||
if aNew {
|
||||
return 1
|
||||
}
|
||||
return -1
|
||||
}
|
||||
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
|
||||
})
|
||||
}
|
||||
|
||||
return items, preChecked, existingModels, cloudModels
|
||||
}
|
||||
|
||||
// IsCloudModelDisabled reports whether the given model name looks like a cloud model and cloud features are disabled.
|
||||
func IsCloudModelDisabled(ctx context.Context, name string) bool {
|
||||
if !IsCloudModelName(name) {
|
||||
return false
|
||||
}
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
disabled, _ := CloudStatusDisabled(ctx, client)
|
||||
return disabled
|
||||
}
|
||||
|
||||
// IsCloudModelName reports whether the model name has an explicit cloud source.
|
||||
func IsCloudModelName(name string) bool {
|
||||
return modelref.HasExplicitCloudSource(name)
|
||||
}
|
||||
|
||||
// FilterCloudModels drops remote-only models from the given inventory.
|
||||
func FilterCloudModels(existing []modelInfo) []modelInfo {
|
||||
filtered := existing[:0]
|
||||
for _, m := range existing {
|
||||
if !m.Remote {
|
||||
filtered = append(filtered, m)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func filterDisabledCloudModels(models []string) []string {
|
||||
var filtered []string
|
||||
for _, m := range models {
|
||||
if !IsCloudModelDisabled(context.Background(), m) {
|
||||
filtered = append(filtered, m)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// FilterCloudItems removes cloud models from selection items.
|
||||
func FilterCloudItems(items []ModelItem) []ModelItem {
|
||||
filtered := items[:0]
|
||||
for _, item := range items {
|
||||
if !IsCloudModelName(item.Name) {
|
||||
filtered = append(filtered, item)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func isCloudModel(ctx context.Context, client *api.Client, name string) bool {
|
||||
if client == nil {
|
||||
return false
|
||||
}
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Model: name})
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return resp.RemoteModel != ""
|
||||
}
|
||||
|
||||
// GetModelItems returns a list of model items including recommendations for the TUI.
|
||||
func GetModelItems(ctx context.Context) ([]ModelItem, map[string]bool) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
models, err := client.List(ctx)
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var existing []modelInfo
|
||||
for _, m := range models.Models {
|
||||
existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""})
|
||||
}
|
||||
|
||||
cloudDisabled, _ := CloudStatusDisabled(ctx, client)
|
||||
if cloudDisabled {
|
||||
existing = FilterCloudModels(existing)
|
||||
}
|
||||
|
||||
lastModel := config.LastModel()
|
||||
var preChecked []string
|
||||
if lastModel != "" {
|
||||
preChecked = []string{lastModel}
|
||||
}
|
||||
|
||||
items, _, existingModels, _ := BuildModelList(existing, preChecked, lastModel)
|
||||
if cloudDisabled {
|
||||
items = FilterCloudItems(items)
|
||||
}
|
||||
|
||||
return items, existingModels
|
||||
}
|
||||
|
||||
// CloudStatusDisabled returns whether cloud usage is currently disabled.
|
||||
func CloudStatusDisabled(ctx context.Context, client *api.Client) (disabled bool, known bool) {
|
||||
status, err := client.CloudStatusExperimental(ctx)
|
||||
if err != nil {
|
||||
var statusErr api.StatusError
|
||||
if errors.As(err, &statusErr) && statusErr.StatusCode == http.StatusNotFound {
|
||||
return false, false
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
return status.Cloud.Disabled, true
|
||||
}
|
||||
|
||||
// TODO(parthsareen): this duplicates the pull progress UI in cmd.PullHandler.
|
||||
// Move the shared pull rendering to a small utility once the package boundary settles.
|
||||
func pullModel(ctx context.Context, client *api.Client, model string, insecure bool) error {
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
|
||||
bars := make(map[string]*progress.Bar)
|
||||
var status string
|
||||
var spinner *progress.Spinner
|
||||
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
if resp.Digest != "" {
|
||||
if resp.Completed == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
|
||||
bar, ok := bars[resp.Digest]
|
||||
if !ok {
|
||||
name, isDigest := strings.CutPrefix(resp.Digest, "sha256:")
|
||||
name = strings.TrimSpace(name)
|
||||
if isDigest {
|
||||
name = name[:min(12, len(name))]
|
||||
}
|
||||
bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed)
|
||||
bars[resp.Digest] = bar
|
||||
p.Add(resp.Digest, bar)
|
||||
}
|
||||
|
||||
bar.Set(resp.Completed)
|
||||
} else if status != resp.Status {
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
|
||||
status = resp.Status
|
||||
spinner = progress.NewSpinner(status)
|
||||
p.Add(status, spinner)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
request := api.PullRequest{Name: model, Insecure: insecure}
|
||||
return client.Pull(ctx, &request, fn)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
@@ -35,7 +36,7 @@ func (c *Openclaw) Run(model string, args []string) error {
|
||||
}
|
||||
|
||||
firstLaunch := true
|
||||
if integrationConfig, err := loadIntegration("openclaw"); err == nil {
|
||||
if integrationConfig, err := config.LoadIntegration("openclaw"); err == nil {
|
||||
firstLaunch = !integrationConfig.Onboarded
|
||||
}
|
||||
|
||||
@@ -45,7 +46,7 @@ func (c *Openclaw) Run(model string, args []string) error {
|
||||
fmt.Fprintf(os.Stderr, " A bad prompt can trick it into doing unsafe things.\n\n")
|
||||
fmt.Fprintf(os.Stderr, "%s Learn more: https://docs.openclaw.ai/gateway/security%s\n\n", ansiGray, ansiReset)
|
||||
|
||||
ok, err := confirmPrompt("I understand the risks. Continue?")
|
||||
ok, err := ConfirmPrompt("I understand the risks. Continue?")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -107,7 +108,7 @@ func (c *Openclaw) Run(model string, args []string) error {
|
||||
return windowsHint(err)
|
||||
}
|
||||
if firstLaunch {
|
||||
if err := integrationOnboarded("openclaw"); err != nil {
|
||||
if err := config.MarkIntegrationOnboarded("openclaw"); err != nil {
|
||||
return fmt.Errorf("failed to save onboarding state: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -166,7 +167,7 @@ func (c *Openclaw) Run(model string, args []string) error {
|
||||
}
|
||||
|
||||
if firstLaunch {
|
||||
if err := integrationOnboarded("openclaw"); err != nil {
|
||||
if err := config.MarkIntegrationOnboarded("openclaw"); err != nil {
|
||||
return fmt.Errorf("failed to save onboarding state: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -426,7 +427,7 @@ func ensureOpenclawInstalled() (string, error) {
|
||||
"and select OpenClaw")
|
||||
}
|
||||
|
||||
ok, err := confirmPrompt("OpenClaw is not installed. Install with npm?")
|
||||
ok, err := ConfirmPrompt("OpenClaw is not installed. Install with npm?")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@@ -116,9 +116,9 @@ func TestOpenclawRunFirstLaunchPersistence(t *testing.T) {
|
||||
if err := c.Run("llama3.2", []string{"gateway", "--status"}); err != nil {
|
||||
t.Fatalf("Run() error = %v", err)
|
||||
}
|
||||
integrationConfig, err := loadIntegration("openclaw")
|
||||
integrationConfig, err := LoadIntegration("openclaw")
|
||||
if err != nil {
|
||||
t.Fatalf("loadIntegration() error = %v", err)
|
||||
t.Fatalf("LoadIntegration() error = %v", err)
|
||||
}
|
||||
if !integrationConfig.Onboarded {
|
||||
t.Fatal("expected onboarding flag to be persisted after successful run")
|
||||
@@ -147,7 +147,7 @@ func TestOpenclawRunFirstLaunchPersistence(t *testing.T) {
|
||||
if err := c.Run("llama3.2", []string{"gateway", "--status"}); err == nil {
|
||||
t.Fatal("expected run failure")
|
||||
}
|
||||
integrationConfig, err := loadIntegration("openclaw")
|
||||
integrationConfig, err := LoadIntegration("openclaw")
|
||||
if err == nil && integrationConfig.Onboarded {
|
||||
t.Fatal("expected onboarding flag to remain unset after failed run")
|
||||
}
|
||||
@@ -1528,7 +1528,7 @@ func TestIntegrationOnboarded(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
integrationConfig, err := loadIntegration("openclaw")
|
||||
integrationConfig, err := LoadIntegration("openclaw")
|
||||
if err == nil && integrationConfig.Onboarded {
|
||||
t.Error("expected false for fresh config")
|
||||
}
|
||||
@@ -1542,7 +1542,7 @@ func TestIntegrationOnboarded(t *testing.T) {
|
||||
if err := integrationOnboarded("openclaw"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
integrationConfig, err := loadIntegration("openclaw")
|
||||
integrationConfig, err := LoadIntegration("openclaw")
|
||||
if err != nil || !integrationConfig.Onboarded {
|
||||
t.Error("expected true after integrationOnboarded")
|
||||
}
|
||||
@@ -1556,7 +1556,7 @@ func TestIntegrationOnboarded(t *testing.T) {
|
||||
if err := integrationOnboarded("OpenClaw"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
integrationConfig, err := loadIntegration("openclaw")
|
||||
integrationConfig, err := LoadIntegration("openclaw")
|
||||
if err != nil || !integrationConfig.Onboarded {
|
||||
t.Error("expected true when set with different case")
|
||||
}
|
||||
@@ -1575,7 +1575,7 @@ func TestIntegrationOnboarded(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify onboarded is set
|
||||
integrationConfig, err := loadIntegration("openclaw")
|
||||
integrationConfig, err := LoadIntegration("openclaw")
|
||||
if err != nil || !integrationConfig.Onboarded {
|
||||
t.Error("expected true after integrationOnboarded")
|
||||
}
|
||||
@@ -1,9 +1,7 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"os"
|
||||
@@ -12,34 +10,12 @@ import (
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// OpenCode implements Runner and Editor for OpenCode integration
|
||||
type OpenCode struct{}
|
||||
|
||||
// cloudModelLimit holds context and output token limits for a cloud model.
|
||||
type cloudModelLimit struct {
|
||||
Context int
|
||||
Output int
|
||||
}
|
||||
|
||||
// lookupCloudModelLimit returns the token limits for a cloud model.
|
||||
// It tries the exact name first, then strips the ":cloud" suffix.
|
||||
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
|
||||
if l, ok := cloudModelLimits[name]; ok {
|
||||
return l, true
|
||||
}
|
||||
base := strings.TrimSuffix(name, ":cloud")
|
||||
if base != name {
|
||||
if l, ok := cloudModelLimits[base]; ok {
|
||||
return l, true
|
||||
}
|
||||
}
|
||||
return cloudModelLimit{}, false
|
||||
}
|
||||
|
||||
func (o *OpenCode) String() string { return "OpenCode" }
|
||||
|
||||
func (o *OpenCode) Run(model string, args []string) error {
|
||||
@@ -47,25 +23,6 @@ func (o *OpenCode) Run(model string, args []string) error {
|
||||
return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
|
||||
}
|
||||
|
||||
// Call Edit() to ensure config is up-to-date before launch
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("opencode"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
var err error
|
||||
models, err = resolveEditorModels("opencode", models, func() ([]string, error) {
|
||||
return selectModels(context.Background(), "opencode", "")
|
||||
})
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := o.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("opencode", args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
@@ -122,13 +79,18 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
if !ok {
|
||||
ollama = map[string]any{
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "Ollama (local)",
|
||||
"name": "Ollama",
|
||||
"options": map[string]any{
|
||||
"baseURL": envconfig.Host().String() + "/v1",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Migrate legacy provider name
|
||||
if name, _ := ollama["name"].(string); name == "Ollama (local)" {
|
||||
ollama["name"] = "Ollama"
|
||||
}
|
||||
|
||||
models, ok := ollama["models"].(map[string]any)
|
||||
if !ok {
|
||||
models = make(map[string]any)
|
||||
@@ -147,8 +109,6 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
client, _ := api.ClientFromEnvironment()
|
||||
|
||||
for _, model := range modelList {
|
||||
if existing, ok := models[model].(map[string]any); ok {
|
||||
// migrate existing models without _launch marker
|
||||
@@ -158,7 +118,7 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
|
||||
}
|
||||
}
|
||||
if isCloudModel(context.Background(), client, model) {
|
||||
if IsCloudModelName(model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
existing["limit"] = map[string]any{
|
||||
"context": l.Context,
|
||||
@@ -172,7 +132,7 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
"name": model,
|
||||
"_launch": true,
|
||||
}
|
||||
if isCloudModel(context.Background(), client, model) {
|
||||
if IsCloudModelName(model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
entry["limit"] = map[string]any{
|
||||
"context": l.Context,
|
||||
@@ -1,4 +1,4 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -232,6 +232,44 @@ func TestOpenCodeEdit(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("migrate Ollama (local) provider name", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"name":"Ollama (local)","npm":"@ai-sdk/openai-compatible","options":{"baseURL":"http://localhost:11434/v1"}}}}`), 0o644)
|
||||
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
provider := cfg["provider"].(map[string]any)
|
||||
ollama := provider["ollama"].(map[string]any)
|
||||
if ollama["name"] != "Ollama" {
|
||||
t.Errorf("provider name not migrated: got %q, want %q", ollama["name"], "Ollama")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserve custom provider name", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"name":"My Custom Ollama","npm":"@ai-sdk/openai-compatible","options":{"baseURL":"http://localhost:11434/v1"}}}}`), 0o644)
|
||||
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
provider := cfg["provider"].(map[string]any)
|
||||
ollama := provider["ollama"].(map[string]any)
|
||||
if ollama["name"] != "My Custom Ollama" {
|
||||
t.Errorf("custom provider name was changed: got %q, want %q", ollama["name"], "My Custom Ollama")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("remove model preserves non-ollama models", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
@@ -628,6 +666,8 @@ func TestLookupCloudModelLimit(t *testing.T) {
|
||||
}{
|
||||
{"glm-4.7", true, 202_752, 131_072},
|
||||
{"glm-4.7:cloud", true, 202_752, 131_072},
|
||||
{"glm-5", true, 202_752, 131_072},
|
||||
{"glm-5:cloud", true, 202_752, 131_072},
|
||||
{"kimi-k2.5", true, 262_144, 262_144},
|
||||
{"kimi-k2.5:cloud", true, 262_144, 262_144},
|
||||
{"deepseek-v3.2", true, 163_840, 65_536},
|
||||
@@ -1,4 +1,4 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -26,15 +26,6 @@ func (p *Pi) Run(model string, args []string) error {
|
||||
return fmt.Errorf("pi is not installed, install with: npm install -g @mariozechner/pi-coding-agent")
|
||||
}
|
||||
|
||||
// Call Edit() to ensure config is up-to-date before launch
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("pi"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
if err := p.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("pi", args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
@@ -206,8 +197,15 @@ func createConfig(ctx context.Context, client *api.Client, modelID string) map[s
|
||||
"_launch": true,
|
||||
}
|
||||
|
||||
applyCloudContextFallback := func() {
|
||||
if l, ok := lookupCloudModelLimit(modelID); ok {
|
||||
cfg["contextWindow"] = l.Context
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelID})
|
||||
if err != nil {
|
||||
applyCloudContextFallback()
|
||||
return cfg
|
||||
}
|
||||
|
||||
@@ -224,14 +222,19 @@ func createConfig(ctx context.Context, client *api.Client, modelID string) map[s
|
||||
}
|
||||
|
||||
// Extract context window from ModelInfo
|
||||
hasContextWindow := false
|
||||
for key, val := range resp.ModelInfo {
|
||||
if strings.HasSuffix(key, ".context_length") {
|
||||
if ctxLen, ok := val.(float64); ok && ctxLen > 0 {
|
||||
cfg["contextWindow"] = int(ctxLen)
|
||||
hasContextWindow = true
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasContextWindow {
|
||||
applyCloudContextFallback()
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -798,6 +798,43 @@ func TestCreateConfig(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cloud model falls back to hardcoded context when show fails", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg := createConfig(context.Background(), client, "kimi-k2.5:cloud")
|
||||
|
||||
if cfg["contextWindow"] != 262_144 {
|
||||
t.Errorf("contextWindow = %v, want 262144", cfg["contextWindow"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cloud model falls back to hardcoded context when show omits model info", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":[],"model_info":{}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg := createConfig(context.Background(), client, "glm-5:cloud")
|
||||
|
||||
if cfg["contextWindow"] != 202_752 {
|
||||
t.Errorf("contextWindow = %v, want 202752", cfg["contextWindow"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("skips zero context length", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
385
cmd/launch/registry.go
Normal file
385
cmd/launch/registry.go
Normal file
@@ -0,0 +1,385 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
)
|
||||
|
||||
// IntegrationInstallSpec describes how launcher should detect and guide installation.
|
||||
type IntegrationInstallSpec struct {
|
||||
CheckInstalled func() bool
|
||||
EnsureInstalled func() error
|
||||
URL string
|
||||
Command []string
|
||||
}
|
||||
|
||||
// IntegrationSpec is the canonical registry entry for one integration.
|
||||
type IntegrationSpec struct {
|
||||
Name string
|
||||
Runner Runner
|
||||
Aliases []string
|
||||
Hidden bool
|
||||
Description string
|
||||
Install IntegrationInstallSpec
|
||||
}
|
||||
|
||||
// IntegrationInfo contains display information about a registered integration.
|
||||
type IntegrationInfo struct {
|
||||
Name string
|
||||
DisplayName string
|
||||
Description string
|
||||
}
|
||||
|
||||
var launcherIntegrationOrder = []string{"opencode", "droid", "pi", "cline"}
|
||||
|
||||
var integrationSpecs = []*IntegrationSpec{
|
||||
{
|
||||
Name: "claude",
|
||||
Runner: &Claude{},
|
||||
Description: "Anthropic's coding tool with subagents",
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
_, err := (&Claude{}).findPath()
|
||||
return err == nil
|
||||
},
|
||||
URL: "https://code.claude.com/docs/en/quickstart",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "cline",
|
||||
Runner: &Cline{},
|
||||
Description: "Autonomous coding agent with parallel execution",
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
_, err := exec.LookPath("cline")
|
||||
return err == nil
|
||||
},
|
||||
Command: []string{"npm", "install", "-g", "cline"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "codex",
|
||||
Runner: &Codex{},
|
||||
Description: "OpenAI's open-source coding agent",
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
_, err := exec.LookPath("codex")
|
||||
return err == nil
|
||||
},
|
||||
URL: "https://developers.openai.com/codex/cli/",
|
||||
Command: []string{"npm", "install", "-g", "@openai/codex"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "droid",
|
||||
Runner: &Droid{},
|
||||
Description: "Factory's coding agent across terminal and IDEs",
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
_, err := exec.LookPath("droid")
|
||||
return err == nil
|
||||
},
|
||||
URL: "https://docs.factory.ai/cli/getting-started/quickstart",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "opencode",
|
||||
Runner: &OpenCode{},
|
||||
Description: "Anomaly's open-source coding agent",
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
_, err := exec.LookPath("opencode")
|
||||
return err == nil
|
||||
},
|
||||
URL: "https://opencode.ai",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "openclaw",
|
||||
Runner: &Openclaw{},
|
||||
Aliases: []string{"clawdbot", "moltbot"},
|
||||
Description: "Personal AI with 100+ skills",
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
if _, err := exec.LookPath("openclaw"); err == nil {
|
||||
return true
|
||||
}
|
||||
if _, err := exec.LookPath("clawdbot"); err == nil {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
},
|
||||
EnsureInstalled: func() error {
|
||||
_, err := ensureOpenclawInstalled()
|
||||
return err
|
||||
},
|
||||
URL: "https://docs.openclaw.ai",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "pi",
|
||||
Runner: &Pi{},
|
||||
Description: "Minimal AI agent toolkit with plugin support",
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
_, err := exec.LookPath("pi")
|
||||
return err == nil
|
||||
},
|
||||
Command: []string{"npm", "install", "-g", "@mariozechner/pi-coding-agent"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var integrationSpecsByName map[string]*IntegrationSpec
|
||||
|
||||
func init() {
|
||||
rebuildIntegrationSpecIndexes()
|
||||
}
|
||||
|
||||
func hyperlink(url, text string) string {
|
||||
return fmt.Sprintf("\033]8;;%s\033\\%s\033]8;;\033\\", url, text)
|
||||
}
|
||||
|
||||
func rebuildIntegrationSpecIndexes() {
|
||||
integrationSpecsByName = make(map[string]*IntegrationSpec, len(integrationSpecs))
|
||||
|
||||
canonical := make(map[string]bool, len(integrationSpecs))
|
||||
for _, spec := range integrationSpecs {
|
||||
key := strings.ToLower(spec.Name)
|
||||
if key == "" {
|
||||
panic("launch: integration spec missing name")
|
||||
}
|
||||
if canonical[key] {
|
||||
panic(fmt.Sprintf("launch: duplicate integration name %q", key))
|
||||
}
|
||||
canonical[key] = true
|
||||
integrationSpecsByName[key] = spec
|
||||
}
|
||||
|
||||
seenAliases := make(map[string]string)
|
||||
for _, spec := range integrationSpecs {
|
||||
for _, alias := range spec.Aliases {
|
||||
key := strings.ToLower(alias)
|
||||
if key == "" {
|
||||
panic(fmt.Sprintf("launch: integration %q has empty alias", spec.Name))
|
||||
}
|
||||
if canonical[key] {
|
||||
panic(fmt.Sprintf("launch: alias %q collides with canonical integration name", key))
|
||||
}
|
||||
if owner, exists := seenAliases[key]; exists {
|
||||
panic(fmt.Sprintf("launch: alias %q collides between %q and %q", key, owner, spec.Name))
|
||||
}
|
||||
seenAliases[key] = spec.Name
|
||||
integrationSpecsByName[key] = spec
|
||||
}
|
||||
}
|
||||
|
||||
orderSeen := make(map[string]bool, len(launcherIntegrationOrder))
|
||||
for _, name := range launcherIntegrationOrder {
|
||||
key := strings.ToLower(name)
|
||||
if orderSeen[key] {
|
||||
panic(fmt.Sprintf("launch: duplicate launcher order entry %q", key))
|
||||
}
|
||||
orderSeen[key] = true
|
||||
|
||||
spec, ok := integrationSpecsByName[key]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("launch: unknown launcher order entry %q", key))
|
||||
}
|
||||
if spec.Name != key {
|
||||
panic(fmt.Sprintf("launch: launcher order entry %q must use canonical name, not alias", key))
|
||||
}
|
||||
if spec.Hidden {
|
||||
panic(fmt.Sprintf("launch: hidden integration %q cannot appear in launcher order", key))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// LookupIntegrationSpec resolves either a canonical integration name or alias to its spec.
|
||||
func LookupIntegrationSpec(name string) (*IntegrationSpec, error) {
|
||||
spec, ok := integrationSpecsByName[strings.ToLower(name)]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
return spec, nil
|
||||
}
|
||||
|
||||
// OverrideIntegration replaces one registry entry's runner and returns a restore function.
|
||||
func OverrideIntegration(name string, runner Runner) func() {
|
||||
spec, err := LookupIntegrationSpec(name)
|
||||
if err != nil {
|
||||
key := strings.ToLower(name)
|
||||
integrationSpecsByName[key] = &IntegrationSpec{Name: key, Runner: runner}
|
||||
return func() {
|
||||
delete(integrationSpecsByName, key)
|
||||
}
|
||||
}
|
||||
|
||||
original := spec.Runner
|
||||
spec.Runner = runner
|
||||
return func() {
|
||||
spec.Runner = original
|
||||
}
|
||||
}
|
||||
|
||||
// LookupIntegration resolves a registry name to the canonical key and runner.
|
||||
func LookupIntegration(name string) (string, Runner, error) {
|
||||
spec, err := LookupIntegrationSpec(name)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
return spec.Name, spec.Runner, nil
|
||||
}
|
||||
|
||||
// ListVisibleIntegrationSpecs returns the canonical integrations that should appear in interactive UIs.
|
||||
func ListVisibleIntegrationSpecs() []IntegrationSpec {
|
||||
visible := make([]IntegrationSpec, 0, len(integrationSpecs))
|
||||
for _, spec := range integrationSpecs {
|
||||
if spec.Hidden {
|
||||
continue
|
||||
}
|
||||
visible = append(visible, *spec)
|
||||
}
|
||||
|
||||
orderRank := make(map[string]int, len(launcherIntegrationOrder))
|
||||
for i, name := range launcherIntegrationOrder {
|
||||
orderRank[name] = i + 1
|
||||
}
|
||||
|
||||
slices.SortFunc(visible, func(a, b IntegrationSpec) int {
|
||||
aRank, bRank := orderRank[a.Name], orderRank[b.Name]
|
||||
if aRank > 0 && bRank > 0 {
|
||||
return aRank - bRank
|
||||
}
|
||||
if aRank > 0 {
|
||||
return 1
|
||||
}
|
||||
if bRank > 0 {
|
||||
return -1
|
||||
}
|
||||
return strings.Compare(a.Name, b.Name)
|
||||
})
|
||||
|
||||
return visible
|
||||
}
|
||||
|
||||
// ListIntegrationInfos returns the registered integrations in launcher display order.
|
||||
func ListIntegrationInfos() []IntegrationInfo {
|
||||
visible := ListVisibleIntegrationSpecs()
|
||||
infos := make([]IntegrationInfo, 0, len(visible))
|
||||
for _, spec := range visible {
|
||||
infos = append(infos, IntegrationInfo{
|
||||
Name: spec.Name,
|
||||
DisplayName: spec.Runner.String(),
|
||||
Description: spec.Description,
|
||||
})
|
||||
}
|
||||
return infos
|
||||
}
|
||||
|
||||
// IntegrationSelectionItems returns the sorted integration items shown by launcher selection UIs.
|
||||
func IntegrationSelectionItems() ([]ModelItem, error) {
|
||||
visible := ListVisibleIntegrationSpecs()
|
||||
if len(visible) == 0 {
|
||||
return nil, fmt.Errorf("no integrations available")
|
||||
}
|
||||
|
||||
items := make([]ModelItem, 0, len(visible))
|
||||
for _, spec := range visible {
|
||||
description := spec.Runner.String()
|
||||
if conn, err := config.LoadIntegration(spec.Name); err == nil && len(conn.Models) > 0 {
|
||||
description = fmt.Sprintf("%s (%s)", spec.Runner.String(), conn.Models[0])
|
||||
}
|
||||
items = append(items, ModelItem{Name: spec.Name, Description: description})
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
// IsIntegrationInstalled checks if an integration binary is installed.
|
||||
func IsIntegrationInstalled(name string) bool {
|
||||
spec, err := LookupIntegrationSpec(name)
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
if spec.Install.CheckInstalled == nil {
|
||||
return true
|
||||
}
|
||||
return spec.Install.CheckInstalled()
|
||||
}
|
||||
|
||||
// AutoInstallable returns true if the integration can be automatically installed when missing.
|
||||
func AutoInstallable(name string) bool {
|
||||
spec, err := LookupIntegrationSpec(name)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return spec.Install.EnsureInstalled != nil
|
||||
}
|
||||
|
||||
// EnsureInstalled checks if an auto-installable integration is present and offers to install it if missing.
|
||||
func EnsureInstalled(name string) error {
|
||||
spec, err := LookupIntegrationSpec(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if spec.Install.EnsureInstalled == nil || IsIntegrationInstalled(name) {
|
||||
return nil
|
||||
}
|
||||
return spec.Install.EnsureInstalled()
|
||||
}
|
||||
|
||||
// IsEditorIntegration returns true if the named integration uses multi-model selection.
|
||||
func IsEditorIntegration(name string) bool {
|
||||
_, runner, err := LookupIntegration(name)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
_, isEditor := runner.(Editor)
|
||||
return isEditor
|
||||
}
|
||||
|
||||
// IntegrationInstallHint returns a user-friendly install hint for the given integration.
|
||||
func IntegrationInstallHint(name string) string {
|
||||
spec, err := LookupIntegrationSpec(name)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
if spec.Install.URL != "" {
|
||||
return "Install from " + hyperlink(spec.Install.URL, spec.Install.URL)
|
||||
}
|
||||
if len(spec.Install.Command) > 0 {
|
||||
return "Install with: " + strings.Join(spec.Install.Command, " ")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// EnsureIntegrationInstalled installs auto-installable integrations when missing.
|
||||
func EnsureIntegrationInstalled(name string, runner Runner) error {
|
||||
if IsIntegrationInstalled(name) {
|
||||
return nil
|
||||
}
|
||||
if AutoInstallable(name) {
|
||||
return EnsureInstalled(name)
|
||||
}
|
||||
return IntegrationInstallError(name, runner)
|
||||
}
|
||||
|
||||
// IntegrationInstallError reports a user-facing install error for missing integrations.
|
||||
func IntegrationInstallError(name string, runner Runner) error {
|
||||
spec, err := LookupIntegrationSpec(name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s is not installed", runner)
|
||||
}
|
||||
switch {
|
||||
case spec.Install.URL != "":
|
||||
return fmt.Errorf("%s is not installed, install from %s", spec.Name, spec.Install.URL)
|
||||
case len(spec.Install.Command) > 0:
|
||||
return fmt.Errorf("%s is not installed, install with: %s", spec.Name, strings.Join(spec.Install.Command, " "))
|
||||
default:
|
||||
return fmt.Errorf("%s is not installed", runner)
|
||||
}
|
||||
}
|
||||
68
cmd/launch/runner_exec_only_test.go
Normal file
68
cmd/launch/runner_exec_only_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEditorRunsDoNotRewriteConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
binary string
|
||||
runner Runner
|
||||
checkPath func(home string) string
|
||||
}{
|
||||
{
|
||||
name: "droid",
|
||||
binary: "droid",
|
||||
runner: &Droid{},
|
||||
checkPath: func(home string) string {
|
||||
return filepath.Join(home, ".factory", "settings.json")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "opencode",
|
||||
binary: "opencode",
|
||||
runner: &OpenCode{},
|
||||
checkPath: func(home string) string {
|
||||
return filepath.Join(home, ".config", "opencode", "opencode.json")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "cline",
|
||||
binary: "cline",
|
||||
runner: &Cline{},
|
||||
checkPath: func(home string) string {
|
||||
return filepath.Join(home, ".cline", "data", "globalState.json")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "pi",
|
||||
binary: "pi",
|
||||
runner: &Pi{},
|
||||
checkPath: func(home string) string {
|
||||
return filepath.Join(home, ".pi", "agent", "models.json")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
home := t.TempDir()
|
||||
setTestHome(t, home)
|
||||
|
||||
binDir := t.TempDir()
|
||||
writeFakeBinary(t, binDir, tt.binary)
|
||||
t.Setenv("PATH", binDir)
|
||||
|
||||
configPath := tt.checkPath(home)
|
||||
if err := tt.runner.Run("llama3.2", nil); err != nil {
|
||||
t.Fatalf("Run returned error: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(configPath); !os.IsNotExist(err) {
|
||||
t.Fatalf("expected Run to leave %s untouched, got err=%v", configPath, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
78
cmd/launch/selector_hooks.go
Normal file
78
cmd/launch/selector_hooks.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// ANSI escape sequences for terminal formatting.
|
||||
const (
|
||||
ansiBold = "\033[1m"
|
||||
ansiReset = "\033[0m"
|
||||
ansiGray = "\033[37m"
|
||||
ansiGreen = "\033[32m"
|
||||
ansiYellow = "\033[33m"
|
||||
)
|
||||
|
||||
// ErrCancelled is returned when the user cancels a selection.
|
||||
var ErrCancelled = errors.New("cancelled")
|
||||
|
||||
// errCancelled is kept as an internal alias for existing call sites.
|
||||
var errCancelled = ErrCancelled
|
||||
|
||||
// DefaultConfirmPrompt provides a TUI-based confirmation prompt.
|
||||
// When set, ConfirmPrompt delegates to it instead of using raw terminal I/O.
|
||||
var DefaultConfirmPrompt func(prompt string) (bool, error)
|
||||
|
||||
// SingleSelector is a function type for single item selection.
|
||||
// current is the name of the previously selected item to highlight; empty means no pre-selection.
|
||||
type SingleSelector func(title string, items []ModelItem, current string) (string, error)
|
||||
|
||||
// MultiSelector is a function type for multi item selection.
|
||||
type MultiSelector func(title string, items []ModelItem, preChecked []string) ([]string, error)
|
||||
|
||||
// DefaultSingleSelector is the default single-select implementation.
|
||||
var DefaultSingleSelector SingleSelector
|
||||
|
||||
// DefaultMultiSelector is the default multi-select implementation.
|
||||
var DefaultMultiSelector MultiSelector
|
||||
|
||||
// DefaultSignIn provides a TUI-based sign-in flow.
|
||||
// When set, EnsureAuth uses it instead of plain text prompts.
|
||||
// Returns the signed-in username or an error.
|
||||
var DefaultSignIn func(modelName, signInURL string) (string, error)
|
||||
|
||||
// ConfirmPrompt asks the user to confirm an action using the configured prompt hook.
|
||||
func ConfirmPrompt(prompt string) (bool, error) {
|
||||
if DefaultConfirmPrompt != nil {
|
||||
return DefaultConfirmPrompt(prompt)
|
||||
}
|
||||
|
||||
fd := int(os.Stdin.Fd())
|
||||
oldState, err := term.MakeRaw(fd)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer term.Restore(fd, oldState)
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%s (\033[1my\033[0m/n) ", prompt)
|
||||
|
||||
buf := make([]byte, 1)
|
||||
for {
|
||||
if _, err := os.Stdin.Read(buf); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
switch buf[0] {
|
||||
case 'Y', 'y', 13:
|
||||
fmt.Fprintf(os.Stderr, "yes\r\n")
|
||||
return true, nil
|
||||
case 'N', 'n', 27, 3:
|
||||
fmt.Fprintf(os.Stderr, "no\r\n")
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"testing"
|
||||
82
cmd/launch/test_config_helpers_test.go
Normal file
82
cmd/launch/test_config_helpers_test.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
)
|
||||
|
||||
var (
|
||||
integrations map[string]Runner
|
||||
integrationAliases map[string]bool
|
||||
integrationOrder = launcherIntegrationOrder
|
||||
)
|
||||
|
||||
func init() {
|
||||
integrations = buildTestIntegrations()
|
||||
integrationAliases = buildTestIntegrationAliases()
|
||||
}
|
||||
|
||||
func buildTestIntegrations() map[string]Runner {
|
||||
result := make(map[string]Runner, len(integrationSpecsByName))
|
||||
for name, spec := range integrationSpecsByName {
|
||||
result[strings.ToLower(name)] = spec.Runner
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func buildTestIntegrationAliases() map[string]bool {
|
||||
result := make(map[string]bool)
|
||||
for _, spec := range integrationSpecs {
|
||||
for _, alias := range spec.Aliases {
|
||||
result[strings.ToLower(alias)] = true
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func setTestHome(t *testing.T, dir string) {
|
||||
t.Helper()
|
||||
setLaunchTestHome(t, dir)
|
||||
}
|
||||
|
||||
func SaveIntegration(appName string, models []string) error {
|
||||
return config.SaveIntegration(appName, models)
|
||||
}
|
||||
|
||||
func LoadIntegration(appName string) (*config.IntegrationConfig, error) {
|
||||
return config.LoadIntegration(appName)
|
||||
}
|
||||
|
||||
func SaveAliases(appName string, aliases map[string]string) error {
|
||||
return config.SaveAliases(appName, aliases)
|
||||
}
|
||||
|
||||
func LastModel() string {
|
||||
return config.LastModel()
|
||||
}
|
||||
|
||||
func SetLastModel(model string) error {
|
||||
return config.SetLastModel(model)
|
||||
}
|
||||
|
||||
func LastSelection() string {
|
||||
return config.LastSelection()
|
||||
}
|
||||
|
||||
func SetLastSelection(selection string) error {
|
||||
return config.SetLastSelection(selection)
|
||||
}
|
||||
|
||||
func IntegrationModel(appName string) string {
|
||||
return config.IntegrationModel(appName)
|
||||
}
|
||||
|
||||
func IntegrationModels(appName string) []string {
|
||||
return config.IntegrationModels(appName)
|
||||
}
|
||||
|
||||
func integrationOnboarded(appName string) error {
|
||||
return config.MarkIntegrationOnboarded(appName)
|
||||
}
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/cmd/launch"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -64,8 +64,8 @@ type SelectItem struct {
|
||||
Recommended bool
|
||||
}
|
||||
|
||||
// ConvertItems converts config.ModelItem slice to SelectItem slice.
|
||||
func ConvertItems(items []config.ModelItem) []SelectItem {
|
||||
// ConvertItems converts launch.ModelItem slice to SelectItem slice.
|
||||
func ConvertItems(items []launch.ModelItem) []SelectItem {
|
||||
out := make([]SelectItem, len(items))
|
||||
for i, item := range items {
|
||||
out[i] = SelectItem{Name: item.Name, Description: item.Description, Recommended: item.Recommended}
|
||||
@@ -101,6 +101,16 @@ type selectorModel struct {
|
||||
width int
|
||||
}
|
||||
|
||||
func selectorModelWithCurrent(title string, items []SelectItem, current string) selectorModel {
|
||||
m := selectorModel{
|
||||
title: title,
|
||||
items: items,
|
||||
cursor: cursorForCurrent(items, current),
|
||||
}
|
||||
m.updateScroll(m.otherStart())
|
||||
return m
|
||||
}
|
||||
|
||||
func (m selectorModel) filteredItems() []SelectItem {
|
||||
if m.filter == "" {
|
||||
return m.items
|
||||
@@ -382,11 +392,7 @@ func SelectSingle(title string, items []SelectItem, current string) (string, err
|
||||
return "", fmt.Errorf("no items to select from")
|
||||
}
|
||||
|
||||
m := selectorModel{
|
||||
title: title,
|
||||
items: items,
|
||||
cursor: cursorForCurrent(items, current),
|
||||
}
|
||||
m := selectorModelWithCurrent(title, items, current)
|
||||
|
||||
p := tea.NewProgram(m)
|
||||
finalModel, err := p.Run()
|
||||
|
||||
@@ -216,6 +216,22 @@ func TestUpdateScroll(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectorModelWithCurrent_ScrollsToCurrentInMoreSection(t *testing.T) {
|
||||
m := selectorModelWithCurrent("Pick:", mixedItems(), "other-10")
|
||||
|
||||
if m.cursor != 11 {
|
||||
t.Fatalf("cursor = %d, want 11", m.cursor)
|
||||
}
|
||||
if m.scrollOffset == 0 {
|
||||
t.Fatal("scrollOffset should move to reveal current item in More section")
|
||||
}
|
||||
|
||||
content := m.renderContent()
|
||||
if !strings.Contains(content, "▸ other-10") {
|
||||
t.Fatalf("expected current item to be visible and highlighted\n%s", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderContent_SectionHeaders(t *testing.T) {
|
||||
m := selectorModel{
|
||||
title: "Pick:",
|
||||
|
||||
@@ -1,15 +1,24 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/launch"
|
||||
)
|
||||
|
||||
type signInTickMsg struct{}
|
||||
|
||||
type signInCheckMsg struct {
|
||||
signedIn bool
|
||||
userName string
|
||||
}
|
||||
|
||||
type signInModel struct {
|
||||
modelName string
|
||||
signInURL string
|
||||
@@ -104,9 +113,21 @@ func renderSignIn(modelName, signInURL string, spinner, width int) string {
|
||||
return lipgloss.NewStyle().PaddingLeft(2).Render(s.String())
|
||||
}
|
||||
|
||||
func checkSignIn() tea.Msg {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return signInCheckMsg{signedIn: false}
|
||||
}
|
||||
user, err := client.Whoami(context.Background())
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return signInCheckMsg{signedIn: true, userName: user.Name}
|
||||
}
|
||||
return signInCheckMsg{signedIn: false}
|
||||
}
|
||||
|
||||
// RunSignIn shows a bubbletea sign-in dialog and polls until the user signs in or cancels.
|
||||
func RunSignIn(modelName, signInURL string) (string, error) {
|
||||
config.OpenBrowser(signInURL)
|
||||
launch.OpenBrowser(signInURL)
|
||||
|
||||
m := signInModel{
|
||||
modelName: modelName,
|
||||
|
||||
839
cmd/tui/tui.go
839
cmd/tui/tui.go
@@ -1,16 +1,11 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/cmd/launch"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
@@ -45,7 +40,7 @@ var (
|
||||
type menuItem struct {
|
||||
title string
|
||||
description string
|
||||
integration string // integration name for loading model config, empty if not an integration
|
||||
integration string
|
||||
isRunModel bool
|
||||
isOthers bool
|
||||
}
|
||||
@@ -57,18 +52,12 @@ var mainMenuItems = []menuItem{
|
||||
isRunModel: true,
|
||||
},
|
||||
{
|
||||
title: "Launch Claude Code",
|
||||
description: "Agentic coding across large codebases",
|
||||
integration: "claude",
|
||||
},
|
||||
{
|
||||
title: "Launch Codex",
|
||||
description: "OpenAI's open-source coding agent",
|
||||
integration: "codex",
|
||||
},
|
||||
{
|
||||
title: "Launch OpenClaw",
|
||||
description: "Personal AI with 100+ skills",
|
||||
integration: "openclaw",
|
||||
},
|
||||
}
|
||||
@@ -79,277 +68,106 @@ var othersMenuItem = menuItem{
|
||||
isOthers: true,
|
||||
}
|
||||
|
||||
// getOtherIntegrations dynamically builds the "Others" list from the integration
|
||||
// registry, excluding any integrations already present in the pinned mainMenuItems.
|
||||
func getOtherIntegrations() []menuItem {
|
||||
pinned := map[string]bool{
|
||||
"run": true, // not an integration but in the pinned list
|
||||
type model struct {
|
||||
state *launch.LauncherState
|
||||
items []menuItem
|
||||
cursor int
|
||||
showOthers bool
|
||||
width int
|
||||
quitting bool
|
||||
selected bool
|
||||
action TUIAction
|
||||
}
|
||||
|
||||
func newModel(state *launch.LauncherState) model {
|
||||
m := model{
|
||||
state: state,
|
||||
}
|
||||
m.showOthers = shouldExpandOthers(state)
|
||||
m.items = buildMenuItems(state, m.showOthers)
|
||||
m.cursor = initialCursor(state, m.items)
|
||||
return m
|
||||
}
|
||||
|
||||
func shouldExpandOthers(state *launch.LauncherState) bool {
|
||||
if state == nil {
|
||||
return false
|
||||
}
|
||||
for _, item := range otherIntegrationItems(state) {
|
||||
if item.integration == state.LastSelection {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func buildMenuItems(state *launch.LauncherState, showOthers bool) []menuItem {
|
||||
items := make([]menuItem, 0, len(mainMenuItems)+1)
|
||||
for _, item := range mainMenuItems {
|
||||
if item.integration != "" {
|
||||
pinned[item.integration] = true
|
||||
if item.integration == "" {
|
||||
items = append(items, item)
|
||||
continue
|
||||
}
|
||||
if integrationState, ok := state.Integrations[item.integration]; ok {
|
||||
items = append(items, integrationMenuItem(integrationState))
|
||||
}
|
||||
}
|
||||
|
||||
var others []menuItem
|
||||
for _, info := range config.ListIntegrationInfos() {
|
||||
if showOthers {
|
||||
items = append(items, otherIntegrationItems(state)...)
|
||||
} else {
|
||||
items = append(items, othersMenuItem)
|
||||
}
|
||||
|
||||
return items
|
||||
}
|
||||
|
||||
func integrationMenuItem(state launch.LauncherIntegrationState) menuItem {
|
||||
description := state.Description
|
||||
if description == "" {
|
||||
description = "Open " + state.DisplayName + " integration"
|
||||
}
|
||||
return menuItem{
|
||||
title: "Launch " + state.DisplayName,
|
||||
description: description,
|
||||
integration: state.Name,
|
||||
}
|
||||
}
|
||||
|
||||
func otherIntegrationItems(state *launch.LauncherState) []menuItem {
|
||||
pinned := map[string]bool{
|
||||
"claude": true,
|
||||
"codex": true,
|
||||
"openclaw": true,
|
||||
}
|
||||
|
||||
var items []menuItem
|
||||
for _, info := range launch.ListIntegrationInfos() {
|
||||
if pinned[info.Name] {
|
||||
continue
|
||||
}
|
||||
desc := info.Description
|
||||
if desc == "" {
|
||||
desc = "Open " + info.DisplayName + " integration"
|
||||
}
|
||||
others = append(others, menuItem{
|
||||
title: "Launch " + info.DisplayName,
|
||||
description: desc,
|
||||
integration: info.Name,
|
||||
})
|
||||
}
|
||||
return others
|
||||
}
|
||||
|
||||
type model struct {
|
||||
items []menuItem
|
||||
cursor int
|
||||
quitting bool
|
||||
selected bool
|
||||
changeModel bool
|
||||
changeModels []string // multi-select result for Editor integrations
|
||||
showOthers bool
|
||||
availableModels map[string]bool
|
||||
err error
|
||||
|
||||
showingModal bool
|
||||
modalSelector selectorModel
|
||||
modalItems []SelectItem
|
||||
|
||||
showingMultiModal bool
|
||||
multiModalSelector multiSelectorModel
|
||||
|
||||
showingSignIn bool
|
||||
signInURL string
|
||||
signInModel string
|
||||
signInSpinner int
|
||||
signInFromModal bool // true if sign-in was triggered from modal (not main menu)
|
||||
|
||||
width int // terminal width from WindowSizeMsg
|
||||
statusMsg string // temporary status message shown near help text
|
||||
}
|
||||
|
||||
type signInTickMsg struct{}
|
||||
|
||||
type signInCheckMsg struct {
|
||||
signedIn bool
|
||||
userName string
|
||||
}
|
||||
|
||||
type clearStatusMsg struct{}
|
||||
|
||||
func (m *model) modelExists(name string) bool {
|
||||
if m.availableModels == nil || name == "" {
|
||||
return false
|
||||
}
|
||||
if m.availableModels[name] {
|
||||
return true
|
||||
}
|
||||
// Check for prefix match (e.g., "llama2" matches "llama2:latest")
|
||||
for modelName := range m.availableModels {
|
||||
if strings.HasPrefix(modelName, name+":") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *model) buildModalItems() []SelectItem {
|
||||
modelItems, _ := config.GetModelItems(context.Background())
|
||||
return ReorderItems(ConvertItems(modelItems))
|
||||
}
|
||||
|
||||
func (m *model) openModelModal(currentModel string) {
|
||||
m.modalItems = m.buildModalItems()
|
||||
cursor := 0
|
||||
if currentModel != "" {
|
||||
for i, item := range m.modalItems {
|
||||
if item.Name == currentModel || strings.HasPrefix(item.Name, currentModel+":") || strings.HasPrefix(currentModel, item.Name+":") {
|
||||
cursor = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
m.modalSelector = selectorModel{
|
||||
title: "Select model:",
|
||||
items: m.modalItems,
|
||||
cursor: cursor,
|
||||
helpText: "↑/↓ navigate • enter select • ← back",
|
||||
}
|
||||
m.modalSelector.updateScroll(m.modalSelector.otherStart())
|
||||
m.showingModal = true
|
||||
}
|
||||
|
||||
func (m *model) openMultiModelModal(integration string) {
|
||||
items := m.buildModalItems()
|
||||
var preChecked []string
|
||||
if models := config.IntegrationModels(integration); len(models) > 0 {
|
||||
preChecked = models
|
||||
}
|
||||
m.multiModalSelector = newMultiSelectorModel("Select models:", items, preChecked)
|
||||
// Set cursor to the first pre-checked (last used) model
|
||||
if len(preChecked) > 0 {
|
||||
for i, item := range items {
|
||||
if item.Name == preChecked[0] {
|
||||
m.multiModalSelector.cursor = i
|
||||
m.multiModalSelector.updateScroll(m.multiModalSelector.otherStart())
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
m.showingMultiModal = true
|
||||
}
|
||||
|
||||
func isCloudModel(name string) bool {
|
||||
return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud")
|
||||
}
|
||||
|
||||
func cloudStatusDisabled(client *api.Client) bool {
|
||||
status, err := client.CloudStatusExperimental(context.Background())
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return status.Cloud.Disabled
|
||||
}
|
||||
|
||||
func cloudModelDisabled(name string) bool {
|
||||
if !isCloudModel(name) {
|
||||
return false
|
||||
}
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return cloudStatusDisabled(client)
|
||||
}
|
||||
|
||||
// checkCloudSignIn checks if a cloud model needs sign-in.
|
||||
// Returns a command to start sign-in if needed, or nil if already signed in.
|
||||
func (m *model) checkCloudSignIn(modelName string, fromModal bool) tea.Cmd {
|
||||
if modelName == "" || !isCloudModel(modelName) {
|
||||
return nil
|
||||
}
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if cloudStatusDisabled(client) {
|
||||
return nil
|
||||
}
|
||||
user, err := client.Whoami(context.Background())
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return nil
|
||||
}
|
||||
var aErr api.AuthorizationError
|
||||
if errors.As(err, &aErr) && aErr.SigninURL != "" {
|
||||
return m.startSignIn(modelName, aErr.SigninURL, fromModal)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// startSignIn initiates the sign-in flow for a cloud model.
|
||||
// fromModal indicates if this was triggered from the model picker modal.
|
||||
func (m *model) startSignIn(modelName, signInURL string, fromModal bool) tea.Cmd {
|
||||
m.showingModal = false
|
||||
m.showingSignIn = true
|
||||
m.signInURL = signInURL
|
||||
m.signInModel = modelName
|
||||
m.signInSpinner = 0
|
||||
m.signInFromModal = fromModal
|
||||
|
||||
config.OpenBrowser(signInURL)
|
||||
|
||||
return tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
||||
return signInTickMsg{}
|
||||
})
|
||||
}
|
||||
|
||||
func checkSignIn() tea.Msg {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return signInCheckMsg{signedIn: false}
|
||||
}
|
||||
user, err := client.Whoami(context.Background())
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return signInCheckMsg{signedIn: true, userName: user.Name}
|
||||
}
|
||||
return signInCheckMsg{signedIn: false}
|
||||
}
|
||||
|
||||
func (m *model) loadAvailableModels() {
|
||||
m.availableModels = make(map[string]bool)
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
models, err := client.List(context.Background())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
cloudDisabled := cloudStatusDisabled(client)
|
||||
for _, mdl := range models.Models {
|
||||
if cloudDisabled && mdl.RemoteModel != "" {
|
||||
integrationState, ok := state.Integrations[info.Name]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
m.availableModels[mdl.Name] = true
|
||||
items = append(items, integrationMenuItem(integrationState))
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
func (m *model) buildItems() {
|
||||
others := getOtherIntegrations()
|
||||
m.items = make([]menuItem, 0, len(mainMenuItems)+1+len(others))
|
||||
m.items = append(m.items, mainMenuItems...)
|
||||
|
||||
if m.showOthers {
|
||||
m.items = append(m.items, others...)
|
||||
} else {
|
||||
m.items = append(m.items, othersMenuItem)
|
||||
func initialCursor(state *launch.LauncherState, items []menuItem) int {
|
||||
if state == nil || state.LastSelection == "" {
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func isOthersIntegration(name string) bool {
|
||||
for _, item := range getOtherIntegrations() {
|
||||
if item.integration == name {
|
||||
return true
|
||||
for i, item := range items {
|
||||
if state.LastSelection == "run" && item.isRunModel {
|
||||
return i
|
||||
}
|
||||
if item.integration == state.LastSelection {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func initialModel() model {
|
||||
m := model{
|
||||
cursor: 0,
|
||||
}
|
||||
m.loadAvailableModels()
|
||||
|
||||
lastSelection := config.LastSelection()
|
||||
if isOthersIntegration(lastSelection) {
|
||||
m.showOthers = true
|
||||
}
|
||||
|
||||
m.buildItems()
|
||||
|
||||
if lastSelection != "" {
|
||||
for i, item := range m.items {
|
||||
if lastSelection == "run" && item.isRunModel {
|
||||
m.cursor = i
|
||||
break
|
||||
} else if item.integration == lastSelection {
|
||||
m.cursor = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return m
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m model) Init() tea.Cmd {
|
||||
@@ -357,143 +175,11 @@ func (m model) Init() tea.Cmd {
|
||||
}
|
||||
|
||||
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if wmsg, ok := msg.(tea.WindowSizeMsg); ok {
|
||||
wasSet := m.width > 0
|
||||
m.width = wmsg.Width
|
||||
if wasSet {
|
||||
return m, tea.EnterAltScreen
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
if _, ok := msg.(clearStatusMsg); ok {
|
||||
m.statusMsg = ""
|
||||
return m, nil
|
||||
}
|
||||
|
||||
if m.showingSignIn {
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
switch msg.Type {
|
||||
case tea.KeyCtrlC, tea.KeyEsc:
|
||||
m.showingSignIn = false
|
||||
if m.signInFromModal {
|
||||
m.showingModal = true
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
case signInTickMsg:
|
||||
m.signInSpinner++
|
||||
// Check sign-in status every 5th tick (~1 second)
|
||||
if m.signInSpinner%5 == 0 {
|
||||
return m, tea.Batch(
|
||||
tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
||||
return signInTickMsg{}
|
||||
}),
|
||||
checkSignIn,
|
||||
)
|
||||
}
|
||||
return m, tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
||||
return signInTickMsg{}
|
||||
})
|
||||
|
||||
case signInCheckMsg:
|
||||
if msg.signedIn {
|
||||
if m.signInFromModal {
|
||||
m.modalSelector.selected = m.signInModel
|
||||
m.changeModel = true
|
||||
} else {
|
||||
m.selected = true
|
||||
}
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
if m.showingMultiModal {
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
if msg.Type == tea.KeyLeft {
|
||||
m.showingMultiModal = false
|
||||
return m, nil
|
||||
}
|
||||
updated, cmd := m.multiModalSelector.Update(msg)
|
||||
m.multiModalSelector = updated.(multiSelectorModel)
|
||||
|
||||
if m.multiModalSelector.cancelled {
|
||||
m.showingMultiModal = false
|
||||
return m, nil
|
||||
}
|
||||
if m.multiModalSelector.confirmed {
|
||||
var selected []string
|
||||
if m.multiModalSelector.singleAdd != "" {
|
||||
// Single-add mode: prepend picked model, keep existing deduped
|
||||
selected = []string{m.multiModalSelector.singleAdd}
|
||||
for _, name := range config.IntegrationModels(m.items[m.cursor].integration) {
|
||||
if name != m.multiModalSelector.singleAdd {
|
||||
selected = append(selected, name)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Last checked is default (first in result)
|
||||
co := m.multiModalSelector.checkOrder
|
||||
last := co[len(co)-1]
|
||||
selected = []string{m.multiModalSelector.items[last].Name}
|
||||
for _, idx := range co {
|
||||
if idx != last {
|
||||
selected = append(selected, m.multiModalSelector.items[idx].Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(selected) > 0 {
|
||||
m.changeModels = selected
|
||||
m.changeModel = true
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
m.multiModalSelector.confirmed = false
|
||||
return m, nil
|
||||
}
|
||||
return m, cmd
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
if m.showingModal {
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
switch msg.Type {
|
||||
case tea.KeyCtrlC, tea.KeyEsc, tea.KeyLeft:
|
||||
m.showingModal = false
|
||||
return m, nil
|
||||
|
||||
case tea.KeyEnter:
|
||||
filtered := m.modalSelector.filteredItems()
|
||||
if len(filtered) > 0 && m.modalSelector.cursor < len(filtered) {
|
||||
m.modalSelector.selected = filtered[m.modalSelector.cursor].Name
|
||||
}
|
||||
if m.modalSelector.selected != "" {
|
||||
if cmd := m.checkCloudSignIn(m.modalSelector.selected, true); cmd != nil {
|
||||
return m, cmd
|
||||
}
|
||||
m.changeModel = true
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
return m, nil
|
||||
|
||||
default:
|
||||
// Delegate navigation (up/down/pgup/pgdown/filter/backspace) to selectorModel
|
||||
m.modalSelector.updateNavigation(msg)
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
m.width = msg.Width
|
||||
return m, nil
|
||||
|
||||
case tea.KeyMsg:
|
||||
switch msg.String() {
|
||||
case "ctrl+c", "q", "esc":
|
||||
@@ -504,162 +190,78 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if m.cursor > 0 {
|
||||
m.cursor--
|
||||
}
|
||||
// Auto-collapse "Others" when cursor moves back into pinned items
|
||||
if m.showOthers && m.cursor < len(mainMenuItems) {
|
||||
m.showOthers = false
|
||||
m.buildItems()
|
||||
m.items = buildMenuItems(m.state, false)
|
||||
m.cursor = min(m.cursor, len(m.items)-1)
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case "down", "j":
|
||||
if m.cursor < len(m.items)-1 {
|
||||
m.cursor++
|
||||
}
|
||||
// Auto-expand "Others..." when cursor lands on it
|
||||
if m.cursor < len(m.items) && m.items[m.cursor].isOthers && !m.showOthers {
|
||||
m.showOthers = true
|
||||
m.buildItems()
|
||||
// cursor now points at the first "other" integration
|
||||
m.items = buildMenuItems(m.state, true)
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case "enter", " ":
|
||||
item := m.items[m.cursor]
|
||||
|
||||
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) && !config.AutoInstallable(item.integration) {
|
||||
return m, nil
|
||||
if m.selectableItem(m.items[m.cursor]) {
|
||||
m.selected = true
|
||||
m.action = actionForMenuItem(m.items[m.cursor], false)
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
|
||||
var configuredModel string
|
||||
if item.isRunModel {
|
||||
configuredModel = config.LastModel()
|
||||
} else if item.integration != "" {
|
||||
configuredModel = config.IntegrationModel(item.integration)
|
||||
}
|
||||
if cmd := m.checkCloudSignIn(configuredModel, false); cmd != nil {
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
if configuredModel != "" && isCloudModel(configuredModel) && cloudModelDisabled(configuredModel) {
|
||||
if item.integration != "" && config.IsEditorIntegration(item.integration) {
|
||||
m.openMultiModelModal(item.integration)
|
||||
} else {
|
||||
m.openModelModal(configuredModel)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
m.selected = true
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
return m, nil
|
||||
|
||||
case "right", "l":
|
||||
item := m.items[m.cursor]
|
||||
if item.integration != "" || item.isRunModel {
|
||||
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) {
|
||||
if config.AutoInstallable(item.integration) {
|
||||
// Auto-installable: select to trigger install flow
|
||||
m.selected = true
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
if item.integration != "" && config.IsEditorIntegration(item.integration) {
|
||||
m.openMultiModelModal(item.integration)
|
||||
} else {
|
||||
var currentModel string
|
||||
if item.isRunModel {
|
||||
currentModel = config.LastModel()
|
||||
} else if item.integration != "" {
|
||||
currentModel = config.IntegrationModel(item.integration)
|
||||
}
|
||||
m.openModelModal(currentModel)
|
||||
}
|
||||
if item.isRunModel || m.changeableItem(item) {
|
||||
m.selected = true
|
||||
m.action = actionForMenuItem(item, true)
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m model) selectableItem(item menuItem) bool {
|
||||
if item.isRunModel {
|
||||
return true
|
||||
}
|
||||
if item.integration == "" || item.isOthers {
|
||||
return false
|
||||
}
|
||||
state, ok := m.state.Integrations[item.integration]
|
||||
return ok && state.Selectable
|
||||
}
|
||||
|
||||
func (m model) changeableItem(item menuItem) bool {
|
||||
if item.integration == "" || item.isOthers {
|
||||
return false
|
||||
}
|
||||
state, ok := m.state.Integrations[item.integration]
|
||||
return ok && state.Changeable
|
||||
}
|
||||
|
||||
func (m model) View() string {
|
||||
if m.quitting {
|
||||
return ""
|
||||
}
|
||||
|
||||
if m.showingSignIn {
|
||||
return m.renderSignInDialog()
|
||||
}
|
||||
|
||||
if m.showingMultiModal {
|
||||
return m.multiModalSelector.View()
|
||||
}
|
||||
|
||||
if m.showingModal {
|
||||
return m.renderModal()
|
||||
}
|
||||
|
||||
s := selectorTitleStyle.Render("Ollama "+versionStyle.Render(version.Version)) + "\n\n"
|
||||
|
||||
for i, item := range m.items {
|
||||
cursor := ""
|
||||
style := menuItemStyle
|
||||
isInstalled := true
|
||||
|
||||
if item.integration != "" {
|
||||
isInstalled = config.IsIntegrationInstalled(item.integration)
|
||||
}
|
||||
|
||||
if m.cursor == i {
|
||||
cursor = "▸ "
|
||||
if isInstalled {
|
||||
style = menuSelectedItemStyle
|
||||
} else {
|
||||
style = greyedSelectedStyle
|
||||
}
|
||||
} else if !isInstalled && item.integration != "" {
|
||||
style = greyedStyle
|
||||
}
|
||||
|
||||
title := item.title
|
||||
var modelSuffix string
|
||||
if item.integration != "" {
|
||||
if !isInstalled {
|
||||
if config.AutoInstallable(item.integration) {
|
||||
title += " " + notInstalledStyle.Render("(install)")
|
||||
} else {
|
||||
title += " " + notInstalledStyle.Render("(not installed)")
|
||||
}
|
||||
} else if m.cursor == i {
|
||||
if mdl := config.IntegrationModel(item.integration); mdl != "" && m.modelExists(mdl) {
|
||||
modelSuffix = " " + modelStyle.Render("("+mdl+")")
|
||||
}
|
||||
}
|
||||
} else if item.isRunModel && m.cursor == i {
|
||||
if mdl := config.LastModel(); mdl != "" && m.modelExists(mdl) {
|
||||
modelSuffix = " " + modelStyle.Render("("+mdl+")")
|
||||
}
|
||||
}
|
||||
|
||||
s += style.Render(cursor+title) + modelSuffix + "\n"
|
||||
|
||||
desc := item.description
|
||||
if !isInstalled && item.integration != "" && m.cursor == i {
|
||||
if config.AutoInstallable(item.integration) {
|
||||
desc = "Press enter to install"
|
||||
} else if hint := config.IntegrationInstallHint(item.integration); hint != "" {
|
||||
desc = hint
|
||||
} else {
|
||||
desc = "not installed"
|
||||
}
|
||||
}
|
||||
s += menuDescStyle.Render(desc) + "\n\n"
|
||||
s += m.renderMenuItem(i, item)
|
||||
}
|
||||
|
||||
if m.statusMsg != "" {
|
||||
s += "\n" + lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "124", Dark: "210"}).Render(m.statusMsg) + "\n"
|
||||
}
|
||||
|
||||
s += "\n" + selectorHelpStyle.Render("↑/↓ navigate • enter launch • → change model • esc quit")
|
||||
s += "\n" + selectorHelpStyle.Render("↑/↓ navigate • enter launch • → configure • esc quit")
|
||||
|
||||
if m.width > 0 {
|
||||
return lipgloss.NewStyle().MaxWidth(m.width).Render(s)
|
||||
@@ -667,80 +269,125 @@ func (m model) View() string {
|
||||
return s
|
||||
}
|
||||
|
||||
func (m model) renderModal() string {
|
||||
modalStyle := lipgloss.NewStyle().
|
||||
PaddingBottom(1).
|
||||
PaddingRight(2)
|
||||
func (m model) renderMenuItem(index int, item menuItem) string {
|
||||
cursor := ""
|
||||
style := menuItemStyle
|
||||
title := item.title
|
||||
description := item.description
|
||||
modelSuffix := ""
|
||||
|
||||
s := modalStyle.Render(m.modalSelector.renderContent())
|
||||
if m.width > 0 {
|
||||
return lipgloss.NewStyle().MaxWidth(m.width).Render(s)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (m model) renderSignInDialog() string {
|
||||
return renderSignIn(m.signInModel, m.signInURL, m.signInSpinner, m.width)
|
||||
}
|
||||
|
||||
type Selection int
|
||||
|
||||
const (
|
||||
SelectionNone Selection = iota
|
||||
SelectionRunModel
|
||||
SelectionChangeRunModel
|
||||
SelectionIntegration // Generic integration selection
|
||||
SelectionChangeIntegration // Generic change model for integration
|
||||
)
|
||||
|
||||
type Result struct {
|
||||
Selection Selection
|
||||
Integration string // integration name if applicable
|
||||
Model string // model name if selected from single-select modal
|
||||
Models []string // models selected from multi-select modal (Editor integrations)
|
||||
}
|
||||
|
||||
func Run() (Result, error) {
|
||||
m := initialModel()
|
||||
p := tea.NewProgram(m)
|
||||
|
||||
finalModel, err := p.Run()
|
||||
if err != nil {
|
||||
return Result{Selection: SelectionNone}, fmt.Errorf("error running TUI: %w", err)
|
||||
}
|
||||
|
||||
fm := finalModel.(model)
|
||||
if fm.err != nil {
|
||||
return Result{Selection: SelectionNone}, fm.err
|
||||
}
|
||||
|
||||
if !fm.selected && !fm.changeModel {
|
||||
return Result{Selection: SelectionNone}, nil
|
||||
}
|
||||
|
||||
item := fm.items[fm.cursor]
|
||||
|
||||
if fm.changeModel {
|
||||
if item.isRunModel {
|
||||
return Result{
|
||||
Selection: SelectionChangeRunModel,
|
||||
Model: fm.modalSelector.selected,
|
||||
}, nil
|
||||
}
|
||||
return Result{
|
||||
Selection: SelectionChangeIntegration,
|
||||
Integration: item.integration,
|
||||
Model: fm.modalSelector.selected,
|
||||
Models: fm.changeModels,
|
||||
}, nil
|
||||
if m.cursor == index {
|
||||
cursor = "▸ "
|
||||
}
|
||||
|
||||
if item.isRunModel {
|
||||
return Result{Selection: SelectionRunModel}, nil
|
||||
if m.cursor == index && m.state.RunModel != "" {
|
||||
modelSuffix = " " + modelStyle.Render("("+m.state.RunModel+")")
|
||||
}
|
||||
if m.cursor == index {
|
||||
style = menuSelectedItemStyle
|
||||
}
|
||||
} else if item.isOthers {
|
||||
if m.cursor == index {
|
||||
style = menuSelectedItemStyle
|
||||
}
|
||||
} else {
|
||||
integrationState := m.state.Integrations[item.integration]
|
||||
if !integrationState.Selectable {
|
||||
if m.cursor == index {
|
||||
style = greyedSelectedStyle
|
||||
} else {
|
||||
style = greyedStyle
|
||||
}
|
||||
} else if m.cursor == index {
|
||||
style = menuSelectedItemStyle
|
||||
}
|
||||
|
||||
if m.cursor == index && integrationState.CurrentModel != "" {
|
||||
modelSuffix = " " + modelStyle.Render("("+integrationState.CurrentModel+")")
|
||||
}
|
||||
|
||||
if !integrationState.Installed {
|
||||
if integrationState.AutoInstallable {
|
||||
title += " " + notInstalledStyle.Render("(install)")
|
||||
} else {
|
||||
title += " " + notInstalledStyle.Render("(not installed)")
|
||||
}
|
||||
if m.cursor == index {
|
||||
if integrationState.AutoInstallable {
|
||||
description = "Press enter to install"
|
||||
} else if integrationState.InstallHint != "" {
|
||||
description = integrationState.InstallHint
|
||||
} else {
|
||||
description = "not installed"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Result{
|
||||
Selection: SelectionIntegration,
|
||||
Integration: item.integration,
|
||||
}, nil
|
||||
return style.Render(cursor+title) + modelSuffix + "\n" + menuDescStyle.Render(description) + "\n\n"
|
||||
}
|
||||
|
||||
type TUIActionKind int
|
||||
|
||||
const (
|
||||
TUIActionNone TUIActionKind = iota
|
||||
TUIActionRunModel
|
||||
TUIActionLaunchIntegration
|
||||
)
|
||||
|
||||
type TUIAction struct {
|
||||
Kind TUIActionKind
|
||||
Integration string
|
||||
ForceConfigure bool
|
||||
}
|
||||
|
||||
func (a TUIAction) LastSelection() string {
|
||||
switch a.Kind {
|
||||
case TUIActionRunModel:
|
||||
return "run"
|
||||
case TUIActionLaunchIntegration:
|
||||
return a.Integration
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func (a TUIAction) RunModelRequest() launch.RunModelRequest {
|
||||
return launch.RunModelRequest{ForcePicker: a.ForceConfigure}
|
||||
}
|
||||
|
||||
func (a TUIAction) IntegrationLaunchRequest() launch.IntegrationLaunchRequest {
|
||||
return launch.IntegrationLaunchRequest{
|
||||
Name: a.Integration,
|
||||
ForceConfigure: a.ForceConfigure,
|
||||
}
|
||||
}
|
||||
|
||||
func actionForMenuItem(item menuItem, forceConfigure bool) TUIAction {
|
||||
switch {
|
||||
case item.isRunModel:
|
||||
return TUIAction{Kind: TUIActionRunModel, ForceConfigure: forceConfigure}
|
||||
case item.integration != "":
|
||||
return TUIAction{Kind: TUIActionLaunchIntegration, Integration: item.integration, ForceConfigure: forceConfigure}
|
||||
default:
|
||||
return TUIAction{Kind: TUIActionNone}
|
||||
}
|
||||
}
|
||||
|
||||
func RunMenu(state *launch.LauncherState) (TUIAction, error) {
|
||||
menu := newModel(state)
|
||||
program := tea.NewProgram(menu)
|
||||
|
||||
finalModel, err := program.Run()
|
||||
if err != nil {
|
||||
return TUIAction{Kind: TUIActionNone}, fmt.Errorf("error running TUI: %w", err)
|
||||
}
|
||||
|
||||
finalMenu := finalModel.(model)
|
||||
if !finalMenu.selected {
|
||||
return TUIAction{Kind: TUIActionNone}, nil
|
||||
}
|
||||
|
||||
return finalMenu.action, nil
|
||||
}
|
||||
|
||||
178
cmd/tui/tui_test.go
Normal file
178
cmd/tui/tui_test.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/ollama/ollama/cmd/launch"
|
||||
)
|
||||
|
||||
func launcherTestState() *launch.LauncherState {
|
||||
return &launch.LauncherState{
|
||||
LastSelection: "run",
|
||||
RunModel: "qwen3:8b",
|
||||
Integrations: map[string]launch.LauncherIntegrationState{
|
||||
"claude": {
|
||||
Name: "claude",
|
||||
DisplayName: "Claude Code",
|
||||
Description: "Anthropic's coding tool with subagents",
|
||||
Selectable: true,
|
||||
Changeable: true,
|
||||
CurrentModel: "glm-5:cloud",
|
||||
},
|
||||
"codex": {
|
||||
Name: "codex",
|
||||
DisplayName: "Codex",
|
||||
Description: "OpenAI's open-source coding agent",
|
||||
Selectable: true,
|
||||
Changeable: true,
|
||||
},
|
||||
"openclaw": {
|
||||
Name: "openclaw",
|
||||
DisplayName: "OpenClaw",
|
||||
Description: "Personal AI with 100+ skills",
|
||||
Selectable: true,
|
||||
Changeable: true,
|
||||
AutoInstallable: true,
|
||||
},
|
||||
"droid": {
|
||||
Name: "droid",
|
||||
DisplayName: "Droid",
|
||||
Description: "Factory's coding agent across terminal and IDEs",
|
||||
Selectable: true,
|
||||
Changeable: true,
|
||||
},
|
||||
"pi": {
|
||||
Name: "pi",
|
||||
DisplayName: "Pi",
|
||||
Description: "Minimal AI agent toolkit with plugin support",
|
||||
Selectable: true,
|
||||
Changeable: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestMenuRendersPinnedItemsAndMore(t *testing.T) {
|
||||
view := newModel(launcherTestState()).View()
|
||||
for _, want := range []string{"Run a model", "Launch Claude Code", "Launch Codex", "Launch OpenClaw", "More..."} {
|
||||
if !strings.Contains(view, want) {
|
||||
t.Fatalf("expected menu view to contain %q\n%s", want, view)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMenuExpandsOthersFromLastSelection(t *testing.T) {
|
||||
state := launcherTestState()
|
||||
state.LastSelection = "pi"
|
||||
|
||||
menu := newModel(state)
|
||||
if !menu.showOthers {
|
||||
t.Fatal("expected others section to expand when last selection is in the overflow list")
|
||||
}
|
||||
view := menu.View()
|
||||
if !strings.Contains(view, "Launch Pi") {
|
||||
t.Fatalf("expected expanded view to contain overflow integration\n%s", view)
|
||||
}
|
||||
if strings.Contains(view, "More...") {
|
||||
t.Fatalf("expected expanded view to replace More... item\n%s", view)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMenuEnterOnRunSelectsRun(t *testing.T) {
|
||||
menu := newModel(launcherTestState())
|
||||
updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||
got := updated.(model)
|
||||
want := TUIAction{Kind: TUIActionRunModel}
|
||||
if !got.selected || got.action != want {
|
||||
t.Fatalf("expected enter on run to select run action, got selected=%v action=%v", got.selected, got.action)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMenuRightOnRunSelectsChangeRun(t *testing.T) {
|
||||
menu := newModel(launcherTestState())
|
||||
updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyRight})
|
||||
got := updated.(model)
|
||||
want := TUIAction{Kind: TUIActionRunModel, ForceConfigure: true}
|
||||
if !got.selected || got.action != want {
|
||||
t.Fatalf("expected right on run to select change-run action, got selected=%v action=%v", got.selected, got.action)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMenuEnterOnIntegrationSelectsLaunch(t *testing.T) {
|
||||
menu := newModel(launcherTestState())
|
||||
menu.cursor = 1
|
||||
updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||
got := updated.(model)
|
||||
want := TUIAction{Kind: TUIActionLaunchIntegration, Integration: "claude"}
|
||||
if !got.selected || got.action != want {
|
||||
t.Fatalf("expected enter on integration to launch, got selected=%v action=%v", got.selected, got.action)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMenuRightOnIntegrationSelectsConfigure(t *testing.T) {
|
||||
menu := newModel(launcherTestState())
|
||||
menu.cursor = 1
|
||||
updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyRight})
|
||||
got := updated.(model)
|
||||
want := TUIAction{Kind: TUIActionLaunchIntegration, Integration: "claude", ForceConfigure: true}
|
||||
if !got.selected || got.action != want {
|
||||
t.Fatalf("expected right on integration to configure, got selected=%v action=%v", got.selected, got.action)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMenuIgnoresDisabledActions(t *testing.T) {
|
||||
state := launcherTestState()
|
||||
claude := state.Integrations["claude"]
|
||||
claude.Selectable = false
|
||||
claude.Changeable = false
|
||||
state.Integrations["claude"] = claude
|
||||
|
||||
menu := newModel(state)
|
||||
menu.cursor = 1
|
||||
|
||||
updatedEnter, _ := menu.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||
if updatedEnter.(model).selected {
|
||||
t.Fatal("expected non-selectable integration to ignore enter")
|
||||
}
|
||||
|
||||
updatedRight, _ := menu.Update(tea.KeyMsg{Type: tea.KeyRight})
|
||||
if updatedRight.(model).selected {
|
||||
t.Fatal("expected non-changeable integration to ignore right")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMenuShowsCurrentModelSuffixes(t *testing.T) {
|
||||
menu := newModel(launcherTestState())
|
||||
runView := menu.View()
|
||||
if !strings.Contains(runView, "(qwen3:8b)") {
|
||||
t.Fatalf("expected run row to show current model suffix\n%s", runView)
|
||||
}
|
||||
|
||||
menu.cursor = 1
|
||||
integrationView := menu.View()
|
||||
if !strings.Contains(integrationView, "(glm-5:cloud)") {
|
||||
t.Fatalf("expected integration row to show current model suffix\n%s", integrationView)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMenuShowsInstallStatusAndHint(t *testing.T) {
|
||||
state := launcherTestState()
|
||||
codex := state.Integrations["codex"]
|
||||
codex.Installed = false
|
||||
codex.Selectable = false
|
||||
codex.Changeable = false
|
||||
codex.InstallHint = "Install from https://example.com/codex"
|
||||
state.Integrations["codex"] = codex
|
||||
|
||||
menu := newModel(state)
|
||||
menu.cursor = 2
|
||||
view := menu.View()
|
||||
if !strings.Contains(view, "(not installed)") {
|
||||
t.Fatalf("expected not-installed marker\n%s", view)
|
||||
}
|
||||
if !strings.Contains(view, codex.InstallHint) {
|
||||
t.Fatalf("expected install hint in description\n%s", view)
|
||||
}
|
||||
}
|
||||
@@ -152,7 +152,9 @@ PARAMETER <parameter> <parametervalue>
|
||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
|
||||
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
|
||||
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
|
||||
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |
|
||||
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.0) | float | repeat_penalty 1.0 |
|
||||
| presence_penalty | Penalizes tokens that have already appeared in the generated text to reduce repetition. (Default: 0.0) | float | presence_penalty 1.5 |
|
||||
| frequency_penalty | Penalizes tokens based on how often they have appeared in the generated text. (Default: 0.0) | float | frequency_penalty 1.0 |
|
||||
| temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | float | temperature 0.7 |
|
||||
| seed | Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. (Default: 0) | int | seed 42 |
|
||||
| stop | Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return. Multiple stop patterns may be set by specifying multiple separate `stop` parameters in a modelfile. | string | stop "AI assistant:" |
|
||||
|
||||
115
internal/modelref/modelref.go
Normal file
115
internal/modelref/modelref.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package modelref
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ModelSource uint8
|
||||
|
||||
const (
|
||||
ModelSourceUnspecified ModelSource = iota
|
||||
ModelSourceLocal
|
||||
ModelSourceCloud
|
||||
)
|
||||
|
||||
var (
|
||||
ErrConflictingSourceSuffix = errors.New("use either :local or :cloud, not both")
|
||||
ErrModelRequired = errors.New("model is required")
|
||||
)
|
||||
|
||||
type ParsedRef struct {
|
||||
Original string
|
||||
Base string
|
||||
Source ModelSource
|
||||
}
|
||||
|
||||
func ParseRef(raw string) (ParsedRef, error) {
|
||||
var zero ParsedRef
|
||||
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return zero, ErrModelRequired
|
||||
}
|
||||
|
||||
base, source, explicit := parseSourceSuffix(raw)
|
||||
if explicit {
|
||||
if _, _, nested := parseSourceSuffix(base); nested {
|
||||
return zero, fmt.Errorf("%w: %q", ErrConflictingSourceSuffix, raw)
|
||||
}
|
||||
}
|
||||
|
||||
return ParsedRef{
|
||||
Original: raw,
|
||||
Base: base,
|
||||
Source: source,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func HasExplicitCloudSource(raw string) bool {
|
||||
parsedRef, err := ParseRef(raw)
|
||||
return err == nil && parsedRef.Source == ModelSourceCloud
|
||||
}
|
||||
|
||||
func HasExplicitLocalSource(raw string) bool {
|
||||
parsedRef, err := ParseRef(raw)
|
||||
return err == nil && parsedRef.Source == ModelSourceLocal
|
||||
}
|
||||
|
||||
func StripCloudSourceTag(raw string) (string, bool) {
|
||||
parsedRef, err := ParseRef(raw)
|
||||
if err != nil || parsedRef.Source != ModelSourceCloud {
|
||||
return strings.TrimSpace(raw), false
|
||||
}
|
||||
|
||||
return parsedRef.Base, true
|
||||
}
|
||||
|
||||
func NormalizePullName(raw string) (string, bool, error) {
|
||||
parsedRef, err := ParseRef(raw)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
|
||||
if parsedRef.Source != ModelSourceCloud {
|
||||
return parsedRef.Base, false, nil
|
||||
}
|
||||
|
||||
return toLegacyCloudPullName(parsedRef.Base), true, nil
|
||||
}
|
||||
|
||||
func toLegacyCloudPullName(base string) string {
|
||||
if hasExplicitTag(base) {
|
||||
return base + "-cloud"
|
||||
}
|
||||
|
||||
return base + ":cloud"
|
||||
}
|
||||
|
||||
func hasExplicitTag(name string) bool {
|
||||
lastSlash := strings.LastIndex(name, "/")
|
||||
lastColon := strings.LastIndex(name, ":")
|
||||
return lastColon > lastSlash
|
||||
}
|
||||
|
||||
func parseSourceSuffix(raw string) (string, ModelSource, bool) {
|
||||
idx := strings.LastIndex(raw, ":")
|
||||
if idx >= 0 {
|
||||
suffixRaw := strings.TrimSpace(raw[idx+1:])
|
||||
suffix := strings.ToLower(suffixRaw)
|
||||
|
||||
switch suffix {
|
||||
case "cloud":
|
||||
return raw[:idx], ModelSourceCloud, true
|
||||
case "local":
|
||||
return raw[:idx], ModelSourceLocal, true
|
||||
}
|
||||
|
||||
if !strings.Contains(suffixRaw, "/") && strings.HasSuffix(suffix, "-cloud") {
|
||||
return raw[:idx+1] + suffixRaw[:len(suffixRaw)-len("-cloud")], ModelSourceCloud, true
|
||||
}
|
||||
}
|
||||
|
||||
return raw, ModelSourceUnspecified, false
|
||||
}
|
||||
268
internal/modelref/modelref_test.go
Normal file
268
internal/modelref/modelref_test.go
Normal file
@@ -0,0 +1,268 @@
|
||||
package modelref
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseRef(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantBase string
|
||||
wantSource ModelSource
|
||||
wantErr error
|
||||
wantCloud bool
|
||||
wantLocal bool
|
||||
wantStripped string
|
||||
wantStripOK bool
|
||||
}{
|
||||
{
|
||||
name: "cloud suffix",
|
||||
input: "gpt-oss:20b:cloud",
|
||||
wantBase: "gpt-oss:20b",
|
||||
wantSource: ModelSourceCloud,
|
||||
wantCloud: true,
|
||||
wantStripped: "gpt-oss:20b",
|
||||
wantStripOK: true,
|
||||
},
|
||||
{
|
||||
name: "legacy cloud suffix",
|
||||
input: "gpt-oss:20b-cloud",
|
||||
wantBase: "gpt-oss:20b",
|
||||
wantSource: ModelSourceCloud,
|
||||
wantCloud: true,
|
||||
wantStripped: "gpt-oss:20b",
|
||||
wantStripOK: true,
|
||||
},
|
||||
{
|
||||
name: "local suffix",
|
||||
input: "qwen3:8b:local",
|
||||
wantBase: "qwen3:8b",
|
||||
wantSource: ModelSourceLocal,
|
||||
wantLocal: true,
|
||||
wantStripped: "qwen3:8b:local",
|
||||
},
|
||||
{
|
||||
name: "no source suffix",
|
||||
input: "llama3.2",
|
||||
wantBase: "llama3.2",
|
||||
wantSource: ModelSourceUnspecified,
|
||||
wantStripped: "llama3.2",
|
||||
},
|
||||
{
|
||||
name: "bare cloud name is not explicit cloud",
|
||||
input: "my-cloud-model",
|
||||
wantBase: "my-cloud-model",
|
||||
wantSource: ModelSourceUnspecified,
|
||||
wantStripped: "my-cloud-model",
|
||||
},
|
||||
{
|
||||
name: "slash in suffix blocks legacy cloud parsing",
|
||||
input: "foo:bar-cloud/baz",
|
||||
wantBase: "foo:bar-cloud/baz",
|
||||
wantSource: ModelSourceUnspecified,
|
||||
wantStripped: "foo:bar-cloud/baz",
|
||||
},
|
||||
{
|
||||
name: "conflicting source suffixes",
|
||||
input: "foo:cloud:local",
|
||||
wantErr: ErrConflictingSourceSuffix,
|
||||
wantSource: ModelSourceUnspecified,
|
||||
},
|
||||
{
|
||||
name: "empty input",
|
||||
input: " ",
|
||||
wantErr: ErrModelRequired,
|
||||
wantSource: ModelSourceUnspecified,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ParseRef(tt.input)
|
||||
if tt.wantErr != nil {
|
||||
if !errors.Is(err, tt.wantErr) {
|
||||
t.Fatalf("ParseRef(%q) error = %v, want %v", tt.input, err, tt.wantErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("ParseRef(%q) returned error: %v", tt.input, err)
|
||||
}
|
||||
|
||||
if got.Base != tt.wantBase {
|
||||
t.Fatalf("base = %q, want %q", got.Base, tt.wantBase)
|
||||
}
|
||||
|
||||
if got.Source != tt.wantSource {
|
||||
t.Fatalf("source = %v, want %v", got.Source, tt.wantSource)
|
||||
}
|
||||
|
||||
if HasExplicitCloudSource(tt.input) != tt.wantCloud {
|
||||
t.Fatalf("HasExplicitCloudSource(%q) = %v, want %v", tt.input, HasExplicitCloudSource(tt.input), tt.wantCloud)
|
||||
}
|
||||
|
||||
if HasExplicitLocalSource(tt.input) != tt.wantLocal {
|
||||
t.Fatalf("HasExplicitLocalSource(%q) = %v, want %v", tt.input, HasExplicitLocalSource(tt.input), tt.wantLocal)
|
||||
}
|
||||
|
||||
stripped, ok := StripCloudSourceTag(tt.input)
|
||||
if ok != tt.wantStripOK {
|
||||
t.Fatalf("StripCloudSourceTag(%q) ok = %v, want %v", tt.input, ok, tt.wantStripOK)
|
||||
}
|
||||
if stripped != tt.wantStripped {
|
||||
t.Fatalf("StripCloudSourceTag(%q) base = %q, want %q", tt.input, stripped, tt.wantStripped)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizePullName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantName string
|
||||
wantCloud bool
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "explicit local strips source",
|
||||
input: "gpt-oss:20b:local",
|
||||
wantName: "gpt-oss:20b",
|
||||
},
|
||||
{
|
||||
name: "explicit cloud with size maps to legacy dash cloud tag",
|
||||
input: "gpt-oss:20b:cloud",
|
||||
wantName: "gpt-oss:20b-cloud",
|
||||
wantCloud: true,
|
||||
},
|
||||
{
|
||||
name: "legacy cloud with size remains stable",
|
||||
input: "gpt-oss:20b-cloud",
|
||||
wantName: "gpt-oss:20b-cloud",
|
||||
wantCloud: true,
|
||||
},
|
||||
{
|
||||
name: "explicit cloud without tag maps to cloud tag",
|
||||
input: "qwen3:cloud",
|
||||
wantName: "qwen3:cloud",
|
||||
wantCloud: true,
|
||||
},
|
||||
{
|
||||
name: "host port without tag keeps host port and appends cloud tag",
|
||||
input: "localhost:11434/library/foo:cloud",
|
||||
wantName: "localhost:11434/library/foo:cloud",
|
||||
wantCloud: true,
|
||||
},
|
||||
{
|
||||
name: "conflicting source suffixes fail",
|
||||
input: "foo:cloud:local",
|
||||
wantErr: ErrConflictingSourceSuffix,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotName, gotCloud, err := NormalizePullName(tt.input)
|
||||
if tt.wantErr != nil {
|
||||
if !errors.Is(err, tt.wantErr) {
|
||||
t.Fatalf("NormalizePullName(%q) error = %v, want %v", tt.input, err, tt.wantErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("NormalizePullName(%q) returned error: %v", tt.input, err)
|
||||
}
|
||||
|
||||
if gotName != tt.wantName {
|
||||
t.Fatalf("normalized name = %q, want %q", gotName, tt.wantName)
|
||||
}
|
||||
if gotCloud != tt.wantCloud {
|
||||
t.Fatalf("cloud = %v, want %v", gotCloud, tt.wantCloud)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSourceSuffix(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantBase string
|
||||
wantSource ModelSource
|
||||
wantExplicit bool
|
||||
}{
|
||||
{
|
||||
name: "explicit cloud suffix",
|
||||
input: "gpt-oss:20b:cloud",
|
||||
wantBase: "gpt-oss:20b",
|
||||
wantSource: ModelSourceCloud,
|
||||
wantExplicit: true,
|
||||
},
|
||||
{
|
||||
name: "explicit local suffix",
|
||||
input: "qwen3:8b:local",
|
||||
wantBase: "qwen3:8b",
|
||||
wantSource: ModelSourceLocal,
|
||||
wantExplicit: true,
|
||||
},
|
||||
{
|
||||
name: "legacy cloud suffix on tag",
|
||||
input: "gpt-oss:20b-cloud",
|
||||
wantBase: "gpt-oss:20b",
|
||||
wantSource: ModelSourceCloud,
|
||||
wantExplicit: true,
|
||||
},
|
||||
{
|
||||
name: "legacy cloud suffix does not match model segment",
|
||||
input: "my-cloud-model",
|
||||
wantBase: "my-cloud-model",
|
||||
wantSource: ModelSourceUnspecified,
|
||||
wantExplicit: false,
|
||||
},
|
||||
{
|
||||
name: "legacy cloud suffix blocked when suffix includes slash",
|
||||
input: "foo:bar-cloud/baz",
|
||||
wantBase: "foo:bar-cloud/baz",
|
||||
wantSource: ModelSourceUnspecified,
|
||||
wantExplicit: false,
|
||||
},
|
||||
{
|
||||
name: "unknown suffix is not explicit source",
|
||||
input: "gpt-oss:clod",
|
||||
wantBase: "gpt-oss:clod",
|
||||
wantSource: ModelSourceUnspecified,
|
||||
wantExplicit: false,
|
||||
},
|
||||
{
|
||||
name: "uppercase suffix is accepted",
|
||||
input: "gpt-oss:20b:CLOUD",
|
||||
wantBase: "gpt-oss:20b",
|
||||
wantSource: ModelSourceCloud,
|
||||
wantExplicit: true,
|
||||
},
|
||||
{
|
||||
name: "no suffix",
|
||||
input: "llama3.2",
|
||||
wantBase: "llama3.2",
|
||||
wantSource: ModelSourceUnspecified,
|
||||
wantExplicit: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotBase, gotSource, gotExplicit := parseSourceSuffix(tt.input)
|
||||
if gotBase != tt.wantBase {
|
||||
t.Fatalf("base = %q, want %q", gotBase, tt.wantBase)
|
||||
}
|
||||
if gotSource != tt.wantSource {
|
||||
t.Fatalf("source = %v, want %v", gotSource, tt.wantSource)
|
||||
}
|
||||
if gotExplicit != tt.wantExplicit {
|
||||
t.Fatalf("explicit = %v, want %v", gotExplicit, tt.wantExplicit)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -74,8 +74,7 @@ type LlamaServer interface {
|
||||
Tokenize(ctx context.Context, content string) ([]int, error)
|
||||
Detokenize(ctx context.Context, tokens []int) (string, error)
|
||||
Close() error
|
||||
VRAMSize() uint64 // Total VRAM across all GPUs
|
||||
TotalSize() uint64
|
||||
MemorySize() (total, vram uint64)
|
||||
VRAMByGPU(id ml.DeviceID) uint64
|
||||
Pid() int
|
||||
GetPort() int
|
||||
@@ -685,8 +684,9 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, system
|
||||
// Windows CUDA should not use mmap for best performance
|
||||
// Linux with a model larger than free space, mmap leads to thrashing
|
||||
// For CPU loads we want the memory to be allocated, not FS cache
|
||||
totalSize, _ := s.MemorySize()
|
||||
if (runtime.GOOS == "windows" && len(gpus) > 0 && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) ||
|
||||
(runtime.GOOS == "linux" && systemInfo.FreeMemory < s.TotalSize() && s.options.UseMMap == nil) ||
|
||||
(runtime.GOOS == "linux" && systemInfo.FreeMemory < totalSize && s.options.UseMMap == nil) ||
|
||||
(len(gpus) == 0 && s.options.UseMMap == nil) ||
|
||||
(len(gpus) > 0 && gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) ||
|
||||
(s.options.UseMMap != nil && !*s.options.UseMMap) {
|
||||
@@ -1848,17 +1848,17 @@ func (s *llamaServer) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *llmServer) VRAMSize() uint64 {
|
||||
func (s *llmServer) MemorySize() (total, vram uint64) {
|
||||
if s.mem == nil {
|
||||
return 0
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
var mem uint64
|
||||
|
||||
for _, g := range s.mem.GPUs {
|
||||
mem += g.Size()
|
||||
vram += g.Size()
|
||||
}
|
||||
|
||||
total = s.mem.InputWeights + s.mem.CPU.Size() + vram
|
||||
|
||||
// Some elements are always on CPU. However, if we have allocated all layers
|
||||
// on the GPU then include the CPU components as well, to represent complete offloading.
|
||||
noCPULayers := true
|
||||
@@ -1869,25 +1869,11 @@ func (s *llmServer) VRAMSize() uint64 {
|
||||
}
|
||||
}
|
||||
if noCPULayers {
|
||||
mem += s.mem.InputWeights
|
||||
mem += s.mem.CPU.Graph
|
||||
vram += s.mem.InputWeights
|
||||
vram += s.mem.CPU.Graph
|
||||
}
|
||||
|
||||
return mem
|
||||
}
|
||||
|
||||
func (s *llmServer) TotalSize() uint64 {
|
||||
if s.mem == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
mem := s.mem.InputWeights
|
||||
mem += s.mem.CPU.Size()
|
||||
for _, g := range s.mem.GPUs {
|
||||
mem += g.Size()
|
||||
}
|
||||
|
||||
return mem
|
||||
return total, vram
|
||||
}
|
||||
|
||||
func (s *llmServer) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||
"github.com/ollama/ollama/internal/modelref"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
@@ -919,7 +920,7 @@ func hasWebSearchTool(tools []anthropic.Tool) bool {
|
||||
}
|
||||
|
||||
func isCloudModelName(name string) bool {
|
||||
return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud")
|
||||
return modelref.HasExplicitCloudSource(name)
|
||||
}
|
||||
|
||||
// extractQueryFromToolCall extracts the search query from a web_search tool call
|
||||
|
||||
@@ -41,8 +41,8 @@ type GatedDeltaNet struct {
|
||||
SSMBeta *nn.Linear `gguf:"ssm_beta"` // -> beta (qwen35)
|
||||
SSMAlpha *nn.Linear `gguf:"ssm_alpha"` // -> alpha (qwen35)
|
||||
SSMConv1D *convKernel `gguf:"ssm_conv1d"`
|
||||
SSMDT ml.Tensor `gguf:"ssm_dt"` // alpha bias
|
||||
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
|
||||
SSMDT ml.Tensor `gguf:"ssm_dt,alt:ssm_dt.bias"` // alpha bias
|
||||
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
|
||||
SSMNorm *nn.RMSNorm `gguf:"ssm_norm"`
|
||||
SSMOut *nn.Linear `gguf:"ssm_out"`
|
||||
|
||||
@@ -135,6 +135,18 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
|
||||
default:
|
||||
return nil, errors.New("qwen3next: missing linear attention beta/alpha projections")
|
||||
}
|
||||
if gdn.SSMDT == nil {
|
||||
return nil, errors.New("qwen3next: missing linear attention ssm_dt tensor")
|
||||
}
|
||||
if gdn.SSMA == nil {
|
||||
return nil, errors.New("qwen3next: missing linear attention ssm_a tensor")
|
||||
}
|
||||
if gdn.SSMConv1D == nil || gdn.SSMConv1D.Weight == nil {
|
||||
return nil, errors.New("qwen3next: missing linear attention ssm_conv1d tensor")
|
||||
}
|
||||
if gdn.SSMNorm == nil || gdn.SSMOut == nil {
|
||||
return nil, errors.New("qwen3next: missing linear attention ssm_norm/ssm_out projections")
|
||||
}
|
||||
|
||||
// Compute gate: softplus(alpha + dt_bias) * -A
|
||||
alphaBiased := alpha.Add(ctx, gdn.SSMDT)
|
||||
@@ -442,6 +454,10 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
|
||||
vT := v.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, chunkSize, headVDim, nChunks, numVHeads*nSeqs)
|
||||
stateT := state.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, 1, numVHeads*nSeqs)
|
||||
|
||||
// Collect chunk outputs and concatenate at the end.
|
||||
// Avoids SET on buffer-less intermediates under partial offload.
|
||||
chunks := make([]ml.Tensor, nChunks)
|
||||
|
||||
for chunk := range nChunks {
|
||||
qChunk := q.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||
vTChunk := vT.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||
@@ -463,14 +479,7 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
|
||||
vAttn := vTNewChunk.Mulmat(ctx, attnChunk)
|
||||
coreAttnOutChunk := attnInter.Add(ctx, vAttn)
|
||||
|
||||
v = v.SetInplace(
|
||||
ctx,
|
||||
coreAttnOutChunk,
|
||||
v.Stride(1),
|
||||
v.Stride(2),
|
||||
v.Stride(3),
|
||||
chunk*v.Stride(2),
|
||||
)
|
||||
chunks[chunk] = coreAttnOutChunk
|
||||
|
||||
// Update state for next chunk
|
||||
gExpLastChunk := gLastExp.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||
@@ -483,6 +492,20 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
|
||||
stateT = stateT.Add(ctx, kgdMulVNew)
|
||||
}
|
||||
|
||||
// Use a balanced concat tree so concat work does not balloon on long prompts.
|
||||
for len(chunks) > 1 {
|
||||
merged := make([]ml.Tensor, 0, (len(chunks)+1)/2)
|
||||
for i := 0; i < len(chunks); i += 2 {
|
||||
if i+1 < len(chunks) {
|
||||
merged = append(merged, chunks[i].Concat(ctx, chunks[i+1], 2))
|
||||
} else {
|
||||
merged = append(merged, chunks[i])
|
||||
}
|
||||
}
|
||||
chunks = merged
|
||||
}
|
||||
v = chunks[0]
|
||||
|
||||
// Final reshape
|
||||
coreAttnOut := v.Contiguous(ctx, headVDim, chunkSize*nChunks, numVHeads, nSeqs)
|
||||
|
||||
|
||||
@@ -437,6 +437,46 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
||||
func (m *Model) Validate() error {
|
||||
if m.Options == nil {
|
||||
return fmt.Errorf("qwen3next: missing model options")
|
||||
}
|
||||
if len(m.Layers) != len(m.Options.isRecurrent) {
|
||||
return fmt.Errorf("qwen3next: layer config mismatch: have %d layers, %d recurrent flags", len(m.Layers), len(m.Options.isRecurrent))
|
||||
}
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
if !m.Options.isRecurrent[i] {
|
||||
continue
|
||||
}
|
||||
|
||||
gdn, ok := layer.Operator.(*GatedDeltaNet)
|
||||
if !ok || gdn == nil {
|
||||
return fmt.Errorf("qwen3next: layer %d expected recurrent operator", i)
|
||||
}
|
||||
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
|
||||
return fmt.Errorf("qwen3next: layer %d missing attn_qkv/attn_gate projections", i)
|
||||
}
|
||||
if gdn.SSMBetaAlpha == nil && (gdn.SSMBeta == nil || gdn.SSMAlpha == nil) {
|
||||
return fmt.Errorf("qwen3next: layer %d missing linear attention beta/alpha projections", i)
|
||||
}
|
||||
if gdn.SSMDT == nil {
|
||||
return fmt.Errorf("qwen3next: layer %d missing ssm_dt tensor", i)
|
||||
}
|
||||
if gdn.SSMA == nil {
|
||||
return fmt.Errorf("qwen3next: layer %d missing ssm_a tensor", i)
|
||||
}
|
||||
if gdn.SSMConv1D == nil || gdn.SSMConv1D.Weight == nil {
|
||||
return fmt.Errorf("qwen3next: layer %d missing ssm_conv1d tensor", i)
|
||||
}
|
||||
if gdn.SSMNorm == nil || gdn.SSMOut == nil {
|
||||
return fmt.Errorf("qwen3next: layer %d missing ssm_norm/ssm_out projections", i)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
m.positionCache = nil
|
||||
if len(m.mropeSections) > 0 {
|
||||
@@ -450,6 +490,64 @@ var (
|
||||
_ model.MultimodalProcessor = (*Model)(nil)
|
||||
)
|
||||
|
||||
func defaultVHeadReordered(arch string) bool {
|
||||
return arch == "qwen35" || arch == "qwen35moe"
|
||||
}
|
||||
|
||||
func inferRecurrentLayers(headCountKV []uint64, numLayers int, fullAttentionInterval uint32) ([]bool, error) {
|
||||
isRecurrent := make([]bool, numLayers)
|
||||
|
||||
hasZero := false
|
||||
hasFull := false
|
||||
for i := range numLayers {
|
||||
if i >= len(headCountKV) {
|
||||
continue
|
||||
}
|
||||
|
||||
if headCountKV[i] == 0 {
|
||||
isRecurrent[i] = true
|
||||
hasZero = true
|
||||
} else {
|
||||
hasFull = true
|
||||
}
|
||||
}
|
||||
if hasZero && hasFull {
|
||||
return isRecurrent, nil
|
||||
}
|
||||
if !hasFull {
|
||||
return nil, fmt.Errorf("qwen3next: attention.head_count_kv must include at least one non-zero value")
|
||||
}
|
||||
|
||||
// Compatibility path: older imports store a scalar KV head count and omit
|
||||
// per-layer recurrent flags. Derive the hybrid layout from the interval.
|
||||
interval := int(fullAttentionInterval)
|
||||
if interval == 0 {
|
||||
interval = min(4, numLayers)
|
||||
}
|
||||
if interval <= 0 {
|
||||
return nil, fmt.Errorf("qwen3next: invalid block_count (%d)", numLayers)
|
||||
}
|
||||
if interval > numLayers {
|
||||
return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds block_count (%d)", interval, numLayers)
|
||||
}
|
||||
|
||||
hasZero = false
|
||||
hasFull = false
|
||||
for i := range numLayers {
|
||||
isRecurrent[i] = (i+1)%interval != 0
|
||||
if isRecurrent[i] {
|
||||
hasZero = true
|
||||
} else {
|
||||
hasFull = true
|
||||
}
|
||||
}
|
||||
if !hasZero || !hasFull {
|
||||
return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) does not produce a mixed recurrent/full layout", interval)
|
||||
}
|
||||
|
||||
return isRecurrent, nil
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
numLayers := int(c.Uint("block_count"))
|
||||
layers := make([]Layer, numLayers)
|
||||
@@ -460,26 +558,14 @@ func New(c fs.Config) (model.Model, error) {
|
||||
HeadCountKV() []uint64
|
||||
}
|
||||
|
||||
var isRecurrent []bool
|
||||
var headCountKV []uint64
|
||||
if hc, ok := c.(headCounts); ok {
|
||||
headCountKV = hc.HeadCountKV()
|
||||
}
|
||||
|
||||
isRecurrent = make([]bool, numLayers)
|
||||
hasZero := false
|
||||
hasFull := false
|
||||
for i := range numLayers {
|
||||
// If KV head count is 0, it's a recurrent layer
|
||||
if i < len(headCountKV) && headCountKV[i] == 0 {
|
||||
isRecurrent[i] = true
|
||||
hasZero = true
|
||||
} else if i < len(headCountKV) && headCountKV[i] > 0 {
|
||||
hasFull = true
|
||||
}
|
||||
}
|
||||
if !hasZero || !hasFull {
|
||||
return nil, fmt.Errorf("qwen3next: invalid attention.head_count_kv array; expected mix of zero and non-zero values")
|
||||
isRecurrent, err := inferRecurrentLayers(headCountKV, numLayers, c.Uint("full_attention_interval"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Determine if MoE
|
||||
@@ -543,7 +629,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
ssmNGroup: int(c.Uint("ssm.group_count")),
|
||||
ssmDtRank: int(c.Uint("ssm.time_step_rank")),
|
||||
convKernelSize: int(c.Uint("ssm.conv_kernel")),
|
||||
vHeadReordered: c.Bool("ssm.v_head_reordered", false),
|
||||
vHeadReordered: c.Bool("ssm.v_head_reordered", defaultVHeadReordered(c.Architecture())),
|
||||
isRecurrent: isRecurrent,
|
||||
mropeSections: slices.Collect(func(yield func(int) bool) {
|
||||
for _, section := range mropeSections {
|
||||
@@ -555,7 +641,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
mropeInterleaved: c.Bool("rope.mrope_interleaved", c.Bool("mrope_interleaved", false)),
|
||||
}
|
||||
if opts.numKVHeads == 0 {
|
||||
return nil, fmt.Errorf("qwen3next: attention.head_count_kv array must include at least one non-zero value")
|
||||
return nil, fmt.Errorf("qwen3next: attention.head_count_kv must include at least one non-zero value")
|
||||
}
|
||||
|
||||
// Calculate cache dimensions
|
||||
|
||||
65
model/models/qwen3next/model_new_test.go
Normal file
65
model/models/qwen3next/model_new_test.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package qwen3next
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestInferRecurrentLayersMixedKVArray(t *testing.T) {
|
||||
got, err := inferRecurrentLayers([]uint64{0, 2, 0, 2}, 4, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("inferRecurrentLayers() error = %v", err)
|
||||
}
|
||||
|
||||
want := []bool{true, false, true, false}
|
||||
if !slices.Equal(got, want) {
|
||||
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInferRecurrentLayersScalarKVDefaultInterval(t *testing.T) {
|
||||
got, err := inferRecurrentLayers([]uint64{2, 2, 2, 2, 2, 2, 2, 2}, 8, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("inferRecurrentLayers() error = %v", err)
|
||||
}
|
||||
|
||||
want := []bool{true, true, true, false, true, true, true, false}
|
||||
if !slices.Equal(got, want) {
|
||||
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInferRecurrentLayersScalarKVConfiguredInterval(t *testing.T) {
|
||||
got, err := inferRecurrentLayers([]uint64{2, 2, 2, 2, 2, 2}, 6, 3)
|
||||
if err != nil {
|
||||
t.Fatalf("inferRecurrentLayers() error = %v", err)
|
||||
}
|
||||
|
||||
want := []bool{true, true, false, true, true, false}
|
||||
if !slices.Equal(got, want) {
|
||||
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInferRecurrentLayersAllZeroRejects(t *testing.T) {
|
||||
_, err := inferRecurrentLayers([]uint64{0, 0, 0, 0}, 4, 0)
|
||||
if err == nil {
|
||||
t.Fatal("inferRecurrentLayers() expected error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "must include at least one non-zero value") {
|
||||
t.Fatalf("unexpected error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultVHeadReordered(t *testing.T) {
|
||||
if !defaultVHeadReordered("qwen35") {
|
||||
t.Fatal("defaultVHeadReordered(qwen35) = false, want true")
|
||||
}
|
||||
if !defaultVHeadReordered("qwen35moe") {
|
||||
t.Fatal("defaultVHeadReordered(qwen35moe) = false, want true")
|
||||
}
|
||||
if defaultVHeadReordered("qwen3next") {
|
||||
t.Fatal("defaultVHeadReordered(qwen3next) = true, want false")
|
||||
}
|
||||
}
|
||||
45
model/models/qwen3next/model_validate_test.go
Normal file
45
model/models/qwen3next/model_validate_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package qwen3next
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
func TestValidateRecurrentLayerRequiresSSMDT(t *testing.T) {
|
||||
m := &Model{
|
||||
Layers: []Layer{{
|
||||
Operator: &GatedDeltaNet{
|
||||
SSMQKV: &nn.Linear{},
|
||||
SSMQKVGate: &nn.Linear{},
|
||||
SSMBeta: &nn.Linear{},
|
||||
SSMAlpha: &nn.Linear{},
|
||||
},
|
||||
}},
|
||||
Options: &Options{
|
||||
isRecurrent: []bool{true},
|
||||
},
|
||||
}
|
||||
|
||||
err := m.Validate()
|
||||
if err == nil {
|
||||
t.Fatal("Validate() expected error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "missing ssm_dt") {
|
||||
t.Fatalf("unexpected error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateNonRecurrentSkipsLinearChecks(t *testing.T) {
|
||||
m := &Model{
|
||||
Layers: []Layer{{Operator: &FullAttention{}}},
|
||||
Options: &Options{
|
||||
isRecurrent: []bool{false},
|
||||
},
|
||||
}
|
||||
|
||||
if err := m.Validate(); err != nil {
|
||||
t.Fatalf("Validate() error = %v", err)
|
||||
}
|
||||
}
|
||||
@@ -32,9 +32,10 @@ const (
|
||||
)
|
||||
|
||||
type GLM46Parser struct {
|
||||
state glm46ParserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
state glm46ParserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
callIndex int
|
||||
}
|
||||
|
||||
func (p *GLM46Parser) HasToolSupport() bool {
|
||||
@@ -48,6 +49,7 @@ func (p *GLM46Parser) HasThinkingSupport() bool {
|
||||
// func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
||||
func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
p.callIndex = 0
|
||||
return tools
|
||||
}
|
||||
|
||||
@@ -89,6 +91,8 @@ func (p *GLM46Parser) Add(s string, done bool) (content string, thinking string,
|
||||
slog.Warn("glm-4.6 tool call parsing failed", "error", err)
|
||||
return "", "", nil, err
|
||||
}
|
||||
toolCall.Function.Index = p.callIndex
|
||||
p.callIndex++
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
case glm46EventThinkingContent:
|
||||
thinkingSb.WriteString(event.content)
|
||||
|
||||
@@ -11,6 +11,7 @@ type GLM47Parser struct {
|
||||
|
||||
func (p *GLM47Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
p.callIndex = 0
|
||||
// When thinking is enabled (nil or true), the prompt ends with <think>,
|
||||
// so model output starts directly with thinking content (no opening tag).
|
||||
if thinkValue == nil || thinkValue.Bool() {
|
||||
|
||||
@@ -97,3 +97,91 @@ func TestGLM47ParserToolCallEscaping(t *testing.T) {
|
||||
t.Fatalf("expected %#v, got %#v", expected, toolCall)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGLM47ParserToolCallIndexing(t *testing.T) {
|
||||
parser := GLM47Parser{}
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
input := `plan</think>
|
||||
<tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call>
|
||||
<tool_call>second<arg_key>b</arg_key><arg_value>2</arg_value></tool_call>
|
||||
<tool_call>third<arg_key>c</arg_key><arg_value>3</arg_value></tool_call>`
|
||||
|
||||
_, _, calls, err := parser.Add(input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
want := []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
||||
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
||||
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
||||
}
|
||||
if len(calls) != len(want) {
|
||||
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
|
||||
}
|
||||
for i := range want {
|
||||
if !toolCallEqual(calls[i], want[i]) {
|
||||
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGLM47ParserToolCallIndexingStreaming(t *testing.T) {
|
||||
parser := GLM47Parser{}
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
var all []api.ToolCall
|
||||
|
||||
_, _, calls, err := parser.Add("plan</think><tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call><tool_call>second<arg_key>b</arg_key>", false)
|
||||
if err != nil {
|
||||
t.Fatalf("step 1 parse failed: %v", err)
|
||||
}
|
||||
all = append(all, calls...)
|
||||
|
||||
_, _, calls, err = parser.Add("<arg_value>2</arg_value></tool_call><tool_call>third<arg_key>c</arg_key><arg_value>3</arg_value></tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatalf("step 2 parse failed: %v", err)
|
||||
}
|
||||
all = append(all, calls...)
|
||||
|
||||
want := []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
||||
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
||||
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
||||
}
|
||||
if len(all) != len(want) {
|
||||
t.Fatalf("expected %d calls, got %d", len(want), len(all))
|
||||
}
|
||||
for i := range want {
|
||||
if !toolCallEqual(all[i], want[i]) {
|
||||
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGLM47ParserToolCallIndexResetOnInit(t *testing.T) {
|
||||
parser := GLM47Parser{}
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
_, _, _, err := parser.Add("plan</think><tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatalf("first parse failed: %v", err)
|
||||
}
|
||||
|
||||
parser.Init(nil, nil, nil)
|
||||
_, _, calls, err := parser.Add("plan</think><tool_call>second<arg_key>b</arg_key><arg_value>2</arg_value></tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatalf("second parse failed: %v", err)
|
||||
}
|
||||
|
||||
want := api.ToolCall{
|
||||
Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 0},
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||
}
|
||||
if !toolCallEqual(calls[0], want) {
|
||||
t.Fatalf("got %#v, want %#v", calls[0], want)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,6 +38,7 @@ type Qwen3Parser struct {
|
||||
state qwen3ParserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
callIndex int
|
||||
hasThinkingSupport bool
|
||||
defaultThinking bool
|
||||
maybeThinkingOpenAtBOL bool
|
||||
@@ -54,6 +55,7 @@ func (p *Qwen3Parser) HasThinkingSupport() bool {
|
||||
func (p *Qwen3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
p.buffer.Reset()
|
||||
p.callIndex = 0
|
||||
|
||||
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
|
||||
if thinkValue == nil {
|
||||
@@ -106,6 +108,8 @@ func (p *Qwen3Parser) Add(s string, done bool) (content string, thinking string,
|
||||
slog.Warn("qwen3 tool call parsing failed", "error", err)
|
||||
return "", "", nil, err
|
||||
}
|
||||
toolCall.Function.Index = p.callIndex
|
||||
p.callIndex++
|
||||
calls = append(calls, toolCall)
|
||||
case qwen3EventThinkingContent:
|
||||
thinkingSb.WriteString(event.content)
|
||||
@@ -204,6 +208,24 @@ func (p *Qwen3Parser) eat() ([]qwen3Event, bool) {
|
||||
p.maybeThinkingOpenAtBOL = false
|
||||
}
|
||||
|
||||
thinkingCloseIdx := strings.Index(acc, qwen3ThinkingCloseTag)
|
||||
toolOpenIdx := strings.Index(acc, qwen3ToolOpenTag)
|
||||
|
||||
// If a tool call starts before </think>, treat that as the end of thinking
|
||||
// for parsing purposes and continue in tool-call mode.
|
||||
if toolOpenIdx != -1 && (thinkingCloseIdx == -1 || toolOpenIdx < thinkingCloseIdx) {
|
||||
before, after := p.splitAtTag(qwen3ToolOpenTag, true)
|
||||
if len(before) > 0 {
|
||||
events = append(events, qwen3EventThinkingContent{content: before})
|
||||
}
|
||||
if after == "" {
|
||||
p.state = qwen3ParserStateToolStartedEatingWhitespace
|
||||
} else {
|
||||
p.state = qwen3ParserStateCollectingToolContent
|
||||
}
|
||||
return events, true
|
||||
}
|
||||
|
||||
if strings.Contains(acc, qwen3ThinkingCloseTag) {
|
||||
thinking, remaining := p.splitAtTag(qwen3ThinkingCloseTag, true)
|
||||
if len(thinking) > 0 {
|
||||
@@ -215,7 +237,7 @@ func (p *Qwen3Parser) eat() ([]qwen3Event, bool) {
|
||||
p.state = qwen3ParserStateCollectingContent
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(acc, qwen3ThinkingCloseTag); overlapLen > 0 {
|
||||
} else if overlapLen := max(overlap(acc, qwen3ThinkingCloseTag), overlap(acc, qwen3ToolOpenTag)); overlapLen > 0 {
|
||||
beforePartialTag := acc[:len(acc)-overlapLen]
|
||||
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWsLen
|
||||
|
||||
@@ -146,6 +146,68 @@ func TestQwen3ParserToolCall(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserThinkingWithToolCallBeforeThinkingClose(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||
|
||||
input := "Let me think<tool_call>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"San Francisco\",\"unit\":\"celsius\"}}</tool_call>"
|
||||
content, thinking, calls, err := parser.Add(input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if content != "" {
|
||||
t.Fatalf("expected empty content, got %q", content)
|
||||
}
|
||||
if thinking != "Let me think" {
|
||||
t.Fatalf("expected thinking %q, got %q", "Let me think", thinking)
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
||||
}
|
||||
if calls[0].Function.Name != "get_weather" {
|
||||
t.Fatalf("expected tool name %q, got %q", "get_weather", calls[0].Function.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserThinkingWithSplitToolOpenTag(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||
|
||||
content, thinking, calls, err := parser.Add("Let me think<tool_ca", false)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed on first chunk: %v", err)
|
||||
}
|
||||
if content != "" || thinking != "Let me think" || len(calls) != 0 {
|
||||
t.Fatalf(
|
||||
"expected content=%q thinking=%q calls=%d, got content=%q thinking=%q calls=%d",
|
||||
"",
|
||||
"Let me think",
|
||||
0,
|
||||
content,
|
||||
thinking,
|
||||
len(calls),
|
||||
)
|
||||
}
|
||||
|
||||
content, thinking, calls, err = parser.Add("ll>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"SF\"}}</tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed on second chunk: %v", err)
|
||||
}
|
||||
if content != "" {
|
||||
t.Fatalf("expected empty content, got %q", content)
|
||||
}
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected no additional thinking on second chunk, got %q", thinking)
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
||||
}
|
||||
if calls[0].Function.Name != "get_weather" {
|
||||
t.Fatalf("expected tool name %q, got %q", "get_weather", calls[0].Function.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35ParserRespectsNoThink(t *testing.T) {
|
||||
parser := ParserForName("qwen3.5")
|
||||
if parser == nil {
|
||||
@@ -168,3 +230,89 @@ func TestQwen35ParserRespectsNoThink(t *testing.T) {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserToolCallIndexing(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||
|
||||
input := `<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call>
|
||||
<tool_call>{"name":"second","arguments":{"b":"2"}}</tool_call>
|
||||
<tool_call>{"name":"third","arguments":{"c":"3"}}</tool_call>`
|
||||
_, _, calls, err := parser.Add(input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
want := []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
||||
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
||||
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
||||
}
|
||||
if len(calls) != len(want) {
|
||||
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
|
||||
}
|
||||
for i := range want {
|
||||
if !toolCallEqual(calls[i], want[i]) {
|
||||
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserToolCallIndexingStreaming(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||
|
||||
var all []api.ToolCall
|
||||
|
||||
_, _, calls, err := parser.Add(`<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call><tool_call>{"name":"second","arguments":{"b":"2"}`, false)
|
||||
if err != nil {
|
||||
t.Fatalf("step 1 parse failed: %v", err)
|
||||
}
|
||||
all = append(all, calls...)
|
||||
|
||||
_, _, calls, err = parser.Add(`}</tool_call><tool_call>{"name":"third","arguments":{"c":"3"}}</tool_call>`, true)
|
||||
if err != nil {
|
||||
t.Fatalf("step 2 parse failed: %v", err)
|
||||
}
|
||||
all = append(all, calls...)
|
||||
|
||||
want := []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
||||
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
||||
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
||||
}
|
||||
if len(all) != len(want) {
|
||||
t.Fatalf("expected %d calls, got %d", len(want), len(all))
|
||||
}
|
||||
for i := range want {
|
||||
if !toolCallEqual(all[i], want[i]) {
|
||||
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserToolCallIndexResetOnInit(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||
|
||||
_, _, _, err := parser.Add(`<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call>`, true)
|
||||
if err != nil {
|
||||
t.Fatalf("first parse failed: %v", err)
|
||||
}
|
||||
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||
_, _, calls, err := parser.Add(`<tool_call>{"name":"second","arguments":{"b":"2"}}</tool_call>`, true)
|
||||
if err != nil {
|
||||
t.Fatalf("second parse failed: %v", err)
|
||||
}
|
||||
|
||||
want := api.ToolCall{
|
||||
Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 0},
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||
}
|
||||
if !toolCallEqual(calls[0], want) {
|
||||
t.Fatalf("got %#v, want %#v", calls[0], want)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,9 +29,10 @@ const (
|
||||
)
|
||||
|
||||
type Qwen3CoderParser struct {
|
||||
state qwenParserState
|
||||
acc strings.Builder
|
||||
tools []api.Tool
|
||||
state qwenParserState
|
||||
acc strings.Builder
|
||||
tools []api.Tool
|
||||
callIndex int
|
||||
}
|
||||
|
||||
func (p *Qwen3CoderParser) HasToolSupport() bool {
|
||||
@@ -44,6 +45,7 @@ func (p *Qwen3CoderParser) HasThinkingSupport() bool {
|
||||
|
||||
func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
p.callIndex = 0
|
||||
return tools // Qwen doesn't modify tools
|
||||
}
|
||||
|
||||
@@ -62,6 +64,8 @@ func (p *Qwen3CoderParser) Add(s string, done bool) (content string, thinking st
|
||||
slog.Warn("qwen tool call parsing failed", "error", err)
|
||||
return "", "", nil, err
|
||||
}
|
||||
toolCall.Function.Index = p.callIndex
|
||||
p.callIndex++
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
case qwenEventContent:
|
||||
// TODO(drifkin): if the same turn contains multiple interleaved content
|
||||
|
||||
@@ -1035,6 +1035,92 @@ func TestQwenToolCallValueParsing(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3CoderParserToolCallIndexing(t *testing.T) {
|
||||
parser := Qwen3CoderParser{}
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
input := `<tool_call><function=first><parameter=a>1</parameter></function></tool_call>
|
||||
<tool_call><function=second><parameter=b>2</parameter></function></tool_call>
|
||||
<tool_call><function=third><parameter=c>3</parameter></function></tool_call>`
|
||||
_, _, calls, err := parser.Add(input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
want := []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "first", Arguments: testArgs(map[string]any{"a": "1"}), Index: 0}},
|
||||
{Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 1}},
|
||||
{Function: api.ToolCallFunction{Name: "third", Arguments: testArgs(map[string]any{"c": "3"}), Index: 2}},
|
||||
}
|
||||
if len(calls) != len(want) {
|
||||
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
|
||||
}
|
||||
for i := range want {
|
||||
if !toolCallEqual(calls[i], want[i]) {
|
||||
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3CoderParserToolCallIndexingStreaming(t *testing.T) {
|
||||
parser := Qwen3CoderParser{}
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
var all []api.ToolCall
|
||||
|
||||
_, _, calls, err := parser.Add("<tool_call><function=first><parameter=a>1</parameter></function></tool_call><tool_call><function=second>", false)
|
||||
if err != nil {
|
||||
t.Fatalf("step 1 parse failed: %v", err)
|
||||
}
|
||||
all = append(all, calls...)
|
||||
|
||||
_, _, calls, err = parser.Add("<parameter=b>2</parameter></function></tool_call><tool_call><function=third><parameter=c>3</parameter></function></tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatalf("step 2 parse failed: %v", err)
|
||||
}
|
||||
all = append(all, calls...)
|
||||
|
||||
want := []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "first", Arguments: testArgs(map[string]any{"a": "1"}), Index: 0}},
|
||||
{Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 1}},
|
||||
{Function: api.ToolCallFunction{Name: "third", Arguments: testArgs(map[string]any{"c": "3"}), Index: 2}},
|
||||
}
|
||||
if len(all) != len(want) {
|
||||
t.Fatalf("expected %d calls, got %d", len(want), len(all))
|
||||
}
|
||||
for i := range want {
|
||||
if !toolCallEqual(all[i], want[i]) {
|
||||
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3CoderParserToolCallIndexResetOnInit(t *testing.T) {
|
||||
parser := Qwen3CoderParser{}
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
_, _, _, err := parser.Add("<tool_call><function=first><parameter=a>1</parameter></function></tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatalf("first parse failed: %v", err)
|
||||
}
|
||||
|
||||
parser.Init(nil, nil, nil)
|
||||
_, _, calls, err := parser.Add("<tool_call><function=second><parameter=b>2</parameter></function></tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatalf("second parse failed: %v", err)
|
||||
}
|
||||
|
||||
want := api.ToolCall{
|
||||
Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 0},
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||
}
|
||||
if !toolCallEqual(calls[0], want) {
|
||||
t.Fatalf("got %#v, want %#v", calls[0], want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwenXMLTransform(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
|
||||
@@ -180,7 +180,22 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
|
||||
return events, false
|
||||
}
|
||||
case CollectingThinkingContent:
|
||||
if strings.Contains(p.buffer.String(), thinkingCloseTag) {
|
||||
acc := p.buffer.String()
|
||||
thinkingCloseIdx := strings.Index(acc, thinkingCloseTag)
|
||||
toolOpenIdx := strings.Index(acc, toolOpenTag)
|
||||
|
||||
// If a tool call starts before </think>, treat that as the end of thinking
|
||||
// for parsing purposes and continue in tool-call mode.
|
||||
if toolOpenIdx != -1 && (thinkingCloseIdx == -1 || toolOpenIdx < thinkingCloseIdx) {
|
||||
before, _ := splitAtTag(&p.buffer, toolOpenTag, false)
|
||||
if len(before) > 0 {
|
||||
events = append(events, qwenEventThinkingContent{content: before})
|
||||
}
|
||||
p.state = CollectingToolContent
|
||||
return events, true
|
||||
}
|
||||
|
||||
if strings.Contains(acc, thinkingCloseTag) {
|
||||
thinking, remaining := splitAtTag(&p.buffer, thinkingCloseTag, true)
|
||||
if len(thinking) > 0 {
|
||||
events = append(events, qwenEventThinkingContent{content: thinking})
|
||||
@@ -191,13 +206,13 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
|
||||
p.state = CollectingContent
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(p.buffer.String(), thinkingCloseTag); overlapLen > 0 {
|
||||
beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
|
||||
} else if overlapLen := max(overlap(acc, thinkingCloseTag), overlap(acc, toolOpenTag)); overlapLen > 0 {
|
||||
beforePartialTag := acc[:len(acc)-overlapLen]
|
||||
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
|
||||
|
||||
unambiguous := p.buffer.String()[:ambiguousStart]
|
||||
ambiguous := p.buffer.String()[ambiguousStart:]
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
@@ -205,11 +220,11 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
|
||||
}
|
||||
return events, false
|
||||
} else {
|
||||
whitespaceLen := trailingWhitespaceLen(p.buffer.String())
|
||||
ambiguousStart := len(p.buffer.String()) - whitespaceLen
|
||||
whitespaceLen := trailingWhitespaceLen(acc)
|
||||
ambiguousStart := len(acc) - whitespaceLen
|
||||
|
||||
unambiguous := p.buffer.String()[:ambiguousStart]
|
||||
ambiguous := p.buffer.String()[ambiguousStart:]
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
|
||||
@@ -98,8 +98,12 @@ func TestQwen3VLThinkingParserStreaming(t *testing.T) {
|
||||
desc: "nested thinking and tool call (outside thinking, inside tool call)",
|
||||
steps: []step{
|
||||
{
|
||||
input: "I'm thinking<tool_call>I'm nested tool call</tool_call></think>",
|
||||
wantEvents: []qwenEvent{qwenEventThinkingContent{content: "I'm thinking<tool_call>I'm nested tool call</tool_call>"}},
|
||||
input: "I'm thinking<tool_call>I'm nested tool call</tool_call></think>",
|
||||
wantEvents: []qwenEvent{
|
||||
qwenEventThinkingContent{content: "I'm thinking"},
|
||||
qwenEventRawToolCall{raw: "I'm nested tool call"},
|
||||
qwenEventContent{content: "</think>"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -109,8 +113,7 @@ func TestQwen3VLThinkingParserStreaming(t *testing.T) {
|
||||
{
|
||||
input: "<tool_call>I'm nested tool call<think>I'm thinking</think></tool_call>",
|
||||
wantEvents: []qwenEvent{
|
||||
qwenEventThinkingContent{content: "<tool_call>I'm nested tool call<think>I'm thinking"},
|
||||
qwenEventContent{content: "</tool_call>"},
|
||||
qwenEventRawToolCall{raw: "I'm nested tool call<think>I'm thinking</think>"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -121,8 +124,8 @@ func TestQwen3VLThinkingParserStreaming(t *testing.T) {
|
||||
{
|
||||
input: "I'm thinking<tool_call>I'm NOT a nested tool call</think></tool_call><tool_call>I'm nested tool call 2<think></tool_call></think>",
|
||||
wantEvents: []qwenEvent{
|
||||
qwenEventThinkingContent{content: "I'm thinking<tool_call>I'm NOT a nested tool call"},
|
||||
qwenEventContent{content: "</tool_call>"},
|
||||
qwenEventThinkingContent{content: "I'm thinking"},
|
||||
qwenEventRawToolCall{raw: "I'm NOT a nested tool call</think>"},
|
||||
qwenEventRawToolCall{raw: "I'm nested tool call 2<think>"},
|
||||
qwenEventContent{content: "</think>"},
|
||||
},
|
||||
|
||||
@@ -8,7 +8,21 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type GlmOcrRenderer struct{}
|
||||
type GlmOcrRenderer struct {
|
||||
useImgTags bool
|
||||
}
|
||||
|
||||
func (r *GlmOcrRenderer) renderContent(message api.Message, imageOffset int) (string, int) {
|
||||
var sb strings.Builder
|
||||
for range message.Images {
|
||||
if r.useImgTags {
|
||||
sb.WriteString(fmt.Sprintf("[img-%d]", imageOffset))
|
||||
imageOffset++
|
||||
}
|
||||
}
|
||||
sb.WriteString(message.Content)
|
||||
return sb.String(), imageOffset
|
||||
}
|
||||
|
||||
func (r *GlmOcrRenderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
@@ -38,11 +52,14 @@ func (r *GlmOcrRenderer) Render(messages []api.Message, tools []api.Tool, thinkV
|
||||
thinkingExplicitlySet = true
|
||||
}
|
||||
|
||||
imageOffset := 0
|
||||
for i, message := range messages {
|
||||
switch message.Role {
|
||||
case "user":
|
||||
sb.WriteString("<|user|>\n")
|
||||
sb.WriteString(message.Content)
|
||||
content, nextOffset := r.renderContent(message, imageOffset)
|
||||
imageOffset = nextOffset
|
||||
sb.WriteString(content)
|
||||
if thinkingExplicitlySet && !enableThinking && !strings.HasSuffix(message.Content, "/nothink") {
|
||||
sb.WriteString("/nothink")
|
||||
}
|
||||
|
||||
99
model/renderers/glmocr_test.go
Normal file
99
model/renderers/glmocr_test.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestGlmOcrRenderer_Images(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
renderer *GlmOcrRenderer
|
||||
messages []api.Message
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "use_img_tags_single_image",
|
||||
renderer: &GlmOcrRenderer{useImgTags: true},
|
||||
messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Describe this image.",
|
||||
Images: []api.ImageData{api.ImageData("img1")},
|
||||
},
|
||||
},
|
||||
expected: "[gMASK]<sop><|user|>\n[img-0]Describe this image.<|assistant|>\n",
|
||||
},
|
||||
{
|
||||
name: "use_img_tags_multiple_images",
|
||||
renderer: &GlmOcrRenderer{useImgTags: true},
|
||||
messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Describe these images.",
|
||||
Images: []api.ImageData{api.ImageData("img1"), api.ImageData("img2")},
|
||||
},
|
||||
},
|
||||
expected: "[gMASK]<sop><|user|>\n[img-0][img-1]Describe these images.<|assistant|>\n",
|
||||
},
|
||||
{
|
||||
name: "multi_turn_increments_image_offset",
|
||||
renderer: &GlmOcrRenderer{useImgTags: true},
|
||||
messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "First image",
|
||||
Images: []api.ImageData{api.ImageData("img1")},
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Processed.",
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Second image",
|
||||
Images: []api.ImageData{api.ImageData("img2")},
|
||||
},
|
||||
},
|
||||
expected: "[gMASK]<sop><|user|>\n[img-0]First image<|assistant|>\n<think></think>\nProcessed.\n<|user|>\n[img-1]Second image<|assistant|>\n",
|
||||
},
|
||||
{
|
||||
name: "default_no_img_tags",
|
||||
renderer: &GlmOcrRenderer{},
|
||||
messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "No image tags expected.",
|
||||
Images: []api.ImageData{api.ImageData("img1")},
|
||||
},
|
||||
},
|
||||
expected: "[gMASK]<sop><|user|>\nNo image tags expected.<|assistant|>\n",
|
||||
},
|
||||
{
|
||||
name: "no_images_content_unchanged",
|
||||
renderer: &GlmOcrRenderer{useImgTags: true},
|
||||
messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Text only message.",
|
||||
},
|
||||
},
|
||||
expected: "[gMASK]<sop><|user|>\nText only message.<|assistant|>\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.renderer.Render(tt.messages, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Render() error = %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(tt.expected, got); diff != "" {
|
||||
t.Fatalf("Render() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -86,7 +86,7 @@ func rendererForName(name string) Renderer {
|
||||
case "glm-4.7":
|
||||
return &GLM47Renderer{}
|
||||
case "glm-ocr":
|
||||
return &GlmOcrRenderer{}
|
||||
return &GlmOcrRenderer{useImgTags: RenderImgTags}
|
||||
case "lfm2":
|
||||
return &LFM2Renderer{IsThinking: false, useImgTags: RenderImgTags}
|
||||
case "lfm2-thinking":
|
||||
|
||||
@@ -562,6 +562,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
||||
if errors.As(err, &reprocess) {
|
||||
// Prepend these inputs to the sequence's inputs queue for reprocessing
|
||||
seq.inputs = append(reprocess.Inputs, seq.inputs...)
|
||||
seq.sampler.Reset()
|
||||
// Skip this sequence but continue processing the rest
|
||||
nextBatch.seqs[seqIdx] = nil // clear this sequence for this batch
|
||||
err = nil
|
||||
@@ -692,6 +693,12 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
// (unless we take down the whole runner).
|
||||
if len(seq.pendingInputs) > 0 {
|
||||
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
|
||||
for _, inp := range seq.pendingInputs {
|
||||
if len(inp.Multimodal) != 0 {
|
||||
continue
|
||||
}
|
||||
seq.sampler.Accept(inp.Token)
|
||||
}
|
||||
seq.pendingInputs = []*input.Input{}
|
||||
}
|
||||
|
||||
@@ -892,6 +899,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
req.Options.TopK,
|
||||
req.Options.TopP,
|
||||
req.Options.MinP,
|
||||
req.Options.RepeatPenalty,
|
||||
req.Options.PresencePenalty,
|
||||
req.Options.FrequencyPenalty,
|
||||
req.Options.Seed,
|
||||
grammar,
|
||||
)
|
||||
@@ -938,6 +948,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
seq.sampler.Reset()
|
||||
for _, inp := range seq.cache.Inputs {
|
||||
if len(inp.Multimodal) != 0 {
|
||||
continue
|
||||
}
|
||||
seq.sampler.Accept(inp.Token)
|
||||
}
|
||||
|
||||
s.seqs[i] = seq
|
||||
s.cond.Signal()
|
||||
found = true
|
||||
|
||||
@@ -16,24 +16,49 @@ type token struct {
|
||||
value float32 // The raw logit or probability from the model
|
||||
}
|
||||
|
||||
const DefaultPenaltyLookback = 64
|
||||
|
||||
type Sampler struct {
|
||||
rng *rand.Rand
|
||||
topK int
|
||||
topP float32
|
||||
minP float32
|
||||
temperature float32
|
||||
repeat float32
|
||||
presence float32
|
||||
frequency float32
|
||||
history []int32
|
||||
grammar *GrammarSampler
|
||||
}
|
||||
|
||||
func (s *Sampler) Reset() {
|
||||
s.history = s.history[:0]
|
||||
}
|
||||
|
||||
func (s *Sampler) Accept(token int32) {
|
||||
s.history = append(s.history, token)
|
||||
if len(s.history) > DefaultPenaltyLookback {
|
||||
copy(s.history, s.history[len(s.history)-DefaultPenaltyLookback:])
|
||||
s.history = s.history[:DefaultPenaltyLookback]
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Sampler) Sample(logits []float32) (int32, error) {
|
||||
if len(logits) == 0 {
|
||||
return -1, errors.New("sample: no logits provided to sample")
|
||||
}
|
||||
|
||||
counts := tokenCounts(s.history, len(logits))
|
||||
|
||||
tokens := make([]token, len(logits))
|
||||
for i := range logits {
|
||||
value := logits[i]
|
||||
if count := counts[int32(i)]; count > 0 {
|
||||
value = applyPenalty(value, count, s.repeat, s.presence, s.frequency)
|
||||
}
|
||||
|
||||
tokens[i].id = int32(i)
|
||||
tokens[i].value = logits[i]
|
||||
tokens[i].value = value
|
||||
}
|
||||
|
||||
t, err := s.sample(tokens)
|
||||
@@ -55,8 +80,12 @@ func (s *Sampler) Sample(logits []float32) (int32, error) {
|
||||
// we need to reset them before applying the grammar and
|
||||
// sampling again
|
||||
for i := range logits {
|
||||
value := logits[i]
|
||||
if count := counts[int32(i)]; count > 0 {
|
||||
value = applyPenalty(value, count, s.repeat, s.presence, s.frequency)
|
||||
}
|
||||
tokens[i].id = int32(i)
|
||||
tokens[i].value = logits[i]
|
||||
tokens[i].value = value
|
||||
}
|
||||
s.grammar.Apply(tokens)
|
||||
t, err = s.sample(tokens)
|
||||
@@ -127,7 +156,7 @@ func (s *Sampler) sample(tokens []token) (token, error) {
|
||||
}
|
||||
|
||||
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
||||
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *GrammarSampler) Sampler {
|
||||
func NewSampler(temperature float32, topK int, topP float32, minP float32, repeatPenalty float32, presencePenalty float32, frequencyPenalty float32, seed int, grammar *GrammarSampler) Sampler {
|
||||
var rng *rand.Rand
|
||||
if seed != -1 {
|
||||
// PCG requires two parameters: sequence and stream
|
||||
@@ -154,12 +183,19 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
|
||||
minP = 1.0
|
||||
}
|
||||
|
||||
if repeatPenalty <= 0 {
|
||||
repeatPenalty = 1.0
|
||||
}
|
||||
|
||||
return Sampler{
|
||||
rng: rng,
|
||||
topK: topK,
|
||||
topP: topP,
|
||||
minP: minP,
|
||||
temperature: temperature,
|
||||
repeat: repeatPenalty,
|
||||
presence: presencePenalty,
|
||||
frequency: frequencyPenalty,
|
||||
grammar: grammar,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
|
||||
logits[i] = float32(rand.Float64()*10 - 5)
|
||||
}
|
||||
|
||||
sampler := NewSampler(0.8, 0, 0, 0, 42, nil)
|
||||
sampler := NewSampler(0.8, 0, 0, 0, 1, 0, 0, 42, nil)
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
sampler.Sample(logits)
|
||||
@@ -49,7 +49,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
|
||||
|
||||
for _, tc := range configs {
|
||||
b.Run("Config"+tc.name, func(b *testing.B) {
|
||||
sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed, nil)
|
||||
sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, 1, 0, 0, tc.seed, nil)
|
||||
sampler.Sample(logits)
|
||||
|
||||
b.ResetTimer()
|
||||
@@ -62,7 +62,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
|
||||
|
||||
// Test with combined transforms separately - topK influences performance greatly
|
||||
b.Run("TransformCombined", func(b *testing.B) {
|
||||
sampler := NewSampler(0.8, 50, 0.9, 0.05, 42, nil)
|
||||
sampler := NewSampler(0.8, 50, 0.9, 0.05, 1, 0, 0, 42, nil)
|
||||
b.ResetTimer()
|
||||
|
||||
for b.Loop() {
|
||||
@@ -81,7 +81,7 @@ func BenchmarkGreedySampler(b *testing.B) {
|
||||
logits[i] = float32(rand.Float64()*10 - 5)
|
||||
}
|
||||
|
||||
sampler := NewSampler(0, -1, 0, 0, -1, nil)
|
||||
sampler := NewSampler(0, -1, 0, 0, 1, 0, 0, -1, nil)
|
||||
b.ResetTimer()
|
||||
|
||||
for b.Loop() {
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
|
||||
func TestWeighted(t *testing.T) {
|
||||
logits := []float32{-10, 3, -10, -10}
|
||||
sampler := NewSampler(0, 0, 0, 0, 0, nil)
|
||||
sampler := NewSampler(0, 0, 0, 0, 1, 0, 0, 0, nil)
|
||||
got, err := sampler.Sample(logits)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
@@ -25,7 +25,7 @@ func TestWeighted(t *testing.T) {
|
||||
}
|
||||
|
||||
logits = []float32{-100, -10, 0, 10}
|
||||
sampler = NewSampler(0, 0, 0, 0, 0, nil)
|
||||
sampler = NewSampler(0, 0, 0, 0, 1, 0, 0, 0, nil)
|
||||
got, err = sampler.Sample(logits)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
@@ -39,7 +39,7 @@ func TestWeighted(t *testing.T) {
|
||||
// Test very high p
|
||||
logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1}
|
||||
// Use extremely small topP to filter out all tokens
|
||||
sampler = NewSampler(1.0, 0, 1e-10, 0, 0, nil)
|
||||
sampler = NewSampler(1.0, 0, 1e-10, 0, 1, 0, 0, 0, nil)
|
||||
got, err = sampler.Sample(logits)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
@@ -52,7 +52,7 @@ func TestWeighted(t *testing.T) {
|
||||
}
|
||||
|
||||
logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())}
|
||||
sampler = NewSampler(1, 0, 0.95, 0.05, 0, nil)
|
||||
sampler = NewSampler(1, 0, 0.95, 0.05, 1, 0, 0, 0, nil)
|
||||
got, err = sampler.Sample(logits)
|
||||
if err == nil {
|
||||
t.Errorf("expected error, got %d", got)
|
||||
@@ -151,8 +151,8 @@ func TestGrammar(t *testing.T) {
|
||||
|
||||
func BenchmarkSample(b *testing.B) {
|
||||
samplers := map[string]Sampler{
|
||||
"Greedy": NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
|
||||
"Weighted": NewSampler(0.5, 10, 0.9, 0.2, -1, nil),
|
||||
"Greedy": NewSampler(0, 0, 0, 0, 1, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
|
||||
"Weighted": NewSampler(0.5, 10, 0.9, 0.2, 1, 0, 0, -1, nil),
|
||||
}
|
||||
|
||||
// Generate random logits for benchmarking
|
||||
|
||||
@@ -25,6 +25,48 @@ func (h *tokenHeap) Pop() any {
|
||||
return x
|
||||
}
|
||||
|
||||
func tokenCounts(history []int32, vocabSize int) map[int32]int {
|
||||
if len(history) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
start := 0
|
||||
if len(history) > DefaultPenaltyLookback {
|
||||
start = len(history) - DefaultPenaltyLookback
|
||||
}
|
||||
|
||||
counts := make(map[int32]int, len(history)-start)
|
||||
for _, token := range history[start:] {
|
||||
if token < 0 || int(token) >= vocabSize {
|
||||
continue
|
||||
}
|
||||
counts[token]++
|
||||
}
|
||||
|
||||
return counts
|
||||
}
|
||||
|
||||
func applyPenalty(logit float32, count int, repeatPenalty float32, presencePenalty float32, frequencyPenalty float32) float32 {
|
||||
if repeatPenalty != 1.0 {
|
||||
// Preserve ordering for negative logits when applying repeat penalty.
|
||||
if logit < 0 {
|
||||
logit *= repeatPenalty
|
||||
} else {
|
||||
logit /= repeatPenalty
|
||||
}
|
||||
}
|
||||
|
||||
if frequencyPenalty != 0 {
|
||||
logit -= float32(count) * frequencyPenalty
|
||||
}
|
||||
|
||||
if presencePenalty != 0 {
|
||||
logit -= presencePenalty
|
||||
}
|
||||
|
||||
return logit
|
||||
}
|
||||
|
||||
// temperature applies scaling to the logits
|
||||
func temperature(ts []token, temp float32) {
|
||||
// Ensure temperature clipping near 0 to avoid numerical instability
|
||||
|
||||
@@ -295,6 +295,86 @@ func TestMinP(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenCounts(t *testing.T) {
|
||||
history := make([]int32, 70)
|
||||
history[0] = 7
|
||||
history[69] = 7
|
||||
|
||||
counts := tokenCounts(history, 8)
|
||||
if got := counts[7]; got != 1 {
|
||||
t.Fatalf("lookback mismatch: got %d want %d", got, 1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyPenalty(t *testing.T) {
|
||||
logit := applyPenalty(5.0, 3, 1.0, 1.5, 0.5)
|
||||
if math.Abs(float64(logit-2.0)) > 1e-6 {
|
||||
t.Fatalf("unexpected penalty result: got %f want %f", logit, 2.0)
|
||||
}
|
||||
|
||||
logit = applyPenalty(4.0, 1, 2.0, 0, 0)
|
||||
if math.Abs(float64(logit-2.0)) > 1e-6 {
|
||||
t.Fatalf("unexpected repeat penalty result for positive logits: got %f want %f", logit, 2.0)
|
||||
}
|
||||
|
||||
logit = applyPenalty(-4.0, 1, 2.0, 0, 0)
|
||||
if math.Abs(float64(logit-(-8.0))) > 1e-6 {
|
||||
t.Fatalf("unexpected repeat penalty result for negative logits: got %f want %f", logit, -8.0)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSamplerPresencePenalty(t *testing.T) {
|
||||
logits := []float32{0.0, 5.0, 0.0}
|
||||
|
||||
baseline := NewSampler(0, 0, 1, 0, 1, 0, 0, -1, nil)
|
||||
baseline.Accept(1)
|
||||
got, err := baseline.Sample(logits)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got != 1 {
|
||||
t.Fatalf("unexpected baseline token: got %d want %d", got, 1)
|
||||
}
|
||||
|
||||
presence := NewSampler(0, 0, 1, 0, 1, 6, 0, -1, nil)
|
||||
presence.Accept(1)
|
||||
got, err = presence.Sample(logits)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got == 1 {
|
||||
t.Fatalf("presence penalty did not change repeated token selection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSamplerFrequencyPenalty(t *testing.T) {
|
||||
logits := []float32{0.0, 5.0, 4.0}
|
||||
|
||||
baseline := NewSampler(0, 0, 1, 0, 1, 0, 0, -1, nil)
|
||||
baseline.Accept(1)
|
||||
baseline.Accept(1)
|
||||
baseline.Accept(1)
|
||||
got, err := baseline.Sample(logits)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got != 1 {
|
||||
t.Fatalf("unexpected baseline token: got %d want %d", got, 1)
|
||||
}
|
||||
|
||||
frequency := NewSampler(0, 0, 1, 0, 1, 0, 1.0, -1, nil)
|
||||
frequency.Accept(1)
|
||||
frequency.Accept(1)
|
||||
frequency.Accept(1)
|
||||
got, err = frequency.Sample(logits)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got != 2 {
|
||||
t.Fatalf("frequency penalty did not demote repeated token as expected: got %d want %d", got, 2)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTransforms(b *testing.B) {
|
||||
// Generate random logits
|
||||
tokens := make([]token, 1<<16)
|
||||
|
||||
460
server/cloud_proxy.go
Normal file
460
server/cloud_proxy.go
Normal file
@@ -0,0 +1,460 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultCloudProxyBaseURL = "https://ollama.com:443"
|
||||
defaultCloudProxySigningHost = "ollama.com"
|
||||
cloudProxyBaseURLEnv = "OLLAMA_CLOUD_BASE_URL"
|
||||
legacyCloudAnthropicKey = "legacy_cloud_anthropic_web_search"
|
||||
)
|
||||
|
||||
var (
|
||||
cloudProxyBaseURL = defaultCloudProxyBaseURL
|
||||
cloudProxySigningHost = defaultCloudProxySigningHost
|
||||
cloudProxySignRequest = signCloudProxyRequest
|
||||
cloudProxySigninURL = signinURL
|
||||
)
|
||||
|
||||
var hopByHopHeaders = map[string]struct{}{
|
||||
"connection": {},
|
||||
"content-length": {},
|
||||
"proxy-connection": {},
|
||||
"keep-alive": {},
|
||||
"proxy-authenticate": {},
|
||||
"proxy-authorization": {},
|
||||
"te": {},
|
||||
"trailer": {},
|
||||
"transfer-encoding": {},
|
||||
"upgrade": {},
|
||||
}
|
||||
|
||||
func init() {
|
||||
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL(envconfig.Var(cloudProxyBaseURLEnv), mode)
|
||||
if err != nil {
|
||||
slog.Warn("ignoring cloud base URL override", "env", cloudProxyBaseURLEnv, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
cloudProxyBaseURL = baseURL
|
||||
cloudProxySigningHost = signingHost
|
||||
|
||||
if overridden {
|
||||
slog.Info("cloud base URL override enabled", "env", cloudProxyBaseURLEnv, "url", cloudProxyBaseURL, "mode", mode)
|
||||
}
|
||||
}
|
||||
|
||||
func cloudPassthroughMiddleware(disabledOperation string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if c.Request.Method != http.MethodPost {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// TODO(drifkin): Avoid full-body buffering here for model detection.
|
||||
// A future optimization can parse just enough JSON to read "model" (and
|
||||
// optionally short-circuit cloud-disabled explicit-cloud requests) while
|
||||
// preserving raw passthrough semantics.
|
||||
body, err := readRequestBody(c.Request)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
model, ok := extractModelField(body)
|
||||
if !ok {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
modelRef, err := parseAndValidateModelRef(model)
|
||||
if err != nil || modelRef.Source != modelSourceCloud {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
normalizedBody, err := replaceJSONModelField(body, modelRef.Base)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// TEMP(drifkin): keep Anthropic web search requests on the local middleware
|
||||
// path so WebSearchAnthropicWriter can orchestrate follow-up calls.
|
||||
if c.Request.URL.Path == "/v1/messages" {
|
||||
if hasAnthropicWebSearchTool(body) {
|
||||
c.Set(legacyCloudAnthropicKey, true)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
proxyCloudRequest(c, normalizedBody, disabledOperation)
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
||||
func cloudModelPathPassthroughMiddleware(disabledOperation string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
modelName := strings.TrimSpace(c.Param("model"))
|
||||
if modelName == "" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
modelRef, err := parseAndValidateModelRef(modelName)
|
||||
if err != nil || modelRef.Source != modelSourceCloud {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
proxyPath := "/v1/models/" + modelRef.Base
|
||||
proxyCloudRequestWithPath(c, nil, proxyPath, disabledOperation)
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
||||
func proxyCloudJSONRequest(c *gin.Context, payload any, disabledOperation string) {
|
||||
// TEMP(drifkin): we currently split out this `WithPath` method because we are
|
||||
// mapping `/v1/messages` + web_search to `/api/chat` temporarily. Once we
|
||||
// stop doing this, we can inline this method.
|
||||
proxyCloudJSONRequestWithPath(c, payload, c.Request.URL.Path, disabledOperation)
|
||||
}
|
||||
|
||||
func proxyCloudJSONRequestWithPath(c *gin.Context, payload any, path string, disabledOperation string) {
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
proxyCloudRequestWithPath(c, body, path, disabledOperation)
|
||||
}
|
||||
|
||||
func proxyCloudRequest(c *gin.Context, body []byte, disabledOperation string) {
|
||||
proxyCloudRequestWithPath(c, body, c.Request.URL.Path, disabledOperation)
|
||||
}
|
||||
|
||||
func proxyCloudRequestWithPath(c *gin.Context, body []byte, path string, disabledOperation string) {
|
||||
if disabled, _ := internalcloud.Status(); disabled {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": internalcloud.DisabledError(disabledOperation)})
|
||||
return
|
||||
}
|
||||
|
||||
baseURL, err := url.Parse(cloudProxyBaseURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
targetURL := baseURL.ResolveReference(&url.URL{
|
||||
Path: path,
|
||||
RawQuery: c.Request.URL.RawQuery,
|
||||
})
|
||||
|
||||
outReq, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, targetURL.String(), bytes.NewReader(body))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
copyProxyRequestHeaders(outReq.Header, c.Request.Header)
|
||||
if outReq.Header.Get("Content-Type") == "" && len(body) > 0 {
|
||||
outReq.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
if err := cloudProxySignRequest(outReq.Context(), outReq); err != nil {
|
||||
slog.Warn("cloud proxy signing failed", "error", err)
|
||||
writeCloudUnauthorized(c)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO(drifkin): Add phase-specific proxy timeouts.
|
||||
// Connect/TLS/TTFB should have bounded timeouts, but once streaming starts
|
||||
// we should not enforce a short total timeout for long-lived responses.
|
||||
resp, err := http.DefaultClient.Do(outReq)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
copyProxyResponseHeaders(c.Writer.Header(), resp.Header)
|
||||
c.Status(resp.StatusCode)
|
||||
|
||||
if err := copyProxyResponseBody(c.Writer, resp.Body); err != nil {
|
||||
c.Error(err) //nolint:errcheck
|
||||
}
|
||||
}
|
||||
|
||||
func replaceJSONModelField(body []byte, model string) ([]byte, error) {
|
||||
if len(body) == 0 {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
var payload map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
modelJSON, err := json.Marshal(model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload["model"] = modelJSON
|
||||
|
||||
return json.Marshal(payload)
|
||||
}
|
||||
|
||||
func readRequestBody(r *http.Request) ([]byte, error) {
|
||||
if r.Body == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r.Body = io.NopCloser(bytes.NewReader(body))
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func extractModelField(body []byte) (string, bool) {
|
||||
if len(body) == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
var payload map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
raw, ok := payload["model"]
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
|
||||
var model string
|
||||
if err := json.Unmarshal(raw, &model); err != nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
model = strings.TrimSpace(model)
|
||||
return model, model != ""
|
||||
}
|
||||
|
||||
func hasAnthropicWebSearchTool(body []byte) bool {
|
||||
if len(body) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Tools []struct {
|
||||
Type string `json:"type"`
|
||||
} `json:"tools"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, tool := range payload.Tools {
|
||||
if strings.HasPrefix(strings.TrimSpace(tool.Type), "web_search") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func writeCloudUnauthorized(c *gin.Context) {
|
||||
signinURL, err := cloudProxySigninURL()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized", "signin_url": signinURL})
|
||||
}
|
||||
|
||||
func signCloudProxyRequest(ctx context.Context, req *http.Request) error {
|
||||
if !strings.EqualFold(req.URL.Hostname(), cloudProxySigningHost) {
|
||||
return nil
|
||||
}
|
||||
|
||||
ts := strconv.FormatInt(time.Now().Unix(), 10)
|
||||
challenge := buildCloudSignatureChallenge(req, ts)
|
||||
signature, err := auth.Sign(ctx, []byte(challenge))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", signature)
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildCloudSignatureChallenge(req *http.Request, ts string) string {
|
||||
query := req.URL.Query()
|
||||
query.Set("ts", ts)
|
||||
req.URL.RawQuery = query.Encode()
|
||||
|
||||
return fmt.Sprintf("%s,%s", req.Method, req.URL.RequestURI())
|
||||
}
|
||||
|
||||
func resolveCloudProxyBaseURL(rawOverride string, runMode string) (baseURL string, signingHost string, overridden bool, err error) {
|
||||
baseURL = defaultCloudProxyBaseURL
|
||||
signingHost = defaultCloudProxySigningHost
|
||||
|
||||
rawOverride = strings.TrimSpace(rawOverride)
|
||||
if rawOverride == "" {
|
||||
return baseURL, signingHost, false, nil
|
||||
}
|
||||
|
||||
u, err := url.Parse(rawOverride)
|
||||
if err != nil {
|
||||
return "", "", false, fmt.Errorf("invalid URL: %w", err)
|
||||
}
|
||||
if u.Scheme == "" || u.Host == "" {
|
||||
return "", "", false, fmt.Errorf("invalid URL: scheme and host are required")
|
||||
}
|
||||
if u.User != nil {
|
||||
return "", "", false, fmt.Errorf("invalid URL: userinfo is not allowed")
|
||||
}
|
||||
if u.Path != "" && u.Path != "/" {
|
||||
return "", "", false, fmt.Errorf("invalid URL: path is not allowed")
|
||||
}
|
||||
if u.RawQuery != "" || u.Fragment != "" {
|
||||
return "", "", false, fmt.Errorf("invalid URL: query and fragment are not allowed")
|
||||
}
|
||||
|
||||
host := u.Hostname()
|
||||
if host == "" {
|
||||
return "", "", false, fmt.Errorf("invalid URL: host is required")
|
||||
}
|
||||
|
||||
loopback := isLoopbackHost(host)
|
||||
if runMode == gin.ReleaseMode && !loopback {
|
||||
return "", "", false, fmt.Errorf("non-loopback cloud override is not allowed in release mode")
|
||||
}
|
||||
if !loopback && !strings.EqualFold(u.Scheme, "https") {
|
||||
return "", "", false, fmt.Errorf("non-loopback cloud override must use https")
|
||||
}
|
||||
|
||||
u.Path = ""
|
||||
u.RawPath = ""
|
||||
u.RawQuery = ""
|
||||
u.Fragment = ""
|
||||
|
||||
return u.String(), strings.ToLower(host), true, nil
|
||||
}
|
||||
|
||||
func isLoopbackHost(host string) bool {
|
||||
if strings.EqualFold(host, "localhost") {
|
||||
return true
|
||||
}
|
||||
|
||||
ip := net.ParseIP(host)
|
||||
return ip != nil && ip.IsLoopback()
|
||||
}
|
||||
|
||||
func copyProxyRequestHeaders(dst, src http.Header) {
|
||||
connectionTokens := connectionHeaderTokens(src)
|
||||
for key, values := range src {
|
||||
if isHopByHopHeader(key) || isConnectionTokenHeader(key, connectionTokens) {
|
||||
continue
|
||||
}
|
||||
|
||||
dst.Del(key)
|
||||
for _, value := range values {
|
||||
dst.Add(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func copyProxyResponseHeaders(dst, src http.Header) {
|
||||
connectionTokens := connectionHeaderTokens(src)
|
||||
for key, values := range src {
|
||||
if isHopByHopHeader(key) || isConnectionTokenHeader(key, connectionTokens) {
|
||||
continue
|
||||
}
|
||||
|
||||
dst.Del(key)
|
||||
for _, value := range values {
|
||||
dst.Add(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func copyProxyResponseBody(dst http.ResponseWriter, src io.Reader) error {
|
||||
flusher, canFlush := dst.(http.Flusher)
|
||||
buf := make([]byte, 32*1024)
|
||||
|
||||
for {
|
||||
n, err := src.Read(buf)
|
||||
if n > 0 {
|
||||
if _, writeErr := dst.Write(buf[:n]); writeErr != nil {
|
||||
return writeErr
|
||||
}
|
||||
if canFlush {
|
||||
// TODO(drifkin): Consider conditional flushing so non-streaming
|
||||
// responses don't flush every write and can optimize throughput.
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isHopByHopHeader(name string) bool {
|
||||
_, ok := hopByHopHeaders[strings.ToLower(name)]
|
||||
return ok
|
||||
}
|
||||
|
||||
func connectionHeaderTokens(header http.Header) map[string]struct{} {
|
||||
tokens := map[string]struct{}{}
|
||||
for _, raw := range header.Values("Connection") {
|
||||
for _, token := range strings.Split(raw, ",") {
|
||||
token = strings.TrimSpace(strings.ToLower(token))
|
||||
if token == "" {
|
||||
continue
|
||||
}
|
||||
tokens[token] = struct{}{}
|
||||
}
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
func isConnectionTokenHeader(name string, tokens map[string]struct{}) bool {
|
||||
if len(tokens) == 0 {
|
||||
return false
|
||||
}
|
||||
_, ok := tokens[strings.ToLower(name)]
|
||||
return ok
|
||||
}
|
||||
154
server/cloud_proxy_test.go
Normal file
154
server/cloud_proxy_test.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestCopyProxyRequestHeaders_StripsConnectionTokenHeaders(t *testing.T) {
|
||||
src := http.Header{}
|
||||
src.Add("Connection", "keep-alive, X-Trace-Hop, x-alt-hop")
|
||||
src.Add("X-Trace-Hop", "drop-me")
|
||||
src.Add("X-Alt-Hop", "drop-me-too")
|
||||
src.Add("Keep-Alive", "timeout=5")
|
||||
src.Add("X-End-To-End", "keep-me")
|
||||
|
||||
dst := http.Header{}
|
||||
copyProxyRequestHeaders(dst, src)
|
||||
|
||||
if got := dst.Get("Connection"); got != "" {
|
||||
t.Fatalf("expected Connection to be stripped, got %q", got)
|
||||
}
|
||||
if got := dst.Get("Keep-Alive"); got != "" {
|
||||
t.Fatalf("expected Keep-Alive to be stripped, got %q", got)
|
||||
}
|
||||
if got := dst.Get("X-Trace-Hop"); got != "" {
|
||||
t.Fatalf("expected X-Trace-Hop to be stripped via Connection token, got %q", got)
|
||||
}
|
||||
if got := dst.Get("X-Alt-Hop"); got != "" {
|
||||
t.Fatalf("expected X-Alt-Hop to be stripped via Connection token, got %q", got)
|
||||
}
|
||||
if got := dst.Get("X-End-To-End"); got != "keep-me" {
|
||||
t.Fatalf("expected X-End-To-End to be forwarded, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyProxyResponseHeaders_StripsConnectionTokenHeaders(t *testing.T) {
|
||||
src := http.Header{}
|
||||
src.Add("Connection", "X-Upstream-Hop")
|
||||
src.Add("X-Upstream-Hop", "drop-me")
|
||||
src.Add("Content-Type", "application/json")
|
||||
src.Add("X-Server-Trace", "keep-me")
|
||||
|
||||
dst := http.Header{}
|
||||
copyProxyResponseHeaders(dst, src)
|
||||
|
||||
if got := dst.Get("Connection"); got != "" {
|
||||
t.Fatalf("expected Connection to be stripped, got %q", got)
|
||||
}
|
||||
if got := dst.Get("X-Upstream-Hop"); got != "" {
|
||||
t.Fatalf("expected X-Upstream-Hop to be stripped via Connection token, got %q", got)
|
||||
}
|
||||
if got := dst.Get("Content-Type"); got != "application/json" {
|
||||
t.Fatalf("expected Content-Type to be forwarded, got %q", got)
|
||||
}
|
||||
if got := dst.Get("X-Server-Trace"); got != "keep-me" {
|
||||
t.Fatalf("expected X-Server-Trace to be forwarded, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCloudProxyBaseURL_Default(t *testing.T) {
|
||||
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL("", gin.ReleaseMode)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if overridden {
|
||||
t.Fatal("expected override=false for empty input")
|
||||
}
|
||||
if baseURL != defaultCloudProxyBaseURL {
|
||||
t.Fatalf("expected default base URL %q, got %q", defaultCloudProxyBaseURL, baseURL)
|
||||
}
|
||||
if signingHost != defaultCloudProxySigningHost {
|
||||
t.Fatalf("expected default signing host %q, got %q", defaultCloudProxySigningHost, signingHost)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCloudProxyBaseURL_ReleaseAllowsLoopback(t *testing.T) {
|
||||
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL("http://localhost:8080", gin.ReleaseMode)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !overridden {
|
||||
t.Fatal("expected override=true")
|
||||
}
|
||||
if baseURL != "http://localhost:8080" {
|
||||
t.Fatalf("unexpected base URL: %q", baseURL)
|
||||
}
|
||||
if signingHost != "localhost" {
|
||||
t.Fatalf("unexpected signing host: %q", signingHost)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCloudProxyBaseURL_ReleaseRejectsNonLoopback(t *testing.T) {
|
||||
_, _, _, err := resolveCloudProxyBaseURL("https://example.com", gin.ReleaseMode)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-loopback override in release mode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCloudProxyBaseURL_DevAllowsNonLoopbackHTTPS(t *testing.T) {
|
||||
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL("https://example.com:8443", gin.DebugMode)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !overridden {
|
||||
t.Fatal("expected override=true")
|
||||
}
|
||||
if baseURL != "https://example.com:8443" {
|
||||
t.Fatalf("unexpected base URL: %q", baseURL)
|
||||
}
|
||||
if signingHost != "example.com" {
|
||||
t.Fatalf("unexpected signing host: %q", signingHost)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCloudProxyBaseURL_DevRejectsNonLoopbackHTTP(t *testing.T) {
|
||||
_, _, _, err := resolveCloudProxyBaseURL("http://example.com", gin.DebugMode)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-loopback http override in dev mode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCloudSignatureChallengeIncludesExistingQuery(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, "https://ollama.com/v1/messages?beta=true&foo=bar", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
}
|
||||
|
||||
got := buildCloudSignatureChallenge(req, "123")
|
||||
want := "POST,/v1/messages?beta=true&foo=bar&ts=123"
|
||||
if got != want {
|
||||
t.Fatalf("challenge mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if req.URL.RawQuery != "beta=true&foo=bar&ts=123" {
|
||||
t.Fatalf("unexpected signed query: %q", req.URL.RawQuery)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCloudSignatureChallengeOverwritesExistingTimestamp(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, "https://ollama.com/v1/messages?beta=true&ts=999", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
}
|
||||
|
||||
got := buildCloudSignatureChallenge(req, "123")
|
||||
want := "POST,/v1/messages?beta=true&ts=123"
|
||||
if got != want {
|
||||
t.Fatalf("challenge mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if req.URL.RawQuery != "beta=true&ts=123" {
|
||||
t.Fatalf("unexpected signed query: %q", req.URL.RawQuery)
|
||||
}
|
||||
}
|
||||
@@ -65,11 +65,22 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
config.Parser = r.Parser
|
||||
config.Requires = r.Requires
|
||||
|
||||
for v := range r.Files {
|
||||
for v, digest := range r.Files {
|
||||
if !fs.ValidPath(v) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errFilePath.Error()})
|
||||
return
|
||||
}
|
||||
if digest == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": manifest.ErrInvalidDigestFormat.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for _, digest := range r.Adapters {
|
||||
if digest == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": manifest.ErrInvalidDigestFormat.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
name := model.ParseName(cmp.Or(r.Model, r.Name))
|
||||
@@ -99,19 +110,26 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
|
||||
if r.From != "" {
|
||||
slog.Debug("create model from model name", "from", r.From)
|
||||
fromName := model.ParseName(r.From)
|
||||
if !fromName.IsValid() {
|
||||
fromRef, err := parseAndValidateModelRef(r.From)
|
||||
if err != nil {
|
||||
ch <- gin.H{"error": errtypes.InvalidModelNameErrMsg, "status": http.StatusBadRequest}
|
||||
return
|
||||
}
|
||||
if r.RemoteHost != "" {
|
||||
ru, err := remoteURL(r.RemoteHost)
|
||||
|
||||
fromName := fromRef.Name
|
||||
remoteHost := r.RemoteHost
|
||||
if fromRef.Source == modelSourceCloud && remoteHost == "" {
|
||||
remoteHost = cloudProxyBaseURL
|
||||
}
|
||||
|
||||
if remoteHost != "" {
|
||||
ru, err := remoteURL(remoteHost)
|
||||
if err != nil {
|
||||
ch <- gin.H{"error": "bad remote", "status": http.StatusBadRequest}
|
||||
return
|
||||
}
|
||||
|
||||
config.RemoteModel = r.From
|
||||
config.RemoteModel = fromRef.Base
|
||||
config.RemoteHost = ru
|
||||
remote = true
|
||||
} else {
|
||||
|
||||
@@ -71,6 +71,10 @@ type Model struct {
|
||||
Template *template.Template
|
||||
}
|
||||
|
||||
func (m *Model) IsMLX() bool {
|
||||
return m.Config.ModelFormat == "safetensors"
|
||||
}
|
||||
|
||||
// Capabilities returns the capabilities that the model supports
|
||||
func (m *Model) Capabilities() []model.Capability {
|
||||
capabilities := []model.Capability{}
|
||||
|
||||
81
server/model_resolver.go
Normal file
81
server/model_resolver.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/internal/modelref"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
type modelSource = modelref.ModelSource
|
||||
|
||||
const (
|
||||
modelSourceUnspecified modelSource = modelref.ModelSourceUnspecified
|
||||
modelSourceLocal modelSource = modelref.ModelSourceLocal
|
||||
modelSourceCloud modelSource = modelref.ModelSourceCloud
|
||||
)
|
||||
|
||||
var (
|
||||
errConflictingModelSource = modelref.ErrConflictingSourceSuffix
|
||||
errModelRequired = modelref.ErrModelRequired
|
||||
)
|
||||
|
||||
type parsedModelRef struct {
|
||||
// Original is the caller-provided model string before source parsing.
|
||||
// Example: "gpt-oss:20b:cloud".
|
||||
Original string
|
||||
// Base is the model string after source suffix normalization.
|
||||
// Example: "gpt-oss:20b:cloud" -> "gpt-oss:20b".
|
||||
Base string
|
||||
// Name is Base parsed as a fully-qualified model.Name with defaults applied.
|
||||
// Example: "registry.ollama.ai/library/gpt-oss:20b".
|
||||
Name model.Name
|
||||
// Source captures explicit source intent from the original input.
|
||||
// Example: "gpt-oss:20b:cloud" -> modelSourceCloud.
|
||||
Source modelSource
|
||||
}
|
||||
|
||||
func parseAndValidateModelRef(raw string) (parsedModelRef, error) {
|
||||
var zero parsedModelRef
|
||||
|
||||
parsed, err := modelref.ParseRef(raw)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
|
||||
name := model.ParseName(parsed.Base)
|
||||
if !name.IsValid() {
|
||||
return zero, model.Unqualified(name)
|
||||
}
|
||||
|
||||
return parsedModelRef{
|
||||
Original: parsed.Original,
|
||||
Base: parsed.Base,
|
||||
Name: name,
|
||||
Source: parsed.Source,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseNormalizePullModelRef(raw string) (parsedModelRef, error) {
|
||||
var zero parsedModelRef
|
||||
|
||||
parsedRef, err := modelref.ParseRef(raw)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
|
||||
normalizedName, _, err := modelref.NormalizePullName(raw)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
|
||||
name := model.ParseName(normalizedName)
|
||||
if !name.IsValid() {
|
||||
return zero, model.Unqualified(name)
|
||||
}
|
||||
|
||||
return parsedModelRef{
|
||||
Original: parsedRef.Original,
|
||||
Base: normalizedName,
|
||||
Name: name,
|
||||
Source: parsedRef.Source,
|
||||
}, nil
|
||||
}
|
||||
170
server/model_resolver_test.go
Normal file
170
server/model_resolver_test.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseModelSelector(t *testing.T) {
|
||||
t.Run("cloud suffix", func(t *testing.T) {
|
||||
got, err := parseAndValidateModelRef("gpt-oss:20b:cloud")
|
||||
if err != nil {
|
||||
t.Fatalf("parseModelSelector returned error: %v", err)
|
||||
}
|
||||
|
||||
if got.Source != modelSourceCloud {
|
||||
t.Fatalf("expected source cloud, got %v", got.Source)
|
||||
}
|
||||
|
||||
if got.Base != "gpt-oss:20b" {
|
||||
t.Fatalf("expected base gpt-oss:20b, got %q", got.Base)
|
||||
}
|
||||
|
||||
if got.Name.String() != "registry.ollama.ai/library/gpt-oss:20b" {
|
||||
t.Fatalf("unexpected resolved name: %q", got.Name.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("legacy cloud suffix", func(t *testing.T) {
|
||||
got, err := parseAndValidateModelRef("gpt-oss:20b-cloud")
|
||||
if err != nil {
|
||||
t.Fatalf("parseModelSelector returned error: %v", err)
|
||||
}
|
||||
|
||||
if got.Source != modelSourceCloud {
|
||||
t.Fatalf("expected source cloud, got %v", got.Source)
|
||||
}
|
||||
|
||||
if got.Base != "gpt-oss:20b" {
|
||||
t.Fatalf("expected base gpt-oss:20b, got %q", got.Base)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("bare dash cloud name is not explicit cloud", func(t *testing.T) {
|
||||
got, err := parseAndValidateModelRef("my-cloud-model")
|
||||
if err != nil {
|
||||
t.Fatalf("parseModelSelector returned error: %v", err)
|
||||
}
|
||||
|
||||
if got.Source != modelSourceUnspecified {
|
||||
t.Fatalf("expected source unspecified, got %v", got.Source)
|
||||
}
|
||||
|
||||
if got.Base != "my-cloud-model" {
|
||||
t.Fatalf("expected base my-cloud-model, got %q", got.Base)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("local suffix", func(t *testing.T) {
|
||||
got, err := parseAndValidateModelRef("qwen3:8b:local")
|
||||
if err != nil {
|
||||
t.Fatalf("parseModelSelector returned error: %v", err)
|
||||
}
|
||||
|
||||
if got.Source != modelSourceLocal {
|
||||
t.Fatalf("expected source local, got %v", got.Source)
|
||||
}
|
||||
|
||||
if got.Base != "qwen3:8b" {
|
||||
t.Fatalf("expected base qwen3:8b, got %q", got.Base)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("conflicting source suffixes fail", func(t *testing.T) {
|
||||
_, err := parseAndValidateModelRef("foo:cloud:local")
|
||||
if !errors.Is(err, errConflictingModelSource) {
|
||||
t.Fatalf("expected errConflictingModelSource, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unspecified source", func(t *testing.T) {
|
||||
got, err := parseAndValidateModelRef("llama3")
|
||||
if err != nil {
|
||||
t.Fatalf("parseModelSelector returned error: %v", err)
|
||||
}
|
||||
|
||||
if got.Source != modelSourceUnspecified {
|
||||
t.Fatalf("expected source unspecified, got %v", got.Source)
|
||||
}
|
||||
|
||||
if got.Name.Tag != "latest" {
|
||||
t.Fatalf("expected default latest tag, got %q", got.Name.Tag)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unknown suffix is treated as tag", func(t *testing.T) {
|
||||
got, err := parseAndValidateModelRef("gpt-oss:clod")
|
||||
if err != nil {
|
||||
t.Fatalf("parseModelSelector returned error: %v", err)
|
||||
}
|
||||
|
||||
if got.Source != modelSourceUnspecified {
|
||||
t.Fatalf("expected source unspecified, got %v", got.Source)
|
||||
}
|
||||
|
||||
if got.Name.Tag != "clod" {
|
||||
t.Fatalf("expected tag clod, got %q", got.Name.Tag)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty model fails", func(t *testing.T) {
|
||||
_, err := parseAndValidateModelRef("")
|
||||
if !errors.Is(err, errModelRequired) {
|
||||
t.Fatalf("expected errModelRequired, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid model fails", func(t *testing.T) {
|
||||
_, err := parseAndValidateModelRef("::cloud")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid model")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "unqualified") {
|
||||
t.Fatalf("expected unqualified model error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestParsePullModelRef(t *testing.T) {
|
||||
t.Run("explicit local is normalized", func(t *testing.T) {
|
||||
got, err := parseNormalizePullModelRef("gpt-oss:20b:local")
|
||||
if err != nil {
|
||||
t.Fatalf("parseNormalizePullModelRef returned error: %v", err)
|
||||
}
|
||||
|
||||
if got.Source != modelSourceLocal {
|
||||
t.Fatalf("expected source local, got %v", got.Source)
|
||||
}
|
||||
|
||||
if got.Base != "gpt-oss:20b" {
|
||||
t.Fatalf("expected base gpt-oss:20b, got %q", got.Base)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("explicit cloud with size maps to legacy cloud suffix", func(t *testing.T) {
|
||||
got, err := parseNormalizePullModelRef("gpt-oss:20b:cloud")
|
||||
if err != nil {
|
||||
t.Fatalf("parseNormalizePullModelRef returned error: %v", err)
|
||||
}
|
||||
if got.Base != "gpt-oss:20b-cloud" {
|
||||
t.Fatalf("expected base gpt-oss:20b-cloud, got %q", got.Base)
|
||||
}
|
||||
if got.Name.String() != "registry.ollama.ai/library/gpt-oss:20b-cloud" {
|
||||
t.Fatalf("unexpected resolved name: %q", got.Name.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("explicit cloud without size maps to cloud tag", func(t *testing.T) {
|
||||
got, err := parseNormalizePullModelRef("qwen3:cloud")
|
||||
if err != nil {
|
||||
t.Fatalf("parseNormalizePullModelRef returned error: %v", err)
|
||||
}
|
||||
if got.Base != "qwen3:cloud" {
|
||||
t.Fatalf("expected base qwen3:cloud, got %q", got.Base)
|
||||
}
|
||||
if got.Name.String() != "registry.ollama.ai/library/qwen3:cloud" {
|
||||
t.Fatalf("unexpected resolved name: %q", got.Name.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -30,42 +30,44 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
lastMsgIdx := len(msgs) - 1
|
||||
currMsgIdx := 0
|
||||
|
||||
// Start with all messages and remove from the front until it fits in context
|
||||
for i := 0; i <= lastMsgIdx; i++ {
|
||||
// Collect system messages from the portion we're about to skip
|
||||
system = make([]api.Message, 0)
|
||||
for j := range i {
|
||||
if msgs[j].Role == "system" {
|
||||
system = append(system, msgs[j])
|
||||
if truncate {
|
||||
// Start with all messages and remove from the front until it fits in context
|
||||
for i := 0; i <= lastMsgIdx; i++ {
|
||||
// Collect system messages from the portion we're about to skip
|
||||
system = make([]api.Message, 0)
|
||||
for j := range i {
|
||||
if msgs[j].Role == "system" {
|
||||
system = append(system, msgs[j])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
s, err := tokenize(ctx, p)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
ctxLen := len(s)
|
||||
if m.ProjectorPaths != nil {
|
||||
for _, msg := range msgs[i:] {
|
||||
ctxLen += imageNumTokens * len(msg.Images)
|
||||
p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if !truncate || ctxLen <= opts.NumCtx {
|
||||
currMsgIdx = i
|
||||
break
|
||||
}
|
||||
s, err := tokenize(ctx, p)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
// Must always include at least the last message
|
||||
if i == lastMsgIdx {
|
||||
currMsgIdx = lastMsgIdx
|
||||
break
|
||||
ctxLen := len(s)
|
||||
if m.ProjectorPaths != nil {
|
||||
for _, msg := range msgs[i:] {
|
||||
ctxLen += imageNumTokens * len(msg.Images)
|
||||
}
|
||||
}
|
||||
|
||||
if ctxLen <= opts.NumCtx {
|
||||
currMsgIdx = i
|
||||
break
|
||||
}
|
||||
|
||||
// Must always include at least the last message
|
||||
if i == lastMsgIdx {
|
||||
currMsgIdx = lastMsgIdx
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package server
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
@@ -366,3 +367,33 @@ func TestChatPromptRendererDoesNotRewriteMessageContent(t *testing.T) {
|
||||
t.Fatal("prompt is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatPromptGLMOcrRendererAddsImageTags(t *testing.T) {
|
||||
msgs := []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "extract text",
|
||||
Images: []api.ImageData{[]byte("img-1"), []byte("img-2")},
|
||||
},
|
||||
}
|
||||
|
||||
m := Model{
|
||||
Config: model.ConfigV2{Renderer: "glm-ocr"},
|
||||
ProjectorPaths: []string{"vision"},
|
||||
}
|
||||
opts := api.Options{Runner: api.Runner{NumCtx: 8192}}
|
||||
think := false
|
||||
|
||||
prompt, images, err := chatPrompt(t.Context(), &m, mockRunner{}.Tokenize, &opts, msgs, nil, &api.ThinkValue{Value: think}, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if got, want := len(images), 2; got != want {
|
||||
t.Fatalf("len(images) = %d, want %d", got, want)
|
||||
}
|
||||
|
||||
if !strings.Contains(prompt, "<|user|>\n[img-0][img-1]extract text") {
|
||||
t.Fatalf("prompt missing glm-ocr image tags, got: %q", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
159
server/routes.go
159
server/routes.go
@@ -64,6 +64,17 @@ const (
|
||||
cloudErrRemoteModelDetailsUnavailable = "remote model details are unavailable"
|
||||
)
|
||||
|
||||
func writeModelRefParseError(c *gin.Context, err error, fallbackStatus int, fallbackMessage string) {
|
||||
switch {
|
||||
case errors.Is(err, errConflictingModelSource):
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
case errors.Is(err, model.ErrUnqualifiedName):
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
|
||||
default:
|
||||
c.JSON(fallbackStatus, gin.H{"error": fallbackMessage})
|
||||
}
|
||||
}
|
||||
|
||||
func shouldUseHarmony(model *Model) bool {
|
||||
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
|
||||
// heuristic to check whether the template expects to be parsed via harmony:
|
||||
@@ -196,14 +207,22 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
// Ideally this is "invalid model name" but we're keeping with
|
||||
// what the API currently returns until we can change it.
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
modelRef, err := parseAndValidateModelRef(req.Model)
|
||||
if err != nil {
|
||||
writeModelRefParseError(c, err, http.StatusNotFound, fmt.Sprintf("model '%s' not found", req.Model))
|
||||
return
|
||||
}
|
||||
|
||||
if modelRef.Source == modelSourceCloud {
|
||||
// TODO(drifkin): evaluate an `/api/*` passthrough for cloud where the
|
||||
// original body (modulo model name normalization) is sent to cloud.
|
||||
req.Model = modelRef.Base
|
||||
proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
name := modelRef.Name
|
||||
|
||||
resolvedName, _, err := s.resolveAlias(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -237,6 +256,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if modelRef.Source == modelSourceLocal && m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
return
|
||||
}
|
||||
|
||||
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
|
||||
if disabled, _ := internalcloud.Status(); disabled {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": internalcloud.DisabledError(cloudErrRemoteInferenceUnavailable)})
|
||||
@@ -484,7 +508,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
// the real chat handler, but doing this as a stopgap to get renderer
|
||||
// support for generate
|
||||
if values.Messages != nil && values.Suffix == "" && req.Template == "" {
|
||||
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, req.Truncate == nil || *req.Truncate)
|
||||
genTruncate := (req.Truncate == nil || *req.Truncate) && !m.IsMLX()
|
||||
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, genTruncate)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -675,6 +700,18 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
modelRef, err := parseAndValidateModelRef(req.Model)
|
||||
if err != nil {
|
||||
writeModelRefParseError(c, err, http.StatusNotFound, fmt.Sprintf("model '%s' not found", req.Model))
|
||||
return
|
||||
}
|
||||
|
||||
if modelRef.Source == modelSourceCloud {
|
||||
req.Model = modelRef.Base
|
||||
proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
var input []string
|
||||
|
||||
switch i := req.Input.(type) {
|
||||
@@ -697,7 +734,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
name, err := getExistingName(model.ParseName(req.Model))
|
||||
name, err := getExistingName(modelRef.Name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
return
|
||||
@@ -844,12 +881,20 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
modelRef, err := parseAndValidateModelRef(req.Model)
|
||||
if err != nil {
|
||||
writeModelRefParseError(c, err, http.StatusBadRequest, "model is required")
|
||||
return
|
||||
}
|
||||
|
||||
if modelRef.Source == modelSourceCloud {
|
||||
req.Model = modelRef.Base
|
||||
proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
name := modelRef.Name
|
||||
|
||||
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
|
||||
if err != nil {
|
||||
handleScheduleError(c, req.Model, err)
|
||||
@@ -891,12 +936,19 @@ func (s *Server) PullHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
name := model.ParseName(cmp.Or(req.Model, req.Name))
|
||||
if !name.IsValid() {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
|
||||
// TEMP(drifkin): we're temporarily allowing to continue pulling cloud model
|
||||
// stub-files until we integrate cloud models into `/api/tags` (in which case
|
||||
// this roundabout way of "adding" cloud models won't be needed anymore). So
|
||||
// right here normalize any `:cloud` models into the legacy-style suffixes
|
||||
// `:<tag>-cloud` and `:cloud`
|
||||
modelRef, err := parseNormalizePullModelRef(cmp.Or(req.Model, req.Name))
|
||||
if err != nil {
|
||||
writeModelRefParseError(c, err, http.StatusBadRequest, errtypes.InvalidModelNameErrMsg)
|
||||
return
|
||||
}
|
||||
|
||||
name := modelRef.Name
|
||||
|
||||
name, err = getExistingName(name)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
@@ -1023,13 +1075,20 @@ func (s *Server) DeleteHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
n := model.ParseName(cmp.Or(r.Model, r.Name))
|
||||
if !n.IsValid() {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
|
||||
modelRef, err := parseNormalizePullModelRef(cmp.Or(r.Model, r.Name))
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, errConflictingModelSource):
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
case errors.Is(err, model.ErrUnqualifiedName):
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
|
||||
default:
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
n, err := getExistingName(n)
|
||||
n, err := getExistingName(modelRef.Name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", cmp.Or(r.Model, r.Name))})
|
||||
return
|
||||
@@ -1078,6 +1137,20 @@ func (s *Server) ShowHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
modelRef, err := parseAndValidateModelRef(req.Model)
|
||||
if err != nil {
|
||||
writeModelRefParseError(c, err, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if modelRef.Source == modelSourceCloud {
|
||||
req.Model = modelRef.Base
|
||||
proxyCloudJSONRequest(c, req, cloudErrRemoteModelDetailsUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
req.Model = modelRef.Base
|
||||
|
||||
resp, err := GetModelInfo(req)
|
||||
if err != nil {
|
||||
var statusErr api.StatusError
|
||||
@@ -1094,6 +1167,11 @@ func (s *Server) ShowHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if modelRef.Source == modelSourceLocal && resp.RemoteHost != "" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", modelRef.Original)})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
@@ -1630,18 +1708,20 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
||||
|
||||
// Inference (OpenAI compatibility)
|
||||
r.POST("/v1/chat/completions", middleware.ChatMiddleware(), s.ChatHandler)
|
||||
r.POST("/v1/completions", middleware.CompletionsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/embeddings", middleware.EmbeddingsMiddleware(), s.EmbedHandler)
|
||||
// TODO(cloud-stage-a): apply Modelfile overlay deltas for local models with cloud
|
||||
// parents on v1 request families while preserving this explicit :cloud passthrough.
|
||||
r.POST("/v1/chat/completions", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ChatMiddleware(), s.ChatHandler)
|
||||
r.POST("/v1/completions", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.CompletionsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/embeddings", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.EmbeddingsMiddleware(), s.EmbedHandler)
|
||||
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
||||
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||
r.GET("/v1/models/:model", cloudModelPathPassthroughMiddleware(cloudErrRemoteModelDetailsUnavailable), middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||
r.POST("/v1/responses", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||
// OpenAI-compatible image generation endpoints
|
||||
r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/images/edits", middleware.ImageEditsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/images/generations", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/images/edits", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ImageEditsMiddleware(), s.GenerateHandler)
|
||||
|
||||
// Inference (Anthropic compatibility)
|
||||
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
||||
r.POST("/v1/messages", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
||||
|
||||
if rc != nil {
|
||||
// wrap old with new
|
||||
@@ -1951,6 +2031,9 @@ func (s *Server) PsHandler(c *gin.Context) {
|
||||
}
|
||||
if v.llama != nil {
|
||||
mr.ContextLength = v.llama.ContextLength()
|
||||
total, vram := v.llama.MemorySize()
|
||||
mr.Size = int64(total)
|
||||
mr.SizeVRAM = int64(vram)
|
||||
}
|
||||
// The scheduler waits to set expiresAt, so if a model is loading it's
|
||||
// possible that it will be set to the unix epoch. For those cases, just
|
||||
@@ -1997,12 +2080,24 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
modelRef, err := parseAndValidateModelRef(req.Model)
|
||||
if err != nil {
|
||||
writeModelRefParseError(c, err, http.StatusBadRequest, "model is required")
|
||||
return
|
||||
}
|
||||
|
||||
if modelRef.Source == modelSourceCloud {
|
||||
req.Model = modelRef.Base
|
||||
if c.GetBool(legacyCloudAnthropicKey) {
|
||||
proxyCloudJSONRequestWithPath(c, req, "/api/chat", cloudErrRemoteInferenceUnavailable)
|
||||
return
|
||||
}
|
||||
proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
name := modelRef.Name
|
||||
|
||||
resolvedName, _, err := s.resolveAlias(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -2034,6 +2129,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if modelRef.Source == modelSourceLocal && m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
return
|
||||
}
|
||||
|
||||
// expire the runner
|
||||
if len(req.Messages) == 0 && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
|
||||
s.sched.expireRunner(m)
|
||||
@@ -2213,6 +2313,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
truncate := req.Truncate == nil || *req.Truncate
|
||||
if m.IsMLX() {
|
||||
truncate = false
|
||||
}
|
||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
|
||||
if err != nil {
|
||||
slog.Error("chat prompt error", "error", err)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -144,6 +144,37 @@ func TestCreateFromBin(t *testing.T) {
|
||||
filepath.Join(p, "blobs", "sha256-6bcdb8859d417753645538d7bbfbd7ca91a3f0c191aef5379c53c05e86b669dd"),
|
||||
filepath.Join(p, "blobs", "sha256-89a2116c3a82d6a97f59f748d86ed4417214353fd178ee54df418fde32495fad"),
|
||||
})
|
||||
|
||||
t.Run("empty file digest", func(t *testing.T) {
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Name: "my-gguf-model",
|
||||
Files: map[string]string{"0.gguf": ""},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected status 400, got %d", w.Code)
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "invalid digest format") {
|
||||
t.Errorf("expected invalid digest format error, got:\n%s", w.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty adapter digest", func(t *testing.T) {
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Name: "my-gguf-model",
|
||||
Files: map[string]string{"0.gguf": digest},
|
||||
Adapters: map[string]string{"adapter.gguf": ""},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected status 400, got %d", w.Code)
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "invalid digest format") {
|
||||
t.Errorf("expected invalid digest format error, got:\n%s", w.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCreateFromModel(t *testing.T) {
|
||||
@@ -763,6 +794,43 @@ func TestCreateAndShowRemoteModel(t *testing.T) {
|
||||
fmt.Printf("resp = %#v\n", resp)
|
||||
}
|
||||
|
||||
func TestCreateFromCloudSourceSuffix(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
var s Server
|
||||
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test-cloud-from-suffix",
|
||||
From: "gpt-oss:20b:cloud",
|
||||
Info: map[string]any{
|
||||
"capabilities": []string{"completion"},
|
||||
},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
w = createRequest(t, s.ShowHandler, api.ShowRequest{Model: "test-cloud-from-suffix"})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp api.ShowResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if resp.RemoteHost != "https://ollama.com:443" {
|
||||
t.Fatalf("expected remote host https://ollama.com:443, got %q", resp.RemoteHost)
|
||||
}
|
||||
|
||||
if resp.RemoteModel != "gpt-oss:20b" {
|
||||
t.Fatalf("expected remote model gpt-oss:20b, got %q", resp.RemoteModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateLicenses(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@@ -111,3 +111,32 @@ func TestDeleteDuplicateLayers(t *testing.T) {
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
|
||||
}
|
||||
|
||||
func TestDeleteCloudSourceNormalizesToLegacyName(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
|
||||
var s Server
|
||||
|
||||
_, digest := createBinFile(t, nil, nil)
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Name: "gpt-oss:20b-cloud",
|
||||
Files: map[string]string{"test.gguf": digest},
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "gpt-oss", "20b-cloud"),
|
||||
})
|
||||
|
||||
w = createRequest(t, s.DeleteHandler, api.DeleteRequest{Name: "gpt-oss:20b:cloud"})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d (%s)", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
|
||||
}
|
||||
|
||||
@@ -231,7 +231,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
}
|
||||
|
||||
// Check for experimental safetensors LLM models
|
||||
if pending.model.Config.ModelFormat == "safetensors" {
|
||||
if pending.model.IsMLX() {
|
||||
if slices.Contains(pending.model.Config.Capabilities, "completion") {
|
||||
// LLM model with safetensors format - use MLX runner
|
||||
if s.loadMLX(pending) {
|
||||
@@ -536,6 +536,7 @@ iGPUScan:
|
||||
}
|
||||
}
|
||||
|
||||
totalSize, vramSize := llama.MemorySize()
|
||||
runner := &runnerRef{
|
||||
model: req.model,
|
||||
modelPath: req.model.ModelPath,
|
||||
@@ -545,8 +546,8 @@ iGPUScan:
|
||||
sessionDuration: sessionDuration,
|
||||
gpus: gpuIDs,
|
||||
discreteGPUs: discreteGPUs,
|
||||
vramSize: llama.VRAMSize(),
|
||||
totalSize: llama.TotalSize(),
|
||||
totalSize: totalSize,
|
||||
vramSize: vramSize,
|
||||
loading: true,
|
||||
pid: llama.Pid(),
|
||||
}
|
||||
@@ -619,6 +620,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
||||
sessionDuration = req.sessionDuration.Duration
|
||||
}
|
||||
|
||||
totalSize, vramSize := server.MemorySize()
|
||||
runner := &runnerRef{
|
||||
model: req.model,
|
||||
modelPath: req.model.ModelPath,
|
||||
@@ -628,8 +630,8 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
||||
loading: false,
|
||||
isImagegen: isImagegen,
|
||||
sessionDuration: sessionDuration,
|
||||
totalSize: server.TotalSize(),
|
||||
vramSize: server.VRAMSize(),
|
||||
totalSize: totalSize,
|
||||
vramSize: vramSize,
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
@@ -762,7 +764,7 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
|
||||
defer cancel()
|
||||
if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
|
||||
!reflect.DeepEqual(runner.model.ProjectorPaths, req.model.ProjectorPaths) || // have the projectors changed?
|
||||
!reflect.DeepEqual(optsExisting, optsNew) || // have the runner options changed?
|
||||
(!runner.model.IsMLX() && !reflect.DeepEqual(optsExisting, optsNew)) || // have the runner options changed?
|
||||
runner.llama.Ping(ctx) != nil {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -861,8 +861,7 @@ func (s *mockLlm) Close() error {
|
||||
s.closeCalled = true
|
||||
return s.closeResp
|
||||
}
|
||||
func (s *mockLlm) VRAMSize() uint64 { return s.vramSize }
|
||||
func (s *mockLlm) TotalSize() uint64 { return s.totalSize }
|
||||
func (s *mockLlm) MemorySize() (uint64, uint64) { return s.totalSize, s.vramSize }
|
||||
func (s *mockLlm) VRAMByGPU(id ml.DeviceID) uint64 { return s.vramByGPU[id] }
|
||||
func (s *mockLlm) Pid() int { return -1 }
|
||||
func (s *mockLlm) GetPort() int { return -1 }
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||
"github.com/ollama/ollama/internal/modelref"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/readline"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
@@ -43,7 +44,7 @@ const (
|
||||
// isLocalModel checks if the model is running locally (not a cloud model).
|
||||
// TODO: Improve local/cloud model identification - could check model metadata
|
||||
func isLocalModel(modelName string) bool {
|
||||
return !strings.HasSuffix(modelName, "-cloud")
|
||||
return !modelref.HasExplicitCloudSource(modelName)
|
||||
}
|
||||
|
||||
// isLocalServer checks if connecting to a local Ollama server.
|
||||
|
||||
@@ -22,12 +22,22 @@ func TestIsLocalModel(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "cloud model",
|
||||
modelName: "gpt-4-cloud",
|
||||
modelName: "gpt-oss:latest-cloud",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "cloud model with :cloud suffix",
|
||||
modelName: "gpt-oss:cloud",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "cloud model with version",
|
||||
modelName: "claude-3-cloud",
|
||||
modelName: "gpt-oss:20b-cloud",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "cloud model with version and :cloud suffix",
|
||||
modelName: "gpt-oss:20b:cloud",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
@@ -134,7 +144,7 @@ func TestTruncateToolOutput(t *testing.T) {
|
||||
{
|
||||
name: "long output cloud model - uses 10k limit",
|
||||
output: string(localLimitOutput), // 20k chars, under 10k token limit
|
||||
modelName: "gpt-4-cloud",
|
||||
modelName: "gpt-oss:latest-cloud",
|
||||
host: "",
|
||||
shouldTrim: false,
|
||||
expectedLimit: defaultTokenLimit,
|
||||
@@ -142,7 +152,7 @@ func TestTruncateToolOutput(t *testing.T) {
|
||||
{
|
||||
name: "very long output cloud model - trimmed at 10k",
|
||||
output: string(defaultLimitOutput),
|
||||
modelName: "gpt-4-cloud",
|
||||
modelName: "gpt-oss:latest-cloud",
|
||||
host: "",
|
||||
shouldTrim: true,
|
||||
expectedLimit: defaultTokenLimit,
|
||||
|
||||
@@ -374,14 +374,9 @@ func (s *Server) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// VRAMSize returns the estimated VRAM usage.
|
||||
func (s *Server) VRAMSize() uint64 {
|
||||
return s.vramSize
|
||||
}
|
||||
|
||||
// TotalSize returns the total memory usage.
|
||||
func (s *Server) TotalSize() uint64 {
|
||||
return s.vramSize
|
||||
// MemorySize returns the total and VRAM memory usage.
|
||||
func (s *Server) MemorySize() (total, vram uint64) {
|
||||
return s.vramSize, s.vramSize
|
||||
}
|
||||
|
||||
// VRAMByGPU returns VRAM usage for a specific GPU.
|
||||
|
||||
@@ -78,6 +78,12 @@ func (c *kvCache) findRemaining(tokens []int32) []int32 {
|
||||
prefix++
|
||||
}
|
||||
|
||||
// Always keep at least one token to re-evaluate so the
|
||||
// pipeline can seed token generation from it.
|
||||
if prefix == len(tokens) && prefix > 0 {
|
||||
prefix--
|
||||
}
|
||||
|
||||
if prefix < len(c.tokens) {
|
||||
trim := len(c.tokens) - prefix
|
||||
for _, kv := range c.caches {
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -19,25 +18,27 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
)
|
||||
|
||||
// Client wraps an MLX runner subprocess to implement llm.LlamaServer for LLM models.
|
||||
type Client struct {
|
||||
port int
|
||||
modelName string
|
||||
vramSize uint64
|
||||
done chan error
|
||||
client *http.Client
|
||||
lastErr string
|
||||
lastErrLock sync.Mutex
|
||||
mu sync.Mutex
|
||||
cmd *exec.Cmd
|
||||
port int
|
||||
modelName string
|
||||
contextLength atomic.Int64
|
||||
memory atomic.Uint64
|
||||
done chan error
|
||||
client *http.Client
|
||||
lastErr string
|
||||
lastErrLock sync.Mutex
|
||||
mu sync.Mutex
|
||||
cmd *exec.Cmd
|
||||
}
|
||||
|
||||
// NewClient spawns a new MLX runner subprocess for LLM models and waits until it's ready.
|
||||
@@ -98,18 +99,9 @@ func NewClient(modelName string) (*Client, error) {
|
||||
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
|
||||
}
|
||||
|
||||
// Estimate VRAM based on tensor size from manifest
|
||||
var vramSize uint64
|
||||
if modelManifest, err := manifest.LoadManifest(modelName); err == nil {
|
||||
vramSize = uint64(modelManifest.TotalTensorSize())
|
||||
} else {
|
||||
vramSize = 8 * 1024 * 1024 * 1024
|
||||
}
|
||||
|
||||
c := &Client{
|
||||
port: port,
|
||||
modelName: modelName,
|
||||
vramSize: vramSize,
|
||||
done: make(chan error, 1),
|
||||
client: &http.Client{Timeout: 10 * time.Minute},
|
||||
cmd: cmd,
|
||||
@@ -201,6 +193,19 @@ type completionOpts struct {
|
||||
NumPredict int `json:"num_predict,omitempty"`
|
||||
}
|
||||
|
||||
type CompletionResponse struct {
|
||||
Content string
|
||||
Done bool
|
||||
DoneReason int
|
||||
|
||||
PromptEvalCount int
|
||||
PromptEvalDuration time.Duration
|
||||
EvalCount int
|
||||
EvalDuration time.Duration
|
||||
|
||||
Error *api.StatusError
|
||||
}
|
||||
|
||||
// Close terminates the subprocess.
|
||||
func (c *Client) Close() error {
|
||||
c.mu.Lock()
|
||||
@@ -260,28 +265,24 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
var raw struct {
|
||||
Content string `json:"content,omitempty"`
|
||||
Done bool `json:"done"`
|
||||
DoneReason int `json:"done_reason,omitempty"`
|
||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||
PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
|
||||
EvalCount int `json:"eval_count,omitempty"`
|
||||
EvalDuration int `json:"eval_duration,omitempty"`
|
||||
}
|
||||
var raw CompletionResponse
|
||||
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
|
||||
slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
|
||||
continue
|
||||
}
|
||||
|
||||
if raw.Error != nil {
|
||||
return *raw.Error
|
||||
}
|
||||
|
||||
cresp := llm.CompletionResponse{
|
||||
Content: raw.Content,
|
||||
Done: raw.Done,
|
||||
DoneReason: llm.DoneReason(raw.DoneReason),
|
||||
PromptEvalCount: raw.PromptEvalCount,
|
||||
PromptEvalDuration: time.Duration(raw.PromptEvalDuration),
|
||||
PromptEvalDuration: raw.PromptEvalDuration,
|
||||
EvalCount: raw.EvalCount,
|
||||
EvalDuration: time.Duration(raw.EvalDuration),
|
||||
EvalDuration: raw.EvalDuration,
|
||||
}
|
||||
|
||||
fn(cresp)
|
||||
@@ -294,7 +295,7 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
||||
}
|
||||
|
||||
func (c *Client) ContextLength() int {
|
||||
return math.MaxInt
|
||||
return int(c.contextLength.Load())
|
||||
}
|
||||
|
||||
// Detokenize implements llm.LlamaServer.
|
||||
@@ -347,9 +348,16 @@ func (c *Client) Pid() int {
|
||||
return -1
|
||||
}
|
||||
|
||||
type statusResponse struct {
|
||||
Status int
|
||||
Progress int
|
||||
ContextLength int
|
||||
Memory uint64
|
||||
}
|
||||
|
||||
// Ping implements llm.LlamaServer.
|
||||
func (c *Client) Ping(ctx context.Context) error {
|
||||
reqURL := fmt.Sprintf("http://127.0.0.1:%d/health", c.port)
|
||||
reqURL := fmt.Sprintf("http://127.0.0.1:%d/v1/status", c.port)
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -362,6 +370,15 @@ func (c *Client) Ping(ctx context.Context) error {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("health check failed: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var status statusResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&status); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.contextLength.Store(int64(status.ContextLength))
|
||||
c.memory.Store(status.Memory)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -388,19 +405,24 @@ func (c *Client) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// TotalSize implements llm.LlamaServer.
|
||||
func (c *Client) TotalSize() uint64 {
|
||||
return c.vramSize
|
||||
func (c *Client) currentMemory() uint64 {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
if err := c.Ping(ctx); err != nil {
|
||||
slog.Warn("failed to get current memory", "error", err)
|
||||
}
|
||||
return c.memory.Load()
|
||||
}
|
||||
|
||||
// MemorySize implements llm.LlamaServer.
|
||||
func (c *Client) MemorySize() (total, vram uint64) {
|
||||
mem := c.currentMemory()
|
||||
return mem, mem
|
||||
}
|
||||
|
||||
// VRAMByGPU implements llm.LlamaServer.
|
||||
func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||
return c.vramSize
|
||||
}
|
||||
|
||||
// VRAMSize implements llm.LlamaServer.
|
||||
func (c *Client) VRAMSize() uint64 {
|
||||
return c.vramSize
|
||||
return c.currentMemory()
|
||||
}
|
||||
|
||||
// WaitUntilRunning implements llm.LlamaServer.
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
type Array struct {
|
||||
ctx C.mlx_array
|
||||
name string
|
||||
pinned bool
|
||||
pinned int
|
||||
}
|
||||
|
||||
var arrays []*Array
|
||||
@@ -129,7 +129,7 @@ func (t *Array) Clone() *Array {
|
||||
func Pin(s ...*Array) {
|
||||
for _, t := range s {
|
||||
if t != nil {
|
||||
t.pinned = true
|
||||
t.pinned++
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -138,7 +138,7 @@ func Pin(s ...*Array) {
|
||||
func Unpin(s ...*Array) {
|
||||
for _, t := range s {
|
||||
if t != nil {
|
||||
t.pinned = false
|
||||
t.pinned--
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -148,7 +148,7 @@ func Unpin(s ...*Array) {
|
||||
func Sweep() {
|
||||
n := 0
|
||||
for _, t := range arrays {
|
||||
if t.pinned && t.Valid() {
|
||||
if t.pinned > 0 && t.Valid() {
|
||||
arrays[n] = t
|
||||
n++
|
||||
} else if t.Valid() {
|
||||
@@ -175,7 +175,7 @@ func (t *Array) String() string {
|
||||
func (t *Array) LogValue() slog.Value {
|
||||
attrs := []slog.Attr{
|
||||
slog.String("name", t.name),
|
||||
slog.Bool("pinned", t.pinned),
|
||||
slog.Int("pinned", t.pinned),
|
||||
}
|
||||
if t.Valid() {
|
||||
attrs = append(attrs,
|
||||
|
||||
@@ -64,6 +64,10 @@ func PeakMemory() int {
|
||||
return int(peak)
|
||||
}
|
||||
|
||||
func ResetPeakMemory() {
|
||||
C.mlx_reset_peak_memory()
|
||||
}
|
||||
|
||||
type Memory struct{}
|
||||
|
||||
func (Memory) LogValue() slog.Value {
|
||||
|
||||
@@ -20,6 +20,7 @@ type Model interface {
|
||||
Unembed(x *mlx.Array) *mlx.Array
|
||||
NumLayers() int
|
||||
Tokenizer() *tokenizer.Tokenizer
|
||||
MaxContextLength() int
|
||||
|
||||
// LoadWeights receives all tensors loaded from the manifest and assigns
|
||||
// them to model fields. Model-specific logic (MLA absorption, expert
|
||||
|
||||
@@ -6,9 +6,12 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
@@ -18,6 +21,17 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
return errors.New("model not loaded")
|
||||
}
|
||||
|
||||
enableCompile := true
|
||||
if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok {
|
||||
enableCompile = modelCompile.EnableCompile()
|
||||
}
|
||||
if enableCompile {
|
||||
mlx.EnableCompile()
|
||||
} else {
|
||||
mlx.DisableCompile()
|
||||
}
|
||||
mlx.ResetPeakMemory()
|
||||
|
||||
var (
|
||||
sample, logprobs *mlx.Array
|
||||
nextSample, nextLogprobs *mlx.Array
|
||||
@@ -33,27 +47,36 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
mlx.LogArrays()
|
||||
r.cache.log()
|
||||
}
|
||||
slog.Info("peak memory", "size", mlx.PrettyBytes(mlx.PeakMemory()))
|
||||
}()
|
||||
|
||||
enableCompile := true
|
||||
if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok {
|
||||
enableCompile = modelCompile.EnableCompile()
|
||||
}
|
||||
if enableCompile {
|
||||
mlx.EnableCompile()
|
||||
} else {
|
||||
mlx.DisableCompile()
|
||||
inputs := r.Tokenizer.Encode(request.Prompt, true)
|
||||
if len(inputs) == 0 {
|
||||
return errors.New("empty prompt")
|
||||
}
|
||||
|
||||
if len(inputs) >= r.contextLength {
|
||||
return api.StatusError{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
ErrorMessage: fmt.Sprintf("input length (%d tokens) exceeds the model's maximum context length (%d tokens)", len(inputs), r.contextLength),
|
||||
}
|
||||
}
|
||||
|
||||
// Cap generation to stay within the model's context length
|
||||
maxGenerate := r.contextLength - len(inputs)
|
||||
if request.Options.MaxTokens <= 0 {
|
||||
request.Options.MaxTokens = maxGenerate
|
||||
} else {
|
||||
request.Options.MaxTokens = min(request.Options.MaxTokens, maxGenerate)
|
||||
}
|
||||
|
||||
inputs := r.Tokenizer.Encode(request.Prompt, true)
|
||||
session := r.cache.begin(r.Model, inputs)
|
||||
defer session.close()
|
||||
|
||||
caches := session.caches
|
||||
tokens := session.remaining
|
||||
|
||||
now := time.Now()
|
||||
total, processed := len(tokens), 0
|
||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
||||
for total-processed > 1 {
|
||||
if err := request.Ctx.Err(); err != nil {
|
||||
return err
|
||||
@@ -93,8 +116,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
|
||||
var b bytes.Buffer
|
||||
|
||||
now := time.Now()
|
||||
final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
|
||||
final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.MaxTokens, DoneReason: 1}
|
||||
for i := range request.Options.MaxTokens {
|
||||
if err := request.Ctx.Err(); err != nil {
|
||||
return err
|
||||
@@ -103,9 +125,8 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
nextSample, nextLogprobs = step(sample)
|
||||
|
||||
if i == 0 {
|
||||
slog.Info("Prompt processing progress", "processed", total, "total", total)
|
||||
mlx.Eval(sample)
|
||||
final.PromptTokensDuration = time.Since(now)
|
||||
final.PromptEvalDuration = time.Since(now)
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
@@ -113,18 +134,16 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
session.outputs = append(session.outputs, output)
|
||||
|
||||
if r.Tokenizer.IsEOS(output) {
|
||||
final.Token = int(output)
|
||||
final.DoneReason = 0
|
||||
final.CompletionTokens = i
|
||||
final.EvalCount = i
|
||||
break
|
||||
}
|
||||
|
||||
select {
|
||||
case <-request.Ctx.Done():
|
||||
return request.Ctx.Err()
|
||||
case request.Responses <- Response{
|
||||
Text: r.Decode(output, &b),
|
||||
Token: int(output),
|
||||
case request.Responses <- CompletionResponse{
|
||||
Content: r.Decode(output, &b),
|
||||
}:
|
||||
}
|
||||
|
||||
@@ -137,7 +156,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
}
|
||||
}
|
||||
|
||||
final.CompletionTokensDuration = time.Since(now)
|
||||
final.EvalDuration = time.Since(now)
|
||||
select {
|
||||
case <-request.Ctx.Done():
|
||||
return request.Ctx.Err()
|
||||
|
||||
@@ -4,14 +4,15 @@ package mlxrunner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
@@ -21,7 +22,7 @@ import (
|
||||
|
||||
type Request struct {
|
||||
TextCompletionsRequest
|
||||
Responses chan Response
|
||||
Responses chan CompletionResponse
|
||||
Pipeline func(Request) error
|
||||
|
||||
Ctx context.Context
|
||||
@@ -43,25 +44,12 @@ type TextCompletionsRequest struct {
|
||||
} `json:"options"`
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
Text string `json:"content,omitempty"`
|
||||
Token int `json:"token,omitempty"`
|
||||
Logprobs []float32 `json:"logprobs,omitempty"`
|
||||
Done bool `json:"done,omitempty"`
|
||||
DoneReason int `json:"done_reason,omitempty"`
|
||||
|
||||
PromptTokens int `json:"prompt_eval_count,omitempty"`
|
||||
PromptTokensDuration time.Duration `json:"prompt_eval_duration,omitempty"`
|
||||
CompletionTokens int `json:"eval_count,omitempty"`
|
||||
CompletionTokensDuration time.Duration `json:"eval_duration,omitempty"`
|
||||
TotalTokens int `json:"total_tokens,omitempty"`
|
||||
}
|
||||
|
||||
type Runner struct {
|
||||
Model base.Model
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
Requests chan Request
|
||||
cache kvCache
|
||||
Model base.Model
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
Requests chan Request
|
||||
cache kvCache
|
||||
contextLength int
|
||||
}
|
||||
|
||||
func (r *Runner) Load(modelName string) error {
|
||||
@@ -90,6 +78,7 @@ func (r *Runner) Load(modelName string) error {
|
||||
|
||||
r.Model = m
|
||||
r.Tokenizer = m.Tokenizer()
|
||||
r.contextLength = m.MaxContextLength()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -158,6 +147,17 @@ func (r *Runner) Run(host, port string, mux http.Handler) error {
|
||||
case request := <-r.Requests:
|
||||
if err := request.Pipeline(request); err != nil {
|
||||
slog.Info("Request terminated", "error", err)
|
||||
var statusErr api.StatusError
|
||||
if !errors.As(err, &statusErr) {
|
||||
statusErr = api.StatusError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
ErrorMessage: err.Error(),
|
||||
}
|
||||
}
|
||||
select {
|
||||
case request.Responses <- CompletionResponse{Error: &statusErr}:
|
||||
case <-request.Ctx.Done():
|
||||
}
|
||||
}
|
||||
|
||||
close(request.Responses)
|
||||
|
||||
@@ -50,9 +50,11 @@ func Execute(args []string) error {
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewEncoder(w).Encode(map[string]any{
|
||||
"status": 0,
|
||||
"progress": 100,
|
||||
if err := json.NewEncoder(w).Encode(statusResponse{
|
||||
Status: 0,
|
||||
Progress: 100,
|
||||
ContextLength: runner.contextLength,
|
||||
Memory: uint64(mlx.ActiveMemory() + mlx.CacheMemory()),
|
||||
}); err != nil {
|
||||
slog.Error("Failed to encode response", "error", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
@@ -78,7 +80,7 @@ func Execute(args []string) error {
|
||||
})
|
||||
|
||||
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||
request := Request{Responses: make(chan Response)}
|
||||
request := Request{Responses: make(chan CompletionResponse)}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil {
|
||||
slog.Error("Failed to decode request", "error", err)
|
||||
@@ -87,9 +89,6 @@ func Execute(args []string) error {
|
||||
}
|
||||
|
||||
request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict)
|
||||
if request.Options.MaxTokens < 1 {
|
||||
request.Options.MaxTokens = 16 << 10
|
||||
}
|
||||
|
||||
request.Pipeline = runner.TextGenerationPipeline
|
||||
request.Sampler = sample.New(
|
||||
|
||||
@@ -430,6 +430,10 @@ func (m *Model) NumLayers() int {
|
||||
return len(m.Layers)
|
||||
}
|
||||
|
||||
func (m *Model) MaxContextLength() int {
|
||||
return int(m.MaxPositionEmbeddings)
|
||||
}
|
||||
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
||||
return m.tok
|
||||
}
|
||||
|
||||
@@ -733,7 +733,7 @@ func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
|
||||
func (m *Model) NumLayers() int { return len(m.Layers) }
|
||||
|
||||
// MaxContextLength returns the maximum context length
|
||||
func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
|
||||
func (m *Model) MaxContextLength() int { return int(m.MaxPositionEmbeddings) }
|
||||
|
||||
// VocabSize returns the vocabulary size
|
||||
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
|
||||
|
||||
@@ -262,6 +262,10 @@ func (m *Model) NumLayers() int {
|
||||
return len(m.Layers)
|
||||
}
|
||||
|
||||
func (m *Model) MaxContextLength() int {
|
||||
return int(m.MaxPositionEmbeddings)
|
||||
}
|
||||
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
||||
return m.tok
|
||||
}
|
||||
|
||||
@@ -279,6 +279,10 @@ func (m *Model) NumLayers() int {
|
||||
return len(m.Layers)
|
||||
}
|
||||
|
||||
func (m *Model) MaxContextLength() int {
|
||||
return int(m.MaxPositionEmbeddings)
|
||||
}
|
||||
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
||||
return m.tok
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user