Compare commits

..

3 Commits

Author SHA1 Message Date
Jesse Gross
c61023f554 ollamarunner: Fix off by one error with numPredict
When numPredict is set, the user will receive one less token
than the requested limit. In addition, the stats will incorrectly
show the number of tokens returned as the limit. In cases where
numPredict is not set, the number of tokens is reported correctly.

This occurs because numPredict is checked when setting up the next
batch but hitting the limit will terminate the current batch as well.
Instead, is is better to check the limit as we actually predict them.
2026-02-04 17:14:24 -08:00
Jeffrey Morgan
d25535c3f3 qwen3next: avoid inplace sigmoid for shared gate (#14077) 2026-02-04 15:50:02 -08:00
Bruce MacDonald
c323161f24 cmd: helpful error message for remote models (#14057)
When trying to use cloud model with OLLAMA_HOST="ollama.com" while not signed in a helpful error message is displayed when the user is not signed in telling them they must sign in to use cloud models. This should be the same experience for models which specify a remote instance.
2026-02-04 14:55:11 -08:00
10 changed files with 175 additions and 817 deletions

View File

@@ -367,14 +367,25 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
return err
} else if info.RemoteHost != "" {
// Cloud model, no need to load/unload
isCloud := strings.HasPrefix(info.RemoteHost, "https://ollama.com")
// Check if user is signed in for ollama.com cloud models
if isCloud {
if _, err := client.Whoami(cmd.Context()); err != nil {
return err
}
}
if opts.ShowConnect {
p.StopAndClear()
if strings.HasPrefix(info.RemoteHost, "https://ollama.com") {
if isCloud {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", info.RemoteModel)
} else {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", info.RemoteModel, info.RemoteHost)
}
}
return nil
}

View File

@@ -3,6 +3,7 @@ package cmd
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
@@ -1659,3 +1660,103 @@ func TestRunOptions_Copy_Independence(t *testing.T) {
t.Error("Copy Think should not be affected by original modification")
}
}
func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
tests := []struct {
name string
remoteHost string
whoamiStatus int
whoamiResp any
expectedError string
}{
{
name: "ollama.com cloud model - user signed in",
remoteHost: "https://ollama.com",
whoamiStatus: http.StatusOK,
whoamiResp: api.UserResponse{Name: "testuser"},
},
{
name: "ollama.com cloud model - user not signed in",
remoteHost: "https://ollama.com",
whoamiStatus: http.StatusUnauthorized,
whoamiResp: map[string]string{
"error": "unauthorized",
"signin_url": "https://ollama.com/signin",
},
expectedError: "unauthorized",
},
{
name: "non-ollama.com remote - no auth check",
remoteHost: "https://other-remote.com",
whoamiStatus: http.StatusUnauthorized, // should not be called
whoamiResp: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
whoamiCalled := false
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/show":
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(api.ShowResponse{
RemoteHost: tt.remoteHost,
RemoteModel: "test-model",
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
case "/api/me":
whoamiCalled = true
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(tt.whoamiStatus)
if tt.whoamiResp != nil {
if err := json.NewEncoder(w).Encode(tt.whoamiResp); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
default:
http.NotFound(w, r)
}
}))
defer mockServer.Close()
t.Setenv("OLLAMA_HOST", mockServer.URL)
cmd := &cobra.Command{}
cmd.SetContext(t.Context())
opts := &runOptions{
Model: "test-cloud-model",
ShowConnect: false,
}
err := loadOrUnloadModel(cmd, opts)
if strings.HasPrefix(tt.remoteHost, "https://ollama.com") {
if !whoamiCalled {
t.Error("expected whoami to be called for ollama.com cloud model")
}
} else {
if whoamiCalled {
t.Error("whoami should not be called for non-ollama.com remote")
}
}
if tt.expectedError != "" {
if err == nil {
t.Errorf("expected error containing %q, got nil", tt.expectedError)
} else {
var authErr api.AuthorizationError
if !errors.As(err, &authErr) {
t.Errorf("expected AuthorizationError, got %T: %v", err, err)
}
}
} else {
if err != nil {
t.Errorf("expected no error, got %v", err)
}
}
})
}
}

