Compare commits

..

1 Commits

Author SHA1 Message Date
ParthSareen
3d140a1f9f cmd/config: offer to pull missing models instead of erroring 2026-02-05 22:51:56 -08:00
8 changed files with 177 additions and 132 deletions

View File

@@ -182,7 +182,7 @@ option(MLX_ENGINE "Enable MLX backend" OFF)
if(MLX_ENGINE)
message(STATUS "Setting up MLX (this takes a while...)")
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/imagegen/mlx)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/ml/backend/mlx)
# Find CUDA toolkit if MLX is built with CUDA support
find_package(CUDAToolkit)
@@ -216,4 +216,4 @@ if(MLX_ENGINE)
COMPONENT MLX)
endif()
endif()
endif()
endif()

View File

@@ -1,7 +1,6 @@
package config
import (
"context"
"encoding/json"
"fmt"
"os"
@@ -9,7 +8,6 @@ import (
"path/filepath"
"slices"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
)
@@ -114,17 +112,9 @@ func (d *Droid) Edit(models []string) error {
}
// Build new Ollama model entries with sequential indices (0, 1, 2, ...)
client, _ := api.ClientFromEnvironment()
var newModels []any
var defaultModelID string
for i, model := range models {
maxOutput := 64000
if isCloudModel(context.Background(), client, model) {
if l, ok := lookupCloudModelLimit(model); ok {
maxOutput = l.Output
}
}
modelID := fmt.Sprintf("custom:%s-%d", model, i)
newModels = append(newModels, modelEntry{
Model: model,
@@ -132,7 +122,7 @@ func (d *Droid) Edit(models []string) error {
BaseURL: envconfig.Host().String() + "/v1",
APIKey: "ollama",
Provider: "generic-chat-completion-api",
MaxOutputTokens: maxOutput,
MaxOutputTokens: 64000,
SupportsImages: false,
ID: modelID,
Index: i,

View File

@@ -1251,55 +1251,6 @@ func TestDroidEdit_LargeNumberOfModels(t *testing.T) {
}
}
func TestDroidEdit_LocalModelDefaultMaxOutput(t *testing.T) {
d := &Droid{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
settingsDir := filepath.Join(tmpDir, ".factory")
settingsPath := filepath.Join(settingsDir, "settings.json")
if err := d.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(settingsPath)
var settings map[string]any
json.Unmarshal(data, &settings)
models := settings["customModels"].([]any)
entry := models[0].(map[string]any)
if entry["maxOutputTokens"] != float64(64000) {
t.Errorf("local model maxOutputTokens = %v, want 64000", entry["maxOutputTokens"])
}
}
func TestDroidEdit_CloudModelLimitsUsed(t *testing.T) {
// Verify that every cloud model in cloudModelLimits has a valid output
// value that would be used for maxOutputTokens when isCloudModel returns true.
// :cloud suffix stripping must also work since that's how users specify them.
for name, expected := range cloudModelLimits {
t.Run(name, func(t *testing.T) {
l, ok := lookupCloudModelLimit(name)
if !ok {
t.Fatalf("lookupCloudModelLimit(%q) returned false", name)
}
if l.Output != expected.Output {
t.Errorf("output = %d, want %d", l.Output, expected.Output)
}
// Also verify :cloud suffix lookup
cloudName := name + ":cloud"
l2, ok := lookupCloudModelLimit(cloudName)
if !ok {
t.Fatalf("lookupCloudModelLimit(%q) returned false", cloudName)
}
if l2.Output != expected.Output {
t.Errorf(":cloud output = %d, want %d", l2.Output, expected.Output)
}
})
}
}
func TestDroidEdit_ArraysWithMixedTypes(t *testing.T) {
d := &Droid{}
tmpDir := t.TempDir()

View File

@@ -194,6 +194,20 @@ func pullIfNeeded(ctx context.Context, client *api.Client, existingModels map[st
return nil
}
// showOrPull checks if a model exists via client.Show and offers to pull it if not found.
func showOrPull(ctx context.Context, client *api.Client, model string) error {
if _, err := client.Show(ctx, &api.ShowRequest{Model: model}); err == nil {
return nil
}
if ok, err := confirmPrompt(fmt.Sprintf("Download %s?", model)); err != nil {
return err
} else if !ok {
return errCancelled
}
fmt.Fprintf(os.Stderr, "\n")
return pullModel(ctx, client, model)
}
func listModels(ctx context.Context) ([]selectItem, map[string]bool, map[string]bool, *api.Client, error) {
client, err := api.ClientFromEnvironment()
if err != nil {
@@ -397,8 +411,11 @@ Examples:
// Validate --model flag if provided
if modelFlag != "" {
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: modelFlag}); err != nil {
return fmt.Errorf("model %q not found", modelFlag)
if err := showOrPull(cmd.Context(), client, modelFlag); err != nil {
if errors.Is(err, errCancelled) {
return nil
}
return err
}
}
@@ -424,9 +441,11 @@ Examples:
// Validate saved model still exists
if model != "" && modelFlag == "" {
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: model}); err != nil {
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Model: model}); err != nil {
fmt.Fprintf(os.Stderr, "%sConfigured model %q not found%s\n\n", ansiGray, model, ansiReset)
model = ""
if err := showOrPull(cmd.Context(), client, model); err != nil {
model = ""
}
}
}
@@ -443,6 +462,13 @@ Examples:
existingAliases = aliases
}
// Ensure cloud models are authenticated
if isCloudModel(cmd.Context(), client, model) {
if err := ensureAuth(cmd.Context(), client, map[string]bool{model: true}, []string{model}); err != nil {
return err
}
}
// Sync aliases and save
if err := syncAliases(cmd.Context(), client, ac, name, model, existingAliases); err != nil {
fmt.Fprintf(os.Stderr, "%sWarning: Could not sync aliases: %v%s\n", ansiGray, err, ansiReset)
@@ -467,8 +493,11 @@ Examples:
if err != nil {
return err
}
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: modelFlag}); err != nil {
return fmt.Errorf("model %q not found", modelFlag)
if err := showOrPull(cmd.Context(), client, modelFlag); err != nil {
if errors.Is(err, errCancelled) {
return nil
}
return err
}
}
@@ -650,7 +679,7 @@ func isCloudModel(ctx context.Context, client *api.Client, name string) bool {
if client == nil {
return false
}
resp, err := client.Show(ctx, &api.ShowRequest{Name: name})
resp, err := client.Show(ctx, &api.ShowRequest{Model: name})
if err != nil {
return false
}

View File

@@ -2,12 +2,17 @@ package config
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"slices"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/spf13/cobra"
)
@@ -539,3 +544,136 @@ func TestAliasConfigurerInterface(t *testing.T) {
}
})
}
func TestShowOrPull_ModelExists(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{"model":"test-model"}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
err := showOrPull(context.Background(), client, "test-model")
if err != nil {
t.Errorf("showOrPull should return nil when model exists, got: %v", err)
}
}
func TestShowOrPull_ModelNotFound_NoTerminal(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, `{"error":"model not found"}`)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
// confirmPrompt will fail in test (no terminal), so showOrPull should return an error
err := showOrPull(context.Background(), client, "missing-model")
if err == nil {
t.Error("showOrPull should return error when model not found and no terminal available")
}
}
func TestShowOrPull_ShowCalledWithCorrectModel(t *testing.T) {
var receivedModel string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
var req api.ShowRequest
if err := json.NewDecoder(r.Body).Decode(&req); err == nil {
receivedModel = req.Model
}
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{"model":"%s"}`, receivedModel)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
_ = showOrPull(context.Background(), client, "qwen3:8b")
if receivedModel != "qwen3:8b" {
t.Errorf("expected Show to be called with %q, got %q", "qwen3:8b", receivedModel)
}
}
func TestEnsureAuth_NoCloudModels(t *testing.T) {
// ensureAuth should be a no-op when no cloud models are selected
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Error("no API calls expected when no cloud models selected")
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
err := ensureAuth(context.Background(), client, map[string]bool{}, []string{"local-model"})
if err != nil {
t.Errorf("ensureAuth should return nil for non-cloud models, got: %v", err)
}
}
func TestEnsureAuth_CloudModelFilteredCorrectly(t *testing.T) {
// ensureAuth should only care about models in cloudModels map
var whoamiCalled bool
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/me" {
whoamiCalled = true
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{"name":"testuser"}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cloudModels := map[string]bool{"cloud-model:cloud": true}
selected := []string{"cloud-model:cloud", "local-model"}
err := ensureAuth(context.Background(), client, cloudModels, selected)
if err != nil {
t.Errorf("ensureAuth should succeed when user is authenticated, got: %v", err)
}
if !whoamiCalled {
t.Error("expected whoami to be called for cloud model")
}
}
func TestEnsureAuth_SkipsWhenNoCloudSelected(t *testing.T) {
var whoamiCalled bool
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/me" {
whoamiCalled = true
}
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
// cloudModels has entries but none are in selected
cloudModels := map[string]bool{"cloud-model:cloud": true}
selected := []string{"local-model"}
err := ensureAuth(context.Background(), client, cloudModels, selected)
if err != nil {
t.Errorf("expected nil error, got: %v", err)
}
if whoamiCalled {
t.Error("whoami should not be called when no cloud models are selected")
}
}

View File

@@ -39,7 +39,6 @@ var cloudModelLimits = map[string]cloudModelLimit{
"kimi-k2-thinking": {Context: 262_144, Output: 262_144},
"nemotron-3-nano:30b": {Context: 1_048_576, Output: 131_072},
"qwen3-coder:480b": {Context: 262_144, Output: 65_536},
"qwen3-coder-next": {Context: 262_144, Output: 32_768},
"qwen3-next:80b": {Context: 262_144, Output: 32_768},
}

View File

@@ -633,7 +633,6 @@ func TestLookupCloudModelLimit(t *testing.T) {
{"deepseek-v3.2", true, 163_840, 65_536},
{"deepseek-v3.2:cloud", true, 163_840, 65_536},
{"qwen3-coder:480b", true, 262_144, 65_536},
{"qwen3-coder-next:cloud", true, 262_144, 32_768},
{"llama3.2", false, 0, 0},
{"unknown-model:cloud", false, 0, 0},
}

View File

@@ -1,61 +0,0 @@
include(FetchContent)
# Read MLX version from top-level file (shared with Dockerfile)
file(READ "${CMAKE_SOURCE_DIR}/MLX_VERSION" MLX_C_GIT_TAG)
string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG)
set(MLX_C_BUILD_EXAMPLES OFF)
set(MLX_BUILD_GGUF OFF)
set(MLX_BUILD_SAFETENSORS ON)
function(set_target_output_directory _target)
if(TARGET ${_target})
set_target_properties(${_target} PROPERTIES
RUNTIME_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR}
LIBRARY_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR}
ARCHIVE_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR}
)
endif()
endfunction()
# Check for Metal support (macOS only)
if(CMAKE_SYSTEM_NAME MATCHES "Darwin")
execute_process(
COMMAND
zsh "-c"
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
if(NOT MLX_METAL_VERSION)
message(STATUS "`xcrun metal` error. Setting MLX_BUILD_METAL=OFF")
set(MLX_BUILD_METAL OFF)
endif()
else()
# On Linux, disable Metal backend
message(STATUS "Non-macOS platform detected. Setting MLX_BUILD_METAL=OFF")
set(MLX_BUILD_METAL OFF)
endif()
# Map CMAKE_CUDA_ARCHITECTURES to MLX_CUDA_ARCHITECTURES if not explicitly set
if(NOT MLX_CUDA_ARCHITECTURES AND CMAKE_CUDA_ARCHITECTURES)
set(MLX_CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES})
message(STATUS "Using CMAKE_CUDA_ARCHITECTURES for MLX: ${MLX_CUDA_ARCHITECTURES}")
endif()
# Enable CUDA backend if CUDA architectures are specified and CUDA compiler is available
if(MLX_CUDA_ARCHITECTURES AND CMAKE_CUDA_COMPILER)
set(MLX_BUILD_CUDA ON CACHE BOOL "Build CUDA backend for MLX" FORCE)
message(STATUS "Enabling MLX CUDA backend with architectures: ${MLX_CUDA_ARCHITECTURES}")
elseif(MLX_CUDA_ARCHITECTURES)
message(WARNING "MLX_CUDA_ARCHITECTURES specified but CUDA compiler not found, CUDA backend will be disabled")
endif()
FetchContent_Declare(
mlx-c
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
GIT_TAG ${MLX_C_GIT_TAG})
FetchContent_MakeAvailable(mlx-c)
set_target_output_directory(mlx)
set_target_output_directory(mlxc)