mirror of
https://github.com/ollama/ollama.git
synced 2026-02-05 21:23:43 -05:00
Compare commits
7 Commits
mxyng/mlx-
...
pdevine/la
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c52a977500 | ||
|
|
b4023686ae | ||
|
|
afa6c9f2da | ||
|
|
b09df8b79c | ||
|
|
82f4699516 | ||
|
|
9f751ea89d | ||
|
|
c16ef7a1af |
22
.github/workflows/test-install.yaml
vendored
22
.github/workflows/test-install.yaml
vendored
@@ -1,22 +0,0 @@
|
||||
name: test-install
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'scripts/install.sh'
|
||||
- '.github/workflows/test-install.yaml'
|
||||
|
||||
jobs:
|
||||
test:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest]
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Run install script
|
||||
run: sh ./scripts/install.sh
|
||||
env:
|
||||
OLLAMA_NO_START: 1 # do not start app
|
||||
- name: Verify ollama is available
|
||||
run: ollama --version
|
||||
@@ -518,26 +518,24 @@ func mapStopReason(reason string, hasToolCalls bool) string {
|
||||
|
||||
// StreamConverter manages state for converting Ollama streaming responses to Anthropic format
|
||||
type StreamConverter struct {
|
||||
ID string
|
||||
Model string
|
||||
firstWrite bool
|
||||
contentIndex int
|
||||
inputTokens int
|
||||
outputTokens int
|
||||
estimatedInputTokens int // Estimated tokens from request (used when actual metrics are 0)
|
||||
thinkingStarted bool
|
||||
thinkingDone bool
|
||||
textStarted bool
|
||||
toolCallsSent map[string]bool
|
||||
ID string
|
||||
Model string
|
||||
firstWrite bool
|
||||
contentIndex int
|
||||
inputTokens int
|
||||
outputTokens int
|
||||
thinkingStarted bool
|
||||
thinkingDone bool
|
||||
textStarted bool
|
||||
toolCallsSent map[string]bool
|
||||
}
|
||||
|
||||
func NewStreamConverter(id, model string, estimatedInputTokens int) *StreamConverter {
|
||||
func NewStreamConverter(id, model string) *StreamConverter {
|
||||
return &StreamConverter{
|
||||
ID: id,
|
||||
Model: model,
|
||||
firstWrite: true,
|
||||
estimatedInputTokens: estimatedInputTokens,
|
||||
toolCallsSent: make(map[string]bool),
|
||||
ID: id,
|
||||
Model: model,
|
||||
firstWrite: true,
|
||||
toolCallsSent: make(map[string]bool),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -553,11 +551,7 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
||||
|
||||
if c.firstWrite {
|
||||
c.firstWrite = false
|
||||
// Use actual metrics if available, otherwise use estimate
|
||||
c.inputTokens = r.Metrics.PromptEvalCount
|
||||
if c.inputTokens == 0 && c.estimatedInputTokens > 0 {
|
||||
c.inputTokens = c.estimatedInputTokens
|
||||
}
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "message_start",
|
||||
@@ -785,123 +779,3 @@ func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
// CountTokensRequest represents an Anthropic count_tokens request
|
||||
type CountTokensRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []MessageParam `json:"messages"`
|
||||
System any `json:"system,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
// EstimateInputTokens estimates input tokens from a MessagesRequest (reuses CountTokensRequest logic)
|
||||
func EstimateInputTokens(req MessagesRequest) int {
|
||||
return estimateTokens(CountTokensRequest{
|
||||
Model: req.Model,
|
||||
Messages: req.Messages,
|
||||
System: req.System,
|
||||
Tools: req.Tools,
|
||||
Thinking: req.Thinking,
|
||||
})
|
||||
}
|
||||
|
||||
// CountTokensResponse represents an Anthropic count_tokens response
|
||||
type CountTokensResponse struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
}
|
||||
|
||||
// estimateTokens returns a rough estimate of tokens (len/4).
|
||||
// TODO: Replace with actual tokenization via Tokenize API for accuracy.
|
||||
// Current len/4 heuristic is a rough approximation (~4 chars/token average).
|
||||
func estimateTokens(req CountTokensRequest) int {
|
||||
var totalLen int
|
||||
|
||||
// Count system prompt
|
||||
if req.System != nil {
|
||||
totalLen += countAnyContent(req.System)
|
||||
}
|
||||
|
||||
// Count messages
|
||||
for _, msg := range req.Messages {
|
||||
// Count role (always present)
|
||||
totalLen += len(msg.Role)
|
||||
// Count content
|
||||
contentLen := countAnyContent(msg.Content)
|
||||
totalLen += contentLen
|
||||
}
|
||||
|
||||
for _, tool := range req.Tools {
|
||||
totalLen += len(tool.Name) + len(tool.Description) + len(tool.InputSchema)
|
||||
}
|
||||
|
||||
// Return len/4 as rough token estimate, minimum 1 if there's any content
|
||||
tokens := totalLen / 4
|
||||
if tokens == 0 && (len(req.Messages) > 0 || req.System != nil) {
|
||||
tokens = 1
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
func countAnyContent(content any) int {
|
||||
if content == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
switch c := content.(type) {
|
||||
case string:
|
||||
return len(c)
|
||||
case []any:
|
||||
total := 0
|
||||
for _, block := range c {
|
||||
total += countContentBlock(block)
|
||||
}
|
||||
return total
|
||||
default:
|
||||
if data, err := json.Marshal(content); err == nil {
|
||||
return len(data)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func countContentBlock(block any) int {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
if s, ok := block.(string); ok {
|
||||
return len(s)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
total := 0
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
|
||||
if text, ok := blockMap["text"].(string); ok {
|
||||
total += len(text)
|
||||
}
|
||||
|
||||
if thinking, ok := blockMap["thinking"].(string); ok {
|
||||
total += len(thinking)
|
||||
}
|
||||
|
||||
if blockType == "tool_use" {
|
||||
if data, err := json.Marshal(blockMap); err == nil {
|
||||
total += len(data)
|
||||
}
|
||||
}
|
||||
|
||||
if blockType == "tool_result" {
|
||||
if data, err := json.Marshal(blockMap); err == nil {
|
||||
total += len(data)
|
||||
}
|
||||
}
|
||||
|
||||
if source, ok := blockMap["source"].(map[string]any); ok {
|
||||
if data, ok := source["data"].(string); ok {
|
||||
total += len(data)
|
||||
}
|
||||
}
|
||||
|
||||
return total
|
||||
}
|
||||
|
||||
@@ -321,6 +321,8 @@ func TestFromMessagesRequest_WithThinking(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestFromMessagesRequest_ThinkingOnlyBlock verifies that messages containing only
|
||||
// a thinking block (no text, images, or tool calls) are preserved and not dropped.
|
||||
func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
@@ -603,7 +605,7 @@ func TestGenerateMessageID(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStreamConverter_Basic(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
// First chunk
|
||||
resp1 := api.ChatResponse{
|
||||
@@ -676,7 +678,7 @@ func TestStreamConverter_Basic(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStreamConverter_WithToolCalls(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
@@ -729,7 +731,7 @@ func TestStreamConverter_WithToolCalls(t *testing.T) {
|
||||
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
|
||||
// Test that unmarshalable arguments (like channels) are handled gracefully
|
||||
// and don't cause a panic or corrupt stream
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
// Create a channel which cannot be JSON marshaled
|
||||
unmarshalable := make(chan int)
|
||||
@@ -776,7 +778,7 @@ func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
|
||||
|
||||
func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
|
||||
// Test that valid tool calls still work when mixed with invalid ones
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
unmarshalable := make(chan int)
|
||||
badArgs := api.NewToolCallFunctionArguments()
|
||||
@@ -840,6 +842,10 @@ func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestContentBlockJSON_EmptyFieldsPresent verifies that empty text and thinking fields
|
||||
// are serialized in JSON output. The Anthropic SDK requires these fields to be present
|
||||
// (even when empty) in content_block_start events to properly accumulate streaming deltas.
|
||||
// Without these fields, the SDK throws: "TypeError: unsupported operand type(s) for +=: 'NoneType' and 'str'"
|
||||
func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -893,9 +899,11 @@ func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamConverter_ContentBlockStartIncludesEmptyFields verifies that content_block_start
|
||||
// events include the required empty fields for SDK compatibility.
|
||||
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
||||
t.Run("text block start includes empty text", func(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
@@ -929,7 +937,7 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("thinking block start includes empty thinking", func(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
@@ -961,105 +969,3 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestEstimateTokens_SimpleMessage(t *testing.T) {
|
||||
req := CountTokensRequest{
|
||||
Model: "test-model",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello, world!"},
|
||||
},
|
||||
}
|
||||
|
||||
tokens := estimateTokens(req)
|
||||
|
||||
// "user" (4) + "Hello, world!" (13) = 17 chars / 4 = 4 tokens
|
||||
if tokens < 1 {
|
||||
t.Errorf("expected at least 1 token, got %d", tokens)
|
||||
}
|
||||
// Sanity check: shouldn't be wildly off
|
||||
if tokens > 10 {
|
||||
t.Errorf("expected fewer than 10 tokens for short message, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateTokens_WithSystemPrompt(t *testing.T) {
|
||||
req := CountTokensRequest{
|
||||
Model: "test-model",
|
||||
System: "You are a helpful assistant.",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
}
|
||||
|
||||
tokens := estimateTokens(req)
|
||||
|
||||
// System prompt adds to count
|
||||
if tokens < 5 {
|
||||
t.Errorf("expected at least 5 tokens with system prompt, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateTokens_WithTools(t *testing.T) {
|
||||
req := CountTokensRequest{
|
||||
Model: "test-model",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
},
|
||||
Tools: []Tool{
|
||||
{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather for a location",
|
||||
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}}}`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tokens := estimateTokens(req)
|
||||
|
||||
// Tools add significant content
|
||||
if tokens < 10 {
|
||||
t.Errorf("expected at least 10 tokens with tools, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateTokens_WithThinking(t *testing.T) {
|
||||
req := CountTokensRequest{
|
||||
Model: "test-model",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "thinking",
|
||||
"thinking": "Let me think about this carefully...",
|
||||
},
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "Here is my response.",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tokens := estimateTokens(req)
|
||||
|
||||
// Thinking content should be counted
|
||||
if tokens < 10 {
|
||||
t.Errorf("expected at least 10 tokens with thinking content, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateTokens_EmptyContent(t *testing.T) {
|
||||
req := CountTokensRequest{
|
||||
Model: "test-model",
|
||||
Messages: []MessageParam{},
|
||||
}
|
||||
|
||||
tokens := estimateTokens(req)
|
||||
|
||||
if tokens != 0 {
|
||||
t.Errorf("expected 0 tokens for empty content, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -466,25 +466,3 @@ func (c *Client) Whoami(ctx context.Context) (*UserResponse, error) {
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// AliasRequest is the request body for creating or updating a model alias.
|
||||
type AliasRequest struct {
|
||||
Alias string `json:"alias"`
|
||||
Target string `json:"target"`
|
||||
PrefixMatching bool `json:"prefix_matching,omitempty"`
|
||||
}
|
||||
|
||||
// SetAliasExperimental creates or updates a model alias via the experimental aliases API.
|
||||
func (c *Client) SetAliasExperimental(ctx context.Context, req *AliasRequest) error {
|
||||
return c.do(ctx, http.MethodPost, "/api/experimental/aliases", req, nil)
|
||||
}
|
||||
|
||||
// AliasDeleteRequest is the request body for deleting a model alias.
|
||||
type AliasDeleteRequest struct {
|
||||
Alias string `json:"alias"`
|
||||
}
|
||||
|
||||
// DeleteAliasExperimental deletes a model alias via the experimental aliases API.
|
||||
func (c *Client) DeleteAliasExperimental(ctx context.Context, req *AliasDeleteRequest) error {
|
||||
return c.do(ctx, http.MethodDelete, "/api/experimental/aliases", req, nil)
|
||||
}
|
||||
|
||||
13
cmd/background_unix.go
Normal file
13
cmd/background_unix.go
Normal file
@@ -0,0 +1,13 @@
|
||||
//go:build !windows
|
||||
|
||||
package cmd
|
||||
|
||||
import "syscall"
|
||||
|
||||
// backgroundServerSysProcAttr returns SysProcAttr for running the server in the background on Unix.
|
||||
// Setpgid prevents the server from being killed when the parent process exits.
|
||||
func backgroundServerSysProcAttr() *syscall.SysProcAttr {
|
||||
return &syscall.SysProcAttr{
|
||||
Setpgid: true,
|
||||
}
|
||||
}
|
||||
12
cmd/background_windows.go
Normal file
12
cmd/background_windows.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package cmd
|
||||
|
||||
import "syscall"
|
||||
|
||||
// backgroundServerSysProcAttr returns SysProcAttr for running the server in the background on Windows.
|
||||
// CREATE_NO_WINDOW (0x08000000) prevents a console window from appearing.
|
||||
func backgroundServerSysProcAttr() *syscall.SysProcAttr {
|
||||
return &syscall.SysProcAttr{
|
||||
CreationFlags: 0x08000000,
|
||||
HideWindow: true,
|
||||
}
|
||||
}
|
||||
194
cmd/cmd.go
194
cmd/cmd.go
@@ -15,6 +15,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
@@ -37,6 +38,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/cmd/tui"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/parser"
|
||||
@@ -1763,7 +1765,7 @@ func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
|
||||
return err
|
||||
}
|
||||
if err := startApp(cmd.Context(), client); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("ollama server not responding - %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -1804,6 +1806,190 @@ Environment Variables:
|
||||
cmd.SetUsageTemplate(cmd.UsageTemplate() + envUsage)
|
||||
}
|
||||
|
||||
// ensureServerRunning checks if the ollama server is running and starts it in the background if not.
|
||||
func ensureServerRunning(ctx context.Context) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if server is already running
|
||||
if err := client.Heartbeat(ctx); err == nil {
|
||||
return nil // server is already running
|
||||
}
|
||||
|
||||
// Server not running, start it in the background
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not find executable: %w", err)
|
||||
}
|
||||
|
||||
serverCmd := exec.CommandContext(ctx, exe, "serve")
|
||||
serverCmd.Env = os.Environ()
|
||||
serverCmd.SysProcAttr = backgroundServerSysProcAttr()
|
||||
if err := serverCmd.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start server: %w", err)
|
||||
}
|
||||
|
||||
// Wait for the server to be ready
|
||||
for {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
if err := client.Heartbeat(ctx); err == nil {
|
||||
return nil // server has started
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// runInteractiveTUI runs the main interactive TUI menu.
|
||||
func runInteractiveTUI(cmd *cobra.Command) {
|
||||
// Ensure the server is running before showing the TUI
|
||||
if err := ensureServerRunning(cmd.Context()); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error starting server: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// errSelectionCancelled is returned when user cancels model selection
|
||||
errSelectionCancelled := errors.New("cancelled")
|
||||
|
||||
// Selector adapters for tui
|
||||
singleSelector := func(title string, items []config.ModelItem) (string, error) {
|
||||
tuiItems := make([]tui.SelectItem, len(items))
|
||||
for i, item := range items {
|
||||
tuiItems[i] = tui.SelectItem{Name: item.Name, Description: item.Description}
|
||||
}
|
||||
result, err := tui.SelectSingle(title, tuiItems)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return "", errSelectionCancelled
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
multiSelector := func(title string, items []config.ModelItem, preChecked []string) ([]string, error) {
|
||||
tuiItems := make([]tui.SelectItem, len(items))
|
||||
for i, item := range items {
|
||||
tuiItems[i] = tui.SelectItem{Name: item.Name, Description: item.Description}
|
||||
}
|
||||
result, err := tui.SelectMultiple(title, tuiItems, preChecked)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return nil, errSelectionCancelled
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
for {
|
||||
result, err := tui.Run()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
runModel := func(modelName string) {
|
||||
_ = config.SetLastModel(modelName)
|
||||
opts := runOptions{
|
||||
Model: modelName,
|
||||
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||
Options: map[string]any{},
|
||||
ShowConnect: true,
|
||||
}
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error loading model: %v\n", err)
|
||||
return
|
||||
}
|
||||
if err := generateInteractive(cmd, opts); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error running model: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
launchIntegration := func(name string) bool {
|
||||
// If not configured or model no longer exists, prompt for model selection
|
||||
configuredModel := config.IntegrationModel(name)
|
||||
if configuredModel == "" || !config.ModelExists(cmd.Context(), configuredModel) {
|
||||
err := config.ConfigureIntegrationWithSelectors(cmd.Context(), name, singleSelector, multiSelector)
|
||||
if errors.Is(err, errSelectionCancelled) {
|
||||
return false // Return to main menu
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", name, err)
|
||||
return true
|
||||
}
|
||||
}
|
||||
if err := config.LaunchIntegration(name); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", name, err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
switch result.Selection {
|
||||
case tui.SelectionNone:
|
||||
// User quit
|
||||
return
|
||||
case tui.SelectionRunModel:
|
||||
_ = config.SetLastSelection("run")
|
||||
// Run last model directly if configured and still exists
|
||||
if modelName := config.LastModel(); modelName != "" && config.ModelExists(cmd.Context(), modelName) {
|
||||
runModel(modelName)
|
||||
} else {
|
||||
// No last model or model no longer exists, show picker
|
||||
modelName, err := config.SelectModelWithSelector(cmd.Context(), singleSelector)
|
||||
if errors.Is(err, errSelectionCancelled) {
|
||||
continue // Return to main menu
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error selecting model: %v\n", err)
|
||||
continue
|
||||
}
|
||||
runModel(modelName)
|
||||
}
|
||||
case tui.SelectionChangeRunModel:
|
||||
_ = config.SetLastSelection("run")
|
||||
// Use model from modal if selected, otherwise show picker
|
||||
modelName := result.Model
|
||||
if modelName == "" {
|
||||
var err error
|
||||
modelName, err = config.SelectModelWithSelector(cmd.Context(), singleSelector)
|
||||
if errors.Is(err, errSelectionCancelled) {
|
||||
continue // Return to main menu
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error selecting model: %v\n", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
runModel(modelName)
|
||||
case tui.SelectionIntegration:
|
||||
_ = config.SetLastSelection(result.Integration)
|
||||
if !launchIntegration(result.Integration) {
|
||||
continue // Return to main menu
|
||||
}
|
||||
case tui.SelectionChangeIntegration:
|
||||
_ = config.SetLastSelection(result.Integration)
|
||||
// Use model from modal if selected, otherwise show picker
|
||||
if result.Model != "" {
|
||||
// Model already selected from modal - save and launch
|
||||
if err := config.SaveIntegrationModel(result.Integration, result.Model); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error saving config: %v\n", err)
|
||||
continue
|
||||
}
|
||||
if err := config.LaunchIntegrationWithModel(result.Integration, result.Model); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
|
||||
}
|
||||
} else {
|
||||
err := config.ConfigureIntegrationWithSelectors(cmd.Context(), result.Integration, singleSelector, multiSelector)
|
||||
if errors.Is(err, errSelectionCancelled) {
|
||||
continue // Return to main menu
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", result.Integration, err)
|
||||
continue
|
||||
}
|
||||
if err := config.LaunchIntegration(result.Integration); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func NewCLI() *cobra.Command {
|
||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
cobra.EnableCommandSorting = false
|
||||
@@ -1826,11 +2012,13 @@ func NewCLI() *cobra.Command {
|
||||
return
|
||||
}
|
||||
|
||||
cmd.Print(cmd.UsageString())
|
||||
runInteractiveTUI(cmd)
|
||||
},
|
||||
}
|
||||
|
||||
rootCmd.Flags().BoolP("version", "v", false, "Show version information")
|
||||
rootCmd.Flags().Bool("verbose", false, "Show timings for response")
|
||||
rootCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
|
||||
|
||||
createCmd := &cobra.Command{
|
||||
Use: "create MODEL",
|
||||
@@ -2044,7 +2232,7 @@ func NewCLI() *cobra.Command {
|
||||
copyCmd,
|
||||
deleteCmd,
|
||||
runnerCmd,
|
||||
config.LaunchCmd(checkServerHeartbeat),
|
||||
config.LaunchCmd(checkServerHeartbeat, runInteractiveTUI),
|
||||
)
|
||||
|
||||
return rootCmd
|
||||
|
||||
@@ -1,23 +1,18 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Claude implements Runner and AliasConfigurer for Claude Code integration
|
||||
// Claude implements Runner for Claude Code integration
|
||||
type Claude struct{}
|
||||
|
||||
// Compile-time check that Claude implements AliasConfigurer
|
||||
var _ AliasConfigurer = (*Claude)(nil)
|
||||
|
||||
func (c *Claude) String() string { return "Claude Code" }
|
||||
|
||||
func (c *Claude) args(model string, extra []string) []string {
|
||||
@@ -65,104 +60,3 @@ func (c *Claude) Run(model string, args []string) error {
|
||||
)
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// ConfigureAliases sets up model aliases for Claude Code.
|
||||
// model: the model to use (if empty, user will be prompted to select)
|
||||
// aliases: existing alias configuration to preserve/update
|
||||
// Cloud-only: subagent routing (fast model) is gated to cloud models only until
|
||||
// there is a better strategy for prompt caching on local models.
|
||||
func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAliases map[string]string, force bool) (map[string]string, bool, error) {
|
||||
aliases := make(map[string]string)
|
||||
for k, v := range existingAliases {
|
||||
aliases[k] = v
|
||||
}
|
||||
|
||||
if model != "" {
|
||||
aliases["primary"] = model
|
||||
}
|
||||
|
||||
if !force && aliases["primary"] != "" {
|
||||
client, _ := api.ClientFromEnvironment()
|
||||
if isCloudModel(ctx, client, aliases["primary"]) {
|
||||
if isCloudModel(ctx, client, aliases["fast"]) {
|
||||
return aliases, false, nil
|
||||
}
|
||||
} else {
|
||||
delete(aliases, "fast")
|
||||
return aliases, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
items, existingModels, cloudModels, client, err := listModels(ctx)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n\n", ansiBold, ansiReset)
|
||||
|
||||
if aliases["primary"] == "" || force {
|
||||
primary, err := selectPrompt("Select model:", items)
|
||||
fmt.Fprintf(os.Stderr, "\033[3A\033[J")
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if err := pullIfNeeded(ctx, client, existingModels, primary); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if err := ensureAuth(ctx, client, cloudModels, []string{primary}); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
aliases["primary"] = primary
|
||||
}
|
||||
|
||||
if isCloudModel(ctx, client, aliases["primary"]) {
|
||||
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
}
|
||||
} else {
|
||||
delete(aliases, "fast")
|
||||
}
|
||||
|
||||
return aliases, true, nil
|
||||
}
|
||||
|
||||
// SetAliases syncs the configured aliases to the Ollama server using prefix matching.
|
||||
// Cloud-only: for local models (fast is empty), we delete any existing aliases to
|
||||
// prevent stale routing to a previous cloud model.
|
||||
func (c *Claude) SetAliases(ctx context.Context, aliases map[string]string) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
prefixes := []string{"claude-sonnet-", "claude-haiku-"}
|
||||
|
||||
if aliases["fast"] == "" {
|
||||
for _, prefix := range prefixes {
|
||||
_ = client.DeleteAliasExperimental(ctx, &api.AliasDeleteRequest{Alias: prefix})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
prefixAliases := map[string]string{
|
||||
"claude-sonnet-": aliases["primary"],
|
||||
"claude-haiku-": aliases["fast"],
|
||||
}
|
||||
|
||||
var errs []string
|
||||
for prefix, target := range prefixAliases {
|
||||
req := &api.AliasRequest{
|
||||
Alias: prefix,
|
||||
Target: target,
|
||||
PrefixMatching: true,
|
||||
}
|
||||
if err := client.SetAliasExperimental(ctx, req); err != nil {
|
||||
errs = append(errs, prefix)
|
||||
}
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("failed to set aliases: %v", errs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -10,15 +11,18 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type integration struct {
|
||||
Models []string `json:"models"`
|
||||
Aliases map[string]string `json:"aliases,omitempty"`
|
||||
Models []string `json:"models"`
|
||||
}
|
||||
|
||||
type config struct {
|
||||
Integrations map[string]*integration `json:"integrations"`
|
||||
Integrations map[string]*integration `json:"integrations"`
|
||||
LastModel string `json:"last_model,omitempty"`
|
||||
LastSelection string `json:"last_selection,omitempty"` // "run" or integration name
|
||||
}
|
||||
|
||||
func configPath() (string, error) {
|
||||
@@ -134,21 +138,81 @@ func saveIntegration(appName string, models []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
key := strings.ToLower(appName)
|
||||
existing := cfg.Integrations[key]
|
||||
var aliases map[string]string
|
||||
if existing != nil && existing.Aliases != nil {
|
||||
aliases = existing.Aliases
|
||||
}
|
||||
|
||||
cfg.Integrations[key] = &integration{
|
||||
Models: models,
|
||||
Aliases: aliases,
|
||||
cfg.Integrations[strings.ToLower(appName)] = &integration{
|
||||
Models: models,
|
||||
}
|
||||
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
// IntegrationModel returns the first configured model for an integration, or empty string if not configured.
|
||||
func IntegrationModel(appName string) string {
|
||||
ic, err := loadIntegration(appName)
|
||||
if err != nil || len(ic.Models) == 0 {
|
||||
return ""
|
||||
}
|
||||
return ic.Models[0]
|
||||
}
|
||||
|
||||
// LastModel returns the last model that was run, or empty string if none.
|
||||
func LastModel() string {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return cfg.LastModel
|
||||
}
|
||||
|
||||
// SetLastModel saves the last model that was run.
|
||||
func SetLastModel(model string) error {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cfg.LastModel = model
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
// LastSelection returns the last menu selection ("run" or integration name), or empty string if none.
|
||||
func LastSelection() string {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return cfg.LastSelection
|
||||
}
|
||||
|
||||
// SetLastSelection saves the last menu selection ("run" or integration name).
|
||||
func SetLastSelection(selection string) error {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cfg.LastSelection = selection
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
// ModelExists checks if a model exists on the Ollama server.
|
||||
func ModelExists(ctx context.Context, name string) bool {
|
||||
if name == "" {
|
||||
return false
|
||||
}
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
models, err := client.List(ctx)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for _, m := range models.Models {
|
||||
if m.Name == name || strings.HasPrefix(m.Name, name+":") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func loadIntegration(appName string) (*integration, error) {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
@@ -163,29 +227,6 @@ func loadIntegration(appName string) (*integration, error) {
|
||||
return ic, nil
|
||||
}
|
||||
|
||||
func saveAliases(appName string, aliases map[string]string) error {
|
||||
if appName == "" {
|
||||
return errors.New("app name cannot be empty")
|
||||
}
|
||||
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
key := strings.ToLower(appName)
|
||||
existing := cfg.Integrations[key]
|
||||
if existing == nil {
|
||||
existing = &integration{}
|
||||
}
|
||||
|
||||
// Replace aliases entirely (not merge) so deletions are persisted
|
||||
existing.Aliases = aliases
|
||||
|
||||
cfg.Integrations[key] = existing
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
func listIntegrations() ([]integration, error) {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
|
||||
@@ -1,677 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSetAliases_CloudModel(t *testing.T) {
|
||||
// Test the SetAliases logic by checking the alias map behavior
|
||||
aliases := map[string]string{
|
||||
"primary": "kimi-k2.5:cloud",
|
||||
"fast": "kimi-k2.5:cloud",
|
||||
}
|
||||
|
||||
// Verify fast is set (cloud model behavior)
|
||||
if aliases["fast"] == "" {
|
||||
t.Error("cloud model should have fast alias set")
|
||||
}
|
||||
if aliases["fast"] != aliases["primary"] {
|
||||
t.Errorf("fast should equal primary for auto-set, got fast=%q primary=%q", aliases["fast"], aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetAliases_LocalModel(t *testing.T) {
|
||||
aliases := map[string]string{
|
||||
"primary": "llama3.2:latest",
|
||||
}
|
||||
// Simulate local model behavior: fast should be empty
|
||||
delete(aliases, "fast")
|
||||
|
||||
if aliases["fast"] != "" {
|
||||
t.Error("local model should have empty fast alias")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAliases_ReplacesNotMerges(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// First save with both primary and fast
|
||||
initial := map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}
|
||||
if err := saveAliases("claude", initial); err != nil {
|
||||
t.Fatalf("failed to save initial aliases: %v", err)
|
||||
}
|
||||
|
||||
// Verify both are saved
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["fast"] != "cloud-model" {
|
||||
t.Errorf("expected fast=cloud-model, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
|
||||
// Now save without fast (simulating switch to local model)
|
||||
updated := map[string]string{
|
||||
"primary": "local-model",
|
||||
// fast intentionally missing
|
||||
}
|
||||
if err := saveAliases("claude", updated); err != nil {
|
||||
t.Fatalf("failed to save updated aliases: %v", err)
|
||||
}
|
||||
|
||||
// Verify fast is GONE (not merged/preserved)
|
||||
loaded, err = loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load after update: %v", err)
|
||||
}
|
||||
if loaded.Aliases["fast"] != "" {
|
||||
t.Errorf("fast should be removed after saving without it, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
if loaded.Aliases["primary"] != "local-model" {
|
||||
t.Errorf("primary should be updated to local-model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAliases_PreservesModels(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// First save integration with models
|
||||
if err := saveIntegration("claude", []string{"model1", "model2"}); err != nil {
|
||||
t.Fatalf("failed to save integration: %v", err)
|
||||
}
|
||||
|
||||
// Then update aliases
|
||||
aliases := map[string]string{"primary": "new-model"}
|
||||
if err := saveAliases("claude", aliases); err != nil {
|
||||
t.Fatalf("failed to save aliases: %v", err)
|
||||
}
|
||||
|
||||
// Verify models are preserved
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if len(loaded.Models) != 2 || loaded.Models[0] != "model1" {
|
||||
t.Errorf("models should be preserved, got %v", loaded.Models)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveAliases_EmptyMap clears all aliases
|
||||
func TestSaveAliases_EmptyMap(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save with aliases
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Save empty map
|
||||
if err := saveAliases("claude", map[string]string{}); err != nil {
|
||||
t.Fatalf("failed to save empty: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if len(loaded.Aliases) != 0 {
|
||||
t.Errorf("aliases should be empty, got %v", loaded.Aliases)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveAliases_NilMap handles nil gracefully
|
||||
func TestSaveAliases_NilMap(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save with aliases first
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Save nil map - should clear aliases
|
||||
if err := saveAliases("claude", nil); err != nil {
|
||||
t.Fatalf("failed to save nil: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if len(loaded.Aliases) > 0 {
|
||||
t.Errorf("aliases should be nil or empty, got %v", loaded.Aliases)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveAliases_EmptyAppName returns error
|
||||
func TestSaveAliases_EmptyAppName(t *testing.T) {
|
||||
err := saveAliases("", map[string]string{"primary": "model"})
|
||||
if err == nil {
|
||||
t.Error("expected error for empty app name")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAliases_CaseInsensitive(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
if err := saveAliases("Claude", map[string]string{"primary": "model1"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Load with different case
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != "model1" {
|
||||
t.Errorf("expected primary=model1, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
|
||||
// Update with different case
|
||||
if err := saveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil {
|
||||
t.Fatalf("failed to update: %v", err)
|
||||
}
|
||||
|
||||
loaded, err = loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load after update: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != "model2" {
|
||||
t.Errorf("expected primary=model2, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveAliases_CreatesIntegration creates integration if it doesn't exist
|
||||
func TestSaveAliases_CreatesIntegration(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save aliases for non-existent integration
|
||||
if err := saveAliases("newintegration", map[string]string{"primary": "model"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("newintegration")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != "model" {
|
||||
t.Errorf("expected primary=model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigureAliases_AliasMap(t *testing.T) {
|
||||
t.Run("cloud model auto-sets fast to primary", func(t *testing.T) {
|
||||
aliases := make(map[string]string)
|
||||
aliases["primary"] = "cloud-model"
|
||||
|
||||
// Simulate cloud model behavior
|
||||
isCloud := true
|
||||
if isCloud {
|
||||
if aliases["fast"] == "" {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
}
|
||||
}
|
||||
|
||||
if aliases["fast"] != "cloud-model" {
|
||||
t.Errorf("expected fast=cloud-model, got %q", aliases["fast"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cloud model preserves custom fast", func(t *testing.T) {
|
||||
aliases := map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "custom-fast-model",
|
||||
}
|
||||
|
||||
// Simulate cloud model behavior - should preserve existing fast
|
||||
isCloud := true
|
||||
if isCloud {
|
||||
if aliases["fast"] == "" {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
}
|
||||
}
|
||||
|
||||
if aliases["fast"] != "custom-fast-model" {
|
||||
t.Errorf("expected fast=custom-fast-model (preserved), got %q", aliases["fast"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("local model clears fast", func(t *testing.T) {
|
||||
aliases := map[string]string{
|
||||
"primary": "local-model",
|
||||
"fast": "should-be-cleared",
|
||||
}
|
||||
|
||||
// Simulate local model behavior
|
||||
isCloud := false
|
||||
if !isCloud {
|
||||
delete(aliases, "fast")
|
||||
}
|
||||
|
||||
if aliases["fast"] != "" {
|
||||
t.Errorf("expected fast to be cleared, got %q", aliases["fast"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("switching cloud to local clears fast", func(t *testing.T) {
|
||||
// Start with cloud config
|
||||
aliases := map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}
|
||||
|
||||
// Switch to local
|
||||
aliases["primary"] = "local-model"
|
||||
isCloud := false
|
||||
if !isCloud {
|
||||
delete(aliases, "fast")
|
||||
}
|
||||
|
||||
if aliases["fast"] != "" {
|
||||
t.Errorf("fast should be cleared when switching to local, got %q", aliases["fast"])
|
||||
}
|
||||
if aliases["primary"] != "local-model" {
|
||||
t.Errorf("primary should be updated, got %q", aliases["primary"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("switching local to cloud sets fast", func(t *testing.T) {
|
||||
// Start with local config (no fast)
|
||||
aliases := map[string]string{
|
||||
"primary": "local-model",
|
||||
}
|
||||
|
||||
// Switch to cloud
|
||||
aliases["primary"] = "cloud-model"
|
||||
isCloud := true
|
||||
if isCloud {
|
||||
if aliases["fast"] == "" {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
}
|
||||
}
|
||||
|
||||
if aliases["fast"] != "cloud-model" {
|
||||
t.Errorf("fast should be set when switching to cloud, got %q", aliases["fast"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetAliases_PrefixMapping(t *testing.T) {
|
||||
// This tests the expected mapping without needing a real client
|
||||
aliases := map[string]string{
|
||||
"primary": "my-cloud-model",
|
||||
"fast": "my-fast-model",
|
||||
}
|
||||
|
||||
expectedMappings := map[string]string{
|
||||
"claude-sonnet-": aliases["primary"],
|
||||
"claude-haiku-": aliases["fast"],
|
||||
}
|
||||
|
||||
if expectedMappings["claude-sonnet-"] != "my-cloud-model" {
|
||||
t.Errorf("claude-sonnet- should map to primary")
|
||||
}
|
||||
if expectedMappings["claude-haiku-"] != "my-fast-model" {
|
||||
t.Errorf("claude-haiku- should map to fast")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetAliases_LocalDeletesPrefixes(t *testing.T) {
|
||||
aliases := map[string]string{
|
||||
"primary": "local-model",
|
||||
// fast is empty/missing - indicates local model
|
||||
}
|
||||
|
||||
prefixesToDelete := []string{"claude-sonnet-", "claude-haiku-"}
|
||||
|
||||
// Verify the logic: when fast is empty, we should delete
|
||||
if aliases["fast"] != "" {
|
||||
t.Error("fast should be empty for local model")
|
||||
}
|
||||
|
||||
// Verify we have the right prefixes to delete
|
||||
if len(prefixesToDelete) != 2 {
|
||||
t.Errorf("expected 2 prefixes to delete, got %d", len(prefixesToDelete))
|
||||
}
|
||||
}
|
||||
|
||||
// TestAtomicUpdate_ServerFailsConfigNotSaved simulates atomic update behavior
|
||||
func TestAtomicUpdate_ServerFailsConfigNotSaved(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Simulate: server fails, config should NOT be saved
|
||||
serverErr := errors.New("server unavailable")
|
||||
|
||||
if serverErr == nil {
|
||||
t.Error("config should NOT be saved when server fails")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAtomicUpdate_ServerSucceedsConfigSaved simulates successful atomic update
|
||||
func TestAtomicUpdate_ServerSucceedsConfigSaved(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Simulate: server succeeds, config should be saved
|
||||
var serverErr error
|
||||
if serverErr != nil {
|
||||
t.Fatal("server should succeed")
|
||||
}
|
||||
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||
t.Fatalf("saveAliases failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify it was actually saved
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != "model" {
|
||||
t.Errorf("expected primary=model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigFile_PreservesUnknownFields(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Write config with extra fields
|
||||
configPath := filepath.Join(tmpDir, ".ollama", "config.json")
|
||||
os.MkdirAll(filepath.Dir(configPath), 0o755)
|
||||
|
||||
// Note: Our config struct only has Integrations, so top-level unknown fields
|
||||
// won't be preserved by our current implementation. This test documents that.
|
||||
initialConfig := `{
|
||||
"integrations": {
|
||||
"claude": {
|
||||
"models": ["model1"],
|
||||
"aliases": {"primary": "model1"},
|
||||
"unknownField": "should be lost"
|
||||
}
|
||||
},
|
||||
"topLevelUnknown": "will be lost"
|
||||
}`
|
||||
os.WriteFile(configPath, []byte(initialConfig), 0o644)
|
||||
|
||||
// Update aliases
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model2"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Read raw file to check
|
||||
data, _ := os.ReadFile(configPath)
|
||||
content := string(data)
|
||||
|
||||
// models should be preserved
|
||||
if !contains(content, "model1") {
|
||||
t.Error("models should be preserved")
|
||||
}
|
||||
|
||||
// primary should be updated
|
||||
if !contains(content, "model2") {
|
||||
t.Error("primary should be updated to model2")
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr))
|
||||
}
|
||||
|
||||
func containsHelper(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func TestClaudeImplementsAliasConfigurer(t *testing.T) {
|
||||
c := &Claude{}
|
||||
var _ AliasConfigurer = c // Compile-time check
|
||||
}
|
||||
|
||||
func TestModelNameEdgeCases(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
model string
|
||||
}{
|
||||
{"simple", "llama3.2"},
|
||||
{"with tag", "llama3.2:latest"},
|
||||
{"with cloud tag", "kimi-k2.5:cloud"},
|
||||
{"with namespace", "library/llama3.2"},
|
||||
{"with dots", "glm-4.7-flash"},
|
||||
{"with numbers", "qwen3:8b"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
aliases := map[string]string{"primary": tc.model}
|
||||
if err := saveAliases("claude", aliases); err != nil {
|
||||
t.Fatalf("failed to save model %q: %v", tc.model, err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != tc.model {
|
||||
t.Errorf("expected primary=%q, got %q", tc.model, loaded.Aliases["primary"])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwitchingScenarios(t *testing.T) {
|
||||
t.Run("cloud to local removes fast", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial cloud config
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Switch to local (no fast)
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "local-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
if loaded.Aliases["fast"] != "" {
|
||||
t.Errorf("fast should be removed, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
if loaded.Aliases["primary"] != "local-model" {
|
||||
t.Errorf("primary should be local-model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("local to cloud adds fast", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial local config
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "local-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Switch to cloud (with fast)
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
if loaded.Aliases["fast"] != "cloud-model" {
|
||||
t.Errorf("fast should be cloud-model, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cloud to different cloud updates both", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial cloud config
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model-1",
|
||||
"fast": "cloud-model-1",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Switch to different cloud
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model-2",
|
||||
"fast": "cloud-model-2",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
if loaded.Aliases["primary"] != "cloud-model-2" {
|
||||
t.Errorf("primary should be cloud-model-2, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
if loaded.Aliases["fast"] != "cloud-model-2" {
|
||||
t.Errorf("fast should be cloud-model-2, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestToolCapabilityFiltering(t *testing.T) {
|
||||
t.Run("all models checked for tool capability", func(t *testing.T) {
|
||||
// Both cloud and local models are checked for tool capability via Show API
|
||||
// Only models with "tools" in capabilities are included
|
||||
m := modelInfo{Name: "tool-model", Remote: false, ToolCapable: true}
|
||||
if !m.ToolCapable {
|
||||
t.Error("tool capable model should be marked as such")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("modelInfo includes ToolCapable field", func(t *testing.T) {
|
||||
m := modelInfo{Name: "test", Remote: true, ToolCapable: true}
|
||||
if !m.ToolCapable {
|
||||
t.Error("ToolCapable field should be accessible")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsCloudModel_RequiresClient(t *testing.T) {
|
||||
t.Run("nil client always returns false", func(t *testing.T) {
|
||||
// isCloudModel now only uses Show API, no suffix detection
|
||||
if isCloudModel(context.Background(), nil, "model:cloud") {
|
||||
t.Error("nil client should return false regardless of suffix")
|
||||
}
|
||||
if isCloudModel(context.Background(), nil, "local-model") {
|
||||
t.Error("nil client should return false")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||
t.Run("saveAliases followed by saveIntegration keeps them in sync", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save aliases with one model
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Save integration with same model (this is the pattern we use)
|
||||
if err := saveIntegration("claude", []string{"model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
if loaded.Aliases["primary"] != loaded.Models[0] {
|
||||
t.Errorf("aliases.primary (%q) != models[0] (%q)", loaded.Aliases["primary"], loaded.Models[0])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("out of sync config is detectable", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Simulate out-of-sync state (like manual edit or bug)
|
||||
if err := saveIntegration("claude", []string{"old-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "new-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
|
||||
// They should be different (this is the bug state)
|
||||
if loaded.Models[0] == loaded.Aliases["primary"] {
|
||||
t.Error("expected out-of-sync state for this test")
|
||||
}
|
||||
|
||||
// The fix: when updating aliases, also update models
|
||||
if err := saveIntegration("claude", []string{loaded.Aliases["primary"]}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ = loadIntegration("claude")
|
||||
if loaded.Models[0] != loaded.Aliases["primary"] {
|
||||
t.Errorf("after fix: models[0] (%q) should equal aliases.primary (%q)",
|
||||
loaded.Models[0], loaded.Aliases["primary"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("updating primary alias updates models too", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial state
|
||||
if err := saveIntegration("claude", []string{"initial-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Update aliases AND models together
|
||||
newAliases := map[string]string{"primary": "updated-model"}
|
||||
if err := saveAliases("claude", newAliases); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveIntegration("claude", []string{newAliases["primary"]}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
if loaded.Models[0] != "updated-model" {
|
||||
t.Errorf("models[0] should be updated-model, got %q", loaded.Models[0])
|
||||
}
|
||||
if loaded.Aliases["primary"] != "updated-model" {
|
||||
t.Errorf("aliases.primary should be updated-model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -46,53 +46,6 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("save and load aliases", func(t *testing.T) {
|
||||
models := []string{"llama3.2"}
|
||||
if err := saveIntegration("claude", models); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
aliases := map[string]string{
|
||||
"primary": "llama3.2:70b",
|
||||
"fast": "llama3.2:8b",
|
||||
}
|
||||
if err := saveAliases("claude", aliases); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if config.Aliases == nil {
|
||||
t.Fatal("expected aliases to be saved")
|
||||
}
|
||||
for k, v := range aliases {
|
||||
if config.Aliases[k] != v {
|
||||
t.Errorf("alias %s: expected %s, got %s", k, v, config.Aliases[k])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("saveIntegration preserves aliases", func(t *testing.T) {
|
||||
if err := saveIntegration("claude", []string{"model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := saveIntegration("claude", []string{"model-b"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
config, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if config.Aliases["primary"] != "model-a" {
|
||||
t.Errorf("expected aliases to be preserved, got %v", config.Aliases)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("defaultModel returns first model", func(t *testing.T) {
|
||||
saveIntegration("codex", []string{"model-a", "model-b"})
|
||||
|
||||
|
||||
@@ -4,9 +4,12 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"maps"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
@@ -39,15 +42,6 @@ type Editor interface {
|
||||
Models() []string
|
||||
}
|
||||
|
||||
// AliasConfigurer can configure model aliases (e.g., for subagent routing).
|
||||
// Integrations like Claude and Codex use this to route model requests to local models.
|
||||
type AliasConfigurer interface {
|
||||
// ConfigureAliases prompts the user to configure aliases and returns the updated map.
|
||||
ConfigureAliases(ctx context.Context, primaryModel string, existing map[string]string, force bool) (map[string]string, bool, error)
|
||||
// SetAliases syncs the configured aliases to the server
|
||||
SetAliases(ctx context.Context, aliases map[string]string) error
|
||||
}
|
||||
|
||||
// integrations is the registry of available integrations.
|
||||
var integrations = map[string]Runner{
|
||||
"claude": &Claude{},
|
||||
@@ -61,7 +55,7 @@ var integrations = map[string]Runner{
|
||||
|
||||
// recommendedModels are shown when the user has no models or as suggestions.
|
||||
// Order matters: local models first, then cloud models.
|
||||
var recommendedModels = []selectItem{
|
||||
var recommendedModels = []ModelItem{
|
||||
{Name: "glm-4.7-flash", Description: "Recommended (requires ~25GB VRAM)"},
|
||||
{Name: "qwen3:8b", Description: "Recommended (requires ~11GB VRAM)"},
|
||||
{Name: "glm-4.7:cloud", Description: "Recommended"},
|
||||
@@ -74,6 +68,256 @@ var integrationAliases = map[string]bool{
|
||||
"moltbot": true,
|
||||
}
|
||||
|
||||
// integrationInstallURLs maps integration names to their install script URLs.
|
||||
var integrationInstallURLs = map[string]string{
|
||||
"claude": "https://claude.ai/install.sh",
|
||||
"openclaw": "https://openclaw.ai/install.sh",
|
||||
"droid": "https://app.factory.ai/cli",
|
||||
"opencode": "https://opencode.ai/install",
|
||||
}
|
||||
|
||||
// CanInstallIntegration returns true if we have an install script for this integration.
|
||||
func CanInstallIntegration(name string) bool {
|
||||
_, ok := integrationInstallURLs[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
// IsIntegrationInstalled checks if an integration binary is installed.
|
||||
func IsIntegrationInstalled(name string) bool {
|
||||
switch name {
|
||||
case "claude":
|
||||
c := &Claude{}
|
||||
_, err := c.findPath()
|
||||
return err == nil
|
||||
case "openclaw":
|
||||
if _, err := exec.LookPath("openclaw"); err == nil {
|
||||
return true
|
||||
}
|
||||
if _, err := exec.LookPath("clawdbot"); err == nil {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
case "codex":
|
||||
_, err := exec.LookPath("codex")
|
||||
return err == nil
|
||||
case "droid":
|
||||
_, err := exec.LookPath("droid")
|
||||
return err == nil
|
||||
case "opencode":
|
||||
_, err := exec.LookPath("opencode")
|
||||
return err == nil
|
||||
default:
|
||||
return true // Assume installed for unknown integrations
|
||||
}
|
||||
}
|
||||
|
||||
// InstallIntegration downloads and runs the install script for an integration.
|
||||
func InstallIntegration(name string) error {
|
||||
url, ok := integrationInstallURLs[name]
|
||||
if !ok {
|
||||
return fmt.Errorf("no install script available for %s", name)
|
||||
}
|
||||
|
||||
// Download the install script
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download install script: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("failed to download install script: HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
script, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read install script: %w", err)
|
||||
}
|
||||
|
||||
// Create a temporary file for the script
|
||||
tmpDir := os.TempDir()
|
||||
scriptPath := filepath.Join(tmpDir, fmt.Sprintf("install-%s.sh", name))
|
||||
if err := os.WriteFile(scriptPath, script, 0o700); err != nil {
|
||||
return fmt.Errorf("failed to write install script: %w", err)
|
||||
}
|
||||
defer os.Remove(scriptPath)
|
||||
|
||||
// Execute the script with bash
|
||||
cmd := exec.Command("bash", scriptPath)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return fmt.Errorf("install script failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SelectModel lets the user select a model to run.
|
||||
// ModelItem represents a model for selection.
|
||||
type ModelItem struct {
|
||||
Name string
|
||||
Description string
|
||||
}
|
||||
|
||||
// SingleSelector is a function type for single item selection.
|
||||
type SingleSelector func(title string, items []ModelItem) (string, error)
|
||||
|
||||
// MultiSelector is a function type for multi item selection.
|
||||
type MultiSelector func(title string, items []ModelItem, preChecked []string) ([]string, error)
|
||||
|
||||
// SelectModelWithSelector prompts the user to select a model using the provided selector.
|
||||
func SelectModelWithSelector(ctx context.Context, selector SingleSelector) (string, error) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
models, err := client.List(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var existing []modelInfo
|
||||
for _, m := range models.Models {
|
||||
existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""})
|
||||
}
|
||||
|
||||
lastModel := LastModel()
|
||||
var preChecked []string
|
||||
if lastModel != "" {
|
||||
preChecked = []string{lastModel}
|
||||
}
|
||||
|
||||
items, _, existingModels, cloudModels := buildModelList(existing, preChecked, lastModel)
|
||||
|
||||
if len(items) == 0 {
|
||||
return "", fmt.Errorf("no models available, run 'ollama pull <model>' first")
|
||||
}
|
||||
|
||||
// Sort with last model first, then existing models, then recommendations
|
||||
slices.SortStableFunc(items, func(a, b ModelItem) int {
|
||||
aIsLast := a.Name == lastModel
|
||||
bIsLast := b.Name == lastModel
|
||||
if aIsLast != bIsLast {
|
||||
if aIsLast {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
aExists := existingModels[a.Name]
|
||||
bExists := existingModels[b.Name]
|
||||
if aExists != bExists {
|
||||
if aExists {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
|
||||
})
|
||||
|
||||
selected, err := selector("Select model to run:", items)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// If the selected model isn't installed, pull it first
|
||||
if !existingModels[selected] {
|
||||
msg := fmt.Sprintf("Download %s?", selected)
|
||||
if ok, err := confirmPrompt(msg); err != nil {
|
||||
return "", err
|
||||
} else if !ok {
|
||||
return "", errCancelled
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
if err := pullModel(ctx, client, selected); err != nil {
|
||||
return "", fmt.Errorf("failed to pull %s: %w", selected, err)
|
||||
}
|
||||
}
|
||||
|
||||
// If it's a cloud model, ensure user is signed in
|
||||
if cloudModels[selected] {
|
||||
user, err := client.Whoami(ctx)
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
var aErr api.AuthorizationError
|
||||
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
|
||||
return "", err
|
||||
}
|
||||
|
||||
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", selected))
|
||||
if err != nil || !yes {
|
||||
return "", fmt.Errorf("%s requires sign in", selected)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
|
||||
|
||||
// Auto-open browser (best effort, fail silently)
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
_ = exec.Command("open", aErr.SigninURL).Start()
|
||||
case "linux":
|
||||
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
|
||||
case "windows":
|
||||
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
|
||||
}
|
||||
|
||||
spinnerFrames := []string{"|", "/", "-", "\\"}
|
||||
frame := 0
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
|
||||
|
||||
ticker := time.NewTicker(200 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
return "", ctx.Err()
|
||||
case <-ticker.C:
|
||||
frame++
|
||||
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
|
||||
|
||||
// poll every 10th frame (~2 seconds)
|
||||
if frame%10 == 0 {
|
||||
u, err := client.Whoami(ctx)
|
||||
if err == nil && u != nil && u.Name != "" {
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
|
||||
return selected, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
func SelectModel(ctx context.Context) (string, error) {
|
||||
return SelectModelWithSelector(ctx, defaultSingleSelector)
|
||||
}
|
||||
|
||||
func defaultSingleSelector(title string, items []ModelItem) (string, error) {
|
||||
selectItems := make([]selectItem, len(items))
|
||||
for i, item := range items {
|
||||
selectItems[i] = selectItem(item)
|
||||
}
|
||||
return selectPrompt(title, selectItems)
|
||||
}
|
||||
|
||||
func defaultMultiSelector(title string, items []ModelItem, preChecked []string) ([]string, error) {
|
||||
selectItems := make([]selectItem, len(items))
|
||||
for i, item := range items {
|
||||
selectItems[i] = selectItem(item)
|
||||
}
|
||||
return multiSelectPrompt(title, selectItems, preChecked)
|
||||
}
|
||||
|
||||
func selectIntegration() (string, error) {
|
||||
if len(integrations) == 0 {
|
||||
return "", fmt.Errorf("no integrations available")
|
||||
@@ -96,8 +340,8 @@ func selectIntegration() (string, error) {
|
||||
return selectPrompt("Select integration:", items)
|
||||
}
|
||||
|
||||
// selectModels lets the user select models for an integration
|
||||
func selectModels(ctx context.Context, name, current string) ([]string, error) {
|
||||
// selectModelsWithSelectors lets the user select models for an integration using provided selectors.
|
||||
func selectModelsWithSelectors(ctx context.Context, name, current string, single SingleSelector, multi MultiSelector) ([]string, error) {
|
||||
r, ok := integrations[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown integration: %s", name)
|
||||
@@ -133,16 +377,12 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
|
||||
|
||||
var selected []string
|
||||
if _, ok := r.(Editor); ok {
|
||||
selected, err = multiSelectPrompt(fmt.Sprintf("Select models for %s:", r), items, preChecked)
|
||||
selected, err = multi(fmt.Sprintf("Select models for %s:", r), items, preChecked)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
prompt := fmt.Sprintf("Select model for %s:", r)
|
||||
if _, ok := r.(AliasConfigurer); ok {
|
||||
prompt = fmt.Sprintf("Select Primary model for %s:", r)
|
||||
}
|
||||
model, err := selectPrompt(prompt, items)
|
||||
model, err := single(fmt.Sprintf("Select model for %s:", r), items)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -170,123 +410,78 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if err := ensureAuth(ctx, client, cloudModels, selected); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
func pullIfNeeded(ctx context.Context, client *api.Client, existingModels map[string]bool, model string) error {
|
||||
if existingModels[model] {
|
||||
return nil
|
||||
}
|
||||
msg := fmt.Sprintf("Download %s?", model)
|
||||
if ok, err := confirmPrompt(msg); err != nil {
|
||||
return err
|
||||
} else if !ok {
|
||||
return errCancelled
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
if err := pullModel(ctx, client, model); err != nil {
|
||||
return fmt.Errorf("failed to pull %s: %w", model, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func listModels(ctx context.Context) ([]selectItem, map[string]bool, map[string]bool, *api.Client, error) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
|
||||
models, err := client.List(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
|
||||
var existing []modelInfo
|
||||
for _, m := range models.Models {
|
||||
existing = append(existing, modelInfo{
|
||||
Name: m.Name,
|
||||
Remote: m.RemoteModel != "",
|
||||
})
|
||||
}
|
||||
|
||||
items, _, existingModels, cloudModels := buildModelList(existing, nil, "")
|
||||
|
||||
if len(items) == 0 {
|
||||
return nil, nil, nil, nil, fmt.Errorf("no models available, run 'ollama pull <model>' first")
|
||||
}
|
||||
|
||||
return items, existingModels, cloudModels, client, nil
|
||||
}
|
||||
|
||||
func ensureAuth(ctx context.Context, client *api.Client, cloudModels map[string]bool, selected []string) error {
|
||||
var selectedCloudModels []string
|
||||
for _, m := range selected {
|
||||
if cloudModels[m] {
|
||||
selectedCloudModels = append(selectedCloudModels, m)
|
||||
}
|
||||
}
|
||||
if len(selectedCloudModels) == 0 {
|
||||
return nil
|
||||
}
|
||||
if len(selectedCloudModels) > 0 {
|
||||
// ensure user is signed in
|
||||
user, err := client.Whoami(ctx)
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
user, err := client.Whoami(ctx)
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return nil
|
||||
}
|
||||
var aErr api.AuthorizationError
|
||||
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var aErr api.AuthorizationError
|
||||
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
|
||||
return err
|
||||
}
|
||||
modelList := strings.Join(selectedCloudModels, ", ")
|
||||
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
|
||||
if err != nil || !yes {
|
||||
return nil, fmt.Errorf("%s requires sign in", modelList)
|
||||
}
|
||||
|
||||
modelList := strings.Join(selectedCloudModels, ", ")
|
||||
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
|
||||
if err != nil || !yes {
|
||||
return fmt.Errorf("%s requires sign in", modelList)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
|
||||
// TODO(parthsareen): extract into auth package for cmd
|
||||
// Auto-open browser (best effort, fail silently)
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
_ = exec.Command("open", aErr.SigninURL).Start()
|
||||
case "linux":
|
||||
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
|
||||
case "windows":
|
||||
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
_ = exec.Command("open", aErr.SigninURL).Start()
|
||||
case "linux":
|
||||
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
|
||||
case "windows":
|
||||
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
|
||||
}
|
||||
spinnerFrames := []string{"|", "/", "-", "\\"}
|
||||
frame := 0
|
||||
|
||||
spinnerFrames := []string{"|", "/", "-", "\\"}
|
||||
frame := 0
|
||||
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
|
||||
ticker := time.NewTicker(200 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
ticker := time.NewTicker(200 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
return nil, ctx.Err()
|
||||
case <-ticker.C:
|
||||
frame++
|
||||
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
frame++
|
||||
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
|
||||
|
||||
// poll every 10th frame (~2 seconds)
|
||||
if frame%10 == 0 {
|
||||
u, err := client.Whoami(ctx)
|
||||
if err == nil && u != nil && u.Name != "" {
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
|
||||
return nil
|
||||
// poll every 10th frame (~2 seconds)
|
||||
if frame%10 == 0 {
|
||||
u, err := client.Whoami(ctx)
|
||||
if err == nil && u != nil && u.Name != "" {
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
|
||||
return selected, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
// selectModels lets the user select models for an integration using default selectors.
|
||||
func selectModels(ctx context.Context, name, current string) ([]string, error) {
|
||||
return selectModelsWithSelectors(ctx, name, current, defaultSingleSelector, defaultMultiSelector)
|
||||
}
|
||||
|
||||
func runIntegration(name, modelName string, args []string) error {
|
||||
@@ -294,42 +489,114 @@ func runIntegration(name, modelName string, args []string) error {
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", r, modelName)
|
||||
return r.Run(modelName, args)
|
||||
}
|
||||
|
||||
// syncAliases syncs aliases to server and saves locally for an AliasConfigurer.
|
||||
func syncAliases(ctx context.Context, client *api.Client, ac AliasConfigurer, name, model string, existing map[string]string) error {
|
||||
aliases := make(map[string]string)
|
||||
for k, v := range existing {
|
||||
aliases[k] = v
|
||||
// LaunchIntegration launches the named integration using saved config or prompts for setup.
|
||||
func LaunchIntegration(name string) error {
|
||||
r, ok := integrations[name]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
aliases["primary"] = model
|
||||
|
||||
if isCloudModel(ctx, client, model) {
|
||||
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
|
||||
aliases["fast"] = model
|
||||
// Try to use saved config
|
||||
if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 {
|
||||
return runIntegration(name, config.Models[0], nil)
|
||||
}
|
||||
|
||||
// No saved config - prompt user to run setup
|
||||
return fmt.Errorf("%s is not configured. Run 'ollama launch %s' to set it up", r, name)
|
||||
}
|
||||
|
||||
// LaunchIntegrationWithModel launches the named integration with the specified model.
|
||||
func LaunchIntegrationWithModel(name, modelName string) error {
|
||||
return runIntegration(name, modelName, nil)
|
||||
}
|
||||
|
||||
// SaveIntegrationModel saves the model for an integration.
|
||||
func SaveIntegrationModel(name, modelName string) error {
|
||||
// Load existing models and prepend the new one
|
||||
var models []string
|
||||
if existing, err := loadIntegration(name); err == nil && len(existing.Models) > 0 {
|
||||
models = existing.Models
|
||||
// Remove the model if it already exists
|
||||
for i, m := range models {
|
||||
if m == modelName {
|
||||
models = append(models[:i], models[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
delete(aliases, "fast")
|
||||
}
|
||||
// Prepend the new model
|
||||
models = append([]string{modelName}, models...)
|
||||
return saveIntegration(name, models)
|
||||
}
|
||||
|
||||
// ConfigureIntegrationWithSelectors allows the user to select/change the model for an integration using custom selectors.
|
||||
func ConfigureIntegrationWithSelectors(ctx context.Context, name string, single SingleSelector, multi MultiSelector) error {
|
||||
r, ok := integrations[name]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
|
||||
if err := ac.SetAliases(ctx, aliases); err != nil {
|
||||
models, err := selectModelsWithSelectors(ctx, name, "", single, multi)
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return saveAliases(name, aliases)
|
||||
|
||||
if editor, isEditor := r.(Editor); isEditor {
|
||||
paths := editor.Paths()
|
||||
if len(paths) > 0 {
|
||||
fmt.Fprintf(os.Stderr, "This will modify your %s configuration:\n", r)
|
||||
for _, p := range paths {
|
||||
fmt.Fprintf(os.Stderr, " %s\n", p)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Backups will be saved to %s/\n\n", backupDir())
|
||||
|
||||
if ok, _ := confirmPrompt("Proceed?"); !ok {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := editor.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := saveIntegration(name, models); err != nil {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
|
||||
if len(models) == 1 {
|
||||
fmt.Fprintf(os.Stderr, "Configured %s with %s\n", r, models[0])
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "Configured %s with %d models (default: %s)\n", r, len(models), models[0])
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConfigureIntegration allows the user to select/change the model for an integration.
|
||||
func ConfigureIntegration(ctx context.Context, name string) error {
|
||||
return ConfigureIntegrationWithSelectors(ctx, name, defaultSingleSelector, defaultMultiSelector)
|
||||
}
|
||||
|
||||
// LaunchCmd returns the cobra command for launching integrations.
|
||||
func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error) *cobra.Command {
|
||||
// The runTUI callback is called when no arguments are provided (alias for main TUI).
|
||||
func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error, runTUI func(cmd *cobra.Command)) *cobra.Command {
|
||||
var modelFlag string
|
||||
var configFlag bool
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "launch [INTEGRATION] [-- [EXTRA_ARGS...]]",
|
||||
Short: "Launch an integration with Ollama",
|
||||
Long: `Launch an integration configured with Ollama models.
|
||||
Short: "Launch the Ollama menu or an integration",
|
||||
Long: `Launch the Ollama interactive menu, or directly launch a specific integration.
|
||||
|
||||
Without arguments, this is equivalent to running 'ollama' directly.
|
||||
|
||||
Supported integrations:
|
||||
claude Claude Code
|
||||
@@ -348,6 +615,12 @@ Examples:
|
||||
Args: cobra.ArbitraryArgs,
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
// No args - run the main TUI (same as 'ollama')
|
||||
if len(args) == 0 && modelFlag == "" && !configFlag {
|
||||
runTUI(cmd)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Extract integration name and args to pass through using -- separator
|
||||
var name string
|
||||
var passArgs []string
|
||||
@@ -388,87 +661,9 @@ Examples:
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
|
||||
// Handle AliasConfigurer integrations (claude, codex)
|
||||
if ac, ok := r.(AliasConfigurer); ok {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate --model flag if provided
|
||||
if modelFlag != "" {
|
||||
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: modelFlag}); err != nil {
|
||||
return fmt.Errorf("model %q not found", modelFlag)
|
||||
}
|
||||
}
|
||||
|
||||
var model string
|
||||
var existingAliases map[string]string
|
||||
|
||||
// Load saved config
|
||||
if cfg, err := loadIntegration(name); err == nil {
|
||||
existingAliases = cfg.Aliases
|
||||
if len(cfg.Models) > 0 {
|
||||
model = cfg.Models[0]
|
||||
// AliasConfigurer integrations use single model; sanitize if multiple
|
||||
if len(cfg.Models) > 1 {
|
||||
_ = saveIntegration(name, []string{model})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --model flag overrides saved model
|
||||
if modelFlag != "" {
|
||||
model = modelFlag
|
||||
}
|
||||
|
||||
// Validate saved model still exists
|
||||
if model != "" && modelFlag == "" {
|
||||
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: model}); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%sConfigured model %q not found%s\n\n", ansiGray, model, ansiReset)
|
||||
model = ""
|
||||
}
|
||||
}
|
||||
|
||||
// If no valid model or --config flag, show picker
|
||||
if model == "" || configFlag {
|
||||
aliases, _, err := ac.ConfigureAliases(cmd.Context(), model, existingAliases, configFlag)
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
model = aliases["primary"]
|
||||
existingAliases = aliases
|
||||
}
|
||||
|
||||
// Sync aliases and save
|
||||
if err := syncAliases(cmd.Context(), client, ac, name, model, existingAliases); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%sWarning: Could not sync aliases: %v%s\n", ansiGray, err, ansiReset)
|
||||
}
|
||||
if err := saveIntegration(name, []string{model}); err != nil {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
|
||||
// Launch (unless --config without confirmation)
|
||||
if configFlag {
|
||||
if launch, _ := confirmPrompt(fmt.Sprintf("Launch %s now?", r)); launch {
|
||||
return runIntegration(name, model, passArgs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return runIntegration(name, model, passArgs)
|
||||
}
|
||||
|
||||
// Validate --model flag for non-AliasConfigurer integrations
|
||||
if modelFlag != "" {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: modelFlag}); err != nil {
|
||||
return fmt.Errorf("model %q not found", modelFlag)
|
||||
if !configFlag && modelFlag == "" {
|
||||
if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 {
|
||||
return runIntegration(name, config.Models[0], passArgs)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -482,8 +677,6 @@ Examples:
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if saved, err := loadIntegration(name); err == nil && len(saved.Models) > 0 && !configFlag {
|
||||
return runIntegration(name, saved.Models[0], passArgs)
|
||||
} else {
|
||||
var err error
|
||||
models, err = selectModels(cmd.Context(), name, "")
|
||||
@@ -546,14 +739,13 @@ Examples:
|
||||
}
|
||||
|
||||
type modelInfo struct {
|
||||
Name string
|
||||
Remote bool
|
||||
ToolCapable bool
|
||||
Name string
|
||||
Remote bool
|
||||
}
|
||||
|
||||
// buildModelList merges existing models with recommendations, sorts them, and returns
|
||||
// the ordered items along with maps of existing and cloud model names.
|
||||
func buildModelList(existing []modelInfo, preChecked []string, current string) (items []selectItem, orderedChecked []string, existingModels, cloudModels map[string]bool) {
|
||||
func buildModelList(existing []modelInfo, preChecked []string, current string) (items []ModelItem, orderedChecked []string, existingModels, cloudModels map[string]bool) {
|
||||
existingModels = make(map[string]bool)
|
||||
cloudModels = make(map[string]bool)
|
||||
recommended := make(map[string]bool)
|
||||
@@ -573,7 +765,7 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
|
||||
}
|
||||
displayName := strings.TrimSuffix(m.Name, ":latest")
|
||||
existingModels[displayName] = true
|
||||
item := selectItem{Name: displayName}
|
||||
item := ModelItem{Name: displayName}
|
||||
if recommended[displayName] {
|
||||
item.Description = "recommended"
|
||||
}
|
||||
@@ -585,7 +777,7 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
|
||||
continue
|
||||
}
|
||||
items = append(items, rec)
|
||||
if strings.HasSuffix(rec.Name, ":cloud") {
|
||||
if isCloudModel(rec.Name) {
|
||||
cloudModels[rec.Name] = true
|
||||
}
|
||||
}
|
||||
@@ -622,7 +814,7 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
|
||||
}
|
||||
|
||||
if hasLocalModel || hasCloudModel {
|
||||
slices.SortStableFunc(items, func(a, b selectItem) int {
|
||||
slices.SortStableFunc(items, func(a, b ModelItem) int {
|
||||
ac, bc := checked[a.Name], checked[b.Name]
|
||||
aNew, bNew := notInstalled[a.Name], notInstalled[b.Name]
|
||||
|
||||
@@ -645,16 +837,58 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
|
||||
return items, preChecked, existingModels, cloudModels
|
||||
}
|
||||
|
||||
// isCloudModel checks if a model is a cloud model using the Show API.
|
||||
func isCloudModel(ctx context.Context, client *api.Client, name string) bool {
|
||||
if client == nil {
|
||||
return false
|
||||
}
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Name: name})
|
||||
func isCloudModel(name string) bool {
|
||||
return strings.HasSuffix(name, ":cloud")
|
||||
}
|
||||
|
||||
// GetModelItems returns a list of model items including recommendations for the TUI.
|
||||
// It includes all locally available models plus recommended models that aren't installed.
|
||||
func GetModelItems(ctx context.Context) ([]ModelItem, map[string]bool) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return false
|
||||
return nil, nil
|
||||
}
|
||||
return resp.RemoteModel != ""
|
||||
|
||||
models, err := client.List(ctx)
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var existing []modelInfo
|
||||
for _, m := range models.Models {
|
||||
existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""})
|
||||
}
|
||||
|
||||
lastModel := LastModel()
|
||||
var preChecked []string
|
||||
if lastModel != "" {
|
||||
preChecked = []string{lastModel}
|
||||
}
|
||||
|
||||
items, _, existingModels, _ := buildModelList(existing, preChecked, lastModel)
|
||||
|
||||
// Sort with last model first, then existing models, then recommendations
|
||||
slices.SortStableFunc(items, func(a, b ModelItem) int {
|
||||
aIsLast := a.Name == lastModel
|
||||
bIsLast := b.Name == lastModel
|
||||
if aIsLast != bIsLast {
|
||||
if aIsLast {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
aExists := existingModels[a.Name]
|
||||
bExists := existingModels[b.Name]
|
||||
if aExists != bExists {
|
||||
if aExists {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
|
||||
})
|
||||
|
||||
return items, existingModels
|
||||
}
|
||||
|
||||
func pullModel(ctx context.Context, client *api.Client, model string) error {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
@@ -89,8 +88,10 @@ func TestLaunchCmd(t *testing.T) {
|
||||
mockCheck := func(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
// Mock TUI function (not called in these tests)
|
||||
mockTUI := func(cmd *cobra.Command) {}
|
||||
|
||||
cmd := LaunchCmd(mockCheck)
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
|
||||
t.Run("command structure", func(t *testing.T) {
|
||||
if cmd.Use != "launch [INTEGRATION] [-- [EXTRA_ARGS...]]" {
|
||||
@@ -123,6 +124,75 @@ func TestLaunchCmd(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestLaunchCmd_TUICallback(t *testing.T) {
|
||||
mockCheck := func(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
t.Run("no args calls TUI", func(t *testing.T) {
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{})
|
||||
_ = cmd.Execute()
|
||||
|
||||
if !tuiCalled {
|
||||
t.Error("TUI callback should be called when no args provided")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("integration arg bypasses TUI", func(t *testing.T) {
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{"claude"})
|
||||
// Will error because claude isn't configured, but that's OK
|
||||
_ = cmd.Execute()
|
||||
|
||||
if tuiCalled {
|
||||
t.Error("TUI callback should NOT be called when integration arg provided")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("--model flag bypasses TUI", func(t *testing.T) {
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{"--model", "test-model"})
|
||||
// Will error because no integration specified, but that's OK
|
||||
_ = cmd.Execute()
|
||||
|
||||
if tuiCalled {
|
||||
t.Error("TUI callback should NOT be called when --model flag provided")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("--config flag bypasses TUI", func(t *testing.T) {
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{"--config"})
|
||||
// Will error because no integration specified, but that's OK
|
||||
_ = cmd.Execute()
|
||||
|
||||
if tuiCalled {
|
||||
t.Error("TUI callback should NOT be called when --config flag provided")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRunIntegration_UnknownIntegration(t *testing.T) {
|
||||
err := runIntegration("unknown-integration", "model", nil)
|
||||
if err == nil {
|
||||
@@ -163,7 +233,7 @@ func TestHasLocalModel_DocumentsHeuristic(t *testing.T) {
|
||||
|
||||
func TestLaunchCmd_NilHeartbeat(t *testing.T) {
|
||||
// This should not panic - cmd creation should work even with nil
|
||||
cmd := LaunchCmd(nil)
|
||||
cmd := LaunchCmd(nil, nil)
|
||||
if cmd == nil {
|
||||
t.Fatal("LaunchCmd returned nil")
|
||||
}
|
||||
@@ -298,18 +368,27 @@ func TestParseArgs(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestIsCloudModel(t *testing.T) {
|
||||
// isCloudModel now only uses Show API, so nil client always returns false
|
||||
t.Run("nil client returns false", func(t *testing.T) {
|
||||
models := []string{"glm-4.7:cloud", "kimi-k2.5:cloud", "local-model"}
|
||||
for _, model := range models {
|
||||
if isCloudModel(context.Background(), nil, model) {
|
||||
t.Errorf("isCloudModel(%q) with nil client should return false", model)
|
||||
tests := []struct {
|
||||
name string
|
||||
want bool
|
||||
}{
|
||||
{"glm-4.7:cloud", true},
|
||||
{"kimi-k2.5:cloud", true},
|
||||
{"glm-4.7-flash", false},
|
||||
{"glm-4.7-flash:latest", false},
|
||||
{"cloud-model", false},
|
||||
{"model:cloudish", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := isCloudModel(tt.name); got != tt.want {
|
||||
t.Errorf("isCloudModel(%q) = %v, want %v", tt.name, got, tt.want)
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func names(items []selectItem) []string {
|
||||
func names(items []ModelItem) []string {
|
||||
var out []string
|
||||
for _, item := range items {
|
||||
out = append(out, item.Name)
|
||||
@@ -501,41 +580,3 @@ func TestBuildModelList_ReturnsExistingAndCloudMaps(t *testing.T) {
|
||||
t.Error("llama3.2 should not be in cloudModels")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEditorIntegration_SavedConfigSkipsSelection(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save a config for opencode so it looks like a previous launch
|
||||
if err := saveIntegration("opencode", []string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Verify loadIntegration returns the saved models
|
||||
saved, err := loadIntegration("opencode")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(saved.Models) == 0 {
|
||||
t.Fatal("expected saved models")
|
||||
}
|
||||
if saved.Models[0] != "llama3.2" {
|
||||
t.Errorf("expected llama3.2, got %s", saved.Models[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAliasConfigurerInterface(t *testing.T) {
|
||||
t.Run("claude implements AliasConfigurer", func(t *testing.T) {
|
||||
claude := &Claude{}
|
||||
if _, ok := interface{}(claude).(AliasConfigurer); !ok {
|
||||
t.Error("Claude should implement AliasConfigurer")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("codex does not implement AliasConfigurer", func(t *testing.T) {
|
||||
codex := &Codex{}
|
||||
if _, ok := interface{}(codex).(AliasConfigurer); ok {
|
||||
t.Error("Codex should not implement AliasConfigurer")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -17,6 +17,8 @@ type Openclaw struct{}
|
||||
|
||||
func (c *Openclaw) String() string { return "OpenClaw" }
|
||||
|
||||
const ansiGreen = "\033[32m"
|
||||
|
||||
func (c *Openclaw) Run(model string, args []string) error {
|
||||
bin := "openclaw"
|
||||
if _, err := exec.LookPath(bin); err != nil {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"maps"
|
||||
@@ -11,52 +10,12 @@ import (
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// OpenCode implements Runner and Editor for OpenCode integration
|
||||
type OpenCode struct{}
|
||||
|
||||
// cloudModelLimit holds context and output token limits for a cloud model.
|
||||
type cloudModelLimit struct {
|
||||
Context int
|
||||
Output int
|
||||
}
|
||||
|
||||
// cloudModelLimits maps cloud model base names to their token limits.
|
||||
// TODO(parthsareen): grab context/output limits from model info instead of hardcoding
|
||||
var cloudModelLimits = map[string]cloudModelLimit{
|
||||
"cogito-2.1:671b": {Context: 163_840, Output: 65_536},
|
||||
"deepseek-v3.1:671b": {Context: 163_840, Output: 163_840},
|
||||
"deepseek-v3.2": {Context: 163_840, Output: 65_536},
|
||||
"glm-4.6": {Context: 202_752, Output: 131_072},
|
||||
"glm-4.7": {Context: 202_752, Output: 131_072},
|
||||
"gpt-oss:120b": {Context: 131_072, Output: 131_072},
|
||||
"gpt-oss:20b": {Context: 131_072, Output: 131_072},
|
||||
"kimi-k2:1t": {Context: 262_144, Output: 262_144},
|
||||
"kimi-k2.5": {Context: 262_144, Output: 262_144},
|
||||
"kimi-k2-thinking": {Context: 262_144, Output: 262_144},
|
||||
"nemotron-3-nano:30b": {Context: 1_048_576, Output: 131_072},
|
||||
"qwen3-coder:480b": {Context: 262_144, Output: 65_536},
|
||||
"qwen3-next:80b": {Context: 262_144, Output: 32_768},
|
||||
}
|
||||
|
||||
// lookupCloudModelLimit returns the token limits for a cloud model.
|
||||
// It tries the exact name first, then strips the ":cloud" suffix.
|
||||
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
|
||||
if l, ok := cloudModelLimits[name]; ok {
|
||||
return l, true
|
||||
}
|
||||
base := strings.TrimSuffix(name, ":cloud")
|
||||
if base != name {
|
||||
if l, ok := cloudModelLimits[base]; ok {
|
||||
return l, true
|
||||
}
|
||||
}
|
||||
return cloudModelLimit{}, false
|
||||
}
|
||||
|
||||
func (o *OpenCode) String() string { return "OpenCode" }
|
||||
|
||||
func (o *OpenCode) Run(model string, args []string) error {
|
||||
@@ -154,8 +113,6 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
client, _ := api.ClientFromEnvironment()
|
||||
|
||||
for _, model := range modelList {
|
||||
if existing, ok := models[model].(map[string]any); ok {
|
||||
// migrate existing models without _launch marker
|
||||
@@ -165,29 +122,12 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
|
||||
}
|
||||
}
|
||||
if isCloudModel(context.Background(), client, model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
existing["limit"] = map[string]any{
|
||||
"context": l.Context,
|
||||
"output": l.Output,
|
||||
}
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
entry := map[string]any{
|
||||
models[model] = map[string]any{
|
||||
"name": model,
|
||||
"_launch": true,
|
||||
}
|
||||
if isCloudModel(context.Background(), client, model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
entry["limit"] = map[string]any{
|
||||
"context": l.Context,
|
||||
"output": l.Output,
|
||||
}
|
||||
}
|
||||
}
|
||||
models[model] = entry
|
||||
}
|
||||
|
||||
ollama["models"] = models
|
||||
|
||||
@@ -2,7 +2,6 @@ package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -496,165 +495,6 @@ func TestOpenCodeEdit_SpecialCharsInModelName(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func readOpenCodeModel(t *testing.T, configPath, model string) map[string]any {
|
||||
t.Helper()
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
provider := cfg["provider"].(map[string]any)
|
||||
ollama := provider["ollama"].(map[string]any)
|
||||
models := ollama["models"].(map[string]any)
|
||||
entry, ok := models[model].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("model %s not found in config", model)
|
||||
}
|
||||
return entry
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_LocalModelNoLimit(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configPath := filepath.Join(tmpDir, ".config", "opencode", "opencode.json")
|
||||
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entry := readOpenCodeModel(t, configPath, "llama3.2")
|
||||
if entry["limit"] != nil {
|
||||
t.Errorf("local model should not have limit set, got %v", entry["limit"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_PreservesUserLimit(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
|
||||
// Set up a model with a user-configured limit
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{
|
||||
"provider": {
|
||||
"ollama": {
|
||||
"models": {
|
||||
"llama3.2": {
|
||||
"name": "llama3.2",
|
||||
"_launch": true,
|
||||
"limit": {"context": 8192, "output": 4096}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`), 0o644)
|
||||
|
||||
// Re-edit should preserve the user's limit (not delete it)
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entry := readOpenCodeModel(t, configPath, "llama3.2")
|
||||
limit, ok := entry["limit"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("user-configured limit was removed")
|
||||
}
|
||||
if limit["context"] != float64(8192) {
|
||||
t.Errorf("context limit changed: got %v, want 8192", limit["context"])
|
||||
}
|
||||
if limit["output"] != float64(4096) {
|
||||
t.Errorf("output limit changed: got %v, want 4096", limit["output"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_CloudModelLimitStructure(t *testing.T) {
|
||||
// Verify that when a cloud model entry has limits set (as Edit would do),
|
||||
// the structure matches what opencode expects and re-edit preserves them.
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
|
||||
expected := cloudModelLimits["glm-4.7"]
|
||||
|
||||
// Simulate a cloud model that already has the limit set by a previous Edit
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(fmt.Sprintf(`{
|
||||
"provider": {
|
||||
"ollama": {
|
||||
"models": {
|
||||
"glm-4.7:cloud": {
|
||||
"name": "glm-4.7:cloud",
|
||||
"_launch": true,
|
||||
"limit": {"context": %d, "output": %d}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`, expected.Context, expected.Output)), 0o644)
|
||||
|
||||
// Re-edit should preserve the cloud model limit
|
||||
if err := o.Edit([]string{"glm-4.7:cloud"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entry := readOpenCodeModel(t, configPath, "glm-4.7:cloud")
|
||||
limit, ok := entry["limit"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("cloud model limit was removed on re-edit")
|
||||
}
|
||||
if limit["context"] != float64(expected.Context) {
|
||||
t.Errorf("context = %v, want %d", limit["context"], expected.Context)
|
||||
}
|
||||
if limit["output"] != float64(expected.Output) {
|
||||
t.Errorf("output = %v, want %d", limit["output"], expected.Output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookupCloudModelLimit(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
wantOK bool
|
||||
wantContext int
|
||||
wantOutput int
|
||||
}{
|
||||
{"glm-4.7", true, 202_752, 131_072},
|
||||
{"glm-4.7:cloud", true, 202_752, 131_072},
|
||||
{"kimi-k2.5", true, 262_144, 262_144},
|
||||
{"kimi-k2.5:cloud", true, 262_144, 262_144},
|
||||
{"deepseek-v3.2", true, 163_840, 65_536},
|
||||
{"deepseek-v3.2:cloud", true, 163_840, 65_536},
|
||||
{"qwen3-coder:480b", true, 262_144, 65_536},
|
||||
{"llama3.2", false, 0, 0},
|
||||
{"unknown-model:cloud", false, 0, 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
l, ok := lookupCloudModelLimit(tt.name)
|
||||
if ok != tt.wantOK {
|
||||
t.Errorf("lookupCloudModelLimit(%q) ok = %v, want %v", tt.name, ok, tt.wantOK)
|
||||
}
|
||||
if ok {
|
||||
if l.Context != tt.wantContext {
|
||||
t.Errorf("context = %d, want %d", l.Context, tt.wantContext)
|
||||
}
|
||||
if l.Output != tt.wantOutput {
|
||||
t.Errorf("output = %d, want %d", l.Output, tt.wantOutput)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeModels_NoConfig(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
@@ -17,7 +17,6 @@ const (
|
||||
ansiBold = "\033[1m"
|
||||
ansiReset = "\033[0m"
|
||||
ansiGray = "\033[37m"
|
||||
ansiGreen = "\033[32m"
|
||||
ansiClearDown = "\033[J"
|
||||
)
|
||||
|
||||
|
||||
@@ -96,14 +96,6 @@ func TestSelectState(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_EmptyFilteredList_EmptyFilter_DoesNothing", func(t *testing.T) {
|
||||
s := newSelectState([]selectItem{})
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
if done || result != "" || err != nil {
|
||||
t.Errorf("expected (false, '', nil), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Escape_ReturnsCancelledError", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
done, result, err := s.handleInput(eventEscape, 0)
|
||||
@@ -582,19 +574,8 @@ func TestRenderSelect(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
renderSelect(&buf, "Select:", s)
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "no matches") {
|
||||
t.Errorf("expected 'no matches' message, got: %s", output)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EmptyFilteredList_EmptyFilter_ShowsNoMatches", func(t *testing.T) {
|
||||
s := newSelectState([]selectItem{})
|
||||
var buf bytes.Buffer
|
||||
renderSelect(&buf, "Select:", s)
|
||||
|
||||
if !strings.Contains(buf.String(), "no matches") {
|
||||
t.Error("expected 'no matches' message for empty list with no filter")
|
||||
t.Error("expected 'no matches' message")
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@@ -10,21 +10,19 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
var errNotRunning = errors.New("could not connect to ollama server, run 'ollama serve' to start it")
|
||||
|
||||
func startApp(ctx context.Context, client *api.Client) error {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return errNotRunning
|
||||
return err
|
||||
}
|
||||
link, err := os.Readlink(exe)
|
||||
if err != nil {
|
||||
return errNotRunning
|
||||
return err
|
||||
}
|
||||
r := regexp.MustCompile(`^.*/Ollama\s?\d*.app`)
|
||||
m := r.FindStringSubmatch(link)
|
||||
if len(m) != 1 {
|
||||
return errNotRunning
|
||||
return errors.New("could not find ollama app")
|
||||
}
|
||||
if err := exec.Command("/usr/bin/open", "-j", "-a", m[0], "--args", "--fast-startup").Run(); err != nil {
|
||||
return err
|
||||
|
||||
509
cmd/tui/selector.go
Normal file
509
cmd/tui/selector.go
Normal file
@@ -0,0 +1,509 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
var (
|
||||
selectorTitleStyle = lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(lipgloss.Color("147"))
|
||||
|
||||
selectorItemStyle = lipgloss.NewStyle().
|
||||
PaddingLeft(4)
|
||||
|
||||
selectorSelectedItemStyle = lipgloss.NewStyle().
|
||||
PaddingLeft(2).
|
||||
Foreground(lipgloss.Color("147")).
|
||||
Bold(true)
|
||||
|
||||
selectorDescStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("241"))
|
||||
|
||||
selectorFilterStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("241")).
|
||||
Italic(true)
|
||||
|
||||
selectorInputStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("252"))
|
||||
|
||||
selectorCheckboxStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("241"))
|
||||
|
||||
selectorCheckboxCheckedStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("147"))
|
||||
|
||||
selectorDefaultTagStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("241")).
|
||||
Italic(true)
|
||||
|
||||
selectorHelpStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("241"))
|
||||
|
||||
selectorMoreStyle = lipgloss.NewStyle().
|
||||
PaddingLeft(4).
|
||||
Foreground(lipgloss.Color("241")).
|
||||
Italic(true)
|
||||
)
|
||||
|
||||
const maxSelectorItems = 10
|
||||
|
||||
// ErrCancelled is returned when the user cancels the selection.
|
||||
var ErrCancelled = errors.New("cancelled")
|
||||
|
||||
// SelectItem represents an item that can be selected.
|
||||
type SelectItem struct {
|
||||
Name string
|
||||
Description string
|
||||
}
|
||||
|
||||
// selectorModel is the bubbletea model for single selection.
|
||||
type selectorModel struct {
|
||||
title string
|
||||
items []SelectItem
|
||||
filter string
|
||||
cursor int
|
||||
scrollOffset int
|
||||
selected string
|
||||
cancelled bool
|
||||
}
|
||||
|
||||
func (m selectorModel) filteredItems() []SelectItem {
|
||||
if m.filter == "" {
|
||||
return m.items
|
||||
}
|
||||
filterLower := strings.ToLower(m.filter)
|
||||
var result []SelectItem
|
||||
for _, item := range m.items {
|
||||
if strings.Contains(strings.ToLower(item.Name), filterLower) {
|
||||
result = append(result, item)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (m selectorModel) Init() tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m selectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
filtered := m.filteredItems()
|
||||
|
||||
switch msg.Type {
|
||||
case tea.KeyCtrlC, tea.KeyEsc:
|
||||
m.cancelled = true
|
||||
return m, tea.Quit
|
||||
|
||||
case tea.KeyEnter:
|
||||
if len(filtered) > 0 && m.cursor < len(filtered) {
|
||||
m.selected = filtered[m.cursor].Name
|
||||
}
|
||||
return m, tea.Quit
|
||||
|
||||
case tea.KeyUp:
|
||||
if m.cursor > 0 {
|
||||
m.cursor--
|
||||
if m.cursor < m.scrollOffset {
|
||||
m.scrollOffset = m.cursor
|
||||
}
|
||||
}
|
||||
|
||||
case tea.KeyDown:
|
||||
if m.cursor < len(filtered)-1 {
|
||||
m.cursor++
|
||||
if m.cursor >= m.scrollOffset+maxSelectorItems {
|
||||
m.scrollOffset = m.cursor - maxSelectorItems + 1
|
||||
}
|
||||
}
|
||||
|
||||
case tea.KeyPgUp:
|
||||
m.cursor -= maxSelectorItems
|
||||
if m.cursor < 0 {
|
||||
m.cursor = 0
|
||||
}
|
||||
m.scrollOffset -= maxSelectorItems
|
||||
if m.scrollOffset < 0 {
|
||||
m.scrollOffset = 0
|
||||
}
|
||||
|
||||
case tea.KeyPgDown:
|
||||
m.cursor += maxSelectorItems
|
||||
if m.cursor >= len(filtered) {
|
||||
m.cursor = len(filtered) - 1
|
||||
}
|
||||
if m.cursor >= m.scrollOffset+maxSelectorItems {
|
||||
m.scrollOffset = m.cursor - maxSelectorItems + 1
|
||||
}
|
||||
|
||||
case tea.KeyBackspace:
|
||||
if len(m.filter) > 0 {
|
||||
m.filter = m.filter[:len(m.filter)-1]
|
||||
m.cursor = 0
|
||||
m.scrollOffset = 0
|
||||
}
|
||||
|
||||
case tea.KeyRunes:
|
||||
m.filter += string(msg.Runes)
|
||||
m.cursor = 0
|
||||
m.scrollOffset = 0
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m selectorModel) View() string {
|
||||
// Clear screen when exiting
|
||||
if m.cancelled || m.selected != "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
var s strings.Builder
|
||||
|
||||
// Title with filter
|
||||
s.WriteString(selectorTitleStyle.Render(m.title))
|
||||
s.WriteString(" ")
|
||||
if m.filter == "" {
|
||||
s.WriteString(selectorFilterStyle.Render("Type to filter..."))
|
||||
} else {
|
||||
s.WriteString(selectorInputStyle.Render(m.filter))
|
||||
}
|
||||
s.WriteString("\n\n")
|
||||
|
||||
filtered := m.filteredItems()
|
||||
|
||||
if len(filtered) == 0 {
|
||||
s.WriteString(selectorItemStyle.Render(selectorDescStyle.Render("(no matches)")))
|
||||
s.WriteString("\n")
|
||||
} else {
|
||||
displayCount := min(len(filtered), maxSelectorItems)
|
||||
|
||||
for i := range displayCount {
|
||||
idx := m.scrollOffset + i
|
||||
if idx >= len(filtered) {
|
||||
break
|
||||
}
|
||||
item := filtered[idx]
|
||||
|
||||
if idx == m.cursor {
|
||||
s.WriteString(selectorSelectedItemStyle.Render("▸ " + item.Name))
|
||||
} else {
|
||||
s.WriteString(selectorItemStyle.Render(item.Name))
|
||||
}
|
||||
|
||||
if item.Description != "" {
|
||||
s.WriteString(" ")
|
||||
s.WriteString(selectorDescStyle.Render("- " + item.Description))
|
||||
}
|
||||
s.WriteString("\n")
|
||||
}
|
||||
|
||||
if remaining := len(filtered) - m.scrollOffset - displayCount; remaining > 0 {
|
||||
s.WriteString(selectorMoreStyle.Render(fmt.Sprintf("... and %d more", remaining)))
|
||||
s.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
s.WriteString("\n")
|
||||
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • enter select • esc cancel"))
|
||||
|
||||
return s.String()
|
||||
}
|
||||
|
||||
// SelectSingle prompts the user to select a single item from a list.
|
||||
func SelectSingle(title string, items []SelectItem) (string, error) {
|
||||
if len(items) == 0 {
|
||||
return "", fmt.Errorf("no items to select from")
|
||||
}
|
||||
|
||||
m := selectorModel{
|
||||
title: title,
|
||||
items: items,
|
||||
}
|
||||
|
||||
p := tea.NewProgram(m)
|
||||
finalModel, err := p.Run()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error running selector: %w", err)
|
||||
}
|
||||
|
||||
fm := finalModel.(selectorModel)
|
||||
if fm.cancelled {
|
||||
return "", ErrCancelled
|
||||
}
|
||||
|
||||
return fm.selected, nil
|
||||
}
|
||||
|
||||
// multiSelectorModel is the bubbletea model for multi selection.
|
||||
type multiSelectorModel struct {
|
||||
title string
|
||||
items []SelectItem
|
||||
itemIndex map[string]int
|
||||
filter string
|
||||
cursor int
|
||||
scrollOffset int
|
||||
checked map[int]bool
|
||||
checkOrder []int
|
||||
cancelled bool
|
||||
confirmed bool
|
||||
}
|
||||
|
||||
func newMultiSelectorModel(title string, items []SelectItem, preChecked []string) multiSelectorModel {
|
||||
m := multiSelectorModel{
|
||||
title: title,
|
||||
items: items,
|
||||
itemIndex: make(map[string]int, len(items)),
|
||||
checked: make(map[int]bool),
|
||||
}
|
||||
|
||||
for i, item := range items {
|
||||
m.itemIndex[item.Name] = i
|
||||
}
|
||||
|
||||
for _, name := range preChecked {
|
||||
if idx, ok := m.itemIndex[name]; ok {
|
||||
m.checked[idx] = true
|
||||
m.checkOrder = append(m.checkOrder, idx)
|
||||
}
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) filteredItems() []SelectItem {
|
||||
if m.filter == "" {
|
||||
return m.items
|
||||
}
|
||||
filterLower := strings.ToLower(m.filter)
|
||||
var result []SelectItem
|
||||
for _, item := range m.items {
|
||||
if strings.Contains(strings.ToLower(item.Name), filterLower) {
|
||||
result = append(result, item)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (m *multiSelectorModel) toggleItem() {
|
||||
filtered := m.filteredItems()
|
||||
if len(filtered) == 0 || m.cursor >= len(filtered) {
|
||||
return
|
||||
}
|
||||
|
||||
item := filtered[m.cursor]
|
||||
origIdx := m.itemIndex[item.Name]
|
||||
|
||||
if m.checked[origIdx] {
|
||||
delete(m.checked, origIdx)
|
||||
for i, idx := range m.checkOrder {
|
||||
if idx == origIdx {
|
||||
m.checkOrder = append(m.checkOrder[:i], m.checkOrder[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
m.checked[origIdx] = true
|
||||
m.checkOrder = append(m.checkOrder, origIdx)
|
||||
}
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) selectedCount() int {
|
||||
return len(m.checkOrder)
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) Init() tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
filtered := m.filteredItems()
|
||||
|
||||
switch msg.Type {
|
||||
case tea.KeyCtrlC, tea.KeyEsc:
|
||||
m.cancelled = true
|
||||
return m, tea.Quit
|
||||
|
||||
case tea.KeyEnter:
|
||||
// Enter confirms if at least one item is selected
|
||||
if len(m.checkOrder) > 0 {
|
||||
m.confirmed = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
|
||||
case tea.KeySpace:
|
||||
// Space always toggles selection
|
||||
m.toggleItem()
|
||||
|
||||
case tea.KeyUp:
|
||||
if m.cursor > 0 {
|
||||
m.cursor--
|
||||
if m.cursor < m.scrollOffset {
|
||||
m.scrollOffset = m.cursor
|
||||
}
|
||||
}
|
||||
|
||||
case tea.KeyDown:
|
||||
if m.cursor < len(filtered)-1 {
|
||||
m.cursor++
|
||||
if m.cursor >= m.scrollOffset+maxSelectorItems {
|
||||
m.scrollOffset = m.cursor - maxSelectorItems + 1
|
||||
}
|
||||
}
|
||||
|
||||
case tea.KeyPgUp:
|
||||
m.cursor -= maxSelectorItems
|
||||
if m.cursor < 0 {
|
||||
m.cursor = 0
|
||||
}
|
||||
m.scrollOffset -= maxSelectorItems
|
||||
if m.scrollOffset < 0 {
|
||||
m.scrollOffset = 0
|
||||
}
|
||||
|
||||
case tea.KeyPgDown:
|
||||
m.cursor += maxSelectorItems
|
||||
if m.cursor >= len(filtered) {
|
||||
m.cursor = len(filtered) - 1
|
||||
}
|
||||
if m.cursor >= m.scrollOffset+maxSelectorItems {
|
||||
m.scrollOffset = m.cursor - maxSelectorItems + 1
|
||||
}
|
||||
|
||||
case tea.KeyBackspace:
|
||||
if len(m.filter) > 0 {
|
||||
m.filter = m.filter[:len(m.filter)-1]
|
||||
m.cursor = 0
|
||||
m.scrollOffset = 0
|
||||
}
|
||||
|
||||
case tea.KeyRunes:
|
||||
m.filter += string(msg.Runes)
|
||||
m.cursor = 0
|
||||
m.scrollOffset = 0
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) View() string {
|
||||
// Clear screen when exiting
|
||||
if m.cancelled || m.confirmed {
|
||||
return ""
|
||||
}
|
||||
|
||||
var s strings.Builder
|
||||
|
||||
// Title with filter
|
||||
s.WriteString(selectorTitleStyle.Render(m.title))
|
||||
s.WriteString(" ")
|
||||
if m.filter == "" {
|
||||
s.WriteString(selectorFilterStyle.Render("Type to filter..."))
|
||||
} else {
|
||||
s.WriteString(selectorInputStyle.Render(m.filter))
|
||||
}
|
||||
s.WriteString("\n\n")
|
||||
|
||||
filtered := m.filteredItems()
|
||||
|
||||
if len(filtered) == 0 {
|
||||
s.WriteString(selectorItemStyle.Render(selectorDescStyle.Render("(no matches)")))
|
||||
s.WriteString("\n")
|
||||
} else {
|
||||
displayCount := min(len(filtered), maxSelectorItems)
|
||||
|
||||
for i := range displayCount {
|
||||
idx := m.scrollOffset + i
|
||||
if idx >= len(filtered) {
|
||||
break
|
||||
}
|
||||
item := filtered[idx]
|
||||
origIdx := m.itemIndex[item.Name]
|
||||
|
||||
// Checkbox
|
||||
var checkbox string
|
||||
if m.checked[origIdx] {
|
||||
checkbox = selectorCheckboxCheckedStyle.Render("[x]")
|
||||
} else {
|
||||
checkbox = selectorCheckboxStyle.Render("[ ]")
|
||||
}
|
||||
|
||||
// Cursor and name
|
||||
var line string
|
||||
if idx == m.cursor {
|
||||
line = selectorSelectedItemStyle.Render("▸ ") + checkbox + " " + selectorSelectedItemStyle.Render(item.Name)
|
||||
} else {
|
||||
line = " " + checkbox + " " + item.Name
|
||||
}
|
||||
|
||||
// Default tag
|
||||
if len(m.checkOrder) > 0 && m.checkOrder[0] == origIdx {
|
||||
line += " " + selectorDefaultTagStyle.Render("(default)")
|
||||
}
|
||||
|
||||
s.WriteString(line)
|
||||
s.WriteString("\n")
|
||||
}
|
||||
|
||||
if remaining := len(filtered) - m.scrollOffset - displayCount; remaining > 0 {
|
||||
s.WriteString(selectorMoreStyle.Render(fmt.Sprintf("... and %d more", remaining)))
|
||||
s.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
s.WriteString("\n")
|
||||
|
||||
// Status line
|
||||
count := m.selectedCount()
|
||||
if count == 0 {
|
||||
s.WriteString(selectorDescStyle.Render(" Select at least one model."))
|
||||
} else {
|
||||
s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count)))
|
||||
}
|
||||
s.WriteString("\n\n")
|
||||
|
||||
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • enter confirm • esc cancel"))
|
||||
|
||||
return s.String()
|
||||
}
|
||||
|
||||
// SelectMultiple prompts the user to select multiple items from a list.
|
||||
func SelectMultiple(title string, items []SelectItem, preChecked []string) ([]string, error) {
|
||||
if len(items) == 0 {
|
||||
return nil, fmt.Errorf("no items to select from")
|
||||
}
|
||||
|
||||
m := newMultiSelectorModel(title, items, preChecked)
|
||||
|
||||
p := tea.NewProgram(m)
|
||||
finalModel, err := p.Run()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error running selector: %w", err)
|
||||
}
|
||||
|
||||
fm := finalModel.(multiSelectorModel)
|
||||
if fm.cancelled {
|
||||
return nil, ErrCancelled
|
||||
}
|
||||
|
||||
if !fm.confirmed {
|
||||
return nil, ErrCancelled
|
||||
}
|
||||
|
||||
var result []string
|
||||
for _, idx := range fm.checkOrder {
|
||||
result = append(result, fm.items[idx].Name)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
808
cmd/tui/tui.go
Normal file
808
cmd/tui/tui.go
Normal file
@@ -0,0 +1,808 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
const (
|
||||
logoNormal = ` ▆▁▂▃▂▁▆
|
||||
▟███████▙
|
||||
█▙▛▅ ▅▜▟█
|
||||
▟█▙▀▀▀▟█▙
|
||||
█████████
|
||||
▟███████▙
|
||||
▀▀▀▀▀▀▀▀▀`
|
||||
|
||||
logoBlink = ` ▆▁▂▃▂▁▆
|
||||
▟███████▙
|
||||
██▛▅ ▅▜██
|
||||
▟█▙▀▀▀▟█▙
|
||||
█████████
|
||||
▟███████▙
|
||||
▀▀▀▀▀▀▀▀▀`
|
||||
|
||||
// logoBlank is used for terminals that don't render the logo well
|
||||
logoBlank = `
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
`
|
||||
|
||||
blinkInterval = 15 * time.Second
|
||||
blinkDuration = 250 * time.Millisecond
|
||||
)
|
||||
|
||||
type (
|
||||
blinkMsg struct{}
|
||||
unblinkMsg struct{}
|
||||
)
|
||||
|
||||
var (
|
||||
logoStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("255")).
|
||||
Background(lipgloss.Color("0"))
|
||||
|
||||
titleStyle = lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
MarginBottom(1)
|
||||
|
||||
versionStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("245"))
|
||||
|
||||
itemStyle = lipgloss.NewStyle().
|
||||
PaddingLeft(2)
|
||||
|
||||
selectedStyle = lipgloss.NewStyle().
|
||||
PaddingLeft(2).
|
||||
Foreground(lipgloss.Color("147")).
|
||||
Bold(true)
|
||||
|
||||
greyedStyle = lipgloss.NewStyle().
|
||||
PaddingLeft(2).
|
||||
Foreground(lipgloss.Color("241"))
|
||||
|
||||
greyedSelectedStyle = lipgloss.NewStyle().
|
||||
PaddingLeft(2).
|
||||
Foreground(lipgloss.Color("243"))
|
||||
|
||||
descStyle = lipgloss.NewStyle().
|
||||
PaddingLeft(4).
|
||||
Foreground(lipgloss.Color("241"))
|
||||
|
||||
modelStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("245"))
|
||||
|
||||
notInstalledStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("241")).
|
||||
Italic(true)
|
||||
)
|
||||
|
||||
type menuItem struct {
|
||||
title string
|
||||
description string
|
||||
integration string // integration name for loading model config, empty if not an integration
|
||||
isRunModel bool // true for the "Run a model" option
|
||||
isOthers bool // true for the "Others..." toggle item
|
||||
}
|
||||
|
||||
var mainMenuItems = []menuItem{
|
||||
{
|
||||
title: "Run a model",
|
||||
description: "Start an interactive chat with a local model",
|
||||
isRunModel: true,
|
||||
},
|
||||
{
|
||||
title: "Launch Claude Code",
|
||||
description: "Open Claude Code AI assistant",
|
||||
integration: "claude",
|
||||
},
|
||||
{
|
||||
title: "Launch Open Claw",
|
||||
description: "Open the Open Claw integration",
|
||||
integration: "openclaw",
|
||||
},
|
||||
}
|
||||
|
||||
var othersMenuItem = menuItem{
|
||||
title: "Others...",
|
||||
description: "Show additional integrations",
|
||||
isOthers: true,
|
||||
}
|
||||
|
||||
// getOtherIntegrations returns the list of other integrations, filtering out
|
||||
// Codex if it's not installed (since it requires npm install).
|
||||
func getOtherIntegrations() []menuItem {
|
||||
items := []menuItem{
|
||||
{
|
||||
title: "Launch Droid",
|
||||
description: "Open Droid integration",
|
||||
integration: "droid",
|
||||
},
|
||||
{
|
||||
title: "Launch Open Code",
|
||||
description: "Open Open Code integration",
|
||||
integration: "opencode",
|
||||
},
|
||||
}
|
||||
|
||||
// Only show Codex if it's already installed
|
||||
if config.IsIntegrationInstalled("codex") {
|
||||
items = append([]menuItem{{
|
||||
title: "Launch Codex",
|
||||
description: "Open Codex CLI",
|
||||
integration: "codex",
|
||||
}}, items...)
|
||||
}
|
||||
|
||||
return items
|
||||
}
|
||||
|
||||
type model struct {
|
||||
items []menuItem
|
||||
cursor int
|
||||
quitting bool
|
||||
selected bool // true if user made a selection (enter/space)
|
||||
changeModel bool // true if user pressed right arrow to change model
|
||||
showOthers bool // true if "Others..." is expanded
|
||||
availableModels map[string]bool // cache of available model names
|
||||
blinking bool // true when showing blink logo
|
||||
err error
|
||||
|
||||
// Modal state
|
||||
showingModal bool // true when model picker modal is visible
|
||||
modalSelector selectorModel // the selector model for the modal
|
||||
modalItems []SelectItem // cached items for the modal
|
||||
|
||||
// Sign-in dialog state
|
||||
showingSignIn bool // true when sign-in dialog is visible
|
||||
signInURL string // URL for sign-in
|
||||
signInModel string // model that requires sign-in
|
||||
signInSpinner int // spinner frame index
|
||||
signInFromModal bool // true if sign-in was triggered from modal (not main menu)
|
||||
}
|
||||
|
||||
// signInTickMsg is sent to animate the sign-in spinner
|
||||
type signInTickMsg struct{}
|
||||
|
||||
// signInCheckMsg is sent to check if sign-in is complete
|
||||
type signInCheckMsg struct {
|
||||
signedIn bool
|
||||
userName string
|
||||
}
|
||||
|
||||
// modelExists checks if a model exists in the cached available models.
|
||||
func (m *model) modelExists(name string) bool {
|
||||
if m.availableModels == nil || name == "" {
|
||||
return false
|
||||
}
|
||||
if m.availableModels[name] {
|
||||
return true
|
||||
}
|
||||
// Check for prefix match (e.g., "llama2" matches "llama2:latest")
|
||||
for modelName := range m.availableModels {
|
||||
if strings.HasPrefix(modelName, name+":") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// buildModalItems creates the list of models for the modal selector.
|
||||
func (m *model) buildModalItems() []SelectItem {
|
||||
modelItems, _ := config.GetModelItems(context.Background())
|
||||
var items []SelectItem
|
||||
for _, item := range modelItems {
|
||||
items = append(items, SelectItem{Name: item.Name, Description: item.Description})
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
// openModelModal opens the model picker modal.
|
||||
func (m *model) openModelModal() {
|
||||
m.modalItems = m.buildModalItems()
|
||||
m.modalSelector = selectorModel{
|
||||
title: "Select model:",
|
||||
items: m.modalItems,
|
||||
}
|
||||
m.showingModal = true
|
||||
}
|
||||
|
||||
// isCloudModel returns true if the model name indicates a cloud model.
|
||||
func isCloudModel(name string) bool {
|
||||
return strings.HasSuffix(name, ":cloud")
|
||||
}
|
||||
|
||||
// checkCloudSignIn checks if a cloud model needs sign-in.
|
||||
// Returns a command to start sign-in if needed, or nil if already signed in.
|
||||
func (m *model) checkCloudSignIn(modelName string, fromModal bool) tea.Cmd {
|
||||
if modelName == "" || !isCloudModel(modelName) {
|
||||
return nil
|
||||
}
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
user, err := client.Whoami(context.Background())
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return nil // Already signed in
|
||||
}
|
||||
var aErr api.AuthorizationError
|
||||
if errors.As(err, &aErr) && aErr.SigninURL != "" {
|
||||
return m.startSignIn(modelName, aErr.SigninURL, fromModal)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// startSignIn initiates the sign-in flow for a cloud model.
|
||||
// fromModal indicates if this was triggered from the model picker modal.
|
||||
func (m *model) startSignIn(modelName, signInURL string, fromModal bool) tea.Cmd {
|
||||
m.showingModal = false
|
||||
m.showingSignIn = true
|
||||
m.signInURL = signInURL
|
||||
m.signInModel = modelName
|
||||
m.signInSpinner = 0
|
||||
m.signInFromModal = fromModal
|
||||
|
||||
// Open browser (best effort)
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
_ = exec.Command("open", signInURL).Start()
|
||||
case "linux":
|
||||
_ = exec.Command("xdg-open", signInURL).Start()
|
||||
case "windows":
|
||||
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", signInURL).Start()
|
||||
}
|
||||
|
||||
// Start the spinner tick
|
||||
return tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
||||
return signInTickMsg{}
|
||||
})
|
||||
}
|
||||
|
||||
// checkSignIn checks if the user has completed sign-in.
|
||||
func checkSignIn() tea.Msg {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return signInCheckMsg{signedIn: false}
|
||||
}
|
||||
user, err := client.Whoami(context.Background())
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return signInCheckMsg{signedIn: true, userName: user.Name}
|
||||
}
|
||||
return signInCheckMsg{signedIn: false}
|
||||
}
|
||||
|
||||
// loadAvailableModels fetches and caches the list of available models.
|
||||
func (m *model) loadAvailableModels() {
|
||||
m.availableModels = make(map[string]bool)
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
models, err := client.List(context.Background())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, mdl := range models.Models {
|
||||
m.availableModels[mdl.Name] = true
|
||||
}
|
||||
}
|
||||
|
||||
func (m *model) buildItems() {
|
||||
others := getOtherIntegrations()
|
||||
m.items = make([]menuItem, 0, len(mainMenuItems)+1+len(others))
|
||||
m.items = append(m.items, mainMenuItems...)
|
||||
|
||||
if m.showOthers {
|
||||
// Change "Others..." to "Hide others..."
|
||||
hideItem := menuItem{
|
||||
title: "Hide others...",
|
||||
description: "Hide additional integrations",
|
||||
isOthers: true,
|
||||
}
|
||||
m.items = append(m.items, hideItem)
|
||||
m.items = append(m.items, others...)
|
||||
} else {
|
||||
m.items = append(m.items, othersMenuItem)
|
||||
}
|
||||
}
|
||||
|
||||
// isOthersIntegration returns true if the integration is in the "Others" menu
|
||||
func isOthersIntegration(name string) bool {
|
||||
switch name {
|
||||
case "codex", "droid", "opencode":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func initialModel() model {
|
||||
m := model{
|
||||
cursor: 0,
|
||||
}
|
||||
m.loadAvailableModels()
|
||||
|
||||
// Check last selection to determine if we need to expand "Others"
|
||||
lastSelection := config.LastSelection()
|
||||
if isOthersIntegration(lastSelection) {
|
||||
m.showOthers = true
|
||||
}
|
||||
|
||||
m.buildItems()
|
||||
|
||||
// Position cursor on last selection
|
||||
if lastSelection != "" {
|
||||
for i, item := range m.items {
|
||||
if lastSelection == "run" && item.isRunModel {
|
||||
m.cursor = i
|
||||
break
|
||||
} else if item.integration == lastSelection {
|
||||
m.cursor = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func (m model) Init() tea.Cmd {
|
||||
return tea.Tick(blinkInterval, func(t time.Time) tea.Msg {
|
||||
return blinkMsg{}
|
||||
})
|
||||
}
|
||||
|
||||
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// Handle sign-in dialog
|
||||
if m.showingSignIn {
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
switch msg.Type {
|
||||
case tea.KeyCtrlC, tea.KeyEsc:
|
||||
// Cancel sign-in and go back
|
||||
m.showingSignIn = false
|
||||
if m.signInFromModal {
|
||||
m.showingModal = true
|
||||
}
|
||||
// If from main menu, just return to main menu (default state)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
case signInTickMsg:
|
||||
m.signInSpinner++
|
||||
// Check sign-in status every 5th tick (~1 second)
|
||||
if m.signInSpinner%5 == 0 {
|
||||
return m, tea.Batch(
|
||||
tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
||||
return signInTickMsg{}
|
||||
}),
|
||||
checkSignIn,
|
||||
)
|
||||
}
|
||||
return m, tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
||||
return signInTickMsg{}
|
||||
})
|
||||
|
||||
case signInCheckMsg:
|
||||
if msg.signedIn {
|
||||
// Sign-in complete - proceed with selection
|
||||
if m.signInFromModal {
|
||||
// Came from modal - set changeModel
|
||||
m.modalSelector.selected = m.signInModel
|
||||
m.changeModel = true
|
||||
} else {
|
||||
// Came from main menu - just select
|
||||
m.selected = true
|
||||
}
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Handle modal input if modal is showing
|
||||
if m.showingModal {
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
switch msg.Type {
|
||||
case tea.KeyCtrlC, tea.KeyEsc:
|
||||
// Close modal without selection
|
||||
m.showingModal = false
|
||||
return m, nil
|
||||
|
||||
case tea.KeyEnter:
|
||||
filtered := m.modalSelector.filteredItems()
|
||||
if len(filtered) > 0 && m.modalSelector.cursor < len(filtered) {
|
||||
m.modalSelector.selected = filtered[m.modalSelector.cursor].Name
|
||||
}
|
||||
if m.modalSelector.selected != "" {
|
||||
if cmd := m.checkCloudSignIn(m.modalSelector.selected, true); cmd != nil {
|
||||
return m, cmd
|
||||
}
|
||||
// Selection made - exit with changeModel
|
||||
m.changeModel = true
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case tea.KeyUp:
|
||||
if m.modalSelector.cursor > 0 {
|
||||
m.modalSelector.cursor--
|
||||
if m.modalSelector.cursor < m.modalSelector.scrollOffset {
|
||||
m.modalSelector.scrollOffset = m.modalSelector.cursor
|
||||
}
|
||||
}
|
||||
|
||||
case tea.KeyDown:
|
||||
filtered := m.modalSelector.filteredItems()
|
||||
if m.modalSelector.cursor < len(filtered)-1 {
|
||||
m.modalSelector.cursor++
|
||||
if m.modalSelector.cursor >= m.modalSelector.scrollOffset+maxSelectorItems {
|
||||
m.modalSelector.scrollOffset = m.modalSelector.cursor - maxSelectorItems + 1
|
||||
}
|
||||
}
|
||||
|
||||
case tea.KeyPgUp:
|
||||
filtered := m.modalSelector.filteredItems()
|
||||
m.modalSelector.cursor -= maxSelectorItems
|
||||
if m.modalSelector.cursor < 0 {
|
||||
m.modalSelector.cursor = 0
|
||||
}
|
||||
m.modalSelector.scrollOffset -= maxSelectorItems
|
||||
if m.modalSelector.scrollOffset < 0 {
|
||||
m.modalSelector.scrollOffset = 0
|
||||
}
|
||||
_ = filtered // suppress unused warning
|
||||
|
||||
case tea.KeyPgDown:
|
||||
filtered := m.modalSelector.filteredItems()
|
||||
m.modalSelector.cursor += maxSelectorItems
|
||||
if m.modalSelector.cursor >= len(filtered) {
|
||||
m.modalSelector.cursor = len(filtered) - 1
|
||||
}
|
||||
if m.modalSelector.cursor >= m.modalSelector.scrollOffset+maxSelectorItems {
|
||||
m.modalSelector.scrollOffset = m.modalSelector.cursor - maxSelectorItems + 1
|
||||
}
|
||||
|
||||
case tea.KeyBackspace:
|
||||
if len(m.modalSelector.filter) > 0 {
|
||||
m.modalSelector.filter = m.modalSelector.filter[:len(m.modalSelector.filter)-1]
|
||||
m.modalSelector.cursor = 0
|
||||
m.modalSelector.scrollOffset = 0
|
||||
}
|
||||
|
||||
case tea.KeyRunes:
|
||||
m.modalSelector.filter += string(msg.Runes)
|
||||
m.modalSelector.cursor = 0
|
||||
m.modalSelector.scrollOffset = 0
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case blinkMsg:
|
||||
m.blinking = true
|
||||
return m, tea.Tick(blinkDuration, func(t time.Time) tea.Msg {
|
||||
return unblinkMsg{}
|
||||
})
|
||||
|
||||
case unblinkMsg:
|
||||
m.blinking = false
|
||||
return m, tea.Tick(blinkInterval, func(t time.Time) tea.Msg {
|
||||
return blinkMsg{}
|
||||
})
|
||||
|
||||
case tea.KeyMsg:
|
||||
switch msg.String() {
|
||||
case "ctrl+c", "q", "esc":
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
|
||||
case "up", "k":
|
||||
if m.cursor > 0 {
|
||||
m.cursor--
|
||||
}
|
||||
|
||||
case "down", "j":
|
||||
if m.cursor < len(m.items)-1 {
|
||||
m.cursor++
|
||||
}
|
||||
|
||||
case "enter", " ":
|
||||
item := m.items[m.cursor]
|
||||
|
||||
// Handle "Others..." toggle
|
||||
if item.isOthers {
|
||||
m.showOthers = !m.showOthers
|
||||
m.buildItems()
|
||||
// Keep cursor on the Others/Hide item
|
||||
if m.cursor >= len(m.items) {
|
||||
m.cursor = len(m.items) - 1
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Don't allow selecting uninstalled integrations
|
||||
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Check if a cloud model is configured and needs sign-in
|
||||
var configuredModel string
|
||||
if item.isRunModel {
|
||||
configuredModel = config.LastModel()
|
||||
} else if item.integration != "" {
|
||||
configuredModel = config.IntegrationModel(item.integration)
|
||||
}
|
||||
if cmd := m.checkCloudSignIn(configuredModel, false); cmd != nil {
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
m.selected = true
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
|
||||
case "right", "l":
|
||||
// Allow model change for integrations and run model
|
||||
item := m.items[m.cursor]
|
||||
if item.integration != "" || item.isRunModel {
|
||||
// Don't allow for uninstalled integrations
|
||||
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) {
|
||||
return m, nil
|
||||
}
|
||||
m.openModelModal()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m model) View() string {
|
||||
if m.quitting {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Render sign-in dialog if showing
|
||||
if m.showingSignIn {
|
||||
return m.renderSignInDialog()
|
||||
}
|
||||
|
||||
// Render modal overlay if showing - replaces main view
|
||||
if m.showingModal {
|
||||
return m.renderModal()
|
||||
}
|
||||
|
||||
logo := logoNormal
|
||||
if m.blinking {
|
||||
logo = logoBlink
|
||||
}
|
||||
if os.Getenv("TERM_PROGRAM") == "Apple_Terminal" {
|
||||
logo = logoBlank
|
||||
}
|
||||
|
||||
versionText := "\n\n Ollama " + versionStyle.Render("v"+version.Version)
|
||||
|
||||
logoRendered := logoStyle.Render(logo)
|
||||
logoBlock := lipgloss.NewStyle().Padding(0, 1).MarginLeft(2).Background(lipgloss.Color("0")).Render(logoRendered)
|
||||
versionBlock := titleStyle.Render(versionText)
|
||||
header := lipgloss.JoinHorizontal(lipgloss.Top, logoBlock, versionBlock)
|
||||
|
||||
s := header + "\n\n"
|
||||
|
||||
for i, item := range m.items {
|
||||
cursor := " "
|
||||
style := itemStyle
|
||||
isInstalled := true
|
||||
|
||||
if item.integration != "" {
|
||||
isInstalled = config.IsIntegrationInstalled(item.integration)
|
||||
}
|
||||
|
||||
if m.cursor == i {
|
||||
cursor = "▸ "
|
||||
if isInstalled {
|
||||
style = selectedStyle
|
||||
} else {
|
||||
style = greyedSelectedStyle
|
||||
}
|
||||
} else if !isInstalled && item.integration != "" {
|
||||
style = greyedStyle
|
||||
}
|
||||
|
||||
title := item.title
|
||||
if item.integration != "" {
|
||||
if !isInstalled {
|
||||
title += " " + notInstalledStyle.Render("(not installed)")
|
||||
} else if mdl := config.IntegrationModel(item.integration); mdl != "" && m.modelExists(mdl) {
|
||||
title += " " + modelStyle.Render("("+mdl+")")
|
||||
}
|
||||
} else if item.isRunModel {
|
||||
if mdl := config.LastModel(); mdl != "" && m.modelExists(mdl) {
|
||||
title += " " + modelStyle.Render("("+mdl+")")
|
||||
}
|
||||
}
|
||||
|
||||
s += style.Render(cursor+title) + "\n"
|
||||
s += descStyle.Render(item.description) + "\n\n"
|
||||
}
|
||||
|
||||
s += "\n" + lipgloss.NewStyle().Foreground(lipgloss.Color("241")).Render("↑/↓ navigate • enter select • → change model • esc quit")
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// renderModal renders the model picker modal.
|
||||
func (m model) renderModal() string {
|
||||
modalStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(lipgloss.Color("147")).
|
||||
Padding(1, 2).
|
||||
MarginLeft(2)
|
||||
|
||||
var content strings.Builder
|
||||
|
||||
// Title with filter
|
||||
content.WriteString(selectorTitleStyle.Render(m.modalSelector.title))
|
||||
content.WriteString(" ")
|
||||
if m.modalSelector.filter == "" {
|
||||
content.WriteString(selectorFilterStyle.Render("Type to filter..."))
|
||||
} else {
|
||||
content.WriteString(selectorInputStyle.Render(m.modalSelector.filter))
|
||||
}
|
||||
content.WriteString("\n\n")
|
||||
|
||||
filtered := m.modalSelector.filteredItems()
|
||||
|
||||
if len(filtered) == 0 {
|
||||
content.WriteString(selectorItemStyle.Render(selectorDescStyle.Render("(no matches)")))
|
||||
content.WriteString("\n")
|
||||
} else {
|
||||
displayCount := min(len(filtered), maxSelectorItems)
|
||||
|
||||
for i := range displayCount {
|
||||
idx := m.modalSelector.scrollOffset + i
|
||||
if idx >= len(filtered) {
|
||||
break
|
||||
}
|
||||
item := filtered[idx]
|
||||
|
||||
if idx == m.modalSelector.cursor {
|
||||
content.WriteString(selectorSelectedItemStyle.Render("▸ " + item.Name))
|
||||
} else {
|
||||
content.WriteString(selectorItemStyle.Render(item.Name))
|
||||
}
|
||||
|
||||
if item.Description != "" {
|
||||
content.WriteString(" ")
|
||||
content.WriteString(selectorDescStyle.Render("- " + item.Description))
|
||||
}
|
||||
content.WriteString("\n")
|
||||
}
|
||||
|
||||
if remaining := len(filtered) - m.modalSelector.scrollOffset - displayCount; remaining > 0 {
|
||||
content.WriteString(selectorMoreStyle.Render(fmt.Sprintf("... and %d more", remaining)))
|
||||
content.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
content.WriteString("\n")
|
||||
content.WriteString(selectorHelpStyle.Render("↑/↓ navigate • enter select • esc cancel"))
|
||||
|
||||
return modalStyle.Render(content.String())
|
||||
}
|
||||
|
||||
// renderSignInDialog renders the sign-in dialog.
|
||||
func (m model) renderSignInDialog() string {
|
||||
dialogStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(lipgloss.Color("147")).
|
||||
Padding(1, 2).
|
||||
MarginLeft(2)
|
||||
|
||||
spinnerFrames := []string{"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"}
|
||||
spinner := spinnerFrames[m.signInSpinner%len(spinnerFrames)]
|
||||
|
||||
var content strings.Builder
|
||||
|
||||
content.WriteString(selectorTitleStyle.Render("Sign in required"))
|
||||
content.WriteString("\n\n")
|
||||
|
||||
content.WriteString(fmt.Sprintf("To use %s, please sign in.\n\n", selectedStyle.Render(m.signInModel)))
|
||||
|
||||
content.WriteString("Navigate to:\n")
|
||||
content.WriteString(lipgloss.NewStyle().Foreground(lipgloss.Color("117")).Render(" " + m.signInURL))
|
||||
content.WriteString("\n\n")
|
||||
|
||||
content.WriteString(lipgloss.NewStyle().Foreground(lipgloss.Color("241")).Render(
|
||||
fmt.Sprintf("%s Waiting for sign in to complete...", spinner)))
|
||||
content.WriteString("\n\n")
|
||||
|
||||
content.WriteString(selectorHelpStyle.Render("esc cancel"))
|
||||
|
||||
return dialogStyle.Render(content.String())
|
||||
}
|
||||
|
||||
// Selection represents what the user selected
|
||||
type Selection int
|
||||
|
||||
const (
|
||||
SelectionNone Selection = iota
|
||||
SelectionRunModel
|
||||
SelectionChangeRunModel
|
||||
SelectionIntegration // Generic integration selection
|
||||
SelectionChangeIntegration // Generic change model for integration
|
||||
)
|
||||
|
||||
// Result contains the selection and any associated data
|
||||
type Result struct {
|
||||
Selection Selection
|
||||
Integration string // integration name if applicable
|
||||
Model string // model name if selected from modal
|
||||
}
|
||||
|
||||
// Run starts the TUI and returns the user's selection
|
||||
func Run() (Result, error) {
|
||||
m := initialModel()
|
||||
p := tea.NewProgram(m)
|
||||
|
||||
finalModel, err := p.Run()
|
||||
if err != nil {
|
||||
return Result{Selection: SelectionNone}, fmt.Errorf("error running TUI: %w", err)
|
||||
}
|
||||
|
||||
fm := finalModel.(model)
|
||||
if fm.err != nil {
|
||||
return Result{Selection: SelectionNone}, fm.err
|
||||
}
|
||||
|
||||
// User quit without selecting
|
||||
if !fm.selected && !fm.changeModel {
|
||||
return Result{Selection: SelectionNone}, nil
|
||||
}
|
||||
|
||||
item := fm.items[fm.cursor]
|
||||
|
||||
// Handle model change request
|
||||
if fm.changeModel {
|
||||
if item.isRunModel {
|
||||
return Result{
|
||||
Selection: SelectionChangeRunModel,
|
||||
Model: fm.modalSelector.selected,
|
||||
}, nil
|
||||
}
|
||||
return Result{
|
||||
Selection: SelectionChangeIntegration,
|
||||
Integration: item.integration,
|
||||
Model: fm.modalSelector.selected,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Handle selection
|
||||
if item.isRunModel {
|
||||
return Result{Selection: SelectionRunModel}, nil
|
||||
}
|
||||
|
||||
return Result{
|
||||
Selection: SelectionIntegration,
|
||||
Integration: item.integration,
|
||||
}, nil
|
||||
}
|
||||
23
go.mod
23
go.mod
@@ -13,7 +13,7 @@ require (
|
||||
github.com/mattn/go-sqlite3 v1.14.24
|
||||
github.com/olekukonko/tablewriter v0.0.5
|
||||
github.com/spf13/cobra v1.7.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/x448/float16 v0.8.4
|
||||
golang.org/x/sync v0.17.0
|
||||
golang.org/x/sys v0.37.0
|
||||
@@ -21,16 +21,16 @@ require (
|
||||
|
||||
require (
|
||||
github.com/agnivade/levenshtein v1.1.1
|
||||
github.com/charmbracelet/bubbletea v1.3.10
|
||||
github.com/charmbracelet/lipgloss v1.1.0
|
||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
|
||||
github.com/dlclark/regexp2 v1.11.4
|
||||
github.com/emirpasic/gods/v2 v2.0.0-alpha
|
||||
github.com/mattn/go-runewidth v0.0.14
|
||||
github.com/mattn/go-runewidth v0.0.16
|
||||
github.com/nlpodyssey/gopickle v0.3.0
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
|
||||
github.com/tkrajina/typescriptify-golang-structs v0.2.0
|
||||
github.com/tree-sitter/go-tree-sitter v0.25.0
|
||||
github.com/tree-sitter/tree-sitter-cpp v0.23.4
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8
|
||||
golang.org/x/image v0.22.0
|
||||
golang.org/x/mod v0.30.0
|
||||
@@ -40,23 +40,34 @@ require (
|
||||
|
||||
require (
|
||||
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 // indirect
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||
github.com/bahlo/generic-list-go v0.2.0 // indirect
|
||||
github.com/buger/jsonparser v1.1.1 // indirect
|
||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
|
||||
github.com/charmbracelet/x/ansi v0.10.1 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect
|
||||
github.com/charmbracelet/x/term v0.2.1 // indirect
|
||||
github.com/chewxy/hm v1.0.0 // indirect
|
||||
github.com/chewxy/math32 v1.11.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||
github.com/cloudwego/iasm v0.2.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/google/flatbuffers v24.3.25+incompatible // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
|
||||
github.com/mailru/easyjson v0.7.7 // indirect
|
||||
github.com/mattn/go-pointer v0.0.1 // indirect
|
||||
github.com/mattn/go-localereader v0.0.1 // indirect
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rivo/uniseg v0.2.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/tkrajina/go-reflector v0.5.5 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
github.com/xtgo/set v1.0.0 // indirect
|
||||
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
|
||||
|
||||
67
go.sum
67
go.sum
@@ -14,6 +14,8 @@ github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 h1:q4dksr6IC
|
||||
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs=
|
||||
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q=
|
||||
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
|
||||
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
|
||||
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
||||
@@ -24,6 +26,18 @@ github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1
|
||||
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
||||
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
|
||||
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
|
||||
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
|
||||
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs=
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk=
|
||||
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
|
||||
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
|
||||
github.com/charmbracelet/x/ansi v0.10.1 h1:rL3Koar5XvX0pHGfovN03f5cxLbCF2YvLeyz7D2jVDQ=
|
||||
github.com/charmbracelet/x/ansi v0.10.1/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs=
|
||||
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
|
||||
github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
|
||||
github.com/chewxy/hm v1.0.0 h1:zy/TSv3LV2nD3dwUEQL2VhXeoXbb9QkpmdRAVUFiA6k=
|
||||
github.com/chewxy/hm v1.0.0/go.mod h1:qg9YI4q6Fkj/whwHR1D+bOGeF7SniIP40VweVepLjg0=
|
||||
github.com/chewxy/math32 v1.0.0/go.mod h1:Miac6hA1ohdDUTagnvJy/q+aNnEk16qWUdb8ZVhvCN0=
|
||||
@@ -59,6 +73,8 @@ github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.m
|
||||
github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk=
|
||||
github.com/envoyproxy/go-control-plane v0.9.9-0.20210512163311-63b5d3c536b0/go.mod h1:hliV/p42l8fGbc6Y9bQ70uLwIvmJyVE5k4iMKlh8wCQ=
|
||||
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
|
||||
github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
|
||||
github.com/fogleman/gg v1.3.0/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||
@@ -148,15 +164,17 @@ github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728 h1:QwWKgMY28TAXaDl+
|
||||
github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728/go.mod h1:1fEHWurg7pvf5SG6XNE5Q8UZmOwex51Mkx3SLhrW5B4=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
|
||||
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-pointer v0.0.1 h1:n+XhsuGeVO6MEAp7xyEukFINEa+Quek5psIR/ylA6o0=
|
||||
github.com/mattn/go-pointer v0.0.1/go.mod h1:2zXcozF6qYGgmsG+SeTZz3oAbFLdD3OWqnUbNvJZAlc=
|
||||
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
|
||||
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
|
||||
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
|
||||
github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU=
|
||||
github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
|
||||
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
|
||||
github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
@@ -164,6 +182,12 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
|
||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
|
||||
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
|
||||
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||
github.com/nlpodyssey/gopickle v0.3.0 h1:BLUE5gxFLyyNOPzlXxt6GoHEMMxD0qhsE4p0CIQyoLw=
|
||||
github.com/nlpodyssey/gopickle v0.3.0/go.mod h1:f070HJ/yR+eLi5WmM1OXJEGaTpuJEUiib19olXgYha0=
|
||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
||||
@@ -184,8 +208,9 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
|
||||
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
|
||||
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
|
||||
@@ -208,39 +233,12 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tkrajina/go-reflector v0.5.5 h1:gwoQFNye30Kk7NrExj8zm3zFtrGPqOkzFMLuQZg1DtQ=
|
||||
github.com/tkrajina/go-reflector v0.5.5/go.mod h1:ECbqLgccecY5kPmPmXg1MrHW585yMcDkVl6IvJe64T4=
|
||||
github.com/tkrajina/typescriptify-golang-structs v0.2.0 h1:ZedWk82egydDspGTryAatbX0/1NZDQbdiZLoCbOk4f8=
|
||||
github.com/tkrajina/typescriptify-golang-structs v0.2.0/go.mod h1:sjU00nti/PMEOZb07KljFlR+lJ+RotsC0GBQMv9EKls=
|
||||
github.com/tree-sitter/go-tree-sitter v0.25.0 h1:sx6kcg8raRFCvc9BnXglke6axya12krCJF5xJ2sftRU=
|
||||
github.com/tree-sitter/go-tree-sitter v0.25.0/go.mod h1:r77ig7BikoZhHrrsjAnv8RqGti5rtSyvDHPzgTPsUuU=
|
||||
github.com/tree-sitter/tree-sitter-c v0.23.4 h1:nBPH3FV07DzAD7p0GfNvXM+Y7pNIoPenQWBpvM++t4c=
|
||||
github.com/tree-sitter/tree-sitter-c v0.23.4/go.mod h1:MkI5dOiIpeN94LNjeCp8ljXN/953JCwAby4bClMr6bw=
|
||||
github.com/tree-sitter/tree-sitter-cpp v0.23.4 h1:LaWZsiqQKvR65yHgKmnaqA+uz6tlDJTJFCyFIeZU/8w=
|
||||
github.com/tree-sitter/tree-sitter-cpp v0.23.4/go.mod h1:doqNW64BriC7WBCQ1klf0KmJpdEvfxyXtoEybnBo6v8=
|
||||
github.com/tree-sitter/tree-sitter-embedded-template v0.23.2 h1:nFkkH6Sbe56EXLmZBqHHcamTpmz3TId97I16EnGy4rg=
|
||||
github.com/tree-sitter/tree-sitter-embedded-template v0.23.2/go.mod h1:HNPOhN0qF3hWluYLdxWs5WbzP/iE4aaRVPMsdxuzIaQ=
|
||||
github.com/tree-sitter/tree-sitter-go v0.23.4 h1:yt5KMGnTHS+86pJmLIAZMWxukr8W7Ae1STPvQUuNROA=
|
||||
github.com/tree-sitter/tree-sitter-go v0.23.4/go.mod h1:Jrx8QqYN0v7npv1fJRH1AznddllYiCMUChtVjxPK040=
|
||||
github.com/tree-sitter/tree-sitter-html v0.23.2 h1:1UYDV+Yd05GGRhVnTcbP58GkKLSHHZwVaN+lBZV11Lc=
|
||||
github.com/tree-sitter/tree-sitter-html v0.23.2/go.mod h1:gpUv/dG3Xl/eebqgeYeFMt+JLOY9cgFinb/Nw08a9og=
|
||||
github.com/tree-sitter/tree-sitter-java v0.23.5 h1:J9YeMGMwXYlKSP3K4Us8CitC6hjtMjqpeOf2GGo6tig=
|
||||
github.com/tree-sitter/tree-sitter-java v0.23.5/go.mod h1:NRKlI8+EznxA7t1Yt3xtraPk1Wzqh3GAIC46wxvc320=
|
||||
github.com/tree-sitter/tree-sitter-javascript v0.23.1 h1:1fWupaRC0ArlHJ/QJzsfQ3Ibyopw7ZfQK4xXc40Zveo=
|
||||
github.com/tree-sitter/tree-sitter-javascript v0.23.1/go.mod h1:lmGD1EJdCA+v0S1u2fFgepMg/opzSg/4pgFym2FPGAs=
|
||||
github.com/tree-sitter/tree-sitter-json v0.24.8 h1:tV5rMkihgtiOe14a9LHfDY5kzTl5GNUYe6carZBn0fQ=
|
||||
github.com/tree-sitter/tree-sitter-json v0.24.8/go.mod h1:F351KK0KGvCaYbZ5zxwx/gWWvZhIDl0eMtn+1r+gQbo=
|
||||
github.com/tree-sitter/tree-sitter-php v0.23.11 h1:iHewsLNDmznh8kgGyfWfujsZxIz1YGbSd2ZTEM0ZiP8=
|
||||
github.com/tree-sitter/tree-sitter-php v0.23.11/go.mod h1:T/kbfi+UcCywQfUNAJnGTN/fMSUjnwPXA8k4yoIks74=
|
||||
github.com/tree-sitter/tree-sitter-python v0.23.6 h1:qHnWFR5WhtMQpxBZRwiaU5Hk/29vGju6CVtmvu5Haas=
|
||||
github.com/tree-sitter/tree-sitter-python v0.23.6/go.mod h1:cpdthSy/Yoa28aJFBscFHlGiU+cnSiSh1kuDVtI8YeM=
|
||||
github.com/tree-sitter/tree-sitter-ruby v0.23.1 h1:T/NKHUA+iVbHM440hFx+lzVOzS4dV6z8Qw8ai+72bYo=
|
||||
github.com/tree-sitter/tree-sitter-ruby v0.23.1/go.mod h1:kUS4kCCQloFcdX6sdpr8p6r2rogbM6ZjTox5ZOQy8cA=
|
||||
github.com/tree-sitter/tree-sitter-rust v0.23.2 h1:6AtoooCW5GqNrRpfnvl0iUhxTAZEovEmLKDbyHlfw90=
|
||||
github.com/tree-sitter/tree-sitter-rust v0.23.2/go.mod h1:hfeGWic9BAfgTrc7Xf6FaOAguCFJRo3RBbs7QJ6D7MI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||
@@ -249,6 +247,8 @@ github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
|
||||
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
||||
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||
github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY=
|
||||
github.com/xtgo/set v1.0.0/go.mod h1:d3NHzGzSa0NmB2NhFyECA+QdRp29oEn2xbT+TpeFoM8=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
@@ -335,6 +335,7 @@ golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
|
||||
@@ -144,47 +144,3 @@ func TestUnicodeModelDir(t *testing.T) {
|
||||
}
|
||||
ChatTestHelper(ctx, t, req, blueSkyExpected)
|
||||
}
|
||||
|
||||
// TestNumPredict verifies that when num_predict is set, the model generates
|
||||
// exactly that many tokens. It uses logprobs to count the actual tokens output.
|
||||
func TestNumPredict(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
if err := PullIfMissing(ctx, client, "qwen3:0.6b"); err != nil {
|
||||
t.Fatalf("failed to pull model: %v", err)
|
||||
}
|
||||
|
||||
req := api.GenerateRequest{
|
||||
Model: "qwen3:0.6b",
|
||||
Prompt: "Write a long story.",
|
||||
Stream: &stream,
|
||||
Logprobs: true,
|
||||
Options: map[string]any{
|
||||
"num_predict": 10,
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
|
||||
logprobCount := 0
|
||||
var finalResponse api.GenerateResponse
|
||||
err := client.Generate(ctx, &req, func(resp api.GenerateResponse) error {
|
||||
logprobCount += len(resp.Logprobs)
|
||||
if resp.Done {
|
||||
finalResponse = resp
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("generate failed: %v", err)
|
||||
}
|
||||
|
||||
if logprobCount != 10 {
|
||||
t.Errorf("expected 10 tokens (logprobs), got %d (EvalCount=%d, DoneReason=%s)",
|
||||
logprobCount, finalResponse.EvalCount, finalResponse.DoneReason)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,7 +34,6 @@ import (
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type filteredEnv []string
|
||||
@@ -117,7 +116,7 @@ type llamaServer struct {
|
||||
type ollamaServer struct {
|
||||
llmServer
|
||||
|
||||
tokenizer tokenizer.Tokenizer // tokenizer handles text encoding/decoding
|
||||
textProcessor model.TextProcessor // textProcessor handles text encoding/decoding
|
||||
}
|
||||
|
||||
// LoadModel will load a model from disk. The model must be in the GGML format.
|
||||
@@ -143,11 +142,11 @@ func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) {
|
||||
// NewLlamaServer will run a server for the given GPUs
|
||||
func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
|
||||
var llamaModel *llama.Model
|
||||
var tok tokenizer.Tokenizer
|
||||
var textProcessor model.TextProcessor
|
||||
var err error
|
||||
if envconfig.NewEngine() || f.KV().OllamaEngineRequired() {
|
||||
if len(projectors) == 0 {
|
||||
tok, err = model.NewTextProcessor(modelPath)
|
||||
textProcessor, err = model.NewTextProcessor(modelPath)
|
||||
} else {
|
||||
err = errors.New("split vision models aren't supported")
|
||||
}
|
||||
@@ -156,7 +155,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err)
|
||||
}
|
||||
}
|
||||
if tok == nil {
|
||||
if textProcessor == nil {
|
||||
llamaModel, err = llama.LoadModelFromFile(modelPath, llama.ModelParams{VocabOnly: true})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -212,7 +211,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
|
||||
kvct := strings.ToLower(envconfig.KvCacheType())
|
||||
|
||||
if tok == nil {
|
||||
if textProcessor == nil {
|
||||
flashAttention := ml.FlashAttentionAuto
|
||||
if faUserSet {
|
||||
if fa {
|
||||
@@ -262,7 +261,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
gpuLibs := ml.LibraryPaths(gpus)
|
||||
status := NewStatusWriter(os.Stderr)
|
||||
cmd, port, err := StartRunner(
|
||||
tok != nil,
|
||||
textProcessor != nil,
|
||||
modelPath,
|
||||
gpuLibs,
|
||||
status,
|
||||
@@ -311,8 +310,8 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
}
|
||||
}()
|
||||
|
||||
if tok != nil {
|
||||
return &ollamaServer{llmServer: s, tokenizer: tok}, nil
|
||||
if textProcessor != nil {
|
||||
return &ollamaServer{llmServer: s, textProcessor: textProcessor}, nil
|
||||
} else {
|
||||
return &llamaServer{llmServer: s, ggml: f}, nil
|
||||
}
|
||||
@@ -1775,7 +1774,7 @@ func (s *llamaServer) Tokenize(ctx context.Context, content string) ([]int, erro
|
||||
}
|
||||
|
||||
func (s *ollamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
tokens, err := s.tokenizer.Encode(content, false)
|
||||
tokens, err := s.textProcessor.Encode(content, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1810,7 +1809,7 @@ func (s *ollamaServer) Detokenize(ctx context.Context, tokens []int) (string, er
|
||||
toks[i] = int32(t)
|
||||
}
|
||||
|
||||
content, err := s.tokenizer.Decode(toks)
|
||||
content, err := s.textProcessor.Decode(toks)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -131,15 +131,12 @@ func AnthropicMessagesMiddleware() gin.HandlerFunc {
|
||||
|
||||
messageID := anthropic.GenerateMessageID()
|
||||
|
||||
// Estimate input tokens for streaming (actual count not available until generation completes)
|
||||
estimatedTokens := anthropic.EstimateInputTokens(req)
|
||||
|
||||
w := &AnthropicWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: messageID,
|
||||
model: req.Model,
|
||||
converter: anthropic.NewStreamConverter(messageID, req.Model, estimatedTokens),
|
||||
converter: anthropic.NewStreamConverter(messageID, req.Model),
|
||||
}
|
||||
|
||||
if req.Stream {
|
||||
|
||||
272
model/bytepairencoding.go
Normal file
272
model/bytepairencoding.go
Normal file
@@ -0,0 +1,272 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"iter"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/dlclark/regexp2"
|
||||
heap "github.com/emirpasic/gods/v2/trees/binaryheap"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type BytePairEncoding struct {
|
||||
vocab *Vocabulary
|
||||
regexps []*regexp2.Regexp
|
||||
}
|
||||
|
||||
var _ TextProcessor = (*BytePairEncoding)(nil)
|
||||
|
||||
func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding {
|
||||
if len(pretokenizers) == 0 {
|
||||
// set default byte-level pretokenizer if none provided, e.g.
|
||||
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44
|
||||
pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
|
||||
}
|
||||
|
||||
return BytePairEncoding{
|
||||
vocab: vocab,
|
||||
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
|
||||
for _, p := range pretokenizers {
|
||||
if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Vocabulary() *Vocabulary {
|
||||
return bpe.vocab
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Is(id int32, special Special) bool {
|
||||
return bpe.vocab.Is(id, special)
|
||||
}
|
||||
|
||||
func (bpe *BytePairEncoding) split(s string) iter.Seq[string] {
|
||||
parts := []string{s}
|
||||
for _, re := range bpe.regexps {
|
||||
parts = slices.Collect(func(yield func(string) bool) {
|
||||
for _, part := range parts {
|
||||
r := []rune(part)
|
||||
var offset int
|
||||
for m, _ := re.FindRunesMatch(r); m != nil; m, _ = re.FindNextMatch(m) {
|
||||
if offset-m.Index != 0 {
|
||||
if !yield(string(r[:m.Index])) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !yield(m.String()) {
|
||||
return
|
||||
}
|
||||
|
||||
offset = m.Index + m.Length
|
||||
}
|
||||
|
||||
if offset < len(r) {
|
||||
if !yield(string(r[offset:])) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return slices.Values(parts)
|
||||
}
|
||||
|
||||
// fragment is a string fragment and their corresponding token IDs
|
||||
type fragment struct {
|
||||
value string
|
||||
ids []int32
|
||||
}
|
||||
|
||||
// pair is a pair of runes and its rank
|
||||
type pair struct {
|
||||
a, b int
|
||||
rank int
|
||||
value string
|
||||
}
|
||||
|
||||
type merge struct {
|
||||
p, n int
|
||||
runes []rune
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
fragments := []fragment{{value: s}}
|
||||
for _, special := range bpe.vocab.SpecialVocabulary() {
|
||||
// TODO: process special tokens concurrently
|
||||
id := bpe.vocab.Encode(special)
|
||||
for i := 0; i < len(fragments); i++ {
|
||||
frag := fragments[i]
|
||||
if len(frag.ids) > 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var middle []fragment
|
||||
switch i := strings.Index(frag.value, special); {
|
||||
case i < 0:
|
||||
middle = append(middle, frag)
|
||||
case i > 0:
|
||||
middle = append(middle, fragment{value: frag.value[:i]})
|
||||
fallthrough
|
||||
default:
|
||||
middle = append(middle, fragment{value: special, ids: []int32{id}})
|
||||
if rest := frag.value[i+len(special):]; rest != "" {
|
||||
middle = append(middle, fragment{value: rest})
|
||||
}
|
||||
}
|
||||
|
||||
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
|
||||
}
|
||||
}
|
||||
|
||||
var ids []int32
|
||||
for _, frag := range fragments {
|
||||
if len(frag.ids) > 0 {
|
||||
ids = append(ids, frag.ids...)
|
||||
continue
|
||||
}
|
||||
|
||||
for split := range bpe.split(frag.value) {
|
||||
// TODO: process splits concurrently
|
||||
var sb strings.Builder
|
||||
for _, b := range []byte(split) {
|
||||
r := rune(b)
|
||||
switch {
|
||||
case r == 0x00ad:
|
||||
r = 0x0143
|
||||
case r <= 0x0020:
|
||||
r = r + 0x0100
|
||||
case r >= 0x007f && r <= 0x00a0:
|
||||
r = r + 0x00a2
|
||||
}
|
||||
|
||||
sb.WriteRune(r)
|
||||
}
|
||||
|
||||
// short circuit if the fragment is in the vocabulary
|
||||
if id := bpe.vocab.Encode(sb.String()); id >= 0 {
|
||||
ids = append(ids, id)
|
||||
continue
|
||||
}
|
||||
|
||||
runes := []rune(sb.String())
|
||||
merges := make([]merge, len(runes))
|
||||
for r := range runes {
|
||||
merges[r] = merge{
|
||||
p: r - 1,
|
||||
n: r + 1,
|
||||
runes: []rune{runes[r]},
|
||||
}
|
||||
}
|
||||
|
||||
pairwise := func(a, b int) *pair {
|
||||
if a < 0 || b >= len(runes) {
|
||||
return nil
|
||||
}
|
||||
|
||||
left, right := string(merges[a].runes), string(merges[b].runes)
|
||||
rank := bpe.vocab.Merge(left, right)
|
||||
if rank < 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &pair{
|
||||
a: a,
|
||||
b: b,
|
||||
rank: rank,
|
||||
value: left + right,
|
||||
}
|
||||
}
|
||||
|
||||
pairs := heap.NewWith(func(i, j *pair) int {
|
||||
return cmp.Compare(i.rank, j.rank)
|
||||
})
|
||||
|
||||
for i := range len(runes) - 1 {
|
||||
if pair := pairwise(i, i+1); pair != nil {
|
||||
pairs.Push(pair)
|
||||
}
|
||||
}
|
||||
|
||||
for !pairs.Empty() {
|
||||
pair, _ := pairs.Pop()
|
||||
|
||||
left, right := merges[pair.a], merges[pair.b]
|
||||
if len(left.runes) == 0 || len(right.runes) == 0 ||
|
||||
string(left.runes)+string(right.runes) != pair.value {
|
||||
continue
|
||||
}
|
||||
|
||||
if id := bpe.vocab.Encode(pair.value); id < 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
merges[pair.a].runes = append(left.runes, right.runes...)
|
||||
merges[pair.b].runes = nil
|
||||
|
||||
merges[pair.a].n = right.n
|
||||
if right.n < len(merges) {
|
||||
merges[right.n].p = pair.a
|
||||
}
|
||||
|
||||
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
|
||||
pairs.Push(pair)
|
||||
}
|
||||
|
||||
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
|
||||
pairs.Push(pair)
|
||||
}
|
||||
}
|
||||
|
||||
for _, merge := range merges {
|
||||
if len(merge.runes) > 0 {
|
||||
// TODO: handle the edge case where the rune isn't in the vocabulary
|
||||
if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if addSpecial {
|
||||
ids = bpe.vocab.addSpecials(ids)
|
||||
}
|
||||
|
||||
logutil.Trace("encoded", "string", s, "ids", ids)
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
|
||||
var sb strings.Builder
|
||||
for _, id := range ids {
|
||||
for _, r := range bpe.vocab.Decode(id) {
|
||||
switch {
|
||||
case r == 0x0100:
|
||||
// this produces 0x00 aka NULL
|
||||
continue
|
||||
case r == 0x0143:
|
||||
r = 0x00ad
|
||||
case r > 0x0100 && r <= 0x0120:
|
||||
r = r - 0x0100
|
||||
case r > 0x0120 && r <= 0x0142:
|
||||
r = r - 0x00a2
|
||||
}
|
||||
|
||||
// NOTE: not using WriteRune here because it writes the UTF-8
|
||||
// encoding of the rune which is _not_ what we want
|
||||
if err := sb.WriteByte(byte(r)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logutil.Trace("decoded", "string", sb.String(), "from", ids)
|
||||
return sb.String(), nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package tokenizer
|
||||
package model
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
func llama(t testing.TB) BytePairEncoding {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.Open(filepath.FromSlash("testdata/llama3.2/encoder.json"))
|
||||
f, err := os.Open(filepath.Join("testdata", "llama3.2", "encoder.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -43,7 +43,7 @@ func llama(t testing.TB) BytePairEncoding {
|
||||
}
|
||||
}
|
||||
|
||||
f, err = os.Open(filepath.FromSlash("testdata/llama3.2/vocab.bpe"))
|
||||
f, err = os.Open(filepath.Join("testdata", "llama3.2", "vocab.bpe"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -23,7 +23,6 @@ import (
|
||||
_ "github.com/ollama/ollama/ml/backend"
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -134,7 +133,7 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func NewTextProcessor(s string) (tokenizer.Tokenizer, error) {
|
||||
func NewTextProcessor(s string) (TextProcessor, error) {
|
||||
r, err := os.Open(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -151,7 +150,7 @@ func NewTextProcessor(s string) (tokenizer.Tokenizer, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tp, ok := m.(tokenizer.Tokenizer)
|
||||
tp, ok := m.(TextProcessor)
|
||||
if !ok {
|
||||
return nil, ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
@@ -10,12 +10,11 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.TextProcessor
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
TypeEmbedding *nn.Embedding `gguf:"token_types"`
|
||||
@@ -130,7 +129,7 @@ func (o Options) headDim() int {
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
vocab := &tokenizer.Vocabulary{
|
||||
vocab := &model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -154,17 +153,17 @@ func New(c fs.Config) (model.Model, error) {
|
||||
},
|
||||
}
|
||||
|
||||
var t tokenizer.Tokenizer
|
||||
var processor model.TextProcessor
|
||||
switch c.String("tokenizer.ggml.model", "bert") {
|
||||
case "bert":
|
||||
t = tokenizer.NewWordPiece(vocab, true)
|
||||
processor = model.NewWordPiece(vocab, true)
|
||||
default:
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
return &Model{
|
||||
Tokenizer: t,
|
||||
Layers: make([]EncoderLayer, c.Uint("block_count")),
|
||||
TextProcessor: processor,
|
||||
Layers: make([]EncoderLayer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -223,7 +222,7 @@ func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.BytePairEncoding
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -278,8 +277,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -10,12 +10,11 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.TextProcessor
|
||||
|
||||
Sam *samModel `gguf:"s"`
|
||||
Vision *visionModel `gguf:"v"`
|
||||
@@ -135,8 +134,8 @@ func init() {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
TextProcessor: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -28,7 +27,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.SentencePiece
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -44,8 +43,8 @@ const (
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
Tokenizer: tokenizer.NewSentencePiece(
|
||||
&tokenizer.Vocabulary{
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
|
||||
@@ -7,12 +7,11 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type embedModel struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.SentencePiece
|
||||
|
||||
*TextModel
|
||||
poolingType pooling.Type
|
||||
@@ -32,8 +31,8 @@ func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, erro
|
||||
|
||||
func newEmbedModel(c fs.Config) (model.Model, error) {
|
||||
m := &embedModel{
|
||||
Tokenizer: tokenizer.NewSentencePiece(
|
||||
&tokenizer.Vocabulary{
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
|
||||
@@ -12,12 +12,11 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.TextProcessor
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
*TextModel
|
||||
@@ -55,7 +54,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
vocabulary := tokenizer.Vocabulary{
|
||||
vocabulary := model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -71,19 +70,19 @@ func New(c fs.Config) (model.Model, error) {
|
||||
),
|
||||
}
|
||||
|
||||
var t tokenizer.Tokenizer
|
||||
var processor model.TextProcessor
|
||||
switch c.String("tokenizer.ggml.model") {
|
||||
case "gpt2":
|
||||
t = tokenizer.NewBytePairEncoding(&vocabulary)
|
||||
processor = model.NewBytePairEncoding(&vocabulary)
|
||||
default:
|
||||
// Previous uploads of Gemma 3 on Ollama did not have token 106
|
||||
// (i.e. "<end_of_turn>") so we need to add in case it's not already present
|
||||
vocabulary.EOS = append(vocabulary.EOS, int32(c.Uint("tokenizer.ggml.eot_token_id", 106)))
|
||||
t = tokenizer.NewSentencePiece(&vocabulary)
|
||||
processor = model.NewSentencePiece(&vocabulary)
|
||||
}
|
||||
|
||||
m := Model{
|
||||
Tokenizer: t,
|
||||
TextProcessor: processor,
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
TextModel: newTextModel(c),
|
||||
|
||||
@@ -6,12 +6,11 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.SentencePiece
|
||||
|
||||
*TextModel
|
||||
}
|
||||
@@ -24,8 +23,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
TextModel: newTextModel(c),
|
||||
Tokenizer: tokenizer.NewSentencePiece(
|
||||
&tokenizer.Vocabulary{
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
var ErrOldModelFormat = errors.New("this model uses a weight format that is no longer supported; please re-download it")
|
||||
@@ -199,7 +198,7 @@ func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.BytePairEncoding
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -237,8 +236,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -11,12 +11,11 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.BytePairEncoding
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v"`
|
||||
@@ -38,8 +37,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
allEOS := append([]int32{eosTokenID}, eosTokenIDs...)
|
||||
|
||||
m := &Model{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -12,12 +12,11 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Transformer struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.BytePairEncoding
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
TransformerBlocks []TransformerBlock `gguf:"blk"`
|
||||
@@ -197,8 +196,8 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Optio
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Transformer{
|
||||
TransformerBlocks: make([]TransformerBlock, c.Uint("block_count")),
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -60,7 +59,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.TextProcessor
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -79,7 +78,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
vocabulary := tokenizer.Vocabulary{
|
||||
vocabulary := model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -105,8 +104,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
TextProcessor: model.NewBytePairEncoding(&vocabulary, pretokenizers...),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
headDim: int(c.Uint("attention.key_length")),
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -26,7 +25,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.TextProcessor
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -42,8 +41,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
return nil, model.ErrUnsupportedModel
|
||||
}
|
||||
|
||||
var processor tokenizer.Tokenizer
|
||||
vocabulary := tokenizer.Vocabulary{
|
||||
var processor model.TextProcessor
|
||||
vocabulary := model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -81,16 +80,16 @@ func New(c fs.Config) (model.Model, error) {
|
||||
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
}
|
||||
}
|
||||
processor = tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...)
|
||||
processor = model.NewBytePairEncoding(&vocabulary, pretokenizers...)
|
||||
case "llama":
|
||||
processor = tokenizer.NewSentencePiece(&vocabulary)
|
||||
processor = model.NewSentencePiece(&vocabulary)
|
||||
default:
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
m := Model{
|
||||
Tokenizer: processor,
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
TextProcessor: processor,
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
|
||||
@@ -11,12 +11,11 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.BytePairEncoding
|
||||
ImageProcessor
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
@@ -34,8 +33,8 @@ func (p *Projector) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor {
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -11,12 +11,11 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.BytePairEncoding
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v"`
|
||||
@@ -29,12 +28,12 @@ type Model struct {
|
||||
var _ model.MultimodalProcessor = (*Model)(nil)
|
||||
|
||||
// Implement TextProcessor interface
|
||||
var _ tokenizer.Tokenizer = (*Model)(nil)
|
||||
var _ model.TextProcessor = (*Model)(nil)
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := &Model{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -11,12 +11,11 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.BytePairEncoding
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
*TextModel
|
||||
@@ -33,8 +32,8 @@ const (
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -11,12 +11,11 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.TextProcessor
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
TypeEmbedding *nn.Embedding `gguf:"token_types"`
|
||||
@@ -179,6 +178,29 @@ func New(c fs.Config) (model.Model, error) {
|
||||
numHeads := int(c.Uint("attention.head_count"))
|
||||
headDim := hiddenSize / numHeads
|
||||
|
||||
processor := model.NewWordPiece(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{
|
||||
int32(cmp.Or(
|
||||
c.Uint("tokenizer.ggml.cls_token_id"),
|
||||
c.Uint("tokenizer.ggml.bos_token_id"),
|
||||
)),
|
||||
},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
|
||||
EOS: []int32{
|
||||
int32(cmp.Or(
|
||||
c.Uint("tokenizer.ggml.separator_token_id"),
|
||||
c.Uint("tokenizer.ggml.eos_token_id"),
|
||||
)),
|
||||
},
|
||||
},
|
||||
false,
|
||||
)
|
||||
|
||||
blockCount := int(c.Uint("block_count"))
|
||||
moeEveryNLayers := int(c.Uint("moe_every_n_layers", 0))
|
||||
layers := make([]EncoderLayer, blockCount)
|
||||
@@ -197,29 +219,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
return &Model{
|
||||
Tokenizer: tokenizer.NewWordPiece(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{
|
||||
int32(cmp.Or(
|
||||
c.Uint("tokenizer.ggml.cls_token_id"),
|
||||
c.Uint("tokenizer.ggml.bos_token_id"),
|
||||
)),
|
||||
},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
|
||||
EOS: []int32{
|
||||
int32(cmp.Or(
|
||||
c.Uint("tokenizer.ggml.separator_token_id"),
|
||||
c.Uint("tokenizer.ggml.eos_token_id"),
|
||||
)),
|
||||
},
|
||||
},
|
||||
false,
|
||||
),
|
||||
Layers: layers,
|
||||
TextProcessor: processor,
|
||||
Layers: layers,
|
||||
Options: Options{
|
||||
hiddenSize: hiddenSize,
|
||||
numHeads: numHeads,
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -34,7 +33,7 @@ type Options struct {
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.TextProcessor
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -45,24 +44,28 @@ type Model struct {
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
vocabulary := model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
}
|
||||
|
||||
processor := model.NewBytePairEncoding(
|
||||
&vocabulary,
|
||||
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
)
|
||||
|
||||
m := Model{
|
||||
TextProcessor: processor,
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -93,7 +92,7 @@ func (d DecoderLayer) Forward(ctx ml.Context, hiddenStates, positions, outputs m
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.BytePairEncoding
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []DecoderLayer `gguf:"blk"`
|
||||
@@ -140,8 +139,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
m := Model{
|
||||
Layers: make([]DecoderLayer, c.Uint("block_count")),
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -10,12 +10,11 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.BytePairEncoding
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v"`
|
||||
@@ -28,8 +27,8 @@ var _ model.MultimodalProcessor = (*Model)(nil)
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := &Model{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -7,12 +7,11 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type embedModel struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.BytePairEncoding
|
||||
|
||||
*Model
|
||||
poolingType pooling.Type
|
||||
@@ -35,8 +34,8 @@ func newEmbed(c fs.Config) (model.Model, error) {
|
||||
layers[i].MLP = &dense{}
|
||||
}
|
||||
m := embedModel{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -160,7 +159,7 @@ func (d *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.BytePairEncoding
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
@@ -219,8 +218,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
// Options contains model configuration
|
||||
@@ -208,7 +207,7 @@ func (l *Layer) Forward(ctx ml.Context, layer int, hiddenStates, positions, outp
|
||||
// Model is the main Qwen3-Next model
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.BytePairEncoding
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
@@ -354,8 +353,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -10,12 +10,11 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.TextProcessor
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v"`
|
||||
@@ -173,8 +172,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
TextProcessor: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package tokenizer
|
||||
package model
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
@@ -17,7 +17,7 @@ type SentencePiece struct {
|
||||
vocab *Vocabulary
|
||||
}
|
||||
|
||||
var _ Tokenizer = (*SentencePiece)(nil)
|
||||
var _ TextProcessor = (*SentencePiece)(nil)
|
||||
|
||||
func (spm SentencePiece) Vocabulary() *Vocabulary {
|
||||
return spm.vocab
|
||||
@@ -224,7 +224,7 @@ func (spm SentencePiece) Decode(ids []int32) (string, error) {
|
||||
data := spm.vocab.Decode(id)
|
||||
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
|
||||
|
||||
// For tokenizer that use byte tokens like "<0xEA>"
|
||||
// For tokenizers that use byte tokens like "<0xEA>"
|
||||
// convert them to the partial unicode character
|
||||
// so they are buffered correctly by the runner instead
|
||||
// of being sent back to the api as "<0xEA>"
|
||||
@@ -1,4 +1,4 @@
|
||||
package tokenizer
|
||||
package model
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
func loadSentencePieceVocab(t *testing.T) SentencePiece {
|
||||
t.Helper()
|
||||
|
||||
bts, err := os.ReadFile(filepath.FromSlash("testdata/gemma2/tokenizer.model"))
|
||||
bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
17
model/textprocessor.go
Normal file
17
model/textprocessor.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package model
|
||||
|
||||
const (
|
||||
TOKEN_TYPE_NORMAL = iota + 1
|
||||
TOKEN_TYPE_UNKNOWN
|
||||
TOKEN_TYPE_CONTROL
|
||||
TOKEN_TYPE_USER_DEFINED
|
||||
TOKEN_TYPE_UNUSED
|
||||
TOKEN_TYPE_BYTE
|
||||
)
|
||||
|
||||
type TextProcessor interface {
|
||||
Encode(s string, addSpecial bool) ([]int32, error)
|
||||
Decode([]int32) (string, error)
|
||||
Is(int32, Special) bool
|
||||
Vocabulary() *Vocabulary
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package tokenizer
|
||||
package model
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
@@ -1,4 +1,4 @@
|
||||
package tokenizer
|
||||
package model
|
||||
|
||||
import (
|
||||
"testing"
|
||||
@@ -1,4 +1,4 @@
|
||||
package tokenizer
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -32,7 +32,7 @@ var wordPieceReplacer = strings.NewReplacer(
|
||||
" 're", "'re",
|
||||
)
|
||||
|
||||
// Decode implements Tokenizer.
|
||||
// Decode implements TextProcessor.
|
||||
func (wpm WordPiece) Decode(ids []int32) (string, error) {
|
||||
var sb strings.Builder
|
||||
for i, id := range ids {
|
||||
@@ -96,7 +96,7 @@ func (wpm WordPiece) words(s string) iter.Seq[string] {
|
||||
}
|
||||
}
|
||||
|
||||
// Encode implements Tokenizer.
|
||||
// Encode implements TextProcessor.
|
||||
func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
var ids []int32
|
||||
|
||||
@@ -151,17 +151,17 @@ func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
// Is implements Tokenizer.
|
||||
// Is implements TextProcessor.
|
||||
func (wpm WordPiece) Is(id int32, special Special) bool {
|
||||
return wpm.vocab.Is(id, special)
|
||||
}
|
||||
|
||||
// Vocabulary implements Tokenizer.
|
||||
// Vocabulary implements TextProcessor.
|
||||
func (wpm WordPiece) Vocabulary() *Vocabulary {
|
||||
return wpm.vocab
|
||||
}
|
||||
|
||||
var _ Tokenizer = (*WordPiece)(nil)
|
||||
var _ TextProcessor = (*WordPiece)(nil)
|
||||
|
||||
func NewWordPiece(vocab *Vocabulary, lowercase bool) WordPiece {
|
||||
return WordPiece{
|
||||
@@ -1,4 +1,4 @@
|
||||
package tokenizer
|
||||
package model
|
||||
|
||||
import (
|
||||
"slices"
|
||||
@@ -37,7 +37,6 @@ import (
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/runner/common"
|
||||
"github.com/ollama/ollama/sample"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
|
||||
_ "github.com/ollama/ollama/model/models"
|
||||
)
|
||||
@@ -211,9 +210,9 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
}
|
||||
|
||||
// calculateLogprobs converts raw logits to log probabilities and finds top K tokens
|
||||
func calculateLogprobs(logits []float32, selectedToken int32, topK int, tok tokenizer.Tokenizer) []llm.Logprob {
|
||||
func calculateLogprobs(logits []float32, selectedToken int32, topK int, textProcessor model.TextProcessor) []llm.Logprob {
|
||||
decoder := func(tokenID int) string {
|
||||
text, _ := tok.Decode([]int32{int32(tokenID)})
|
||||
text, _ := textProcessor.Decode([]int32{int32(tokenID)})
|
||||
return text
|
||||
}
|
||||
return common.CalculateLogprobs(logits, int(selectedToken), topK, decoder)
|
||||
@@ -243,7 +242,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input,
|
||||
|
||||
for i, part := range parts {
|
||||
// text - tokenize
|
||||
tokens, err := s.model.(tokenizer.Tokenizer).Encode(part, i == 0)
|
||||
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@@ -515,6 +514,13 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
||||
continue
|
||||
}
|
||||
|
||||
// if past the num predict limit
|
||||
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
||||
s.removeSequence(seqIdx, llm.DoneReasonLength)
|
||||
nextBatch.seqs[seqIdx] = nil
|
||||
continue
|
||||
}
|
||||
|
||||
if !s.cache.enabled {
|
||||
seq.inputs = append(seq.cache.Inputs, seq.inputs...)
|
||||
seq.cache.Inputs = []*input.Input{}
|
||||
@@ -703,6 +709,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
continue
|
||||
}
|
||||
|
||||
seq.numPredicted++
|
||||
nextToken := &input.Input{Token: 0} // placeholder we'll fill in after Compute/Floats
|
||||
seq.inputs = []*input.Input{nextToken}
|
||||
nextBatchTokens[i] = nextToken
|
||||
@@ -738,9 +745,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
logutil.Trace("computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i)
|
||||
continue
|
||||
}
|
||||
|
||||
seq.lastUpdatedAt = t
|
||||
seq.numPredicted++
|
||||
if seq.numPredicted == 1 {
|
||||
seq.processingDuration = seq.lastUpdatedAt.Sub(seq.startedAt)
|
||||
seq.startedAt = seq.lastUpdatedAt
|
||||
@@ -765,7 +770,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
nextBatchTokens[i].Token = token
|
||||
|
||||
// if it's an end of sequence token, break
|
||||
if s.model.(tokenizer.Tokenizer).Is(token, tokenizer.SpecialEOS) {
|
||||
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
|
||||
// TODO (jmorganca): we should send this back
|
||||
// as it's important for the /api/generate context
|
||||
// seq.responses <- piece
|
||||
@@ -774,25 +779,18 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
continue
|
||||
}
|
||||
|
||||
piece, err := s.model.(tokenizer.Tokenizer).Decode([]int32{token})
|
||||
piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
|
||||
if err != nil {
|
||||
panic("failed to decode token")
|
||||
}
|
||||
|
||||
// Calculate logprobs if requested (after EOS check to avoid logprobs for EOS tokens)
|
||||
if seq.logprobs {
|
||||
logprobs := calculateLogprobs(logits, token, seq.topLogprobs, s.model.(tokenizer.Tokenizer))
|
||||
logprobs := calculateLogprobs(logits, token, seq.topLogprobs, s.model.(model.TextProcessor))
|
||||
seq.pendingLogprobs = append(seq.pendingLogprobs, logprobs...)
|
||||
}
|
||||
|
||||
seq.pendingResponses = append(seq.pendingResponses, piece)
|
||||
|
||||
// if past the num predict limit
|
||||
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
||||
s.removeSequence(i, llm.DoneReasonLength)
|
||||
continue
|
||||
}
|
||||
|
||||
sequence := strings.Join(seq.pendingResponses, "")
|
||||
|
||||
if ok, stop := common.FindStop(sequence, seq.stop); ok {
|
||||
@@ -879,7 +877,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
var grammar *sample.GrammarSampler
|
||||
var err error
|
||||
if req.Grammar != "" {
|
||||
grammar, err = sample.NewGrammarSampler(s.model.(tokenizer.Tokenizer), req.Grammar)
|
||||
grammar, err = sample.NewGrammarSampler(s.model.(model.TextProcessor), req.Grammar)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
|
||||
return
|
||||
|
||||
@@ -3,7 +3,6 @@ package runner
|
||||
import (
|
||||
"github.com/ollama/ollama/runner/llamarunner"
|
||||
"github.com/ollama/ollama/runner/ollamarunner"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/mlxrunner"
|
||||
)
|
||||
|
||||
@@ -12,15 +11,22 @@ func Execute(args []string) error {
|
||||
args = args[1:]
|
||||
}
|
||||
|
||||
if len(args) > 0 {
|
||||
switch args[0] {
|
||||
case "--ollama-engine":
|
||||
return ollamarunner.Execute(args[1:])
|
||||
case "--imagegen-engine":
|
||||
return imagegen.Execute(args[1:])
|
||||
case "--mlx-engine":
|
||||
return mlxrunner.Execute(args[1:])
|
||||
}
|
||||
var newRunner bool
|
||||
var mlxRunner bool
|
||||
if len(args) > 0 && args[0] == "--ollama-engine" {
|
||||
args = args[1:]
|
||||
newRunner = true
|
||||
}
|
||||
if len(args) > 0 && args[0] == "--mlx-engine" {
|
||||
args = args[1:]
|
||||
mlxRunner = true
|
||||
}
|
||||
|
||||
if mlxRunner {
|
||||
return mlxrunner.Execute(args)
|
||||
} else if newRunner {
|
||||
return ollamarunner.Execute(args)
|
||||
} else {
|
||||
return llamarunner.Execute(args)
|
||||
}
|
||||
return llamarunner.Execute(args)
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
// token represents information about a single token during sampling
|
||||
@@ -168,15 +168,15 @@ type GrammarSampler struct {
|
||||
grammar *llama.Grammar
|
||||
}
|
||||
|
||||
func NewGrammarSampler(tok tokenizer.Tokenizer, grammarStr string) (*GrammarSampler, error) {
|
||||
vocabIds := make([]uint32, len(tok.Vocabulary().Values))
|
||||
pieces := make([]string, len(tok.Vocabulary().Values))
|
||||
for i := range tok.Vocabulary().Values {
|
||||
pieces[i], _ = tok.Decode([]int32{int32(i)})
|
||||
func NewGrammarSampler(model model.TextProcessor, grammarStr string) (*GrammarSampler, error) {
|
||||
vocabIds := make([]uint32, len(model.Vocabulary().Values))
|
||||
pieces := make([]string, len(model.Vocabulary().Values))
|
||||
for i := range model.Vocabulary().Values {
|
||||
pieces[i], _ = model.Decode([]int32{int32(i)})
|
||||
vocabIds[i] = uint32(i)
|
||||
}
|
||||
|
||||
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, tok.Vocabulary().EOS)
|
||||
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, model.Vocabulary().EOS)
|
||||
if grammar == nil {
|
||||
return nil, errors.New("sample: failed to initialize grammar")
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
func TestWeighted(t *testing.T) {
|
||||
@@ -60,10 +60,10 @@ func TestWeighted(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func modelHelper(t testing.TB) tokenizer.Tokenizer {
|
||||
func modelHelper(t testing.TB) model.BytePairEncoding {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.Open(filepath.FromSlash("../tokenizer/testdata/llama3.2/encoder.json"))
|
||||
f, err := os.Open(filepath.Join("..", "model", "testdata", "llama3.2", "encoder.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -81,8 +81,8 @@ func modelHelper(t testing.TB) tokenizer.Tokenizer {
|
||||
|
||||
merges := make([]string, 0, 1)
|
||||
// Only need vocab for Grammar Test
|
||||
return tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
return model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: tokens,
|
||||
Types: make([]int32, len(vocab)),
|
||||
Merges: merges,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#!/bin/sh
|
||||
# This script installs Ollama on Linux and macOS.
|
||||
# This script installs Ollama on Linux.
|
||||
# It detects the current operating system architecture and installs the appropriate version of Ollama.
|
||||
|
||||
set -eu
|
||||
@@ -27,7 +27,8 @@ require() {
|
||||
echo $MISSING
|
||||
}
|
||||
|
||||
OS="$(uname -s)"
|
||||
[ "$(uname -s)" = "Linux" ] || error 'This script is intended to run on Linux only.'
|
||||
|
||||
ARCH=$(uname -m)
|
||||
case "$ARCH" in
|
||||
x86_64) ARCH="amd64" ;;
|
||||
@@ -35,65 +36,6 @@ case "$ARCH" in
|
||||
*) error "Unsupported architecture: $ARCH" ;;
|
||||
esac
|
||||
|
||||
###########################################
|
||||
# macOS
|
||||
###########################################
|
||||
|
||||
if [ "$OS" = "Darwin" ]; then
|
||||
NEEDS=$(require curl unzip)
|
||||
if [ -n "$NEEDS" ]; then
|
||||
status "ERROR: The following tools are required but missing:"
|
||||
for NEED in $NEEDS; do
|
||||
echo " - $NEED"
|
||||
done
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -n "${OLLAMA_VERSION:-}" ]; then
|
||||
DOWNLOAD_URL="https://github.com/ollama/ollama/releases/download/${OLLAMA_VERSION}/Ollama-darwin.zip"
|
||||
else
|
||||
DOWNLOAD_URL="https://github.com/ollama/ollama/releases/latest/download/Ollama-darwin.zip"
|
||||
fi
|
||||
|
||||
if pgrep -x Ollama >/dev/null 2>&1; then
|
||||
status "Stopping running Ollama instance..."
|
||||
pkill -x Ollama 2>/dev/null || true
|
||||
sleep 2
|
||||
fi
|
||||
|
||||
if [ -d "/Applications/Ollama.app" ]; then
|
||||
status "Removing existing Ollama installation..."
|
||||
rm -rf "/Applications/Ollama.app"
|
||||
fi
|
||||
|
||||
status "Downloading Ollama for macOS..."
|
||||
curl --fail --show-error --location --progress-bar \
|
||||
-o "$TEMP_DIR/Ollama-darwin.zip" "$DOWNLOAD_URL"
|
||||
|
||||
status "Installing Ollama to /Applications..."
|
||||
unzip -q "$TEMP_DIR/Ollama-darwin.zip" -d "$TEMP_DIR"
|
||||
mv "$TEMP_DIR/Ollama.app" "/Applications/"
|
||||
|
||||
status "Adding 'ollama' command to PATH (may require password)..."
|
||||
mkdir -p "/usr/local/bin" 2>/dev/null || sudo mkdir -p "/usr/local/bin"
|
||||
ln -sf "/Applications/Ollama.app/Contents/Resources/ollama" "/usr/local/bin/ollama" 2>/dev/null || \
|
||||
sudo ln -sf "/Applications/Ollama.app/Contents/Resources/ollama" "/usr/local/bin/ollama"
|
||||
|
||||
if [ -z "${OLLAMA_NO_START:-}" ]; then
|
||||
status "Starting Ollama..."
|
||||
open -a Ollama --args hidden
|
||||
fi
|
||||
|
||||
status "Install complete. You can now run 'ollama'."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
###########################################
|
||||
# Linux
|
||||
###########################################
|
||||
|
||||
[ "$OS" = "Linux" ] || error 'This script is intended to run on Linux and macOS only.'
|
||||
|
||||
IS_WSL2=false
|
||||
|
||||
KERN=$(uname -r)
|
||||
|
||||
@@ -1,422 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
const (
|
||||
serverConfigFilename = "server.json"
|
||||
serverConfigVersion = 1
|
||||
)
|
||||
|
||||
var errAliasCycle = errors.New("alias cycle detected")
|
||||
|
||||
type aliasEntry struct {
|
||||
Alias string `json:"alias"`
|
||||
Target string `json:"target"`
|
||||
PrefixMatching bool `json:"prefix_matching,omitempty"`
|
||||
}
|
||||
|
||||
type serverConfig struct {
|
||||
Version int `json:"version"`
|
||||
Aliases []aliasEntry `json:"aliases"`
|
||||
}
|
||||
|
||||
type store struct {
|
||||
mu sync.RWMutex
|
||||
path string
|
||||
entries map[string]aliasEntry // normalized alias -> entry (exact matches)
|
||||
prefixEntries []aliasEntry // prefix matches, sorted longest-first
|
||||
}
|
||||
|
||||
func createStore(path string) (*store, error) {
|
||||
store := &store{
|
||||
path: path,
|
||||
entries: make(map[string]aliasEntry),
|
||||
}
|
||||
if err := store.load(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func (s *store) load() error {
|
||||
data, err := os.ReadFile(s.path)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
var cfg serverConfig
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if cfg.Version != 0 && cfg.Version != serverConfigVersion {
|
||||
return fmt.Errorf("unsupported router config version %d", cfg.Version)
|
||||
}
|
||||
|
||||
for _, entry := range cfg.Aliases {
|
||||
targetName := model.ParseName(entry.Target)
|
||||
if !targetName.IsValid() {
|
||||
slog.Warn("invalid alias target in router config", "target", entry.Target)
|
||||
continue
|
||||
}
|
||||
canonicalTarget := displayAliasName(targetName)
|
||||
|
||||
if entry.PrefixMatching {
|
||||
// Prefix aliases don't need to be valid model names
|
||||
alias := strings.TrimSpace(entry.Alias)
|
||||
if alias == "" {
|
||||
slog.Warn("empty prefix alias in router config")
|
||||
continue
|
||||
}
|
||||
s.prefixEntries = append(s.prefixEntries, aliasEntry{
|
||||
Alias: alias,
|
||||
Target: canonicalTarget,
|
||||
PrefixMatching: true,
|
||||
})
|
||||
} else {
|
||||
aliasName := model.ParseName(entry.Alias)
|
||||
if !aliasName.IsValid() {
|
||||
slog.Warn("invalid alias name in router config", "alias", entry.Alias)
|
||||
continue
|
||||
}
|
||||
canonicalAlias := displayAliasName(aliasName)
|
||||
s.entries[normalizeAliasKey(aliasName)] = aliasEntry{
|
||||
Alias: canonicalAlias,
|
||||
Target: canonicalTarget,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort prefix entries by alias length descending (longest prefix wins)
|
||||
s.sortPrefixEntriesLocked()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *store) saveLocked() error {
|
||||
dir := filepath.Dir(s.path)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Combine exact and prefix entries
|
||||
entries := make([]aliasEntry, 0, len(s.entries)+len(s.prefixEntries))
|
||||
for _, entry := range s.entries {
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
entries = append(entries, s.prefixEntries...)
|
||||
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return strings.Compare(entries[i].Alias, entries[j].Alias) < 0
|
||||
})
|
||||
|
||||
cfg := serverConfig{
|
||||
Version: serverConfigVersion,
|
||||
Aliases: entries,
|
||||
}
|
||||
|
||||
f, err := os.CreateTemp(dir, "router-*.json")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
enc := json.NewEncoder(f)
|
||||
enc.SetIndent("", " ")
|
||||
if err := enc.Encode(cfg); err != nil {
|
||||
_ = f.Close()
|
||||
_ = os.Remove(f.Name())
|
||||
return err
|
||||
}
|
||||
|
||||
if err := f.Close(); err != nil {
|
||||
_ = os.Remove(f.Name())
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.Chmod(f.Name(), 0o644); err != nil {
|
||||
_ = os.Remove(f.Name())
|
||||
return err
|
||||
}
|
||||
|
||||
return os.Rename(f.Name(), s.path)
|
||||
}
|
||||
|
||||
func (s *store) ResolveName(name model.Name) (model.Name, bool, error) {
|
||||
// If a local model exists, do not allow alias shadowing (highest priority).
|
||||
exists, err := localModelExists(name)
|
||||
if err != nil {
|
||||
return name, false, err
|
||||
}
|
||||
if exists {
|
||||
return name, false, nil
|
||||
}
|
||||
|
||||
key := normalizeAliasKey(name)
|
||||
|
||||
s.mu.RLock()
|
||||
entry, exactMatch := s.entries[key]
|
||||
var prefixMatch *aliasEntry
|
||||
if !exactMatch {
|
||||
// Try prefix matching - prefixEntries is sorted longest-first
|
||||
nameStr := strings.ToLower(displayAliasName(name))
|
||||
for i := range s.prefixEntries {
|
||||
prefix := strings.ToLower(s.prefixEntries[i].Alias)
|
||||
if strings.HasPrefix(nameStr, prefix) {
|
||||
prefixMatch = &s.prefixEntries[i]
|
||||
break // First match is longest due to sorting
|
||||
}
|
||||
}
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exactMatch && prefixMatch == nil {
|
||||
return name, false, nil
|
||||
}
|
||||
|
||||
var current string
|
||||
var visited map[string]struct{}
|
||||
|
||||
if exactMatch {
|
||||
visited = map[string]struct{}{key: {}}
|
||||
current = entry.Target
|
||||
} else {
|
||||
// For prefix match, use the target as-is
|
||||
visited = map[string]struct{}{}
|
||||
current = prefixMatch.Target
|
||||
}
|
||||
|
||||
targetKey := normalizeAliasKeyString(current)
|
||||
|
||||
for {
|
||||
targetName := model.ParseName(current)
|
||||
if !targetName.IsValid() {
|
||||
return name, false, fmt.Errorf("alias target %q is invalid", current)
|
||||
}
|
||||
|
||||
if _, seen := visited[targetKey]; seen {
|
||||
return name, false, errAliasCycle
|
||||
}
|
||||
visited[targetKey] = struct{}{}
|
||||
|
||||
s.mu.RLock()
|
||||
next, ok := s.entries[targetKey]
|
||||
s.mu.RUnlock()
|
||||
if !ok {
|
||||
return targetName, true, nil
|
||||
}
|
||||
|
||||
current = next.Target
|
||||
targetKey = normalizeAliasKeyString(current)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *store) Set(alias, target model.Name, prefixMatching bool) error {
|
||||
targetKey := normalizeAliasKey(target)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if prefixMatching {
|
||||
// For prefix aliases, we skip cycle detection since prefix matching
|
||||
// works differently and the target is a specific model
|
||||
aliasStr := displayAliasName(alias)
|
||||
|
||||
// Remove any existing prefix entry with the same alias
|
||||
for i, e := range s.prefixEntries {
|
||||
if strings.EqualFold(e.Alias, aliasStr) {
|
||||
s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
s.prefixEntries = append(s.prefixEntries, aliasEntry{
|
||||
Alias: aliasStr,
|
||||
Target: displayAliasName(target),
|
||||
PrefixMatching: true,
|
||||
})
|
||||
s.sortPrefixEntriesLocked()
|
||||
return s.saveLocked()
|
||||
}
|
||||
|
||||
aliasKey := normalizeAliasKey(alias)
|
||||
|
||||
if aliasKey == targetKey {
|
||||
return fmt.Errorf("alias cannot point to itself")
|
||||
}
|
||||
|
||||
visited := map[string]struct{}{aliasKey: {}}
|
||||
currentKey := targetKey
|
||||
for {
|
||||
if _, seen := visited[currentKey]; seen {
|
||||
return errAliasCycle
|
||||
}
|
||||
visited[currentKey] = struct{}{}
|
||||
|
||||
next, ok := s.entries[currentKey]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
currentKey = normalizeAliasKeyString(next.Target)
|
||||
}
|
||||
|
||||
s.entries[aliasKey] = aliasEntry{
|
||||
Alias: displayAliasName(alias),
|
||||
Target: displayAliasName(target),
|
||||
}
|
||||
|
||||
return s.saveLocked()
|
||||
}
|
||||
|
||||
func (s *store) Delete(alias model.Name) (bool, error) {
|
||||
aliasKey := normalizeAliasKey(alias)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Try exact match first
|
||||
if _, ok := s.entries[aliasKey]; ok {
|
||||
delete(s.entries, aliasKey)
|
||||
return true, s.saveLocked()
|
||||
}
|
||||
|
||||
// Try prefix entries
|
||||
aliasStr := displayAliasName(alias)
|
||||
for i, e := range s.prefixEntries {
|
||||
if strings.EqualFold(e.Alias, aliasStr) {
|
||||
s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...)
|
||||
return true, s.saveLocked()
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// DeleteByString deletes an alias by its raw string value, useful for prefix
|
||||
// aliases that may not be valid model names.
|
||||
func (s *store) DeleteByString(alias string) (bool, error) {
|
||||
alias = strings.TrimSpace(alias)
|
||||
aliasLower := strings.ToLower(alias)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Try prefix entries first (since this is mainly for prefix aliases)
|
||||
for i, e := range s.prefixEntries {
|
||||
if strings.EqualFold(e.Alias, alias) {
|
||||
s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...)
|
||||
return true, s.saveLocked()
|
||||
}
|
||||
}
|
||||
|
||||
// Also check exact entries by normalized key
|
||||
if _, ok := s.entries[aliasLower]; ok {
|
||||
delete(s.entries, aliasLower)
|
||||
return true, s.saveLocked()
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (s *store) List() []aliasEntry {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
entries := make([]aliasEntry, 0, len(s.entries)+len(s.prefixEntries))
|
||||
for _, entry := range s.entries {
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
entries = append(entries, s.prefixEntries...)
|
||||
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return strings.Compare(entries[i].Alias, entries[j].Alias) < 0
|
||||
})
|
||||
return entries
|
||||
}
|
||||
|
||||
func normalizeAliasKey(name model.Name) string {
|
||||
return strings.ToLower(displayAliasName(name))
|
||||
}
|
||||
|
||||
func (s *store) sortPrefixEntriesLocked() {
|
||||
sort.Slice(s.prefixEntries, func(i, j int) bool {
|
||||
// Sort by length descending (longest prefix first)
|
||||
return len(s.prefixEntries[i].Alias) > len(s.prefixEntries[j].Alias)
|
||||
})
|
||||
}
|
||||
|
||||
func normalizeAliasKeyString(value string) string {
|
||||
n := model.ParseName(value)
|
||||
if !n.IsValid() {
|
||||
return strings.ToLower(strings.TrimSpace(value))
|
||||
}
|
||||
return normalizeAliasKey(n)
|
||||
}
|
||||
|
||||
func displayAliasName(n model.Name) string {
|
||||
display := n.DisplayShortest()
|
||||
if strings.EqualFold(n.Tag, "latest") {
|
||||
if idx := strings.LastIndex(display, ":"); idx != -1 {
|
||||
return display[:idx]
|
||||
}
|
||||
}
|
||||
return display
|
||||
}
|
||||
|
||||
func localModelExists(name model.Name) (bool, error) {
|
||||
manifests, err := manifest.Manifests(true)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
needle := name.String()
|
||||
for existing := range manifests {
|
||||
if strings.EqualFold(existing.String(), needle) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func serverConfigPath() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return filepath.Join(".ollama", serverConfigFilename)
|
||||
}
|
||||
return filepath.Join(home, ".ollama", serverConfigFilename)
|
||||
}
|
||||
|
||||
func (s *Server) aliasStore() (*store, error) {
|
||||
s.aliasesOnce.Do(func() {
|
||||
s.aliases, s.aliasesErr = createStore(serverConfigPath())
|
||||
})
|
||||
|
||||
return s.aliases, s.aliasesErr
|
||||
}
|
||||
|
||||
func (s *Server) resolveAlias(name model.Name) (model.Name, bool, error) {
|
||||
store, err := s.aliasStore()
|
||||
if err != nil {
|
||||
return name, false, err
|
||||
}
|
||||
|
||||
if store == nil {
|
||||
return name, false, nil
|
||||
}
|
||||
|
||||
return store.ResolveName(name)
|
||||
}
|
||||
@@ -22,7 +22,6 @@ import (
|
||||
"os/signal"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
@@ -52,7 +51,7 @@ import (
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
imagegenmanifest "github.com/ollama/ollama/x/imagegen/manifest"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
xserver "github.com/ollama/ollama/x/server"
|
||||
)
|
||||
|
||||
@@ -82,9 +81,6 @@ type Server struct {
|
||||
addr net.Addr
|
||||
sched *Scheduler
|
||||
defaultNumCtx int
|
||||
aliasesOnce sync.Once
|
||||
aliases *store
|
||||
aliasesErr error
|
||||
}
|
||||
|
||||
func init() {
|
||||
@@ -195,16 +191,9 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
resolvedName, _, err := s.resolveAlias(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
name = resolvedName
|
||||
|
||||
// We cannot currently consolidate this into GetModel because all we'll
|
||||
// induce infinite recursion given the current code structure.
|
||||
name, err = getExistingName(name)
|
||||
name, err := getExistingName(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
return
|
||||
@@ -1106,7 +1095,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
|
||||
// For image generation models, populate details from imagegen package
|
||||
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
|
||||
if info, err := imagegenmanifest.GetModelInfo(name.String()); err == nil {
|
||||
if info, err := imagegen.GetModelInfo(name.String()); err == nil {
|
||||
modelDetails.Family = info.Architecture
|
||||
modelDetails.ParameterSize = format.HumanNumber(uint64(info.ParameterCount))
|
||||
modelDetails.QuantizationLevel = info.Quantization
|
||||
@@ -1591,9 +1580,6 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
|
||||
r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
|
||||
r.POST("/api/copy", s.CopyHandler)
|
||||
r.GET("/api/experimental/aliases", s.ListAliasesHandler)
|
||||
r.POST("/api/experimental/aliases", s.CreateAliasHandler)
|
||||
r.DELETE("/api/experimental/aliases", s.DeleteAliasHandler)
|
||||
|
||||
// Inference
|
||||
r.GET("/api/ps", s.PsHandler)
|
||||
@@ -1964,20 +1950,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
resolvedName, _, err := s.resolveAlias(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
name = resolvedName
|
||||
|
||||
name, err = getExistingName(name)
|
||||
name, err := getExistingName(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
return
|
||||
}
|
||||
|
||||
m, err := GetModel(name.String())
|
||||
m, err := GetModel(req.Model)
|
||||
if err != nil {
|
||||
switch {
|
||||
case os.IsNotExist(err):
|
||||
|
||||
@@ -1,159 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
type aliasListResponse struct {
|
||||
Aliases []aliasEntry `json:"aliases"`
|
||||
}
|
||||
|
||||
type aliasDeleteRequest struct {
|
||||
Alias string `json:"alias"`
|
||||
}
|
||||
|
||||
func (s *Server) ListAliasesHandler(c *gin.Context) {
|
||||
store, err := s.aliasStore()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
var aliases []aliasEntry
|
||||
if store != nil {
|
||||
aliases = store.List()
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, aliasListResponse{Aliases: aliases})
|
||||
}
|
||||
|
||||
func (s *Server) CreateAliasHandler(c *gin.Context) {
|
||||
var req aliasEntry
|
||||
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
} else if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
req.Alias = strings.TrimSpace(req.Alias)
|
||||
req.Target = strings.TrimSpace(req.Target)
|
||||
if req.Alias == "" || req.Target == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias and target are required"})
|
||||
return
|
||||
}
|
||||
|
||||
// Target must always be a valid model name
|
||||
targetName := model.ParseName(req.Target)
|
||||
if !targetName.IsValid() {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("target %q is invalid", req.Target)})
|
||||
return
|
||||
}
|
||||
|
||||
var aliasName model.Name
|
||||
if req.PrefixMatching {
|
||||
// For prefix aliases, we still parse the alias to normalize it,
|
||||
// but we allow any non-empty string since prefix patterns may not be valid model names
|
||||
aliasName = model.ParseName(req.Alias)
|
||||
// Even if not valid as a model name, we accept it for prefix matching
|
||||
} else {
|
||||
aliasName = model.ParseName(req.Alias)
|
||||
if !aliasName.IsValid() {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("alias %q is invalid", req.Alias)})
|
||||
return
|
||||
}
|
||||
|
||||
if normalizeAliasKey(aliasName) == normalizeAliasKey(targetName) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias cannot point to itself"})
|
||||
return
|
||||
}
|
||||
|
||||
exists, err := localModelExists(aliasName)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if exists {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("alias %q conflicts with existing model", req.Alias)})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
store, err := s.aliasStore()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := store.Set(aliasName, targetName, req.PrefixMatching); err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
if errors.Is(err, errAliasCycle) {
|
||||
status = http.StatusBadRequest
|
||||
}
|
||||
c.AbortWithStatusJSON(status, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
resp := aliasEntry{
|
||||
Alias: displayAliasName(aliasName),
|
||||
Target: displayAliasName(targetName),
|
||||
PrefixMatching: req.PrefixMatching,
|
||||
}
|
||||
if req.PrefixMatching && !aliasName.IsValid() {
|
||||
// For prefix aliases that aren't valid model names, use the raw alias
|
||||
resp.Alias = req.Alias
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (s *Server) DeleteAliasHandler(c *gin.Context) {
|
||||
var req aliasDeleteRequest
|
||||
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
} else if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
req.Alias = strings.TrimSpace(req.Alias)
|
||||
if req.Alias == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias is required"})
|
||||
return
|
||||
}
|
||||
|
||||
store, err := s.aliasStore()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
aliasName := model.ParseName(req.Alias)
|
||||
var deleted bool
|
||||
if aliasName.IsValid() {
|
||||
deleted, err = store.Delete(aliasName)
|
||||
} else {
|
||||
// For invalid model names (like prefix aliases), try deleting by raw string
|
||||
deleted, err = store.DeleteByString(req.Alias)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if !deleted {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("alias %q not found", req.Alias)})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"deleted": true})
|
||||
}
|
||||
@@ -1,426 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func TestAliasShadowingRejected(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Setenv("HOME", t.TempDir())
|
||||
|
||||
s := Server{}
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "shadowed-model",
|
||||
RemoteHost: "example.com",
|
||||
From: "test",
|
||||
Info: map[string]any{
|
||||
"capabilities": []string{"completion"},
|
||||
},
|
||||
Stream: &stream,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
w = createRequest(t, s.CreateAliasHandler, aliasEntry{Alias: "shadowed-model", Target: "other-model"})
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected status 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAliasResolvesForChatRemote(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Setenv("HOME", t.TempDir())
|
||||
|
||||
var remoteModel string
|
||||
rs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var req api.ChatRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
remoteModel = req.Model
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
resp := api.ChatResponse{
|
||||
Model: req.Model,
|
||||
Done: true,
|
||||
DoneReason: "load",
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}))
|
||||
defer rs.Close()
|
||||
|
||||
p, err := url.Parse(rs.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_REMOTES", p.Hostname())
|
||||
|
||||
s := Server{}
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "target-model",
|
||||
RemoteHost: rs.URL,
|
||||
From: "test",
|
||||
Info: map[string]any{
|
||||
"capabilities": []string{"completion"},
|
||||
},
|
||||
Stream: &stream,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
w = createRequest(t, s.CreateAliasHandler, aliasEntry{Alias: "alias-model", Target: "target-model"})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
w = createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "alias-model",
|
||||
Messages: []api.Message{{Role: "user", Content: "hi"}},
|
||||
Stream: &stream,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp api.ChatResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if resp.Model != "alias-model" {
|
||||
t.Fatalf("expected response model to be alias-model, got %q", resp.Model)
|
||||
}
|
||||
|
||||
if remoteModel != "test" {
|
||||
t.Fatalf("expected remote model to be 'test', got %q", remoteModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrefixAliasBasicMatching(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
store, err := createStore(filepath.Join(tmpDir, "server.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a prefix alias: "myprefix-" -> "targetmodel"
|
||||
targetName := model.ParseName("targetmodel")
|
||||
|
||||
// Set a prefix alias (using "myprefix-" as the pattern)
|
||||
store.mu.Lock()
|
||||
store.prefixEntries = append(store.prefixEntries, aliasEntry{
|
||||
Alias: "myprefix-",
|
||||
Target: "targetmodel",
|
||||
PrefixMatching: true,
|
||||
})
|
||||
store.mu.Unlock()
|
||||
|
||||
// Test that "myprefix-foo" resolves to "targetmodel"
|
||||
testName := model.ParseName("myprefix-foo")
|
||||
resolved, wasResolved, err := store.ResolveName(testName)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected name to be resolved")
|
||||
}
|
||||
if resolved.DisplayShortest() != targetName.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be %q, got %q", targetName.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
|
||||
// Test that "otherprefix-foo" does not resolve
|
||||
otherName := model.ParseName("otherprefix-foo")
|
||||
_, wasResolved, err = store.ResolveName(otherName)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if wasResolved {
|
||||
t.Fatal("expected name not to be resolved")
|
||||
}
|
||||
|
||||
// Test that exact alias takes precedence
|
||||
exactAlias := model.ParseName("myprefix-exact")
|
||||
exactTarget := model.ParseName("exacttarget")
|
||||
if err := store.Set(exactAlias, exactTarget, false); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
resolved, wasResolved, err = store.ResolveName(exactAlias)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected name to be resolved")
|
||||
}
|
||||
if resolved.DisplayShortest() != exactTarget.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be %q (exact match), got %q", exactTarget.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrefixAliasLongestMatchWins(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
store, err := createStore(filepath.Join(tmpDir, "server.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Add two prefix aliases with overlapping patterns
|
||||
store.mu.Lock()
|
||||
store.prefixEntries = []aliasEntry{
|
||||
{Alias: "abc-", Target: "short-target", PrefixMatching: true},
|
||||
{Alias: "abc-def-", Target: "long-target", PrefixMatching: true},
|
||||
}
|
||||
store.sortPrefixEntriesLocked()
|
||||
store.mu.Unlock()
|
||||
|
||||
// "abc-def-ghi" should match the longer prefix "abc-def-"
|
||||
testName := model.ParseName("abc-def-ghi")
|
||||
resolved, wasResolved, err := store.ResolveName(testName)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected name to be resolved")
|
||||
}
|
||||
expectedLongTarget := model.ParseName("long-target")
|
||||
if resolved.DisplayShortest() != expectedLongTarget.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be %q (longest prefix match), got %q", expectedLongTarget.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
|
||||
// "abc-xyz" should match the shorter prefix "abc-"
|
||||
testName2 := model.ParseName("abc-xyz")
|
||||
resolved, wasResolved, err = store.ResolveName(testName2)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected name to be resolved")
|
||||
}
|
||||
expectedShortTarget := model.ParseName("short-target")
|
||||
if resolved.DisplayShortest() != expectedShortTarget.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be %q, got %q", expectedShortTarget.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrefixAliasChain(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
store, err := createStore(filepath.Join(tmpDir, "server.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a chain: prefix "test-" -> "intermediate" -> "final"
|
||||
intermediate := model.ParseName("intermediate")
|
||||
final := model.ParseName("final")
|
||||
|
||||
// Add prefix alias
|
||||
store.mu.Lock()
|
||||
store.prefixEntries = []aliasEntry{
|
||||
{Alias: "test-", Target: "intermediate", PrefixMatching: true},
|
||||
}
|
||||
store.mu.Unlock()
|
||||
|
||||
// Add exact alias for the intermediate step
|
||||
if err := store.Set(intermediate, final, false); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// "test-foo" should resolve through the chain to "final"
|
||||
testName := model.ParseName("test-foo")
|
||||
resolved, wasResolved, err := store.ResolveName(testName)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected name to be resolved")
|
||||
}
|
||||
if resolved.DisplayShortest() != final.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be %q, got %q", final.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrefixAliasCRUD(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Setenv("HOME", t.TempDir())
|
||||
|
||||
s := Server{}
|
||||
|
||||
// Create a prefix alias via API
|
||||
w := createRequest(t, s.CreateAliasHandler, aliasEntry{
|
||||
Alias: "myprefix-",
|
||||
Target: "llama2",
|
||||
PrefixMatching: true,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var createResp aliasEntry
|
||||
if err := json.NewDecoder(w.Body).Decode(&createResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !createResp.PrefixMatching {
|
||||
t.Fatal("expected prefix_matching to be true in response")
|
||||
}
|
||||
|
||||
// List aliases and verify the prefix alias is included
|
||||
w = createRequest(t, s.ListAliasesHandler, nil)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var listResp aliasListResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&listResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, a := range listResp.Aliases {
|
||||
if a.PrefixMatching && a.Target == "llama2" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatal("expected to find prefix alias in list")
|
||||
}
|
||||
|
||||
// Delete the prefix alias
|
||||
w = createRequest(t, s.DeleteAliasHandler, aliasDeleteRequest{Alias: "myprefix-"})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Verify it's deleted
|
||||
w = createRequest(t, s.ListAliasesHandler, nil)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(w.Body).Decode(&listResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, a := range listResp.Aliases {
|
||||
if a.PrefixMatching {
|
||||
t.Fatal("expected prefix alias to be deleted")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrefixAliasCaseInsensitive(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
store, err := createStore(filepath.Join(tmpDir, "server.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Add a prefix alias with mixed case
|
||||
store.mu.Lock()
|
||||
store.prefixEntries = []aliasEntry{
|
||||
{Alias: "MyPrefix-", Target: "targetmodel", PrefixMatching: true},
|
||||
}
|
||||
store.mu.Unlock()
|
||||
|
||||
// Test that matching is case-insensitive
|
||||
testName := model.ParseName("myprefix-foo")
|
||||
resolved, wasResolved, err := store.ResolveName(testName)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected name to be resolved (case-insensitive)")
|
||||
}
|
||||
expectedTarget := model.ParseName("targetmodel")
|
||||
if resolved.DisplayShortest() != expectedTarget.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be %q, got %q", expectedTarget.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
|
||||
// Test uppercase request
|
||||
testName2 := model.ParseName("MYPREFIX-BAR")
|
||||
_, wasResolved, err = store.ResolveName(testName2)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected name to be resolved (uppercase)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrefixAliasLocalModelPrecedence(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Setenv("HOME", t.TempDir())
|
||||
|
||||
s := Server{}
|
||||
|
||||
// Create a local model that would match a prefix alias
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "myprefix-localmodel",
|
||||
RemoteHost: "example.com",
|
||||
From: "test",
|
||||
Info: map[string]any{
|
||||
"capabilities": []string{"completion"},
|
||||
},
|
||||
Stream: &stream,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Create a prefix alias that would match the local model name
|
||||
w = createRequest(t, s.CreateAliasHandler, aliasEntry{
|
||||
Alias: "myprefix-",
|
||||
Target: "someothermodel",
|
||||
PrefixMatching: true,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Verify that resolving "myprefix-localmodel" returns the local model, not the alias target
|
||||
store, err := s.aliasStore()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
localModelName := model.ParseName("myprefix-localmodel")
|
||||
resolved, wasResolved, err := store.ResolveName(localModelName)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if wasResolved {
|
||||
t.Fatalf("expected local model to take precedence (wasResolved should be false), but got resolved to %q", resolved.DisplayShortest())
|
||||
}
|
||||
if resolved.DisplayShortest() != localModelName.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be local model %q, got %q", localModelName.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
|
||||
// Also verify that a non-local model matching the prefix DOES resolve to the alias target
|
||||
nonLocalName := model.ParseName("myprefix-nonexistent")
|
||||
resolved, wasResolved, err = store.ResolveName(nonLocalName)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected non-local model to resolve via prefix alias")
|
||||
}
|
||||
expectedTarget := model.ParseName("someothermodel")
|
||||
if resolved.DisplayShortest() != expectedTarget.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be %q, got %q", expectedTarget.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
}
|
||||
119
server/sched.go
119
server/sched.go
@@ -5,13 +5,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math/rand"
|
||||
"os"
|
||||
"os/exec"
|
||||
"reflect"
|
||||
"slices"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -25,7 +21,6 @@ import (
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/mlxrunner"
|
||||
)
|
||||
|
||||
@@ -200,14 +195,25 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "gpu_count", len(gpus))
|
||||
}
|
||||
|
||||
// Check for experimental safetensors LLM models
|
||||
if pending.model.Config.ModelFormat == "safetensors" {
|
||||
if s.loadSafetensors(pending) {
|
||||
// Check for image generation models - all use MLX runner
|
||||
if slices.Contains(pending.model.Config.Capabilities, "image") {
|
||||
if s.loadMLX(pending) {
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for experimental safetensors LLM models
|
||||
if pending.model.Config.ModelFormat == "safetensors" {
|
||||
if slices.Contains(pending.model.Config.Capabilities, "completion") {
|
||||
// LLM model with safetensors format - use MLX runner
|
||||
if s.loadMLX(pending) {
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Load model for fitting
|
||||
logutil.Trace("loading model metadata", "model", pending.model.ModelPath)
|
||||
ggml, err := llm.LoadModel(pending.model.ModelPath, 1024)
|
||||
@@ -411,9 +417,9 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo
|
||||
numParallel = 1
|
||||
}
|
||||
|
||||
// Some architectures are not safe with num_parallel > 1.
|
||||
// `mllama`, `qwen3vl`, and `qwen3vlmoe` are snowflakes and uses an encoder cache which cannot be used with num_parallel > 1
|
||||
// ref: https://github.com/ollama/ollama/issues/4165
|
||||
if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe", "qwen3next", "lfm2", "lfm2moe"}, req.model.Config.ModelFamily) && numParallel != 1 {
|
||||
if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe"}, req.model.Config.ModelFamily) && numParallel != 1 {
|
||||
numParallel = 1
|
||||
slog.Warn("model architecture does not currently support parallel requests", "architecture", req.model.Config.ModelFamily)
|
||||
}
|
||||
@@ -557,101 +563,20 @@ iGPUScan:
|
||||
return false
|
||||
}
|
||||
|
||||
func subproc(args, environ []string) (*exec.Cmd, int, error) {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("unable to lookup executable path: %w", err)
|
||||
}
|
||||
|
||||
for range 3 {
|
||||
// get a random port in the ephemeral range
|
||||
port := rand.Intn(65535-49152) + 49152
|
||||
cmd := exec.Command(exe, slices.Concat([]string{"runner"}, args, []string{"--port", strconv.Itoa(port)})...)
|
||||
cmd.Env = slices.Concat(os.Environ(), environ)
|
||||
cmd.Stdout = os.Stderr
|
||||
cmd.Stderr = os.Stderr
|
||||
if err := cmd.Start(); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
return cmd, port, nil
|
||||
}
|
||||
|
||||
return nil, 0, fmt.Errorf("unable to start subprocess after multiple attempts")
|
||||
}
|
||||
|
||||
func (s *Scheduler) loadSafetensors(req *LlmRequest) bool {
|
||||
if slices.Contains(req.model.Config.Capabilities, "image") {
|
||||
return s.loadImageGen(req)
|
||||
}
|
||||
|
||||
args := []string{"--mlx-engine", "--model", req.model.ShortName}
|
||||
environ := []string{}
|
||||
cmd, port, err := subproc(args, environ)
|
||||
if err != nil {
|
||||
req.errCh <- fmt.Errorf("failed to start mlx subprocess: %w", err)
|
||||
return true
|
||||
}
|
||||
|
||||
sessionDuration := envconfig.KeepAlive()
|
||||
if req.sessionDuration != nil {
|
||||
sessionDuration = req.sessionDuration.Duration
|
||||
}
|
||||
|
||||
runner := &runnerRef{
|
||||
model: req.model,
|
||||
modelPath: req.model.ModelPath,
|
||||
Options: &req.opts,
|
||||
loading: false,
|
||||
sessionDuration: sessionDuration,
|
||||
llama: &mlxrunner.Client{
|
||||
Cmd: cmd,
|
||||
Port: port,
|
||||
},
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
s.loaded[req.model.ModelPath] = runner
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
runner.refMu.Lock()
|
||||
if sessionDuration > 0 {
|
||||
runner.expireTimer = time.AfterFunc(sessionDuration, func() {
|
||||
s.expiredCh <- runner
|
||||
})
|
||||
}
|
||||
runner.refMu.Unlock()
|
||||
req.useLoadedRunner(runner, s.finishedReqCh)
|
||||
|
||||
for range time.Tick(20 * time.Millisecond) {
|
||||
if err := func() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
return runner.llama.Ping(ctx)
|
||||
}(); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// loadImageGen loads an experimental safetensors model using the unified MLX runner.
|
||||
// loadMLX loads an experimental safetensors model using the unified MLX runner.
|
||||
// This supports both LLM (completion) and image generation models.
|
||||
func (s *Scheduler) loadImageGen(req *LlmRequest) bool {
|
||||
func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
||||
// Determine mode based on capabilities
|
||||
var mode imagegen.ModelMode
|
||||
var mode mlxrunner.ModelMode
|
||||
if slices.Contains(req.model.Config.Capabilities, "image") {
|
||||
mode = imagegen.ModeImageGen
|
||||
mode = mlxrunner.ModeImageGen
|
||||
} else {
|
||||
mode = imagegen.ModeLLM
|
||||
mode = mlxrunner.ModeLLM
|
||||
}
|
||||
|
||||
// Use model name for MLX (it resolves manifests by name, not file path)
|
||||
modelName := req.model.ShortName
|
||||
server, err := imagegen.NewServer(modelName, mode)
|
||||
server, err := mlxrunner.NewServer(modelName, mode)
|
||||
if err != nil {
|
||||
req.errCh <- err
|
||||
return true
|
||||
|
||||
@@ -1,310 +0,0 @@
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
const (
|
||||
TOKEN_TYPE_NORMAL = iota + 1
|
||||
TOKEN_TYPE_UNKNOWN
|
||||
TOKEN_TYPE_CONTROL
|
||||
TOKEN_TYPE_USER_DEFINED
|
||||
TOKEN_TYPE_UNUSED
|
||||
TOKEN_TYPE_BYTE
|
||||
)
|
||||
|
||||
type Tokenizer interface {
|
||||
Encode(s string, addSpecial bool) ([]int32, error)
|
||||
Decode([]int32) (string, error)
|
||||
Is(int32, Special) bool
|
||||
Vocabulary() *Vocabulary
|
||||
}
|
||||
|
||||
func New(root *model.Root) (Tokenizer, error) {
|
||||
f, err := root.Open("tokenizer.json")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var tokenizer struct {
|
||||
Model struct {
|
||||
Type string `json:"type"`
|
||||
Vocab map[string]int32 `json:"vocab"`
|
||||
Merges json.RawMessage `json:"merges"`
|
||||
} `json:"model"`
|
||||
|
||||
PreTokenizer json.RawMessage `json:"pre_tokenizer"`
|
||||
Decoder json.RawMessage `json:"decoder"`
|
||||
|
||||
AddedTokens []struct {
|
||||
ID int32 `json:"id"`
|
||||
Content string `json:"content"`
|
||||
Special bool `json:"special"`
|
||||
} `json:"added_tokens"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(f).Decode(&tokenizer); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
special := make(map[int32]struct{})
|
||||
for _, token := range tokenizer.AddedTokens {
|
||||
tokenizer.Model.Vocab[token.Content] = token.ID
|
||||
special[token.ID] = struct{}{}
|
||||
}
|
||||
|
||||
vocab, err := specialTokens(root, tokenizer.Model.Vocab)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
vocab.Values = make([]string, len(tokenizer.Model.Vocab))
|
||||
vocab.Scores = make([]float32, len(tokenizer.Model.Vocab))
|
||||
vocab.Types = make([]int32, len(tokenizer.Model.Vocab))
|
||||
for content, id := range tokenizer.Model.Vocab {
|
||||
vocab.Values[id] = content
|
||||
vocab.Scores[id] = float32(id)
|
||||
vocab.Types[id] = TOKEN_TYPE_NORMAL
|
||||
if _, ok := special[id]; ok {
|
||||
vocab.Types[id] = TOKEN_TYPE_USER_DEFINED
|
||||
}
|
||||
}
|
||||
|
||||
if tokenizer.Model.Merges != nil {
|
||||
var pairs [][]string
|
||||
if err := json.Unmarshal(tokenizer.Model.Merges, &pairs); err == nil {
|
||||
vocab.Merges = make([]string, len(pairs))
|
||||
for i, pair := range pairs {
|
||||
vocab.Merges[i] = pair[0] + " " + pair[1]
|
||||
}
|
||||
} else if err := json.Unmarshal(tokenizer.Model.Merges, &vocab.Merges); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
vocab.valuesOnce.Do(func() {})
|
||||
vocab.values = tokenizer.Model.Vocab
|
||||
|
||||
if tokenizer.Model.Type == "WordPiece" {
|
||||
return NewWordPiece(vocab, true), nil
|
||||
}
|
||||
|
||||
if tokenizer.Decoder != nil {
|
||||
var decoder struct {
|
||||
Type string `json:"type"`
|
||||
Decoders []struct {
|
||||
Type string `json:"type"`
|
||||
Pattern struct {
|
||||
String string `json:"string"`
|
||||
} `json:"pattern"`
|
||||
} `json:"decoders"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(tokenizer.Decoder, &decoder); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if decoder.Type == "Sequence" {
|
||||
for _, d := range decoder.Decoders {
|
||||
if d.Type == "Replace" && d.Pattern.String == "▁" {
|
||||
return NewSentencePiece(vocab), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var pretokenizers []string
|
||||
if tokenizer.PreTokenizer != nil {
|
||||
var pretokenizer struct {
|
||||
Type string `json:"type"`
|
||||
Pretokenizers []struct {
|
||||
Type string `json:"type"`
|
||||
Pattern struct {
|
||||
Regex string
|
||||
} `json:"pattern"`
|
||||
IndividualDigits bool `json:"individual_digits"`
|
||||
}
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(tokenizer.PreTokenizer, &pretokenizer); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if pretokenizer.Type == "Sequence" {
|
||||
for _, pretokenizer := range pretokenizer.Pretokenizers {
|
||||
switch pretokenizer.Type {
|
||||
case "Digits":
|
||||
if pretokenizer.IndividualDigits {
|
||||
pretokenizers = append(pretokenizers, `\d`)
|
||||
} else {
|
||||
pretokenizers = append(pretokenizers, `\d+`)
|
||||
}
|
||||
case "Punctuation":
|
||||
pretokenizers = append(pretokenizers, `[^\p{L}\p{N}]+`)
|
||||
case "Split":
|
||||
pretokenizers = append(pretokenizers, pretokenizer.Pattern.Regex)
|
||||
case "WhitespaceSplit":
|
||||
pretokenizers = append(pretokenizers, `\s+(?!\S)|\s+`)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return NewBytePairEncoding(vocab, pretokenizers...), nil
|
||||
}
|
||||
|
||||
// valueOrValues is a type that can unmarshal from either a single value or an array of values.
|
||||
type valueOrValues[E any] []E
|
||||
|
||||
func (m *valueOrValues[E]) UnmarshalJSON(data []byte) error {
|
||||
var s []E
|
||||
if err := json.Unmarshal(data, &s); err != nil {
|
||||
var e E
|
||||
if err := json.Unmarshal(data, &e); err != nil {
|
||||
return err
|
||||
}
|
||||
s = []E{e}
|
||||
}
|
||||
*m = valueOrValues[E](s)
|
||||
return nil
|
||||
}
|
||||
|
||||
type specialTokenIDs struct {
|
||||
BOSTokenID valueOrValues[int32] `json:"bos_token_id"`
|
||||
EOSTokenID valueOrValues[int32] `json:"eos_token_id"`
|
||||
}
|
||||
|
||||
// stringOrContent is a type that can unmarshal from either a string or an object with a "content" field.
|
||||
type stringOrContent string
|
||||
|
||||
func (t *stringOrContent) UnmarshalJSON(data []byte) error {
|
||||
var s string
|
||||
if err := json.Unmarshal(data, &s); err != nil {
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
return err
|
||||
}
|
||||
if content, ok := m["content"].(string); ok {
|
||||
s = content
|
||||
}
|
||||
}
|
||||
*t = stringOrContent(s)
|
||||
return nil
|
||||
}
|
||||
|
||||
func specialTokens(root *model.Root, values map[string]int32) (*Vocabulary, error) {
|
||||
var vocab Vocabulary
|
||||
for _, c := range []struct {
|
||||
name string
|
||||
fn func(io.Reader) error
|
||||
}{
|
||||
{
|
||||
name: "generation_config.json",
|
||||
fn: func(r io.Reader) error {
|
||||
var c specialTokenIDs
|
||||
if err := json.NewDecoder(r).Decode(&c); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
vocab.BOS = c.BOSTokenID
|
||||
vocab.EOS = c.EOSTokenID
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "config.json",
|
||||
fn: func(r io.Reader) error {
|
||||
var c specialTokenIDs
|
||||
if err := json.NewDecoder(r).Decode(&c); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(vocab.BOS) == 0 {
|
||||
vocab.BOS = c.BOSTokenID
|
||||
}
|
||||
|
||||
if len(vocab.EOS) == 0 {
|
||||
vocab.EOS = c.EOSTokenID
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tokenizer_config.json",
|
||||
fn: func(r io.Reader) error {
|
||||
var c struct {
|
||||
BOSToken stringOrContent `json:"bos_token"`
|
||||
EOSToken stringOrContent `json:"eos_token"`
|
||||
PADToken stringOrContent `json:"pad_token"`
|
||||
AddBOSToken bool `json:"add_bos_token"`
|
||||
AddEOSToken bool `json:"add_eos_token"`
|
||||
}
|
||||
if err := json.NewDecoder(r).Decode(&c); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(vocab.BOS) == 0 && c.BOSToken != "" {
|
||||
if id, ok := values[string(c.BOSToken)]; ok {
|
||||
vocab.BOS = []int32{id}
|
||||
}
|
||||
}
|
||||
|
||||
if len(vocab.EOS) == 0 && c.EOSToken != "" {
|
||||
if id, ok := values[string(c.EOSToken)]; ok {
|
||||
vocab.EOS = []int32{id}
|
||||
}
|
||||
}
|
||||
|
||||
vocab.AddBOS = c.AddBOSToken
|
||||
vocab.AddEOS = c.AddEOSToken
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "special_tokens_map.json",
|
||||
fn: func(r io.Reader) error {
|
||||
var c map[string]stringOrContent
|
||||
if err := json.NewDecoder(r).Decode(&c); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if bos, ok := c["bos_token"]; ok && len(vocab.BOS) == 0 {
|
||||
if id, ok := values[string(bos)]; ok {
|
||||
vocab.BOS = []int32{id}
|
||||
}
|
||||
}
|
||||
|
||||
if eos, ok := c["eos_token"]; ok && len(vocab.EOS) == 0 {
|
||||
if id, ok := values[string(eos)]; ok {
|
||||
vocab.EOS = []int32{id}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
} {
|
||||
if err := func() error {
|
||||
f, err := root.Open(c.name)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
return c.fn(f)
|
||||
}(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &vocab, nil
|
||||
}
|
||||
@@ -1,309 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"io/fs"
|
||||
"iter"
|
||||
"maps"
|
||||
"mime"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
func root() (*os.Root, error) {
|
||||
root, err := os.OpenRoot(envconfig.Models())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, sub := range []string{"manifests", "blobs"} {
|
||||
if _, err := root.Stat(sub); errors.Is(err, fs.ErrNotExist) {
|
||||
if err := root.MkdirAll(sub, 0o750); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return root, nil
|
||||
}
|
||||
|
||||
// Open opens an existing file for reading. It will return [fs.ErrNotExist]
|
||||
// if the file does not exist. The returned [*Root] can only be used for reading.
|
||||
// It is the caller's responsibility to close the file when done.
|
||||
func Open(n Name) (*Root, error) {
|
||||
r, err := root()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, err := r.Open(filepath.Join("manifests", n.Filepath()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var m manifest
|
||||
if err := json.NewDecoder(f).Decode(&m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
blobs := make(map[string]*blob, len(m.Layers)+1)
|
||||
blobs[NamePrefix] = m.Config
|
||||
for _, layer := range m.Layers {
|
||||
if layer.Name == "" && layer.MediaType != "" {
|
||||
mediatype, _, err := mime.ParseMediaType(layer.MediaType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if suffix, ok := strings.CutPrefix(mediatype, MediaTypePrefix); ok {
|
||||
layer.Name = NamePrefix + suffix
|
||||
}
|
||||
}
|
||||
|
||||
blobs[layer.Name] = layer
|
||||
}
|
||||
|
||||
return &Root{
|
||||
root: r,
|
||||
name: n,
|
||||
blobs: blobs,
|
||||
flags: os.O_RDONLY,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Create creates a new file. The returned [Root] can be used for both reading
|
||||
// and writing. It is the caller's responsibility to close the file when done
|
||||
// in order to finalize any new blobs and write the manifest.
|
||||
func Create(n Name) (*Root, error) {
|
||||
r, err := root()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Root{
|
||||
root: r,
|
||||
name: n,
|
||||
blobs: make(map[string]*blob),
|
||||
flags: os.O_RDWR,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type blob struct {
|
||||
Digest string `json:"digest"`
|
||||
MediaType string `json:"mediaType"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Size int64 `json:"size"`
|
||||
|
||||
// tempfile is the temporary file where the blob data is written.
|
||||
tempfile *os.File
|
||||
|
||||
// hash is the hash.Hash used to compute the blob digest.
|
||||
hash hash.Hash
|
||||
}
|
||||
|
||||
func (b *blob) Write(p []byte) (int, error) {
|
||||
return io.MultiWriter(b.tempfile, b.hash).Write(p)
|
||||
}
|
||||
|
||||
func (b *blob) Filepath() string {
|
||||
return strings.ReplaceAll(b.Digest, ":", "-")
|
||||
}
|
||||
|
||||
type manifest struct {
|
||||
SchemaVersion int `json:"schemaVersion"`
|
||||
MediaType string `json:"mediaType"`
|
||||
Config *blob `json:"config"`
|
||||
Layers []*blob `json:"layers"`
|
||||
}
|
||||
|
||||
// Root represents a model file. It can be used to read and write blobs
|
||||
// associated with the model.
|
||||
//
|
||||
// Blobs are identified by name. Certain names are special and reserved;
|
||||
// see [NamePrefix] for details.
|
||||
type Root struct {
|
||||
root *os.Root
|
||||
name Name
|
||||
blobs map[string]*blob
|
||||
flags int
|
||||
}
|
||||
|
||||
const MediaTypePrefix = "application/vnd.ollama"
|
||||
|
||||
// NamePrefix is the prefix used for identifying special names. Names
|
||||
// with this prefix are idenfitied by their media types:
|
||||
//
|
||||
// - name: NamePrefix + suffix
|
||||
// - mediaType: [MediaTypePrefix] + suffix
|
||||
//
|
||||
// For example:
|
||||
//
|
||||
// - name: "./..image.model"
|
||||
// - mediaType: "application/vnd.ollama.image.model"
|
||||
//
|
||||
// NamePrefix by itself identifies the manifest config.
|
||||
const NamePrefix = "./."
|
||||
|
||||
// Open opens the named blob for reading. It is the caller's responsibility
|
||||
// to close the returned [io.ReadCloser] when done. It will return
|
||||
// [fs.ErrNotExist] if the blob does not exist.
|
||||
func (r Root) Open(name string) (io.ReadCloser, error) {
|
||||
if b, ok := r.blobs[name]; ok {
|
||||
r, err := r.root.Open(filepath.Join("blobs", b.Filepath()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
return nil, fs.ErrNotExist
|
||||
}
|
||||
|
||||
func (r Root) ReadFile(name string) ([]byte, error) {
|
||||
f, err := r.Open(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
return io.ReadAll(f)
|
||||
}
|
||||
|
||||
// Create creates or replaces a named blob in the file. If the blob already
|
||||
// exists, it will be overwritten. It will return [fs.ErrInvalid] if the file
|
||||
// was opened in read-only mode. The returned [io.Writer] can be used to write
|
||||
// to the blob and does not need be closed, but the file must be closed to
|
||||
// finalize the blob.
|
||||
func (r *Root) Create(name string) (io.Writer, error) {
|
||||
if r.flags&os.O_RDWR != 0 {
|
||||
w, err := os.CreateTemp(r.root.Name(), "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r.blobs[name] = &blob{Name: name, tempfile: w, hash: sha256.New()}
|
||||
return r.blobs[name], nil
|
||||
}
|
||||
|
||||
return nil, fs.ErrInvalid
|
||||
}
|
||||
|
||||
// Close closes the file. If the file was opened in read-write mode, it
|
||||
// will finalize any writeable blobs and write the manifest.
|
||||
func (r *Root) Close() error {
|
||||
if r.flags&os.O_RDWR != 0 {
|
||||
for _, b := range r.blobs {
|
||||
if b.tempfile != nil {
|
||||
fi, err := b.tempfile.Stat()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := b.tempfile.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
b.Size = fi.Size()
|
||||
b.Digest = fmt.Sprintf("sha256:%x", b.hash.Sum(nil))
|
||||
|
||||
if suffix, ok := strings.CutPrefix(b.Name, NamePrefix); ok {
|
||||
if b.Name == NamePrefix {
|
||||
b.MediaType = "application/vnd.docker.container.image.v1+json"
|
||||
} else {
|
||||
b.MediaType = MediaTypePrefix + suffix
|
||||
}
|
||||
b.Name = ""
|
||||
}
|
||||
|
||||
rel, err := filepath.Rel(r.root.Name(), b.tempfile.Name())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := r.root.Rename(rel, filepath.Join("blobs", b.Filepath())); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
p := filepath.Join("manifests", r.name.Filepath())
|
||||
if _, err := r.root.Stat(filepath.Dir(p)); errors.Is(err, os.ErrNotExist) {
|
||||
if err := r.root.MkdirAll(filepath.Dir(p), 0o750); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
f, err := r.root.OpenFile(p, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o640)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := json.NewEncoder(f).Encode(manifest{
|
||||
SchemaVersion: 2,
|
||||
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
||||
Config: r.blobs[NamePrefix],
|
||||
Layers: func() []*blob {
|
||||
blobs := make([]*blob, 0, len(r.blobs))
|
||||
for name, b := range r.blobs {
|
||||
if name != NamePrefix {
|
||||
blobs = append(blobs, b)
|
||||
}
|
||||
}
|
||||
return blobs
|
||||
}(),
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return r.root.Close()
|
||||
}
|
||||
|
||||
// Name returns the name of the file.
|
||||
func (r Root) Name() Name {
|
||||
return r.name
|
||||
}
|
||||
|
||||
// Names returns an iterator over the names in the file.
|
||||
func (r Root) Names() iter.Seq[string] {
|
||||
return maps.Keys(r.blobs)
|
||||
}
|
||||
|
||||
// Glob returns an iterator over the names in the file that match the given
|
||||
// pattern.
|
||||
//
|
||||
// The pattern syntax is the same as [filepath.Match]. As with filepath.Match,
|
||||
// the only possible returned error is ErrBadPattern, when pattern is malformed.
|
||||
func (r Root) Glob(pattern string) (iter.Seq[string], error) {
|
||||
if _, err := filepath.Match(pattern, ""); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return func(yield func(string) bool) {
|
||||
for name, blob := range r.blobs {
|
||||
if matched, _ := filepath.Match(pattern, name); matched {
|
||||
if !yield(blob.Filepath()) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r Root) JoinPath(parts ...string) string {
|
||||
return filepath.Join(append([]string{r.root.Name()}, parts...)...)
|
||||
}
|
||||
@@ -1,90 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// setup is a helper function to set up the test environment.
|
||||
func setup(t *testing.T, models map[Name]map[string]io.Reader) {
|
||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||
|
||||
for m, s := range models {
|
||||
f, err := Create(m)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for n, r := range s {
|
||||
w, err := f.Create(n)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := io.Copy(w, r); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := f.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpen(t *testing.T) {
|
||||
setup(t, map[Name]map[string]io.Reader{
|
||||
ParseName("namespace/model"): {
|
||||
"./.": strings.NewReader(`{"key":"value"}`),
|
||||
},
|
||||
ParseName("namespace/model:8b"): {
|
||||
"./.": strings.NewReader(`{"foo":"bar"}`),
|
||||
},
|
||||
ParseName("another/model"): {
|
||||
"./.": strings.NewReader(`{"another":"config"}`),
|
||||
},
|
||||
})
|
||||
|
||||
f, err := Open(ParseName("namespace/model"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, name := range []string{"./."} {
|
||||
r, err := f.Open(name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := io.ReadAll(r); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := r.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := f.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("does not exist", func(t *testing.T) {
|
||||
if _, err := Open(ParseName("namespace/unknown")); err == nil {
|
||||
t.Error("expected error for unknown model")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("write", func(t *testing.T) {
|
||||
f, err := Open(ParseName("namespace/model"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if _, err := f.Create("new-blob"); err == nil {
|
||||
t.Error("expected error creating blob in read-only mode")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"iter"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
func All() (iter.Seq[Name], error) {
|
||||
r, err := root()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
manifests, err := r.OpenRoot("manifests")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
matches, err := fs.Glob(manifests.FS(), "*/*/*/*")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return func(yield func(Name) bool) {
|
||||
for _, match := range matches {
|
||||
name := ParseNameFromFilepath(filepath.ToSlash(match))
|
||||
if !yield(name) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
@@ -227,17 +227,6 @@ func (n Name) String() string {
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// Set implements [flag.Value]. It parses the provided input as a name string
|
||||
// and sets the receiver to the parsed value. If the parsed name is not valid,
|
||||
// ErrUnqualifiedName is returned.
|
||||
func (n *Name) Set(s string) error {
|
||||
*n = ParseName(s)
|
||||
if !n.IsValid() {
|
||||
return ErrUnqualifiedName
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DisplayShortest returns a short string version of the name.
|
||||
func (n Name) DisplayShortest() string {
|
||||
var sb strings.Builder
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package manifest
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -1,4 +1,4 @@
|
||||
package manifest
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
@@ -14,8 +14,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"runtime"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
)
|
||||
|
||||
// SupportedBackends lists the backends that support image generation.
|
||||
@@ -43,8 +41,8 @@ func CheckPlatformSupport() error {
|
||||
// ResolveModelName checks if a model name is a known image generation model.
|
||||
// Returns the normalized model name if found, empty string otherwise.
|
||||
func ResolveModelName(modelName string) string {
|
||||
modelManifest, err := manifest.LoadManifest(modelName)
|
||||
if err == nil && modelManifest.HasTensorLayers() {
|
||||
manifest, err := LoadManifest(modelName)
|
||||
if err == nil && manifest.HasTensorLayers() {
|
||||
return modelName
|
||||
}
|
||||
return ""
|
||||
@@ -54,12 +52,12 @@ func ResolveModelName(modelName string) string {
|
||||
// Checks both "architecture" (Ollama format) and "_class_name" (diffusers format).
|
||||
// Returns empty string if detection fails.
|
||||
func DetectModelType(modelName string) string {
|
||||
modelManifest, err := manifest.LoadManifest(modelName)
|
||||
manifest, err := LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
data, err := modelManifest.ReadConfig("model_index.json")
|
||||
data, err := manifest.ReadConfig("model_index.json")
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen3"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
@@ -61,7 +61,7 @@ func (m *Model) Load(modelName string) error {
|
||||
m.ModelName = modelName
|
||||
|
||||
// Load manifest
|
||||
manifest, err := manifest.LoadManifest(modelName)
|
||||
manifest, err := imagegen.LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load manifest: %w", err)
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
@@ -14,19 +14,19 @@ import (
|
||||
|
||||
// TransformerConfig holds Flux2 transformer configuration
|
||||
type TransformerConfig struct {
|
||||
AttentionHeadDim int32 `json:"attention_head_dim"` // 128
|
||||
AxesDimsRoPE []int32 `json:"axes_dims_rope"` // [32, 32, 32, 32]
|
||||
Eps float32 `json:"eps"` // 1e-6
|
||||
GuidanceEmbeds bool `json:"guidance_embeds"` // false for Klein
|
||||
InChannels int32 `json:"in_channels"` // 128
|
||||
JointAttentionDim int32 `json:"joint_attention_dim"` // 7680
|
||||
MLPRatio float32 `json:"mlp_ratio"` // 3.0
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"` // 24
|
||||
NumLayers int32 `json:"num_layers"` // 5
|
||||
NumSingleLayers int32 `json:"num_single_layers"` // 20
|
||||
PatchSize int32 `json:"patch_size"` // 1
|
||||
RopeTheta int32 `json:"rope_theta"` // 2000
|
||||
TimestepGuidanceChannels int32 `json:"timestep_guidance_channels"` // 256
|
||||
AttentionHeadDim int32 `json:"attention_head_dim"` // 128
|
||||
AxesDimsRoPE []int32 `json:"axes_dims_rope"` // [32, 32, 32, 32]
|
||||
Eps float32 `json:"eps"` // 1e-6
|
||||
GuidanceEmbeds bool `json:"guidance_embeds"` // false for Klein
|
||||
InChannels int32 `json:"in_channels"` // 128
|
||||
JointAttentionDim int32 `json:"joint_attention_dim"` // 7680
|
||||
MLPRatio float32 `json:"mlp_ratio"` // 3.0
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"` // 24
|
||||
NumLayers int32 `json:"num_layers"` // 5
|
||||
NumSingleLayers int32 `json:"num_single_layers"` // 20
|
||||
PatchSize int32 `json:"patch_size"` // 1
|
||||
RopeTheta int32 `json:"rope_theta"` // 2000
|
||||
TimestepGuidanceChannels int32 `json:"timestep_guidance_channels"` // 256
|
||||
}
|
||||
|
||||
// Computed dimensions
|
||||
@@ -392,12 +392,12 @@ type Flux2Transformer2DModel struct {
|
||||
}
|
||||
|
||||
// Load loads the Flux2 transformer from ollama blob storage.
|
||||
func (m *Flux2Transformer2DModel) Load(modelManifest *manifest.ModelManifest) error {
|
||||
func (m *Flux2Transformer2DModel) Load(manifest *imagegen.ModelManifest) error {
|
||||
fmt.Print(" Loading transformer... ")
|
||||
|
||||
// Load config from blob
|
||||
var cfg TransformerConfig
|
||||
if err := modelManifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
|
||||
if err := manifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.TransformerConfig = &cfg
|
||||
@@ -412,7 +412,7 @@ func (m *Flux2Transformer2DModel) Load(modelManifest *manifest.ModelManifest) er
|
||||
}
|
||||
|
||||
// Load weights from tensor blobs
|
||||
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "transformer")
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
@@ -15,21 +15,21 @@ import (
|
||||
|
||||
// VAEConfig holds AutoencoderKLFlux2 configuration
|
||||
type VAEConfig struct {
|
||||
ActFn string `json:"act_fn"` // "silu"
|
||||
BatchNormEps float32 `json:"batch_norm_eps"` // 0.0001
|
||||
BatchNormMomentum float32 `json:"batch_norm_momentum"` // 0.1
|
||||
BlockOutChannels []int32 `json:"block_out_channels"` // [128, 256, 512, 512]
|
||||
ForceUpcast bool `json:"force_upcast"` // true
|
||||
InChannels int32 `json:"in_channels"` // 3
|
||||
LatentChannels int32 `json:"latent_channels"` // 32
|
||||
LayersPerBlock int32 `json:"layers_per_block"` // 2
|
||||
ActFn string `json:"act_fn"` // "silu"
|
||||
BatchNormEps float32 `json:"batch_norm_eps"` // 0.0001
|
||||
BatchNormMomentum float32 `json:"batch_norm_momentum"` // 0.1
|
||||
BlockOutChannels []int32 `json:"block_out_channels"` // [128, 256, 512, 512]
|
||||
ForceUpcast bool `json:"force_upcast"` // true
|
||||
InChannels int32 `json:"in_channels"` // 3
|
||||
LatentChannels int32 `json:"latent_channels"` // 32
|
||||
LayersPerBlock int32 `json:"layers_per_block"` // 2
|
||||
MidBlockAddAttn bool `json:"mid_block_add_attention"` // true
|
||||
NormNumGroups int32 `json:"norm_num_groups"` // 32
|
||||
OutChannels int32 `json:"out_channels"` // 3
|
||||
PatchSize []int32 `json:"patch_size"` // [2, 2]
|
||||
SampleSize int32 `json:"sample_size"` // 1024
|
||||
UsePostQuantConv bool `json:"use_post_quant_conv"` // true
|
||||
UseQuantConv bool `json:"use_quant_conv"` // true
|
||||
NormNumGroups int32 `json:"norm_num_groups"` // 32
|
||||
OutChannels int32 `json:"out_channels"` // 3
|
||||
PatchSize []int32 `json:"patch_size"` // [2, 2]
|
||||
SampleSize int32 `json:"sample_size"` // 1024
|
||||
UsePostQuantConv bool `json:"use_post_quant_conv"` // true
|
||||
UseQuantConv bool `json:"use_quant_conv"` // true
|
||||
}
|
||||
|
||||
// BatchNorm2D implements 2D batch normalization with running statistics
|
||||
@@ -356,18 +356,18 @@ func (db *DownEncoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
}
|
||||
|
||||
// Load loads the Flux2 VAE from ollama blob storage.
|
||||
func (m *AutoencoderKLFlux2) Load(modelManifest *manifest.ModelManifest) error {
|
||||
func (m *AutoencoderKLFlux2) Load(manifest *imagegen.ModelManifest) error {
|
||||
fmt.Print(" Loading VAE... ")
|
||||
|
||||
// Load config from blob
|
||||
var cfg VAEConfig
|
||||
if err := modelManifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
|
||||
if err := manifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.Config = &cfg
|
||||
|
||||
// Load weights from tensor blobs
|
||||
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "vae")
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vae")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
@@ -38,11 +38,11 @@ type Config struct {
|
||||
AttentionBias bool `json:"attention_bias"`
|
||||
|
||||
// MLA (Multi-head Latent Attention) parameters
|
||||
QLoraRank int32 `json:"q_lora_rank"`
|
||||
KVLoraRank int32 `json:"kv_lora_rank"`
|
||||
QKRopeHeadDim int32 `json:"qk_rope_head_dim"`
|
||||
QKNopeHeadDim int32 `json:"qk_nope_head_dim"`
|
||||
VHeadDim int32 `json:"v_head_dim"`
|
||||
QLoraRank int32 `json:"q_lora_rank"`
|
||||
KVLoraRank int32 `json:"kv_lora_rank"`
|
||||
QKRopeHeadDim int32 `json:"qk_rope_head_dim"`
|
||||
QKNopeHeadDim int32 `json:"qk_nope_head_dim"`
|
||||
VHeadDim int32 `json:"v_head_dim"`
|
||||
|
||||
// MoE parameters
|
||||
NRoutedExperts int32 `json:"n_routed_experts"`
|
||||
@@ -82,7 +82,7 @@ type MLAAttention struct {
|
||||
// Absorbed MLA projections (derived from kv_b_proj)
|
||||
// EmbedQ: projects q_nope to latent space [num_heads, kv_lora_rank, qk_nope_head_dim]
|
||||
// UnembedOut: projects attention output from latent space [num_heads, v_head_dim, kv_lora_rank]
|
||||
EmbedQ *nn.MultiLinear `weight:"-"`
|
||||
EmbedQ *nn.MultiLinear `weight:"-"`
|
||||
UnembedOut *nn.MultiLinear `weight:"-"`
|
||||
|
||||
// Output projection
|
||||
@@ -194,8 +194,8 @@ func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
|
||||
// MoEGate implements the expert gating mechanism
|
||||
type MoEGate struct {
|
||||
Gate nn.LinearLayer `weight:"mlp.gate"`
|
||||
EScoreCorrectionBias *mlx.Array `weight:"mlp.gate.e_score_correction_bias,optional"`
|
||||
Gate nn.LinearLayer `weight:"mlp.gate"`
|
||||
EScoreCorrectionBias *mlx.Array `weight:"mlp.gate.e_score_correction_bias,optional"`
|
||||
}
|
||||
|
||||
// Forward computes expert selection indices and scores
|
||||
@@ -617,9 +617,9 @@ func sanitizeExpertWeights(weights safetensors.WeightSource, prefix string, numE
|
||||
}
|
||||
|
||||
// LoadFromManifest loads a GLM4-MoE-Lite model from a manifest (Ollama blob storage).
|
||||
func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) {
|
||||
func LoadFromManifest(manifest *imagegen.ModelManifest) (*Model, error) {
|
||||
// Read config from manifest
|
||||
configData, err := modelManifest.ReadConfig("config.json")
|
||||
configData, err := manifest.ReadConfig("config.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
@@ -634,7 +634,7 @@ func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) {
|
||||
cfg.Scale = computeScale(&cfg)
|
||||
|
||||
// Load weights from manifest blobs
|
||||
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "")
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
@@ -653,7 +653,7 @@ func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) {
|
||||
}
|
||||
|
||||
// Load tokenizer from manifest with config files for EOS token detection
|
||||
tokData, err := modelManifest.ReadConfig("tokenizer.json")
|
||||
tokData, err := manifest.ReadConfig("tokenizer.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load tokenizer config: %w", err)
|
||||
}
|
||||
@@ -664,12 +664,12 @@ func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) {
|
||||
}
|
||||
|
||||
// Try to load generation_config.json if available (preferred source for EOS)
|
||||
if genConfigData, err := modelManifest.ReadConfig("generation_config.json"); err == nil {
|
||||
if genConfigData, err := manifest.ReadConfig("generation_config.json"); err == nil {
|
||||
tokConfig.GenerationConfigJSON = genConfigData
|
||||
}
|
||||
|
||||
// Try to load tokenizer_config.json if available
|
||||
if tokConfigData, err := modelManifest.ReadConfig("tokenizer_config.json"); err == nil {
|
||||
if tokConfigData, err := manifest.ReadConfig("tokenizer_config.json"); err == nil {
|
||||
tokConfig.TokenizerConfigJSON = tokConfigData
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
@@ -181,19 +181,19 @@ type TextEncoder struct {
|
||||
}
|
||||
|
||||
// Load loads the Qwen3 text encoder from ollama blob storage.
|
||||
func (m *TextEncoder) Load(modelManifest *manifest.ModelManifest, configPath string) error {
|
||||
func (m *TextEncoder) Load(manifest *imagegen.ModelManifest, configPath string) error {
|
||||
fmt.Print(" Loading text encoder... ")
|
||||
|
||||
// Load config from blob
|
||||
var cfg Config
|
||||
if err := modelManifest.ReadConfigJSON(configPath, &cfg); err != nil {
|
||||
if err := manifest.ReadConfigJSON(configPath, &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.Config = &cfg
|
||||
m.Layers = make([]*Block, cfg.NumHiddenLayers)
|
||||
|
||||
// Load weights from tensor blobs
|
||||
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "text_encoder")
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "text_encoder")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
@@ -7,8 +7,8 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
@@ -38,7 +38,7 @@ type TransformerConfig struct {
|
||||
type TimestepEmbedder struct {
|
||||
Linear1 nn.LinearLayer `weight:"mlp.0"`
|
||||
Linear2 nn.LinearLayer `weight:"mlp.2"`
|
||||
FreqEmbedSize int32 // 256 (computed)
|
||||
FreqEmbedSize int32 // 256 (computed)
|
||||
}
|
||||
|
||||
// Forward computes timestep embeddings -> [B, 256]
|
||||
@@ -85,9 +85,9 @@ func (xe *XEmbedder) Forward(x *mlx.Array) *mlx.Array {
|
||||
|
||||
// CapEmbedder projects caption features to model dimension
|
||||
type CapEmbedder struct {
|
||||
Norm *nn.RMSNorm `weight:"0"`
|
||||
Linear nn.LinearLayer `weight:"1"`
|
||||
PadToken *mlx.Array // loaded separately at root level
|
||||
Norm *nn.RMSNorm `weight:"0"`
|
||||
Linear nn.LinearLayer `weight:"1"`
|
||||
PadToken *mlx.Array // loaded separately at root level
|
||||
}
|
||||
|
||||
// Forward projects caption embeddings: [B, L, cap_feat_dim] -> [B, L, dim]
|
||||
@@ -103,9 +103,10 @@ type FeedForward struct {
|
||||
W1 nn.LinearLayer `weight:"w1"` // gate projection
|
||||
W2 nn.LinearLayer `weight:"w2"` // down projection
|
||||
W3 nn.LinearLayer `weight:"w3"` // up projection
|
||||
OutDim int32 // computed from W2
|
||||
OutDim int32 // computed from W2
|
||||
}
|
||||
|
||||
|
||||
// Forward applies SwiGLU: silu(W1(x)) * W3(x), then W2
|
||||
func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
@@ -131,11 +132,11 @@ type Attention struct {
|
||||
ToK nn.LinearLayer `weight:"to_k"`
|
||||
ToV nn.LinearLayer `weight:"to_v"`
|
||||
ToOut nn.LinearLayer `weight:"to_out.0"`
|
||||
NormQ *mlx.Array `weight:"norm_q.weight"` // [head_dim] for per-head RMSNorm
|
||||
NormK *mlx.Array `weight:"norm_k.weight"`
|
||||
NormQ *mlx.Array `weight:"norm_q.weight"` // [head_dim] for per-head RMSNorm
|
||||
NormK *mlx.Array `weight:"norm_k.weight"`
|
||||
// Fused QKV (computed at init time for efficiency, not loaded from weights)
|
||||
ToQKV nn.LinearLayer `weight:"-"` // Fused Q+K+V projection (created by FuseQKV)
|
||||
Fused bool `weight:"-"` // Whether to use fused QKV path
|
||||
Fused bool `weight:"-"` // Whether to use fused QKV path
|
||||
// Computed fields (not loaded from weights)
|
||||
NHeads int32 `weight:"-"`
|
||||
HeadDim int32 `weight:"-"`
|
||||
@@ -287,13 +288,13 @@ func applyRoPE3D(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
|
||||
|
||||
// TransformerBlock is a single transformer block with optional AdaLN modulation
|
||||
type TransformerBlock struct {
|
||||
Attention *Attention `weight:"attention"`
|
||||
FeedForward *FeedForward `weight:"feed_forward"`
|
||||
AttentionNorm1 *nn.RMSNorm `weight:"attention_norm1"`
|
||||
AttentionNorm2 *nn.RMSNorm `weight:"attention_norm2"`
|
||||
FFNNorm1 *nn.RMSNorm `weight:"ffn_norm1"`
|
||||
FFNNorm2 *nn.RMSNorm `weight:"ffn_norm2"`
|
||||
AdaLN nn.LinearLayer `weight:"adaLN_modulation.0,optional"` // only if modulation
|
||||
Attention *Attention `weight:"attention"`
|
||||
FeedForward *FeedForward `weight:"feed_forward"`
|
||||
AttentionNorm1 *nn.RMSNorm `weight:"attention_norm1"`
|
||||
AttentionNorm2 *nn.RMSNorm `weight:"attention_norm2"`
|
||||
FFNNorm1 *nn.RMSNorm `weight:"ffn_norm1"`
|
||||
FFNNorm2 *nn.RMSNorm `weight:"ffn_norm2"`
|
||||
AdaLN nn.LinearLayer `weight:"adaLN_modulation.0,optional"` // only if modulation
|
||||
// Computed fields
|
||||
HasModulation bool
|
||||
Dim int32
|
||||
@@ -349,7 +350,7 @@ func (tb *TransformerBlock) Forward(x *mlx.Array, adaln *mlx.Array, cos, sin *ml
|
||||
type FinalLayer struct {
|
||||
AdaLN nn.LinearLayer `weight:"adaLN_modulation.1"` // [256] -> [dim]
|
||||
Output nn.LinearLayer `weight:"linear"` // [dim] -> [out_channels]
|
||||
OutDim int32 // computed from Output
|
||||
OutDim int32 // computed from Output
|
||||
}
|
||||
|
||||
// Forward computes final output
|
||||
@@ -400,12 +401,12 @@ type Transformer struct {
|
||||
}
|
||||
|
||||
// Load loads the Z-Image transformer from ollama blob storage.
|
||||
func (m *Transformer) Load(modelManifest *manifest.ModelManifest) error {
|
||||
func (m *Transformer) Load(manifest *imagegen.ModelManifest) error {
|
||||
fmt.Print(" Loading transformer... ")
|
||||
|
||||
// Load config from blob
|
||||
var cfg TransformerConfig
|
||||
if err := modelManifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
|
||||
if err := manifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
if len(cfg.AllPatchSize) > 0 {
|
||||
@@ -416,7 +417,7 @@ func (m *Transformer) Load(modelManifest *manifest.ModelManifest) error {
|
||||
m.ContextRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
|
||||
m.Layers = make([]*TransformerBlock, cfg.NLayers)
|
||||
|
||||
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "transformer")
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/imagegen/vae"
|
||||
@@ -562,7 +562,7 @@ func (ub *UpDecoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
if ub.Upsample != nil {
|
||||
// Stage 1: Upsample2x (nearest neighbor)
|
||||
{
|
||||
prev := x
|
||||
prev := x
|
||||
x = Upsample2x(x)
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
@@ -570,7 +570,7 @@ func (ub *UpDecoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
|
||||
// Stage 2: Upsample conv
|
||||
{
|
||||
prev := x
|
||||
prev := x
|
||||
x = ub.Upsample.Forward(x)
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
@@ -643,16 +643,16 @@ type VAEDecoder struct {
|
||||
}
|
||||
|
||||
// Load loads the VAE decoder from ollama blob storage.
|
||||
func (m *VAEDecoder) Load(modelManifest *manifest.ModelManifest) error {
|
||||
func (m *VAEDecoder) Load(manifest *imagegen.ModelManifest) error {
|
||||
// Load config from blob
|
||||
var cfg VAEConfig
|
||||
if err := modelManifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
|
||||
if err := manifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.Config = &cfg
|
||||
|
||||
// Load weights from tensor blobs
|
||||
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "vae")
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vae")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
"github.com/ollama/ollama/x/imagegen/vae"
|
||||
@@ -18,14 +18,14 @@ import (
|
||||
// GenerateConfig holds all options for image generation.
|
||||
type GenerateConfig struct {
|
||||
Prompt string
|
||||
NegativePrompt string // Empty = no CFG
|
||||
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
|
||||
Width int32 // Image width (default: 1024)
|
||||
Height int32 // Image height (default: 1024)
|
||||
Steps int // Denoising steps (default: 9 for turbo)
|
||||
Seed int64 // Random seed
|
||||
NegativePrompt string // Empty = no CFG
|
||||
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
|
||||
Width int32 // Image width (default: 1024)
|
||||
Height int32 // Image height (default: 1024)
|
||||
Steps int // Denoising steps (default: 9 for turbo)
|
||||
Seed int64 // Random seed
|
||||
Progress func(step, totalSteps int) // Optional progress callback
|
||||
CapturePath string // GPU capture path (debug)
|
||||
CapturePath string // GPU capture path (debug)
|
||||
|
||||
// TeaCache options (timestep embedding aware caching)
|
||||
TeaCache bool // TeaCache is always enabled for faster inference
|
||||
@@ -58,7 +58,7 @@ func (m *Model) Load(modelName string) error {
|
||||
m.ModelName = modelName
|
||||
|
||||
// Load manifest
|
||||
manifest, err := manifest.LoadManifest(modelName)
|
||||
manifest, err := imagegen.LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load manifest: %w", err)
|
||||
}
|
||||
|
||||
@@ -1,203 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package imagegen provides a unified MLX runner for both LLM and image generation models.
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// Execute is the entry point for the unified MLX runner subprocess.
|
||||
func Execute(args []string) error {
|
||||
// Set up logging with appropriate level from environment
|
||||
slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: envconfig.LogLevel()})))
|
||||
|
||||
fs := flag.NewFlagSet("mlx-runner", flag.ExitOnError)
|
||||
modelName := fs.String("model", "", "path to model")
|
||||
port := fs.Int("port", 0, "port to listen on")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if *modelName == "" {
|
||||
return fmt.Errorf("--model is required")
|
||||
}
|
||||
if *port == 0 {
|
||||
return fmt.Errorf("--port is required")
|
||||
}
|
||||
|
||||
// Initialize MLX
|
||||
if err := mlx.InitMLX(); err != nil {
|
||||
slog.Error("unable to initialize MLX", "error", err)
|
||||
return err
|
||||
}
|
||||
slog.Info("MLX library initialized")
|
||||
|
||||
// Detect model type from capabilities
|
||||
mode := detectModelMode(*modelName)
|
||||
slog.Info("starting mlx runner", "model", *modelName, "port", *port, "mode", mode)
|
||||
|
||||
// Create and start server
|
||||
server, err := newServer(*modelName, *port, mode)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create server: %w", err)
|
||||
}
|
||||
|
||||
// Set up HTTP handlers
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/health", server.healthHandler)
|
||||
mux.HandleFunc("/completion", server.completionHandler)
|
||||
|
||||
// LLM-specific endpoints
|
||||
if mode == ModeLLM {
|
||||
mux.HandleFunc("/tokenize", server.tokenizeHandler)
|
||||
mux.HandleFunc("/embedding", server.embeddingHandler)
|
||||
}
|
||||
|
||||
httpServer := &http.Server{
|
||||
Addr: fmt.Sprintf("127.0.0.1:%d", *port),
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
// Handle shutdown
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-sigCh
|
||||
slog.Info("shutting down mlx runner")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
httpServer.Shutdown(ctx)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
slog.Info("mlx runner listening", "addr", httpServer.Addr)
|
||||
if err := httpServer.ListenAndServe(); err != http.ErrServerClosed {
|
||||
return err
|
||||
}
|
||||
|
||||
<-done
|
||||
return nil
|
||||
}
|
||||
|
||||
// detectModelMode determines whether a model is an LLM or image generation model.
|
||||
func detectModelMode(modelName string) ModelMode {
|
||||
// Check for image generation model by looking at model_index.json
|
||||
modelType := DetectModelType(modelName)
|
||||
if modelType != "" {
|
||||
// Known image generation model types
|
||||
switch modelType {
|
||||
case "ZImagePipeline", "FluxPipeline", "Flux2KleinPipeline":
|
||||
return ModeImageGen
|
||||
}
|
||||
}
|
||||
|
||||
// Default to LLM mode for safetensors models without known image gen types
|
||||
return ModeLLM
|
||||
}
|
||||
|
||||
// server holds the model and handles HTTP requests.
|
||||
type server struct {
|
||||
mode ModelMode
|
||||
modelName string
|
||||
port int
|
||||
|
||||
// Image generation model (when mode == ModeImageGen)
|
||||
imageModel ImageModel
|
||||
|
||||
// LLM model (when mode == ModeLLM)
|
||||
llmModel *llmState
|
||||
}
|
||||
|
||||
// newServer creates a new server instance and loads the appropriate model.
|
||||
func newServer(modelName string, port int, mode ModelMode) (*server, error) {
|
||||
s := &server{
|
||||
mode: mode,
|
||||
modelName: modelName,
|
||||
port: port,
|
||||
}
|
||||
|
||||
switch mode {
|
||||
case ModeImageGen:
|
||||
if err := s.loadImageModel(); err != nil {
|
||||
return nil, fmt.Errorf("failed to load image model: %w", err)
|
||||
}
|
||||
case ModeLLM:
|
||||
if err := s.loadLLMModel(); err != nil {
|
||||
return nil, fmt.Errorf("failed to load LLM model: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *server) healthHandler(w http.ResponseWriter, r *http.Request) {
|
||||
resp := HealthResponse{Status: "ok"}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
func (s *server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var req Request
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
switch s.mode {
|
||||
case ModeImageGen:
|
||||
s.handleImageCompletion(w, r, req)
|
||||
case ModeLLM:
|
||||
s.handleLLMCompletion(w, r, req)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *server) tokenizeHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if s.llmModel == nil {
|
||||
http.Error(w, "LLM model not loaded", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
tok := s.llmModel.model.Tokenizer()
|
||||
tokens := tok.Encode(req.Content, false)
|
||||
|
||||
// Convert int32 to int for JSON response
|
||||
intTokens := make([]int, len(tokens))
|
||||
for i, t := range tokens {
|
||||
intTokens[i] = int(t)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string][]int{"tokens": intTokens})
|
||||
}
|
||||
|
||||
func (s *server) embeddingHandler(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "embeddings not yet implemented for MLX models", http.StatusNotImplemented)
|
||||
}
|
||||
@@ -1,471 +0,0 @@
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
)
|
||||
|
||||
// Server wraps an MLX runner subprocess to implement llm.LlamaServer.
|
||||
//
|
||||
// This implementation is compatible with Ollama's scheduler and can be loaded/unloaded
|
||||
// like any other model. It supports both LLM (safetensors) and image generation models.
|
||||
type Server struct {
|
||||
mu sync.Mutex
|
||||
cmd *exec.Cmd
|
||||
port int
|
||||
modelName string
|
||||
mode ModelMode
|
||||
vramSize uint64
|
||||
done chan error
|
||||
client *http.Client
|
||||
lastErr string // Last stderr line for error reporting
|
||||
lastErrLock sync.Mutex
|
||||
}
|
||||
|
||||
// NewServer spawns a new MLX runner subprocess and waits until it's ready.
|
||||
func NewServer(modelName string, mode ModelMode) (*Server, error) {
|
||||
// Validate platform support before attempting to start
|
||||
if err := CheckPlatformSupport(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Find a free port
|
||||
port := 0
|
||||
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
|
||||
if l, err := net.ListenTCP("tcp", a); err == nil {
|
||||
port = l.Addr().(*net.TCPAddr).Port
|
||||
l.Close()
|
||||
}
|
||||
}
|
||||
if port == 0 {
|
||||
port = rand.Intn(65535-49152) + 49152
|
||||
}
|
||||
|
||||
// Get the current executable path (we use the same binary with runner subcommand)
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
|
||||
}
|
||||
if eval, err := filepath.EvalSymlinks(exe); err == nil {
|
||||
exe = eval
|
||||
}
|
||||
|
||||
// Spawn subprocess: ollama runner --imagegen-engine --model <path> --port <port>
|
||||
cmd := exec.Command(exe, "runner", "--imagegen-engine", "--model", modelName, "--port", strconv.Itoa(port))
|
||||
cmd.Env = os.Environ()
|
||||
|
||||
// On Linux, set LD_LIBRARY_PATH to include MLX library directories
|
||||
if runtime.GOOS == "linux" {
|
||||
// Build library paths: start with LibOllamaPath, then add any mlx_* subdirectories
|
||||
libraryPaths := []string{ml.LibOllamaPath}
|
||||
if mlxDirs, err := filepath.Glob(filepath.Join(ml.LibOllamaPath, "mlx_*")); err == nil {
|
||||
libraryPaths = append(libraryPaths, mlxDirs...)
|
||||
}
|
||||
|
||||
// Append existing LD_LIBRARY_PATH if set
|
||||
if existingPath, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
|
||||
libraryPaths = append(libraryPaths, filepath.SplitList(existingPath)...)
|
||||
}
|
||||
|
||||
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
|
||||
|
||||
// Update or add LD_LIBRARY_PATH in cmd.Env
|
||||
found := false
|
||||
for i := range cmd.Env {
|
||||
if strings.HasPrefix(cmd.Env[i], "LD_LIBRARY_PATH=") {
|
||||
cmd.Env[i] = "LD_LIBRARY_PATH=" + pathEnvVal
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
cmd.Env = append(cmd.Env, "LD_LIBRARY_PATH="+pathEnvVal)
|
||||
}
|
||||
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
|
||||
}
|
||||
|
||||
// Estimate VRAM based on tensor size from manifest
|
||||
var vramSize uint64
|
||||
if modelManifest, err := manifest.LoadManifest(modelName); err == nil {
|
||||
vramSize = uint64(modelManifest.TotalTensorSize())
|
||||
} else {
|
||||
// Fallback: default to 8GB if manifest can't be loaded
|
||||
vramSize = 8 * 1024 * 1024 * 1024
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
cmd: cmd,
|
||||
port: port,
|
||||
modelName: modelName,
|
||||
mode: mode,
|
||||
vramSize: vramSize,
|
||||
done: make(chan error, 1),
|
||||
client: &http.Client{Timeout: 10 * time.Minute},
|
||||
}
|
||||
|
||||
// Forward subprocess stdout/stderr to server logs
|
||||
stdout, _ := cmd.StdoutPipe()
|
||||
stderr, _ := cmd.StderrPipe()
|
||||
go func() {
|
||||
scanner := bufio.NewScanner(stdout)
|
||||
for scanner.Scan() {
|
||||
slog.Info("mlx-runner", "msg", scanner.Text())
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
scanner := bufio.NewScanner(stderr)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
slog.Warn("mlx-runner", "msg", line)
|
||||
s.lastErrLock.Lock()
|
||||
s.lastErr = line
|
||||
s.lastErrLock.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
slog.Info("starting mlx runner subprocess", "exe", exe, "model", modelName, "port", port, "mode", mode)
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("failed to start mlx runner: %w", err)
|
||||
}
|
||||
|
||||
// Reap subprocess when it exits
|
||||
go func() {
|
||||
err := cmd.Wait()
|
||||
s.done <- err
|
||||
}()
|
||||
|
||||
// Wait for subprocess to be ready
|
||||
if err := s.waitUntilRunning(); err != nil {
|
||||
s.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// ModelPath returns the path to the model.
|
||||
func (s *Server) ModelPath() string {
|
||||
return s.modelName
|
||||
}
|
||||
|
||||
// Load satisfies the LlamaServer interface. MLX models don't need GPU layer assignment.
|
||||
func (s *Server) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Ping checks if the subprocess is healthy.
|
||||
func (s *Server) Ping(ctx context.Context) error {
|
||||
url := fmt.Sprintf("http://127.0.0.1:%d/health", s.port)
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("health check failed: %d", resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// waitUntilRunning waits for the subprocess to be ready.
|
||||
func (s *Server) waitUntilRunning() error {
|
||||
ctx := context.Background()
|
||||
timeout := time.After(2 * time.Minute)
|
||||
ticker := time.NewTicker(100 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case err := <-s.done:
|
||||
// Include recent stderr lines for better error context
|
||||
errMsg := s.getLastErr()
|
||||
if errMsg != "" {
|
||||
return fmt.Errorf("mlx runner failed: %s (exit: %v)", errMsg, err)
|
||||
}
|
||||
return fmt.Errorf("mlx runner exited unexpectedly: %w", err)
|
||||
case <-timeout:
|
||||
errMsg := s.getLastErr()
|
||||
if errMsg != "" {
|
||||
return fmt.Errorf("timeout waiting for mlx runner: %s", errMsg)
|
||||
}
|
||||
return errors.New("timeout waiting for mlx runner to start")
|
||||
case <-ticker.C:
|
||||
if err := s.Ping(ctx); err == nil {
|
||||
slog.Info("mlx runner is ready", "port", s.port)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getLastErr returns the last stderr line.
|
||||
func (s *Server) getLastErr() string {
|
||||
s.lastErrLock.Lock()
|
||||
defer s.lastErrLock.Unlock()
|
||||
return s.lastErr
|
||||
}
|
||||
|
||||
// WaitUntilRunning satisfies the LlamaServer interface.
|
||||
func (s *Server) WaitUntilRunning(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Completion handles both text and image generation requests.
|
||||
func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||
seed := req.Seed
|
||||
if seed == 0 {
|
||||
seed = time.Now().UnixNano()
|
||||
}
|
||||
|
||||
// Extract raw image bytes from llm.ImageData slice
|
||||
var images [][]byte
|
||||
for _, img := range req.Images {
|
||||
images = append(images, img.Data)
|
||||
}
|
||||
|
||||
// Build request for subprocess
|
||||
creq := Request{
|
||||
Prompt: req.Prompt,
|
||||
Width: req.Width,
|
||||
Height: req.Height,
|
||||
Steps: int(req.Steps),
|
||||
Seed: seed,
|
||||
Images: images,
|
||||
}
|
||||
|
||||
// Pass LLM options if present
|
||||
if req.Options != nil {
|
||||
creq.Options = &RequestOptions{
|
||||
NumPredict: req.Options.NumPredict,
|
||||
Temperature: float64(req.Options.Temperature),
|
||||
TopP: float64(req.Options.TopP),
|
||||
TopK: req.Options.TopK,
|
||||
Stop: req.Options.Stop,
|
||||
}
|
||||
}
|
||||
|
||||
body, err := json.Marshal(creq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := s.client.Do(httpReq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("%s", strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 1024*1024), 16*1024*1024) // 16MB max
|
||||
for scanner.Scan() {
|
||||
// Parse subprocess response
|
||||
var raw struct {
|
||||
Image string `json:"image,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Done bool `json:"done"`
|
||||
Step int `json:"step,omitempty"`
|
||||
Total int `json:"total,omitempty"`
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||
PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
|
||||
EvalCount int `json:"eval_count,omitempty"`
|
||||
EvalDuration int `json:"eval_duration,omitempty"`
|
||||
}
|
||||
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
|
||||
slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
|
||||
continue
|
||||
}
|
||||
|
||||
// Log stop reason when generation completes
|
||||
if raw.Done && raw.StopReason != "" {
|
||||
slog.Info("mlx generation completed", "stop_reason", raw.StopReason)
|
||||
}
|
||||
|
||||
// Convert to llm.CompletionResponse
|
||||
cresp := llm.CompletionResponse{
|
||||
Content: raw.Content,
|
||||
Done: raw.Done,
|
||||
Step: raw.Step,
|
||||
TotalSteps: raw.Total,
|
||||
Image: raw.Image,
|
||||
PromptEvalCount: raw.PromptEvalCount,
|
||||
PromptEvalDuration: time.Duration(raw.PromptEvalDuration),
|
||||
EvalCount: raw.EvalCount,
|
||||
EvalDuration: time.Duration(raw.EvalDuration),
|
||||
}
|
||||
|
||||
fn(cresp)
|
||||
if cresp.Done {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Scanner exited without receiving Done - connection was likely closed
|
||||
scanErr := scanner.Err()
|
||||
if scanErr != nil {
|
||||
slog.Error("mlx scanner error", "error", scanErr)
|
||||
} else {
|
||||
slog.Warn("mlx scanner EOF without Done response - subprocess may have crashed")
|
||||
}
|
||||
|
||||
// Check if subprocess is still alive
|
||||
if s.HasExited() {
|
||||
slog.Error("mlx subprocess has exited unexpectedly")
|
||||
}
|
||||
|
||||
return scanErr
|
||||
}
|
||||
|
||||
// Close terminates the subprocess.
|
||||
func (s *Server) Close() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.cmd != nil && s.cmd.Process != nil {
|
||||
slog.Info("stopping mlx runner subprocess", "pid", s.cmd.Process.Pid)
|
||||
s.cmd.Process.Signal(os.Interrupt)
|
||||
|
||||
// Wait briefly for graceful shutdown
|
||||
select {
|
||||
case <-s.done:
|
||||
case <-time.After(5 * time.Second):
|
||||
s.cmd.Process.Kill()
|
||||
}
|
||||
s.cmd = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// VRAMSize returns the estimated VRAM usage.
|
||||
func (s *Server) VRAMSize() uint64 {
|
||||
return s.vramSize
|
||||
}
|
||||
|
||||
// TotalSize returns the total memory usage.
|
||||
func (s *Server) TotalSize() uint64 {
|
||||
return s.vramSize
|
||||
}
|
||||
|
||||
// VRAMByGPU returns VRAM usage for a specific GPU.
|
||||
func (s *Server) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||
return s.vramSize
|
||||
}
|
||||
|
||||
// ContextLength returns the context length (not applicable for image generation).
|
||||
func (s *Server) ContextLength() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Embedding returns embeddings for the input.
|
||||
func (s *Server) Embedding(ctx context.Context, input string) ([]float32, int, error) {
|
||||
return nil, 0, errors.New("embeddings not supported for MLX models")
|
||||
}
|
||||
|
||||
// Tokenize tokenizes the input content.
|
||||
func (s *Server) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
body, err := json.Marshal(map[string]string{"content": content})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("http://127.0.0.1:%d/tokenize", s.port)
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("tokenize failed: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Tokens []int `json:"tokens"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result.Tokens, nil
|
||||
}
|
||||
|
||||
// Detokenize converts tokens back to text.
|
||||
func (s *Server) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
||||
return "", errors.New("detokenization not supported for MLX models")
|
||||
}
|
||||
|
||||
// Pid returns the process ID of the subprocess.
|
||||
func (s *Server) Pid() int {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.cmd != nil && s.cmd.Process != nil {
|
||||
return s.cmd.Process.Pid
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// GetPort returns the port the subprocess is listening on.
|
||||
func (s *Server) GetPort() int {
|
||||
return s.port
|
||||
}
|
||||
|
||||
// GetDeviceInfos returns device information.
|
||||
func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasExited returns whether the subprocess has exited.
|
||||
func (s *Server) HasExited() bool {
|
||||
select {
|
||||
case <-s.done:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure Server implements llm.LlamaServer
|
||||
var _ llm.LlamaServer = (*Server)(nil)
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build mlx
|
||||
|
||||
package manifest
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -15,9 +15,9 @@ import (
|
||||
type ManifestWeights struct {
|
||||
manifest *ModelManifest
|
||||
component string
|
||||
tensors map[string]ManifestLayer // name -> layer
|
||||
cache map[string]*mlx.Array // name -> loaded array
|
||||
nativeCache []*mlx.SafetensorsFile // keep native handles alive
|
||||
tensors map[string]ManifestLayer // name -> layer
|
||||
cache map[string]*mlx.Array // name -> loaded array
|
||||
nativeCache []*mlx.SafetensorsFile // keep native handles alive
|
||||
}
|
||||
|
||||
// LoadWeightsFromManifest creates a weight loader from manifest storage.
|
||||
77
x/kvcache/cache.go
Normal file
77
x/kvcache/cache.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package kvcache
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
"github.com/ollama/ollama/x/model/input"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrKvCacheFull = errors.New("could not find a kv cache slot")
|
||||
ErrNotSupported = errors.New("model does not support operation")
|
||||
)
|
||||
|
||||
type Cache interface {
|
||||
// ** used by model implementations **
|
||||
|
||||
// SetLayer sets the active layer of the cache
|
||||
SetLayer(layer int)
|
||||
|
||||
// Get returns the history of key and value tensors plus a mask
|
||||
//
|
||||
// The shape of the tensors is documented in the specific
|
||||
// cache implementation used.
|
||||
Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)
|
||||
|
||||
// Put stores a batch of key and value in the cache
|
||||
//
|
||||
// The shape of the tensors is documented in the specific
|
||||
// cache implementation used.
|
||||
Put(ctx ml.Context, key, value ml.Tensor)
|
||||
|
||||
// SetConfig controls optimizations (mostly backend-specific) that may transform
|
||||
// the output of the cache to work better with specific kernels. If not called,
|
||||
// the backend settings will be used. This works well when calling Attention.
|
||||
//
|
||||
// The config can be overridden by models, especially if they require vanilla
|
||||
// output when implementing their own version of attention. To do this, pass
|
||||
// an empty ml.CacheConfig.
|
||||
//
|
||||
// Most models will not need to use this.
|
||||
SetConfig(ml.CacheConfig)
|
||||
|
||||
// ** cache management **
|
||||
|
||||
// Init sets up runtime parameters.
|
||||
// backend: Used to allocate cache data storage and execute management operations (such as defrag)
|
||||
// dtype: The data type for storing cache entries
|
||||
// maxSequences: The maximum number of sequences stored in the cache - across all batches
|
||||
// capacity: The number of cache entries to store, per sequence
|
||||
// maxBatch: The maximum number of tokens that can occur in a single batch
|
||||
Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
|
||||
|
||||
// Close closes the cache and frees resources associated with it
|
||||
Close()
|
||||
|
||||
// StartForward is called before the start of the model's forward pass.
|
||||
// For each token in the coming batch, there must be a corresponding
|
||||
// entry in positions and seqs. reserve is to preallocate memory
|
||||
// without actually storing data in the cache.
|
||||
StartForward(ctx ml.Context, batch input.Batch, reserve bool) error
|
||||
|
||||
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
|
||||
CopyPrefix(srcSeq, dstSeq int, len int32)
|
||||
|
||||
// CanResume returns true if the cache can continue with the next token at
|
||||
// the given position and sequence. Assumes that the caller has already
|
||||
// verified the contents of the cache.
|
||||
CanResume(seq int, pos int32) bool
|
||||
|
||||
// Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set
|
||||
// endIndex to math.MaxInt32 to remove everything starting at beginIndex.
|
||||
//
|
||||
// If an error occurs, the entire context for the sequence should be
|
||||
// removed by calling Remove(seq, 0, math.MaxInt32)
|
||||
Remove(seq int, beginIndex, endIndex int32) error
|
||||
}
|
||||
144
x/kvcache/causal.go
Normal file
144
x/kvcache/causal.go
Normal file
@@ -0,0 +1,144 @@
|
||||
//go:build mlx
|
||||
|
||||
package kvcache
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
"github.com/ollama/ollama/x/model/input"
|
||||
)
|
||||
|
||||
// Causal cache stores K and V tensors according to their position in the
|
||||
// sequence. Returns the history and a mask for attending to past tokens
|
||||
type Causal struct {
|
||||
DType ml.DType
|
||||
|
||||
// locations for data storage for this batch
|
||||
curLocPut ml.Tensor
|
||||
|
||||
// locations for data storage for this batch
|
||||
curLocGet ml.Tensor
|
||||
|
||||
// the active layer for Get and Put
|
||||
curLayer int
|
||||
|
||||
capacity int
|
||||
|
||||
offset int
|
||||
|
||||
backend ml.Backend
|
||||
ctxs map[int]ml.Context
|
||||
keys, values map[int]ml.Tensor
|
||||
|
||||
// TODO is this needed per layer, or will it always be consistent?
|
||||
kHeadDims, vHeadDims, numKVHeads map[int]int
|
||||
}
|
||||
|
||||
func NewCausalCache() *Causal {
|
||||
return &Causal{
|
||||
ctxs: make(map[int]ml.Context),
|
||||
keys: make(map[int]ml.Tensor),
|
||||
values: make(map[int]ml.Tensor),
|
||||
kHeadDims: make(map[int]int),
|
||||
vHeadDims: make(map[int]int),
|
||||
numKVHeads: make(map[int]int),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
c.DType = dtype
|
||||
c.capacity = capacity
|
||||
c.backend = backend
|
||||
}
|
||||
|
||||
func (c *Causal) SetConfig(config ml.CacheConfig) {}
|
||||
|
||||
func (c *Causal) SetLayer(layer int) {
|
||||
c.curLayer = layer
|
||||
}
|
||||
|
||||
func (c *Causal) Close() {
|
||||
// slog.Info("XXX Causal.Close called", "number of contexts", len(c.ctxs))
|
||||
for _, ctx := range c.ctxs {
|
||||
ctx.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
locsPut := make([]int32, len(batch.Positions))
|
||||
for i := c.offset; i < len(batch.Positions); i++ {
|
||||
locsPut[i-c.offset] = int32(i)
|
||||
}
|
||||
c.offset += len(batch.Positions)
|
||||
locsGet := make([]int32, c.offset)
|
||||
for i := range c.offset {
|
||||
locsGet[i] = int32(i)
|
||||
}
|
||||
c.curLocGet = ctx.Input().FromInts(locsGet, len(locsGet))
|
||||
c.curLocPut = ctx.Input().FromInts(locsPut, len(locsPut))
|
||||
// slog.Info("XXX Causal.StartForward", "offset", c.offset, "put", locsPut, "get", locsGet)
|
||||
|
||||
return nil
|
||||
}
|
||||
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
kHeadDim := key.Dim(3)
|
||||
vHeadDim := value.Dim(3)
|
||||
numKVHeads := key.Dim(1)
|
||||
batchSize := key.Dim(2)
|
||||
kCellSize := kHeadDim * numKVHeads
|
||||
vCellSize := vHeadDim * numKVHeads
|
||||
// slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize, "kCellSize", kCellSize, "vCellSize", vCellSize)
|
||||
|
||||
if _, ok := c.ctxs[c.curLayer]; !ok {
|
||||
// slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer)
|
||||
c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
|
||||
}
|
||||
|
||||
if _, ok := c.keys[c.curLayer]; !ok {
|
||||
// slog.Info("XXX Causal.Put allocating keys and values", "c.curLayer", c.curLayer, "shape", []int{c.capacity, kCellSize})
|
||||
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, kCellSize)
|
||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, vCellSize)
|
||||
c.kHeadDims[c.curLayer] = kHeadDim
|
||||
c.vHeadDims[c.curLayer] = vHeadDim
|
||||
c.numKVHeads[c.curLayer] = numKVHeads
|
||||
}
|
||||
key = key.Reshape(ctx, batchSize, 1, kCellSize)
|
||||
|
||||
// slog.Info("XXX Causal.Put ", "c.keys[c.curLayer]", c.keys[c.curLayer])
|
||||
// slog.Info("XXX Causal.Put ", "c.curLocPut", c.curLocPut)
|
||||
// slog.Info("XXX Causal.Put ", "key", key)
|
||||
ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, key, []int{0}))
|
||||
value = value.Reshape(ctx, batchSize, 1, vCellSize)
|
||||
ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, value, []int{0}))
|
||||
|
||||
}
|
||||
|
||||
func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
key := c.keys[c.curLayer]
|
||||
value := c.values[c.curLayer]
|
||||
|
||||
kHeadDim := c.kHeadDims[c.curLayer]
|
||||
vHeadDim := c.vHeadDims[c.curLayer]
|
||||
numKVHeads := c.numKVHeads[c.curLayer]
|
||||
// rowSize := numKVHeads * c.curBatchSize
|
||||
// cachedSize := c.curMask.Dim(1)
|
||||
cachedSize := c.curLocGet.Dim(0)
|
||||
// kCellSize := kHeadDim * numKVHeads
|
||||
// vCellSize := vHeadDim * numKVHeads
|
||||
// slog.Info("XXX Causal.Get", "shape", []int{1, numKVHeads, cachedSize, kHeadDim})
|
||||
|
||||
key = key.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim)
|
||||
value = value.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim)
|
||||
return key, value, nil
|
||||
}
|
||||
|
||||
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *Causal) CanResume(seq int, pos int32) bool {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
156
x/kvcache/encoder.go
Normal file
156
x/kvcache/encoder.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package kvcache
|
||||
|
||||
// import (
|
||||
// "fmt"
|
||||
|
||||
// "github.com/ollama/ollama/ml"
|
||||
// "github.com/ollama/ollama/model/input"
|
||||
// )
|
||||
|
||||
// // Encoder cache stores K and V tensors that are position independent
|
||||
// //
|
||||
// // The tensors can be of any shape and will be returned as they were stored
|
||||
// // The mask is currently always nil
|
||||
// //
|
||||
// // Not currently safe for multiple sequences
|
||||
// type EncoderCache struct {
|
||||
// // config controls mostly backend-specific optimizations
|
||||
// config *ml.CacheConfig
|
||||
|
||||
// // ** current forward pass **
|
||||
|
||||
// // the active layer for Get and Put
|
||||
// curLayer int
|
||||
|
||||
// // if something is stored during this pass, this
|
||||
// // will be the position (but there is no guarantee
|
||||
// // anything will be stored)
|
||||
// curPos int32
|
||||
|
||||
// // curReserve indicates that this forward pass is only for
|
||||
// // memory reservation and we should not update our metadata
|
||||
// // based on it.
|
||||
// curReserve bool
|
||||
|
||||
// // ** cache metadata **
|
||||
|
||||
// // was something stored in the cache?
|
||||
// encoderCached bool
|
||||
|
||||
// // position of the cached data
|
||||
// encoderPos int32
|
||||
|
||||
// // ** cache data storage **
|
||||
// backend ml.Backend
|
||||
// ctxs map[int]ml.Context
|
||||
// keys, values map[int]ml.Tensor
|
||||
// }
|
||||
|
||||
// func NewEncoderCache() *EncoderCache {
|
||||
// return &EncoderCache{
|
||||
// ctxs: make(map[int]ml.Context),
|
||||
// keys: make(map[int]ml.Tensor),
|
||||
// values: make(map[int]ml.Tensor),
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
// if c.config == nil {
|
||||
// var config ml.CacheConfig
|
||||
// if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
||||
// config = cc.CacheConfig()
|
||||
// }
|
||||
// c.config = &config
|
||||
// }
|
||||
|
||||
// if maxSequences > 1 {
|
||||
// panic(fmt.Errorf("encoder cache does not support multiple sequences; requested: %v", maxSequences))
|
||||
// }
|
||||
|
||||
// if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
|
||||
// panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
|
||||
// }
|
||||
|
||||
// c.backend = backend
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
|
||||
// if c.config != nil {
|
||||
// panic("config cannot be changed after being previously set, either by the model or backend")
|
||||
// }
|
||||
|
||||
// c.config = &config
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) Close() {
|
||||
// for _, ctx := range c.ctxs {
|
||||
// ctx.Close()
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
// // We work with the most recent image
|
||||
// if len(batch.Multimodal) > 0 {
|
||||
// c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
|
||||
// }
|
||||
|
||||
// c.curReserve = reserve
|
||||
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) SetLayer(layer int) {
|
||||
// c.curLayer = layer
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) EncoderCached() bool {
|
||||
// return c.encoderCached
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
// return c.keys[c.curLayer], c.values[c.curLayer], nil
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
// if !c.curReserve {
|
||||
// c.encoderPos = c.curPos
|
||||
// c.encoderCached = true
|
||||
// }
|
||||
|
||||
// if c.config.PermutedV {
|
||||
// value = value.Transpose(ctx, 1, 2, 0, 3)
|
||||
// }
|
||||
|
||||
// if _, ok := c.ctxs[c.curLayer]; !ok {
|
||||
// c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
|
||||
// }
|
||||
|
||||
// if _, ok := c.keys[c.curLayer]; !ok {
|
||||
// c.keys[c.curLayer] = c.ctxs[c.curLayer].Empty(key.DType(), key.Shape()...)
|
||||
// }
|
||||
|
||||
// if _, ok := c.values[c.curLayer]; !ok {
|
||||
// c.values[c.curLayer] = c.ctxs[c.curLayer].Empty(value.DType(), value.Shape()...)
|
||||
// }
|
||||
|
||||
// ctx.Forward(
|
||||
// key.Copy(ctx, c.keys[c.curLayer]),
|
||||
// value.Copy(ctx, c.values[c.curLayer]),
|
||||
// )
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
// panic("encoder cache does not support multiple sequences")
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) CanResume(seq int, pos int32) bool {
|
||||
// return true
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
// if c.encoderPos >= beginIndex && c.encoderPos < endIndex {
|
||||
// c.encoderCached = false
|
||||
// }
|
||||
|
||||
// return nil
|
||||
// }
|
||||
110
x/kvcache/wrapper.go
Normal file
110
x/kvcache/wrapper.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package kvcache
|
||||
|
||||
// import (
|
||||
// "math"
|
||||
|
||||
// "github.com/ollama/ollama/ml"
|
||||
// "github.com/ollama/ollama/model/input"
|
||||
// )
|
||||
|
||||
// // Wrapper cache is a container for multiple types of caches,
|
||||
// // such as for the encoding and decoding portions of a model.
|
||||
// type WrapperCache struct {
|
||||
// // caches we are wrapping
|
||||
// caches []Cache
|
||||
|
||||
// // cache to be used for this layer
|
||||
// curType int
|
||||
// }
|
||||
|
||||
// func NewWrapperCache(caches ...Cache) *WrapperCache {
|
||||
// return &WrapperCache{
|
||||
// caches: caches,
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
// for _, cache := range c.caches {
|
||||
// cache.Init(backend, dtype, maxSequences, capacity, maxBatch)
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) SetConfig(config ml.CacheConfig) {
|
||||
// for _, cache := range c.caches {
|
||||
// cache.SetConfig(config)
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) Close() {
|
||||
// for _, cache := range c.caches {
|
||||
// cache.Close()
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
// for i, cache := range c.caches {
|
||||
// err := cache.StartForward(ctx, batch, reserve)
|
||||
// if err != nil {
|
||||
// // unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
|
||||
// for j := i - 1; j >= 0; j-- {
|
||||
// for k := range batch.Positions {
|
||||
// _ = c.caches[j].Remove(batch.Sequences[k], batch.Positions[k], math.MaxInt32)
|
||||
// }
|
||||
// }
|
||||
// return err
|
||||
// }
|
||||
// }
|
||||
|
||||
// c.curType = 0
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) SetLayer(layer int) {
|
||||
// for _, cache := range c.caches {
|
||||
// cache.SetLayer(layer)
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) SetLayerType(layerType int) {
|
||||
// c.curType = layerType
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) UnderlyingCache() Cache {
|
||||
// return c.caches[c.curType]
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
// return c.caches[c.curType].Get(ctx)
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
// c.caches[c.curType].Put(ctx, key, value)
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
// for _, cache := range c.caches {
|
||||
// cache.CopyPrefix(srcSeq, dstSeq, len)
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) CanResume(seq int, pos int32) bool {
|
||||
// for _, cache := range c.caches {
|
||||
// if !cache.CanResume(seq, pos) {
|
||||
// return false
|
||||
// }
|
||||
// }
|
||||
|
||||
// return true
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
// // If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail
|
||||
// for _, cache := range c.caches {
|
||||
// err := cache.Remove(seq, beginIndex, endIndex)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// }
|
||||
|
||||
// return nil
|
||||
// }
|
||||
433
x/ml/backend.go
Normal file
433
x/ml/backend.go
Normal file
@@ -0,0 +1,433 @@
|
||||
package ml
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
)
|
||||
|
||||
type Backend interface {
|
||||
// Close frees all memory associated with this backend
|
||||
// Close()
|
||||
|
||||
// Load(ctx context.Context, progress func(float32)) error
|
||||
|
||||
// BackendMemory returns the memory allocations that were made for this model
|
||||
// BackendMemory() BackendMemory
|
||||
|
||||
Config() fs.Config
|
||||
Get(name string) Tensor
|
||||
NewContext() Context
|
||||
// NewContextSize(size int) Context
|
||||
|
||||
// Enumerate the devices available for inference via this backend
|
||||
// BackendDevices() []DeviceInfo
|
||||
}
|
||||
|
||||
// BackendCacheConfig should be implemented by backends that need special output
|
||||
// from the cache to meet specific requirements. It is frequently implemented in
|
||||
// conjunction with ScaledDotProductAttention.
|
||||
type BackendCacheConfig interface {
|
||||
CacheConfig() CacheConfig
|
||||
}
|
||||
|
||||
// CacheConfig controls optimizations (mostly backend-specific) that may transform
|
||||
// the output the cache to work better with specific kernels.
|
||||
type CacheConfig struct {
|
||||
// CachePadding specifies the multiple for the number of tokens of cache history
|
||||
// that will be returned from cache Get for k, v and mask. The capacity of the
|
||||
// cache itself will also be increased to a multiple of this size if needed.
|
||||
CachePadding int
|
||||
|
||||
// PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put
|
||||
// and return the permuted version via Get. This uses the cache copy operation
|
||||
// to avoid a Contiguous call on the permuted tensor.
|
||||
PermutedV bool
|
||||
|
||||
// MaskDType specifies the data type for generating the mask. If unset it will
|
||||
// default to DTypeF32.
|
||||
MaskDType DType
|
||||
|
||||
// MaskBatchPadding specifies the multiple for the batch size dimension in the mask.
|
||||
// Any position that does not correspond to an actual token will be filled with -Inf.
|
||||
MaskBatchPadding int
|
||||
}
|
||||
|
||||
// BackendParams controls how the backend loads and executes models
|
||||
type BackendParams struct {
|
||||
// AllocMemory causes the backend to allocate memory for the model. If
|
||||
// false, this is only being used for discovering the required amount of
|
||||
// memory and cannot load the model for running.
|
||||
AllocMemory bool
|
||||
|
||||
// NumThreads sets the number of threads to use if running on the CPU
|
||||
NumThreads int
|
||||
|
||||
// GPULayers is the set of layers to offload to GPUs
|
||||
GPULayers GPULayersList
|
||||
|
||||
// FlashAttention indicates that we should use a fused flash attention kernel
|
||||
FlashAttention bool
|
||||
}
|
||||
|
||||
var backends = make(map[string]func(string, BackendParams) (Backend, error))
|
||||
|
||||
func RegisterBackend(name string, f func(string, BackendParams) (Backend, error)) {
|
||||
if _, ok := backends[name]; ok {
|
||||
panic("backend: backend already registered")
|
||||
}
|
||||
|
||||
backends[name] = f
|
||||
}
|
||||
|
||||
func NewBackend(modelPath string, params BackendParams) (Backend, error) {
|
||||
be := os.Getenv("OLLAMA_BACKEND")
|
||||
if be == "" {
|
||||
be = "mlx"
|
||||
slog.Info("Defaulting to " + be + ". Set OLLAMA_BACKEND to override")
|
||||
}
|
||||
slog.Info("Loading new engine", "backend", be)
|
||||
if backend, ok := backends[be]; ok {
|
||||
return backend(modelPath, params)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unsupported backend")
|
||||
}
|
||||
|
||||
type Context interface {
|
||||
Empty(dtype DType, shape ...int) Tensor
|
||||
Zeros(dtype DType, shape ...int) Tensor
|
||||
// FromBytes(dtype DType, s []byte, shape ...int) Tensor
|
||||
FromFloats(s []float32, shape ...int) Tensor
|
||||
FromInts(s []int32, shape ...int) Tensor
|
||||
RandomNormal(shape []int, dtype DType, loc, scale float32, key Tensor) Tensor
|
||||
|
||||
// Arange creates a 1D tensor with values within an interval (start, stop] increased by step.
|
||||
Arange(start, stop, step float32, dtype DType) Tensor
|
||||
|
||||
Forward(...Tensor) Context
|
||||
|
||||
// SetBatchSize provides a hint on the batch size to optimize processing
|
||||
// Uses heuristics if not set
|
||||
// SetBatchSize(int)
|
||||
|
||||
Compute(...Tensor)
|
||||
// ComputeWithNotify(func(), ...Tensor) // notify callback once compute has begun
|
||||
|
||||
// Reserve is analogous to Compute but rather than executing a
|
||||
// graph, simply preallocates memory. Typically called with a
|
||||
// worst case graph to ensure all resources are available for
|
||||
// for future inference.
|
||||
// Reserve()
|
||||
|
||||
// MaxGraphNodes() int
|
||||
Close()
|
||||
|
||||
// Input returns a context appropriate for creating tensors that are
|
||||
// inputs to the model (which includes things like output locations)
|
||||
Input() Context
|
||||
|
||||
// Layer returns a context appropriate for creating intermediate tensors
|
||||
Layer(int) Context
|
||||
|
||||
// Load a tensor from "filename" safetensors file, and compare with the input tensor
|
||||
// Returns error if the shape is inconsistent, or similarity measures are below 99%
|
||||
CompareWith(filename string, tensors map[string]Tensor, abortOnError bool) error
|
||||
}
|
||||
|
||||
type RoPEOptions struct {
|
||||
Base *float32
|
||||
Freqs Tensor
|
||||
}
|
||||
|
||||
func WithRoPEBase(base float32) func(*RoPEOptions) {
|
||||
return func(opts *RoPEOptions) {
|
||||
opts.Base = &base
|
||||
}
|
||||
}
|
||||
|
||||
func WithRoPEFreqs(freqs Tensor) func(*RoPEOptions) {
|
||||
return func(opts *RoPEOptions) {
|
||||
opts.Freqs = freqs
|
||||
}
|
||||
}
|
||||
|
||||
type Tensor interface {
|
||||
ToString() string
|
||||
RoPE(ctx Context, dims int, traditional bool, scale float32, offset int, options ...func(*RoPEOptions)) Tensor
|
||||
ScaledDotProductAttention(ctx Context, keys, values Tensor, scale float64, maskMode string, mask Tensor, sinks Tensor) Tensor
|
||||
TakeAxes(ctx Context, indicies Tensor, axes int) Tensor
|
||||
// TakeAxes(ctx Context, axes int, indicies ...int) Tensor
|
||||
|
||||
Dim(n int) int
|
||||
Stride(n int) int
|
||||
|
||||
Shape() []int
|
||||
DType() DType
|
||||
// Cast(ctx Context, dtype DType) Tensor
|
||||
|
||||
// Bytes() []byte
|
||||
Floats() []float32
|
||||
Ints() []int32
|
||||
|
||||
// FromBytes([]byte)
|
||||
// FromFloats([]float32)
|
||||
// FromInts([]int32)
|
||||
|
||||
Add(ctx Context, t2 Tensor) Tensor
|
||||
Sub(ctx Context, t2 Tensor) Tensor
|
||||
// Mul(ctx Context, t2 Tensor) Tensor
|
||||
// Div(ctx Context, t2 Tensor) Tensor
|
||||
|
||||
Max(ctx Context, axes []int, keepDims bool) Tensor
|
||||
Min(ctx Context, axes []int, keepDims bool) Tensor
|
||||
|
||||
Matmul(ctx Context, a2 Tensor) Tensor
|
||||
// Mulmat(ctx Context, t2 Tensor) Tensor
|
||||
// MulmatFullPrec(ctx Context, t2 Tensor) Tensor
|
||||
// MulmatID(ctx Context, t2, ids Tensor) Tensor
|
||||
// AddID(ctx Context, t2, ids Tensor) Tensor
|
||||
|
||||
Softmax(ctx Context) Tensor
|
||||
L2Norm(ctx Context, eps float32) Tensor
|
||||
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
|
||||
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
|
||||
Scale(ctx Context, s float64) Tensor
|
||||
// SumRows(ctx Context) Tensor
|
||||
|
||||
AvgPool2D(ctx Context, k, s int, p float32) Tensor
|
||||
Conv2D(ctx Context, weight Tensor, stride0, stride1, padding0, padding1, dilation0, dilation1, groups int) Tensor
|
||||
Conv3D(ctx Context, weight Tensor, stride0, stride1, stride2, padding0, padding1, padding2, dilation0, dilation1, dilation2, groups int) Tensor
|
||||
|
||||
// IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||
|
||||
// Sin(ctx Context) Tensor
|
||||
// Cos(ctx Context) Tensor
|
||||
// Tanh(ctx Context) Tensor
|
||||
GELU(ctx Context, up ...Tensor) Tensor
|
||||
// QuickGELU(ctx Context, up ...Tensor) Tensor
|
||||
// SILU(ctx Context, up ...Tensor) Tensor
|
||||
// RELU(ctx Context, up ...Tensor) Tensor
|
||||
// Sigmoid(ctx Context) Tensor
|
||||
|
||||
// AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit]
|
||||
// SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor
|
||||
|
||||
Reshape(ctx Context, shape ...int) Tensor
|
||||
AsStrided(ctx Context, shape, strides []int, offset int) Tensor
|
||||
Transpose(ctx Context, shape ...int) Tensor
|
||||
Contiguous(ctx Context, allowColMajor bool) Tensor
|
||||
|
||||
// Pad(ctx Context, shape ...int) Tensor
|
||||
|
||||
// Stack(ctx Context, dim int, s ...Tensor) Tensor
|
||||
|
||||
// Repeat repeats the tensor n times along dimension dim
|
||||
// Repeat(ctx Context, dim, n int) Tensor
|
||||
// Concat(ctx Context, t2 Tensor, dim int) Tensor
|
||||
// Rows(ctx Context, t2 Tensor) Tensor
|
||||
|
||||
// TODO these probably aren't actually needed - false starts on trying to wire up cache
|
||||
// SliceUpdate(ctx Context, update Tensor, start, stop, strides []int) Tensor
|
||||
// SliceUpdateDynamic(ctx Context, update, start Tensor, axes []int) Tensor
|
||||
// PutAlongAxis(ctx Context, indicies, values Tensor, axis int) Tensor
|
||||
|
||||
Scatter(ctx Context, indicies []Tensor, updates Tensor, axes []int) Tensor
|
||||
|
||||
Copy(ctx Context, t2 Tensor) Tensor
|
||||
// Duplicate(ctx Context) Tensor
|
||||
|
||||
// Slice(ctx Context, dim, low, high, step int) Tensor
|
||||
// Chunk(ctx Context, dim int, size int) []Tensor
|
||||
// ChunkSections(ctx Context, dim int, sections ...int) []Tensor
|
||||
|
||||
// TopK(ctx Context, k int) Tensor
|
||||
// Argsort(ctx Context) Tensor
|
||||
// Mean(ctx Context) Tensor
|
||||
// Variance(ctx Context) Tensor
|
||||
// Stddev(ctx Context) Tensor
|
||||
// Sqr(ctx Context) Tensor
|
||||
// Sqrt(ctx Context) Tensor
|
||||
|
||||
// Interpolate(ctx Context, dims [4]int, samplingMode SamplingMode) Tensor
|
||||
}
|
||||
|
||||
// ScaledDotProductAttention implements a fused attention
|
||||
// operation equivalent to following code on a tensor named
|
||||
// query:
|
||||
//
|
||||
// query = query.Permute(ctx, 0, 2, 1, 3)
|
||||
// key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
//
|
||||
// kq := key.MulmatFullPrec(ctx, query)
|
||||
//
|
||||
// kq = kq.Scale(ctx, scale)
|
||||
//
|
||||
// if mask != nil {
|
||||
// kq = kq.Add(ctx, mask)
|
||||
// }
|
||||
//
|
||||
// kq = kq.Softmax(ctx)
|
||||
//
|
||||
// kqv := value.Mulmat(ctx, kq)
|
||||
// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
// type ScaledDotProductAttention interface {
|
||||
// ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64) Tensor
|
||||
// }
|
||||
|
||||
// type number interface {
|
||||
// ~int | ~int8 | ~int16 | ~int32 | ~int64 |
|
||||
// ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
|
||||
// ~float32 | ~float64 |
|
||||
// ~complex64 | ~complex128
|
||||
// }
|
||||
|
||||
// func mul[T number](s ...T) T {
|
||||
// p := T(1)
|
||||
// for _, v := range s {
|
||||
// p *= v
|
||||
// }
|
||||
|
||||
// return p
|
||||
// }
|
||||
|
||||
// type DumpOptions func(*dumpOptions)
|
||||
|
||||
// // DumpWithPrecision sets the number of decimal places to print. Applies to float32 and float64.
|
||||
// func DumpWithPrecision(n int) DumpOptions {
|
||||
// return func(opts *dumpOptions) {
|
||||
// opts.Precision = n
|
||||
// }
|
||||
// }
|
||||
|
||||
// // DumpWithThreshold sets the threshold for printing the entire tensor. If the number of elements
|
||||
// // is less than or equal to this value, the entire tensor will be printed. Otherwise, only the
|
||||
// // beginning and end of each dimension will be printed.
|
||||
// func DumpWithThreshold(n int) DumpOptions {
|
||||
// return func(opts *dumpOptions) {
|
||||
// opts.Threshold = n
|
||||
// }
|
||||
// }
|
||||
|
||||
// // DumpWithEdgeItems sets the number of elements to print at the beginning and end of each dimension.
|
||||
// func DumpWithEdgeItems(n int) DumpOptions {
|
||||
// return func(opts *dumpOptions) {
|
||||
// opts.EdgeItems = n
|
||||
// }
|
||||
// }
|
||||
|
||||
// type dumpOptions struct {
|
||||
// Precision, Threshold, EdgeItems int
|
||||
// }
|
||||
|
||||
// func Dump(ctx Context, t Tensor, optsFuncs ...DumpOptions) string {
|
||||
// opts := dumpOptions{Precision: 4, Threshold: 1000, EdgeItems: 3}
|
||||
// for _, optsFunc := range optsFuncs {
|
||||
// optsFunc(&opts)
|
||||
// }
|
||||
|
||||
// if mul(t.Shape()...) <= opts.Threshold {
|
||||
// opts.EdgeItems = math.MaxInt
|
||||
// }
|
||||
|
||||
// switch t.DType() {
|
||||
// case DTypeFloat32:
|
||||
// return dump[[]float32](ctx, t, opts.EdgeItems, func(f float32) string {
|
||||
// return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
|
||||
// })
|
||||
// case DTypeFloat16: // TODO other types...
|
||||
// f32 := ctx.Input().Empty(DTypeFloat32, t.Shape()...)
|
||||
// f32 = t.Copy(ctx, f32)
|
||||
// return dump[[]float32](ctx, f32, opts.EdgeItems, func(f float32) string {
|
||||
// return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
|
||||
// })
|
||||
// case DTypeInt32:
|
||||
// return dump[[]int32](ctx, t, opts.EdgeItems, func(i int32) string {
|
||||
// return strconv.FormatInt(int64(i), 10)
|
||||
// })
|
||||
// default:
|
||||
// return "<unsupported>"
|
||||
// }
|
||||
// }
|
||||
|
||||
// func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
|
||||
// if t.Bytes() == nil {
|
||||
// ctx.Compute(t)
|
||||
// }
|
||||
|
||||
// s := make(S, mul(t.Shape()...))
|
||||
// if err := binary.Read(bytes.NewBuffer(t.Bytes()), binary.LittleEndian, &s); err != nil {
|
||||
// panic(err)
|
||||
// }
|
||||
|
||||
// shape := t.Shape()
|
||||
// slices.Reverse(shape)
|
||||
|
||||
// var sb strings.Builder
|
||||
// var f func([]int, int)
|
||||
// f = func(dims []int, stride int) {
|
||||
// prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
|
||||
// sb.WriteString("[")
|
||||
// defer func() { sb.WriteString("]") }()
|
||||
// for i := 0; i < dims[0]; i++ {
|
||||
// if i >= items && i < dims[0]-items {
|
||||
// sb.WriteString("..., ")
|
||||
// // skip to next printable element
|
||||
// skip := dims[0] - 2*items
|
||||
// if len(dims) > 1 {
|
||||
// stride += mul(append(dims[1:], skip)...)
|
||||
// fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
|
||||
// }
|
||||
// i += skip - 1
|
||||
// } else if len(dims) > 1 {
|
||||
// f(dims[1:], stride)
|
||||
// stride += mul(dims[1:]...)
|
||||
// if i < dims[0]-1 {
|
||||
// fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
|
||||
// }
|
||||
// } else {
|
||||
// text := fn(s[stride+i])
|
||||
// if len(text) > 0 && text[0] != '-' {
|
||||
// sb.WriteString(" ")
|
||||
// }
|
||||
|
||||
// sb.WriteString(text)
|
||||
// if i < dims[0]-1 {
|
||||
// sb.WriteString(", ")
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// f(shape, 0)
|
||||
|
||||
// return sb.String()
|
||||
// }
|
||||
|
||||
type DType int
|
||||
|
||||
const (
|
||||
DTypeBool DType = iota
|
||||
DTypeUint8
|
||||
DTypeUint16
|
||||
DTypeUint32
|
||||
DTypeUint64
|
||||
DTypeInt8
|
||||
DTypeInt16
|
||||
DTypeInt32
|
||||
DTypeInt64
|
||||
DTypeFloat16
|
||||
DTypeFloat32
|
||||
DTypeFloat64
|
||||
DTypeBfloat16
|
||||
DTypeComplex64
|
||||
)
|
||||
|
||||
type SamplingMode int
|
||||
|
||||
const (
|
||||
SamplingModeNearest SamplingMode = iota
|
||||
SamplingModeBilinear
|
||||
)
|
||||
3
x/ml/backend/backend.go
Normal file
3
x/ml/backend/backend.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package backend
|
||||
|
||||
// _ "github.com/ollama/ollama/x/ml/backend/mlx"
|
||||
61
x/ml/backend/mlx/CMakeLists.txt
Normal file
61
x/ml/backend/mlx/CMakeLists.txt
Normal file
@@ -0,0 +1,61 @@
|
||||
include(FetchContent)
|
||||
|
||||
# Read MLX version from top-level file (shared with Dockerfile)
|
||||
file(READ "${CMAKE_SOURCE_DIR}/MLX_VERSION" MLX_C_GIT_TAG)
|
||||
string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG)
|
||||
|
||||
set(MLX_C_BUILD_EXAMPLES OFF)
|
||||
|
||||
set(MLX_BUILD_GGUF OFF)
|
||||
set(MLX_BUILD_SAFETENSORS ON)
|
||||
|
||||
function(set_target_output_directory _target)
|
||||
if(TARGET ${_target})
|
||||
set_target_properties(${_target} PROPERTIES
|
||||
RUNTIME_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR}
|
||||
LIBRARY_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR}
|
||||
ARCHIVE_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR}
|
||||
)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
# Check for Metal support (macOS only)
|
||||
if(CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
||||
execute_process(
|
||||
COMMAND
|
||||
zsh "-c"
|
||||
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
|
||||
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
|
||||
if(NOT MLX_METAL_VERSION)
|
||||
message(STATUS "`xcrun metal` error. Setting MLX_BUILD_METAL=OFF")
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
endif()
|
||||
else()
|
||||
# On Linux, disable Metal backend
|
||||
message(STATUS "Non-macOS platform detected. Setting MLX_BUILD_METAL=OFF")
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
endif()
|
||||
|
||||
# Map CMAKE_CUDA_ARCHITECTURES to MLX_CUDA_ARCHITECTURES if not explicitly set
|
||||
if(NOT MLX_CUDA_ARCHITECTURES AND CMAKE_CUDA_ARCHITECTURES)
|
||||
set(MLX_CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES})
|
||||
message(STATUS "Using CMAKE_CUDA_ARCHITECTURES for MLX: ${MLX_CUDA_ARCHITECTURES}")
|
||||
endif()
|
||||
|
||||
# Enable CUDA backend if CUDA architectures are specified and CUDA compiler is available
|
||||
if(MLX_CUDA_ARCHITECTURES AND CMAKE_CUDA_COMPILER)
|
||||
set(MLX_BUILD_CUDA ON CACHE BOOL "Build CUDA backend for MLX" FORCE)
|
||||
message(STATUS "Enabling MLX CUDA backend with architectures: ${MLX_CUDA_ARCHITECTURES}")
|
||||
elseif(MLX_CUDA_ARCHITECTURES)
|
||||
message(WARNING "MLX_CUDA_ARCHITECTURES specified but CUDA compiler not found, CUDA backend will be disabled")
|
||||
endif()
|
||||
|
||||
FetchContent_Declare(
|
||||
mlx-c
|
||||
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
|
||||
GIT_TAG ${MLX_C_GIT_TAG})
|
||||
FetchContent_MakeAvailable(mlx-c)
|
||||
|
||||
set_target_output_directory(mlx)
|
||||
set_target_output_directory(mlxc)
|
||||
1278
x/ml/backend/mlx/mlx.go
Normal file
1278
x/ml/backend/mlx/mlx.go
Normal file
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user