View File

@@ -48,7 +48,6 @@ var integrations = map[string]Runner{
"droid": &Droid{},
"opencode": &OpenCode{},
"openclaw": &Openclaw{},
"pi": &Pi{},
}
// recommendedModels are shown when the user has no models or as suggestions.
@@ -64,7 +63,6 @@ var recommendedModels = []selectItem{
var integrationAliases = map[string]bool{
"clawdbot": true,
"moltbot": true,
"pi": true,
}
func selectIntegration() (string, error) {

View File

@@ -1,196 +0,0 @@
package config
import (
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"slices"
"github.com/ollama/ollama/envconfig"
)
// Pi implements Runner and Editor for Pi (Pi Coding Agent) integration
type Pi struct{}
func (p *Pi) String() string { return "Pi" }
func (p *Pi) Run(model string, args []string) error {
if _, err := exec.LookPath("pi"); err != nil {
return fmt.Errorf("pi is not installed, install with: npm install -g @mariozechner/pi-coding-agent")
}
// Call Edit() to ensure config is up-to-date before launch
models := []string{model}
if config, err := loadIntegration("pi"); err == nil && len(config.Models) > 0 {
models = config.Models
}
if err := p.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("pi", args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func (p *Pi) Paths() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
var paths []string
modelsPath := filepath.Join(home, ".pi", "agent", "models.json")
if _, err := os.Stat(modelsPath); err == nil {
paths = append(paths, modelsPath)
}
settingsPath := filepath.Join(home, ".pi", "agent", "settings.json")
if _, err := os.Stat(settingsPath); err == nil {
paths = append(paths, settingsPath)
}
return paths
}
func (p *Pi) Edit(models []string) error {
if len(models) == 0 {
return nil
}
home, err := os.UserHomeDir()
if err != nil {
return err
}
configPath := filepath.Join(home, ".pi", "agent", "models.json")
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
config := make(map[string]any)
if data, err := os.ReadFile(configPath); err == nil {
_ = json.Unmarshal(data, &config)
}
providers, ok := config["providers"].(map[string]any)
if !ok {
providers = make(map[string]any)
}
ollama, ok := providers["ollama"].(map[string]any)
if !ok {
ollama = map[string]any{
"baseUrl": envconfig.Host().String() + "/v1",
"api": "openai-completions",
"apiKey": "ollama",
}
}
existingModels, ok := ollama["models"].([]any)
if !ok {
existingModels = make([]any, 0)
}
// Build set of selected models to track which need to be added
selectedSet := make(map[string]bool, len(models))
for _, m := range models {
selectedSet[m] = true
}
// Build new models list:
// 1. Keep user-managed models (no _launch marker) - untouched
// 2. Keep ollama-managed models (_launch marker) that are still selected
// 3. Add new ollama-managed models
var newModels []any
for _, m := range existingModels {
if modelObj, ok := m.(map[string]any); ok {
if id, ok := modelObj["id"].(string); ok {
// User-managed model (no _launch marker) - always preserve
if !isPiOllamaModel(modelObj) {
newModels = append(newModels, m)
} else if selectedSet[id] {
// Ollama-managed and still selected - keep it
newModels = append(newModels, m)
selectedSet[id] = false
}
}
}
}
// Add newly selected models that weren't already in the list
for _, model := range models {
if selectedSet[model] {
newModels = append(newModels, map[string]any{
"id": model,
"_launch": true,
})
}
}
ollama["models"] = newModels
providers["ollama"] = ollama
config["providers"] = providers
configData, err := json.MarshalIndent(config, "", " ")
if err != nil {
return err
}
if err := writeWithBackup(configPath, configData); err != nil {
return err
}
// Update settings.json with default provider and model
settingsPath := filepath.Join(home, ".pi", "agent", "settings.json")
settings := make(map[string]any)
if data, err := os.ReadFile(settingsPath); err == nil {
_ = json.Unmarshal(data, &settings)
}
settings["defaultProvider"] = "ollama"
settings["defaultModel"] = models[0]
settingsData, err := json.MarshalIndent(settings, "", " ")
if err != nil {
return err
}
return writeWithBackup(settingsPath, settingsData)
}
func (p *Pi) Models() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
configPath := filepath.Join(home, ".pi", "agent", "models.json")
config, err := readJSONFile(configPath)
if err != nil {
return nil
}
providers, _ := config["providers"].(map[string]any)
ollama, _ := providers["ollama"].(map[string]any)
models, _ := ollama["models"].([]any)
var result []string
for _, m := range models {
if modelObj, ok := m.(map[string]any); ok {
if id, ok := modelObj["id"].(string); ok {
result = append(result, id)
}
}
}
slices.Sort(result)
return result
}
// isPiOllamaModel reports whether a model config entry is managed by ollama launch
func isPiOllamaModel(cfg map[string]any) bool {
if v, ok := cfg["_launch"].(bool); ok && v {
return true
}
return false
}

View File

@@ -1,609 +0,0 @@
package config
import (
"encoding/json"
"os"
"path/filepath"
"testing"
)
func TestPiIntegration(t *testing.T) {
pi := &Pi{}
t.Run("String", func(t *testing.T) {
if got := pi.String(); got != "Pi" {
t.Errorf("String() = %q, want %q", got, "Pi")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = pi
})
t.Run("implements Editor", func(t *testing.T) {
var _ Editor = pi
})
}
func TestPiPaths(t *testing.T) {
pi := &Pi{}
t.Run("returns empty when no config exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
paths := pi.Paths()
if len(paths) != 0 {
t.Errorf("Paths() = %v, want empty", paths)
}
})
t.Run("returns path when config exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".pi", "agent")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
configPath := filepath.Join(configDir, "models.json")
if err := os.WriteFile(configPath, []byte("{}"), 0o644); err != nil {
t.Fatal(err)
}
paths := pi.Paths()
if len(paths) != 1 || paths[0] != configPath {
t.Errorf("Paths() = %v, want [%s]", paths, configPath)
}
})
}
func TestPiEdit(t *testing.T) {
pi := &Pi{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".pi", "agent")
configPath := filepath.Join(configDir, "models.json")
cleanup := func() {
os.RemoveAll(configDir)
}
readConfig := func() map[string]any {
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
return cfg
}
t.Run("returns nil for empty models", func(t *testing.T) {
if err := pi.Edit([]string{}); err != nil {
t.Errorf("Edit([]) error = %v, want nil", err)
}
})
t.Run("creates config with models", func(t *testing.T) {
cleanup()
models := []string{"llama3.2", "qwen3:8b"}
if err := pi.Edit(models); err != nil {
t.Fatalf("Edit() error = %v", err)
}
cfg := readConfig()
providers, ok := cfg["providers"].(map[string]any)
if !ok {
t.Error("Config missing providers")
}
ollama, ok := providers["ollama"].(map[string]any)
if !ok {
t.Error("Providers missing ollama")
}
modelsArray, ok := ollama["models"].([]any)
if !ok || len(modelsArray) != 2 {
t.Errorf("Expected 2 models, got %v", modelsArray)
}
if ollama["baseUrl"] == nil {
t.Error("Missing baseUrl")
}
if ollama["api"] != "openai-completions" {
t.Errorf("Expected api=openai-completions, got %v", ollama["api"])
}
if ollama["apiKey"] != "ollama" {
t.Errorf("Expected apiKey=ollama, got %v", ollama["apiKey"])
}
})
t.Run("updates existing config preserving ollama provider settings", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
existingConfig := `{
"providers": {
"ollama": {
"baseUrl": "http://custom:8080/v1",
"api": "custom-api",
"apiKey": "custom-key",
"models": [
{"id": "old-model", "_launch": true}
]
}
}
}`
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
t.Fatal(err)
}
models := []string{"new-model"}
if err := pi.Edit(models); err != nil {
t.Fatalf("Edit() error = %v", err)
}
cfg := readConfig()
providers := cfg["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
if ollama["baseUrl"] != "http://custom:8080/v1" {
t.Errorf("Custom baseUrl not preserved, got %v", ollama["baseUrl"])
}
if ollama["api"] != "custom-api" {
t.Errorf("Custom api not preserved, got %v", ollama["api"])
}
if ollama["apiKey"] != "custom-key" {
t.Errorf("Custom apiKey not preserved, got %v", ollama["apiKey"])
}
modelsArray := ollama["models"].([]any)
if len(modelsArray) != 1 {
t.Errorf("Expected 1 model after update, got %d", len(modelsArray))
} else {
modelEntry := modelsArray[0].(map[string]any)
if modelEntry["id"] != "new-model" {
t.Errorf("Expected new-model, got %v", modelEntry["id"])
}
// Verify _launch marker is present
if modelEntry["_launch"] != true {
t.Errorf("Expected _launch marker to be true")
}
}
})
t.Run("replaces old models with new ones", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
// Old models must have _launch marker to be managed by us
existingConfig := `{
"providers": {
"ollama": {
"baseUrl": "http://localhost:11434/v1",
"api": "openai-completions",
"apiKey": "ollama",
"models": [
{"id": "old-model-1", "_launch": true},
{"id": "old-model-2", "_launch": true}
]
}
}
}`
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
t.Fatal(err)
}
newModels := []string{"new-model-1", "new-model-2"}
if err := pi.Edit(newModels); err != nil {
t.Fatalf("Edit() error = %v", err)
}
cfg := readConfig()
providers := cfg["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelsArray := ollama["models"].([]any)
if len(modelsArray) != 2 {
t.Errorf("Expected 2 models, got %d", len(modelsArray))
}
modelIDs := make(map[string]bool)
for _, m := range modelsArray {
modelObj := m.(map[string]any)
id := modelObj["id"].(string)
modelIDs[id] = true
}
if !modelIDs["new-model-1"] || !modelIDs["new-model-2"] {
t.Errorf("Expected new models, got %v", modelIDs)
}
if modelIDs["old-model-1"] || modelIDs["old-model-2"] {
t.Errorf("Old models should have been removed, got %v", modelIDs)
}
})
t.Run("handles partial overlap in model list", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
// Models must have _launch marker to be managed
existingConfig := `{
"providers": {
"ollama": {
"baseUrl": "http://localhost:11434/v1",
"api": "openai-completions",
"apiKey": "ollama",
"models": [
{"id": "keep-model", "_launch": true},
{"id": "remove-model", "_launch": true}
]
}
}
}`
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
t.Fatal(err)
}
newModels := []string{"keep-model", "add-model"}
if err := pi.Edit(newModels); err != nil {
t.Fatalf("Edit() error = %v", err)
}
cfg := readConfig()
providers := cfg["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelsArray := ollama["models"].([]any)
if len(modelsArray) != 2 {
t.Errorf("Expected 2 models, got %d", len(modelsArray))
}
modelIDs := make(map[string]bool)
for _, m := range modelsArray {
modelObj := m.(map[string]any)
id := modelObj["id"].(string)
modelIDs[id] = true
}
if !modelIDs["keep-model"] || !modelIDs["add-model"] {
t.Errorf("Expected keep-model and add-model, got %v", modelIDs)
}
if modelIDs["remove-model"] {
t.Errorf("remove-model should have been removed")
}
})
t.Run("handles corrupt config gracefully", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
if err := os.WriteFile(configPath, []byte("{invalid json}"), 0o644); err != nil {
t.Fatal(err)
}
models := []string{"test-model"}
if err := pi.Edit(models); err != nil {
t.Fatalf("Edit() should not fail with corrupt config, got %v", err)
}
data, err := os.ReadFile(configPath)
if err != nil {
t.Fatalf("Failed to read config: %v", err)
}
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Fatalf("Config should be valid after Edit, got parse error: %v", err)
}
providers := cfg["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelsArray := ollama["models"].([]any)
if len(modelsArray) != 1 {
t.Errorf("Expected 1 model, got %d", len(modelsArray))
}
})
// CRITICAL SAFETY TEST: verifies we don't stomp on user configs
t.Run("preserves user-managed models without _launch marker", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
// User has manually configured models in ollama provider (no _launch marker)
existingConfig := `{
"providers": {
"ollama": {
"baseUrl": "http://localhost:11434/v1",
"api": "openai-completions",
"apiKey": "ollama",
"models": [
{"id": "user-model-1"},
{"id": "user-model-2", "customField": "preserved"},
{"id": "ollama-managed", "_launch": true}
]
}
}
}`
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
t.Fatal(err)
}
// Add a new ollama-managed model
newModels := []string{"new-ollama-model"}
if err := pi.Edit(newModels); err != nil {
t.Fatalf("Edit() error = %v", err)
}
cfg := readConfig()
providers := cfg["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelsArray := ollama["models"].([]any)
// Should have: new-ollama-model (managed) + 2 user models (preserved)
if len(modelsArray) != 3 {
t.Errorf("Expected 3 models (1 new managed + 2 preserved user models), got %d", len(modelsArray))
}
modelIDs := make(map[string]map[string]any)
for _, m := range modelsArray {
modelObj := m.(map[string]any)
id := modelObj["id"].(string)
modelIDs[id] = modelObj
}
// Verify new model has _launch marker
if m, ok := modelIDs["new-ollama-model"]; !ok {
t.Errorf("new-ollama-model should be present")
} else if m["_launch"] != true {
t.Errorf("new-ollama-model should have _launch marker")
}
// Verify user models are preserved
if _, ok := modelIDs["user-model-1"]; !ok {
t.Errorf("user-model-1 should be preserved")
}
if _, ok := modelIDs["user-model-2"]; !ok {
t.Errorf("user-model-2 should be preserved")
} else if modelIDs["user-model-2"]["customField"] != "preserved" {
t.Errorf("user-model-2 customField should be preserved")
}
// Verify old ollama-managed model is removed (not in new list)
if _, ok := modelIDs["ollama-managed"]; ok {
t.Errorf("ollama-managed should be removed (old ollama model not in new selection)")
}
})
t.Run("updates settings.json with default provider and model", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
// Create existing settings with other fields
settingsPath := filepath.Join(configDir, "settings.json")
existingSettings := `{
"theme": "dark",
"customSetting": "value",
"defaultProvider": "anthropic",
"defaultModel": "claude-3"
}`
if err := os.WriteFile(settingsPath, []byte(existingSettings), 0o644); err != nil {
t.Fatal(err)
}
models := []string{"llama3.2"}
if err := pi.Edit(models); err != nil {
t.Fatalf("Edit() error = %v", err)
}
data, err := os.ReadFile(settingsPath)
if err != nil {
t.Fatalf("Failed to read settings: %v", err)
}
var settings map[string]any
if err := json.Unmarshal(data, &settings); err != nil {
t.Fatalf("Failed to parse settings: %v", err)
}
// Verify defaultProvider is set to ollama
if settings["defaultProvider"] != "ollama" {
t.Errorf("defaultProvider = %v, want ollama", settings["defaultProvider"])
}
// Verify defaultModel is set to first model
if settings["defaultModel"] != "llama3.2" {
t.Errorf("defaultModel = %v, want llama3.2", settings["defaultModel"])
}
// Verify other fields are preserved
if settings["theme"] != "dark" {
t.Errorf("theme = %v, want dark (preserved)", settings["theme"])
}
if settings["customSetting"] != "value" {
t.Errorf("customSetting = %v, want value (preserved)", settings["customSetting"])
}
})
t.Run("creates settings.json if it does not exist", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
models := []string{"qwen3:8b"}
if err := pi.Edit(models); err != nil {
t.Fatalf("Edit() error = %v", err)
}
settingsPath := filepath.Join(configDir, "settings.json")
data, err := os.ReadFile(settingsPath)
if err != nil {
t.Fatalf("settings.json should be created: %v", err)
}
var settings map[string]any
if err := json.Unmarshal(data, &settings); err != nil {
t.Fatalf("Failed to parse settings: %v", err)
}
if settings["defaultProvider"] != "ollama" {
t.Errorf("defaultProvider = %v, want ollama", settings["defaultProvider"])
}
if settings["defaultModel"] != "qwen3:8b" {
t.Errorf("defaultModel = %v, want qwen3:8b", settings["defaultModel"])
}
})
t.Run("handles corrupt settings.json gracefully", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
// Create corrupt settings
settingsPath := filepath.Join(configDir, "settings.json")
if err := os.WriteFile(settingsPath, []byte("{invalid"), 0o644); err != nil {
t.Fatal(err)
}
models := []string{"test-model"}
if err := pi.Edit(models); err != nil {
t.Fatalf("Edit() should not fail with corrupt settings, got %v", err)
}
data, err := os.ReadFile(settingsPath)
if err != nil {
t.Fatalf("Failed to read settings: %v", err)
}
var settings map[string]any
if err := json.Unmarshal(data, &settings); err != nil {
t.Fatalf("settings.json should be valid after Edit, got parse error: %v", err)
}
if settings["defaultProvider"] != "ollama" {
t.Errorf("defaultProvider = %v, want ollama", settings["defaultProvider"])
}
if settings["defaultModel"] != "test-model" {
t.Errorf("defaultModel = %v, want test-model", settings["defaultModel"])
}
})
}
func TestPiModels(t *testing.T) {
pi := &Pi{}
t.Run("returns nil when no config exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
models := pi.Models()
if models != nil {
t.Errorf("Models() = %v, want nil", models)
}
})
t.Run("returns models from config", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".pi", "agent")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
config := `{
"providers": {
"ollama": {
"models": [
{"id": "llama3.2"},
{"id": "qwen3:8b"}
]
}
}
}`
configPath := filepath.Join(configDir, "models.json")
if err := os.WriteFile(configPath, []byte(config), 0o644); err != nil {
t.Fatal(err)
}
models := pi.Models()
if len(models) != 2 {
t.Errorf("Models() returned %d models, want 2", len(models))
}
if models[0] != "llama3.2" || models[1] != "qwen3:8b" {
t.Errorf("Models() = %v, want [llama3.2 qwen3:8b] (sorted)", models)
}
})
t.Run("returns sorted models", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".pi", "agent")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
config := `{
"providers": {
"ollama": {
"models": [
{"id": "z-model"},
{"id": "a-model"},
{"id": "m-model"}
]
}
}
}`
configPath := filepath.Join(configDir, "models.json")
if err := os.WriteFile(configPath, []byte(config), 0o644); err != nil {
t.Fatal(err)
}
models := pi.Models()
if models[0] != "a-model" || models[1] != "m-model" || models[2] != "z-model" {
t.Errorf("Models() = %v, want [a-model m-model z-model] (sorted)", models)
}
})
t.Run("returns nil when models array is missing", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".pi", "agent")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
config := `{
"providers": {
"ollama": {}
}
}`
configPath := filepath.Join(configDir, "models.json")
if err := os.WriteFile(configPath, []byte(config), 0o644); err != nil {
t.Fatal(err)
}
models := pi.Models()
if models != nil {
t.Errorf("Models() = %v, want nil when models array is missing", models)
}
})
t.Run("handles corrupt config gracefully", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".pi", "agent")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
configPath := filepath.Join(configDir, "models.json")
if err := os.WriteFile(configPath, []byte("{invalid json}"), 0o644); err != nil {
t.Fatal(err)
}
models := pi.Models()
if models != nil {
t.Errorf("Models() = %v, want nil for corrupt config", models)
}
})
}

View File

@@ -144,3 +144,47 @@ func TestUnicodeModelDir(t *testing.T) {
}
ChatTestHelper(ctx, t, req, blueSkyExpected)
}
// TestNumPredict verifies that when num_predict is set, the model generates
// exactly that many tokens. It uses logprobs to count the actual tokens output.
func TestNumPredict(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, "qwen3:0.6b"); err != nil {
t.Fatalf("failed to pull model: %v", err)
}
req := api.GenerateRequest{
Model: "qwen3:0.6b",
Prompt: "Write a long story.",
Stream: &stream,
Logprobs: true,
Options: map[string]any{
"num_predict": 10,
"temperature": 0,
"seed": 123,
},
}
logprobCount := 0
var finalResponse api.GenerateResponse
err := client.Generate(ctx, &req, func(resp api.GenerateResponse) error {
logprobCount += len(resp.Logprobs)
if resp.Done {
finalResponse = resp
}
return nil
})
if err != nil {
t.Fatalf("generate failed: %v", err)
}
if logprobCount != 10 {
t.Errorf("expected 10 tokens (logprobs), got %d (EvalCount=%d, DoneReason=%s)",
logprobCount, finalResponse.EvalCount, finalResponse.DoneReason)
}
}

View File

@@ -175,6 +175,7 @@ type Tensor interface {
SILU(ctx Context, up ...Tensor) Tensor
RELU(ctx Context, up ...Tensor) Tensor
Sigmoid(ctx Context) Tensor
SigmoidOut(ctx Context) Tensor
// AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit]
SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor

View File

@@ -1468,6 +1468,13 @@ func (t *Tensor) Sigmoid(ctx ml.Context) ml.Tensor {
}
}
func (t *Tensor) SigmoidOut(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_sigmoid(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
switch len(shape) {
case 1:

View File

@@ -135,7 +135,7 @@ func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options
// Apply shared expert gating
if mlp.SharedGateInp != nil {
sharedGateVal := mlp.SharedGateInp.Forward(ctx, hiddenStates2D)
sharedGateVal = sharedGateVal.Sigmoid(ctx)
sharedGateVal = sharedGateVal.SigmoidOut(ctx)
// Broadcast gate to match dimensions
sharedGateVal = sharedGateVal.Repeat(ctx, 0, sharedOut.Dim(0))
sharedOut = sharedOut.Mul(ctx, sharedGateVal)

View File

@@ -514,13 +514,6 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
continue
}
// if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
s.removeSequence(seqIdx, llm.DoneReasonLength)
nextBatch.seqs[seqIdx] = nil
continue
}
if !s.cache.enabled {
seq.inputs = append(seq.cache.Inputs, seq.inputs...)
seq.cache.Inputs = []*input.Input{}
@@ -709,7 +702,6 @@ func (s *Server) computeBatch(activeBatch batchState) {
continue
}
seq.numPredicted++
nextToken := &input.Input{Token: 0} // placeholder we'll fill in after Compute/Floats
seq.inputs = []*input.Input{nextToken}
nextBatchTokens[i] = nextToken
@@ -745,7 +737,9 @@ func (s *Server) computeBatch(activeBatch batchState) {
logutil.Trace("computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i)
continue
}
seq.lastUpdatedAt = t
seq.numPredicted++
if seq.numPredicted == 1 {
seq.processingDuration = seq.lastUpdatedAt.Sub(seq.startedAt)
seq.startedAt = seq.lastUpdatedAt
@@ -791,6 +785,13 @@ func (s *Server) computeBatch(activeBatch batchState) {
}
seq.pendingResponses = append(seq.pendingResponses, piece)
// if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
s.removeSequence(i, llm.DoneReasonLength)
continue
}
sequence := strings.Join(seq.pendingResponses, "")
if ok, stop := common.FindStop(sequence, seq.stop); ok {