mirror of
https://github.com/ollama/ollama.git
synced 2026-02-05 21:23:43 -05:00
Compare commits
1 Commits
main
...
drifkin/de
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2e9d9acf18 |
22
.github/workflows/test-install.yaml
vendored
22
.github/workflows/test-install.yaml
vendored
@@ -1,22 +0,0 @@
|
||||
name: test-install
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'scripts/install.sh'
|
||||
- '.github/workflows/test-install.yaml'
|
||||
|
||||
jobs:
|
||||
test:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest]
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Run install script
|
||||
run: sh ./scripts/install.sh
|
||||
env:
|
||||
OLLAMA_NO_START: 1 # do not start app
|
||||
- name: Verify ollama is available
|
||||
run: ollama --version
|
||||
@@ -518,26 +518,24 @@ func mapStopReason(reason string, hasToolCalls bool) string {
|
||||
|
||||
// StreamConverter manages state for converting Ollama streaming responses to Anthropic format
|
||||
type StreamConverter struct {
|
||||
ID string
|
||||
Model string
|
||||
firstWrite bool
|
||||
contentIndex int
|
||||
inputTokens int
|
||||
outputTokens int
|
||||
estimatedInputTokens int // Estimated tokens from request (used when actual metrics are 0)
|
||||
thinkingStarted bool
|
||||
thinkingDone bool
|
||||
textStarted bool
|
||||
toolCallsSent map[string]bool
|
||||
ID string
|
||||
Model string
|
||||
firstWrite bool
|
||||
contentIndex int
|
||||
inputTokens int
|
||||
outputTokens int
|
||||
thinkingStarted bool
|
||||
thinkingDone bool
|
||||
textStarted bool
|
||||
toolCallsSent map[string]bool
|
||||
}
|
||||
|
||||
func NewStreamConverter(id, model string, estimatedInputTokens int) *StreamConverter {
|
||||
func NewStreamConverter(id, model string) *StreamConverter {
|
||||
return &StreamConverter{
|
||||
ID: id,
|
||||
Model: model,
|
||||
firstWrite: true,
|
||||
estimatedInputTokens: estimatedInputTokens,
|
||||
toolCallsSent: make(map[string]bool),
|
||||
ID: id,
|
||||
Model: model,
|
||||
firstWrite: true,
|
||||
toolCallsSent: make(map[string]bool),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -553,11 +551,7 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
||||
|
||||
if c.firstWrite {
|
||||
c.firstWrite = false
|
||||
// Use actual metrics if available, otherwise use estimate
|
||||
c.inputTokens = r.Metrics.PromptEvalCount
|
||||
if c.inputTokens == 0 && c.estimatedInputTokens > 0 {
|
||||
c.inputTokens = c.estimatedInputTokens
|
||||
}
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "message_start",
|
||||
@@ -785,123 +779,3 @@ func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
// CountTokensRequest represents an Anthropic count_tokens request
|
||||
type CountTokensRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []MessageParam `json:"messages"`
|
||||
System any `json:"system,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
// EstimateInputTokens estimates input tokens from a MessagesRequest (reuses CountTokensRequest logic)
|
||||
func EstimateInputTokens(req MessagesRequest) int {
|
||||
return estimateTokens(CountTokensRequest{
|
||||
Model: req.Model,
|
||||
Messages: req.Messages,
|
||||
System: req.System,
|
||||
Tools: req.Tools,
|
||||
Thinking: req.Thinking,
|
||||
})
|
||||
}
|
||||
|
||||
// CountTokensResponse represents an Anthropic count_tokens response
|
||||
type CountTokensResponse struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
}
|
||||
|
||||
// estimateTokens returns a rough estimate of tokens (len/4).
|
||||
// TODO: Replace with actual tokenization via Tokenize API for accuracy.
|
||||
// Current len/4 heuristic is a rough approximation (~4 chars/token average).
|
||||
func estimateTokens(req CountTokensRequest) int {
|
||||
var totalLen int
|
||||
|
||||
// Count system prompt
|
||||
if req.System != nil {
|
||||
totalLen += countAnyContent(req.System)
|
||||
}
|
||||
|
||||
// Count messages
|
||||
for _, msg := range req.Messages {
|
||||
// Count role (always present)
|
||||
totalLen += len(msg.Role)
|
||||
// Count content
|
||||
contentLen := countAnyContent(msg.Content)
|
||||
totalLen += contentLen
|
||||
}
|
||||
|
||||
for _, tool := range req.Tools {
|
||||
totalLen += len(tool.Name) + len(tool.Description) + len(tool.InputSchema)
|
||||
}
|
||||
|
||||
// Return len/4 as rough token estimate, minimum 1 if there's any content
|
||||
tokens := totalLen / 4
|
||||
if tokens == 0 && (len(req.Messages) > 0 || req.System != nil) {
|
||||
tokens = 1
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
func countAnyContent(content any) int {
|
||||
if content == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
switch c := content.(type) {
|
||||
case string:
|
||||
return len(c)
|
||||
case []any:
|
||||
total := 0
|
||||
for _, block := range c {
|
||||
total += countContentBlock(block)
|
||||
}
|
||||
return total
|
||||
default:
|
||||
if data, err := json.Marshal(content); err == nil {
|
||||
return len(data)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func countContentBlock(block any) int {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
if s, ok := block.(string); ok {
|
||||
return len(s)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
total := 0
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
|
||||
if text, ok := blockMap["text"].(string); ok {
|
||||
total += len(text)
|
||||
}
|
||||
|
||||
if thinking, ok := blockMap["thinking"].(string); ok {
|
||||
total += len(thinking)
|
||||
}
|
||||
|
||||
if blockType == "tool_use" {
|
||||
if data, err := json.Marshal(blockMap); err == nil {
|
||||
total += len(data)
|
||||
}
|
||||
}
|
||||
|
||||
if blockType == "tool_result" {
|
||||
if data, err := json.Marshal(blockMap); err == nil {
|
||||
total += len(data)
|
||||
}
|
||||
}
|
||||
|
||||
if source, ok := blockMap["source"].(map[string]any); ok {
|
||||
if data, ok := source["data"].(string); ok {
|
||||
total += len(data)
|
||||
}
|
||||
}
|
||||
|
||||
return total
|
||||
}
|
||||
|
||||
@@ -321,6 +321,8 @@ func TestFromMessagesRequest_WithThinking(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestFromMessagesRequest_ThinkingOnlyBlock verifies that messages containing only
|
||||
// a thinking block (no text, images, or tool calls) are preserved and not dropped.
|
||||
func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
@@ -603,7 +605,7 @@ func TestGenerateMessageID(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStreamConverter_Basic(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
// First chunk
|
||||
resp1 := api.ChatResponse{
|
||||
@@ -676,7 +678,7 @@ func TestStreamConverter_Basic(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStreamConverter_WithToolCalls(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
@@ -729,7 +731,7 @@ func TestStreamConverter_WithToolCalls(t *testing.T) {
|
||||
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
|
||||
// Test that unmarshalable arguments (like channels) are handled gracefully
|
||||
// and don't cause a panic or corrupt stream
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
// Create a channel which cannot be JSON marshaled
|
||||
unmarshalable := make(chan int)
|
||||
@@ -776,7 +778,7 @@ func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
|
||||
|
||||
func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
|
||||
// Test that valid tool calls still work when mixed with invalid ones
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
unmarshalable := make(chan int)
|
||||
badArgs := api.NewToolCallFunctionArguments()
|
||||
@@ -840,6 +842,10 @@ func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestContentBlockJSON_EmptyFieldsPresent verifies that empty text and thinking fields
|
||||
// are serialized in JSON output. The Anthropic SDK requires these fields to be present
|
||||
// (even when empty) in content_block_start events to properly accumulate streaming deltas.
|
||||
// Without these fields, the SDK throws: "TypeError: unsupported operand type(s) for +=: 'NoneType' and 'str'"
|
||||
func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -893,9 +899,11 @@ func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamConverter_ContentBlockStartIncludesEmptyFields verifies that content_block_start
|
||||
// events include the required empty fields for SDK compatibility.
|
||||
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
||||
t.Run("text block start includes empty text", func(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
@@ -929,7 +937,7 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("thinking block start includes empty thinking", func(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
@@ -961,105 +969,3 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestEstimateTokens_SimpleMessage(t *testing.T) {
|
||||
req := CountTokensRequest{
|
||||
Model: "test-model",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello, world!"},
|
||||
},
|
||||
}
|
||||
|
||||
tokens := estimateTokens(req)
|
||||
|
||||
// "user" (4) + "Hello, world!" (13) = 17 chars / 4 = 4 tokens
|
||||
if tokens < 1 {
|
||||
t.Errorf("expected at least 1 token, got %d", tokens)
|
||||
}
|
||||
// Sanity check: shouldn't be wildly off
|
||||
if tokens > 10 {
|
||||
t.Errorf("expected fewer than 10 tokens for short message, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateTokens_WithSystemPrompt(t *testing.T) {
|
||||
req := CountTokensRequest{
|
||||
Model: "test-model",
|
||||
System: "You are a helpful assistant.",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
}
|
||||
|
||||
tokens := estimateTokens(req)
|
||||
|
||||
// System prompt adds to count
|
||||
if tokens < 5 {
|
||||
t.Errorf("expected at least 5 tokens with system prompt, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateTokens_WithTools(t *testing.T) {
|
||||
req := CountTokensRequest{
|
||||
Model: "test-model",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
},
|
||||
Tools: []Tool{
|
||||
{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather for a location",
|
||||
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}}}`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tokens := estimateTokens(req)
|
||||
|
||||
// Tools add significant content
|
||||
if tokens < 10 {
|
||||
t.Errorf("expected at least 10 tokens with tools, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateTokens_WithThinking(t *testing.T) {
|
||||
req := CountTokensRequest{
|
||||
Model: "test-model",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "thinking",
|
||||
"thinking": "Let me think about this carefully...",
|
||||
},
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "Here is my response.",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tokens := estimateTokens(req)
|
||||
|
||||
// Thinking content should be counted
|
||||
if tokens < 10 {
|
||||
t.Errorf("expected at least 10 tokens with thinking content, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateTokens_EmptyContent(t *testing.T) {
|
||||
req := CountTokensRequest{
|
||||
Model: "test-model",
|
||||
Messages: []MessageParam{},
|
||||
}
|
||||
|
||||
tokens := estimateTokens(req)
|
||||
|
||||
if tokens != 0 {
|
||||
t.Errorf("expected 0 tokens for empty content, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -466,25 +466,3 @@ func (c *Client) Whoami(ctx context.Context) (*UserResponse, error) {
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// AliasRequest is the request body for creating or updating a model alias.
|
||||
type AliasRequest struct {
|
||||
Alias string `json:"alias"`
|
||||
Target string `json:"target"`
|
||||
PrefixMatching bool `json:"prefix_matching,omitempty"`
|
||||
}
|
||||
|
||||
// SetAliasExperimental creates or updates a model alias via the experimental aliases API.
|
||||
func (c *Client) SetAliasExperimental(ctx context.Context, req *AliasRequest) error {
|
||||
return c.do(ctx, http.MethodPost, "/api/experimental/aliases", req, nil)
|
||||
}
|
||||
|
||||
// AliasDeleteRequest is the request body for deleting a model alias.
|
||||
type AliasDeleteRequest struct {
|
||||
Alias string `json:"alias"`
|
||||
}
|
||||
|
||||
// DeleteAliasExperimental deletes a model alias via the experimental aliases API.
|
||||
func (c *Client) DeleteAliasExperimental(ctx context.Context, req *AliasDeleteRequest) error {
|
||||
return c.do(ctx, http.MethodDelete, "/api/experimental/aliases", req, nil)
|
||||
}
|
||||
|
||||
@@ -1763,7 +1763,7 @@ func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
|
||||
return err
|
||||
}
|
||||
if err := startApp(cmd.Context(), client); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("ollama server not responding - %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -1,23 +1,18 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Claude implements Runner and AliasConfigurer for Claude Code integration
|
||||
// Claude implements Runner for Claude Code integration
|
||||
type Claude struct{}
|
||||
|
||||
// Compile-time check that Claude implements AliasConfigurer
|
||||
var _ AliasConfigurer = (*Claude)(nil)
|
||||
|
||||
func (c *Claude) String() string { return "Claude Code" }
|
||||
|
||||
func (c *Claude) args(model string, extra []string) []string {
|
||||
@@ -65,104 +60,3 @@ func (c *Claude) Run(model string, args []string) error {
|
||||
)
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// ConfigureAliases sets up model aliases for Claude Code.
|
||||
// model: the model to use (if empty, user will be prompted to select)
|
||||
// aliases: existing alias configuration to preserve/update
|
||||
// Cloud-only: subagent routing (fast model) is gated to cloud models only until
|
||||
// there is a better strategy for prompt caching on local models.
|
||||
func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAliases map[string]string, force bool) (map[string]string, bool, error) {
|
||||
aliases := make(map[string]string)
|
||||
for k, v := range existingAliases {
|
||||
aliases[k] = v
|
||||
}
|
||||
|
||||
if model != "" {
|
||||
aliases["primary"] = model
|
||||
}
|
||||
|
||||
if !force && aliases["primary"] != "" {
|
||||
client, _ := api.ClientFromEnvironment()
|
||||
if isCloudModel(ctx, client, aliases["primary"]) {
|
||||
if isCloudModel(ctx, client, aliases["fast"]) {
|
||||
return aliases, false, nil
|
||||
}
|
||||
} else {
|
||||
delete(aliases, "fast")
|
||||
return aliases, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
items, existingModels, cloudModels, client, err := listModels(ctx)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n\n", ansiBold, ansiReset)
|
||||
|
||||
if aliases["primary"] == "" || force {
|
||||
primary, err := selectPrompt("Select model:", items)
|
||||
fmt.Fprintf(os.Stderr, "\033[3A\033[J")
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if err := pullIfNeeded(ctx, client, existingModels, primary); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if err := ensureAuth(ctx, client, cloudModels, []string{primary}); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
aliases["primary"] = primary
|
||||
}
|
||||
|
||||
if isCloudModel(ctx, client, aliases["primary"]) {
|
||||
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
}
|
||||
} else {
|
||||
delete(aliases, "fast")
|
||||
}
|
||||
|
||||
return aliases, true, nil
|
||||
}
|
||||
|
||||
// SetAliases syncs the configured aliases to the Ollama server using prefix matching.
|
||||
// Cloud-only: for local models (fast is empty), we delete any existing aliases to
|
||||
// prevent stale routing to a previous cloud model.
|
||||
func (c *Claude) SetAliases(ctx context.Context, aliases map[string]string) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
prefixes := []string{"claude-sonnet-", "claude-haiku-"}
|
||||
|
||||
if aliases["fast"] == "" {
|
||||
for _, prefix := range prefixes {
|
||||
_ = client.DeleteAliasExperimental(ctx, &api.AliasDeleteRequest{Alias: prefix})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
prefixAliases := map[string]string{
|
||||
"claude-sonnet-": aliases["primary"],
|
||||
"claude-haiku-": aliases["fast"],
|
||||
}
|
||||
|
||||
var errs []string
|
||||
for prefix, target := range prefixAliases {
|
||||
req := &api.AliasRequest{
|
||||
Alias: prefix,
|
||||
Target: target,
|
||||
PrefixMatching: true,
|
||||
}
|
||||
if err := client.SetAliasExperimental(ctx, req); err != nil {
|
||||
errs = append(errs, prefix)
|
||||
}
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("failed to set aliases: %v", errs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -13,8 +13,7 @@ import (
|
||||
)
|
||||
|
||||
type integration struct {
|
||||
Models []string `json:"models"`
|
||||
Aliases map[string]string `json:"aliases,omitempty"`
|
||||
Models []string `json:"models"`
|
||||
}
|
||||
|
||||
type config struct {
|
||||
@@ -134,16 +133,8 @@ func saveIntegration(appName string, models []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
key := strings.ToLower(appName)
|
||||
existing := cfg.Integrations[key]
|
||||
var aliases map[string]string
|
||||
if existing != nil && existing.Aliases != nil {
|
||||
aliases = existing.Aliases
|
||||
}
|
||||
|
||||
cfg.Integrations[key] = &integration{
|
||||
Models: models,
|
||||
Aliases: aliases,
|
||||
cfg.Integrations[strings.ToLower(appName)] = &integration{
|
||||
Models: models,
|
||||
}
|
||||
|
||||
return save(cfg)
|
||||
@@ -163,29 +154,6 @@ func loadIntegration(appName string) (*integration, error) {
|
||||
return ic, nil
|
||||
}
|
||||
|
||||
func saveAliases(appName string, aliases map[string]string) error {
|
||||
if appName == "" {
|
||||
return errors.New("app name cannot be empty")
|
||||
}
|
||||
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
key := strings.ToLower(appName)
|
||||
existing := cfg.Integrations[key]
|
||||
if existing == nil {
|
||||
existing = &integration{}
|
||||
}
|
||||
|
||||
// Replace aliases entirely (not merge) so deletions are persisted
|
||||
existing.Aliases = aliases
|
||||
|
||||
cfg.Integrations[key] = existing
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
func listIntegrations() ([]integration, error) {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
|
||||
@@ -1,677 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSetAliases_CloudModel(t *testing.T) {
|
||||
// Test the SetAliases logic by checking the alias map behavior
|
||||
aliases := map[string]string{
|
||||
"primary": "kimi-k2.5:cloud",
|
||||
"fast": "kimi-k2.5:cloud",
|
||||
}
|
||||
|
||||
// Verify fast is set (cloud model behavior)
|
||||
if aliases["fast"] == "" {
|
||||
t.Error("cloud model should have fast alias set")
|
||||
}
|
||||
if aliases["fast"] != aliases["primary"] {
|
||||
t.Errorf("fast should equal primary for auto-set, got fast=%q primary=%q", aliases["fast"], aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetAliases_LocalModel(t *testing.T) {
|
||||
aliases := map[string]string{
|
||||
"primary": "llama3.2:latest",
|
||||
}
|
||||
// Simulate local model behavior: fast should be empty
|
||||
delete(aliases, "fast")
|
||||
|
||||
if aliases["fast"] != "" {
|
||||
t.Error("local model should have empty fast alias")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAliases_ReplacesNotMerges(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// First save with both primary and fast
|
||||
initial := map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}
|
||||
if err := saveAliases("claude", initial); err != nil {
|
||||
t.Fatalf("failed to save initial aliases: %v", err)
|
||||
}
|
||||
|
||||
// Verify both are saved
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["fast"] != "cloud-model" {
|
||||
t.Errorf("expected fast=cloud-model, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
|
||||
// Now save without fast (simulating switch to local model)
|
||||
updated := map[string]string{
|
||||
"primary": "local-model",
|
||||
// fast intentionally missing
|
||||
}
|
||||
if err := saveAliases("claude", updated); err != nil {
|
||||
t.Fatalf("failed to save updated aliases: %v", err)
|
||||
}
|
||||
|
||||
// Verify fast is GONE (not merged/preserved)
|
||||
loaded, err = loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load after update: %v", err)
|
||||
}
|
||||
if loaded.Aliases["fast"] != "" {
|
||||
t.Errorf("fast should be removed after saving without it, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
if loaded.Aliases["primary"] != "local-model" {
|
||||
t.Errorf("primary should be updated to local-model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAliases_PreservesModels(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// First save integration with models
|
||||
if err := saveIntegration("claude", []string{"model1", "model2"}); err != nil {
|
||||
t.Fatalf("failed to save integration: %v", err)
|
||||
}
|
||||
|
||||
// Then update aliases
|
||||
aliases := map[string]string{"primary": "new-model"}
|
||||
if err := saveAliases("claude", aliases); err != nil {
|
||||
t.Fatalf("failed to save aliases: %v", err)
|
||||
}
|
||||
|
||||
// Verify models are preserved
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if len(loaded.Models) != 2 || loaded.Models[0] != "model1" {
|
||||
t.Errorf("models should be preserved, got %v", loaded.Models)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveAliases_EmptyMap clears all aliases
|
||||
func TestSaveAliases_EmptyMap(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save with aliases
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Save empty map
|
||||
if err := saveAliases("claude", map[string]string{}); err != nil {
|
||||
t.Fatalf("failed to save empty: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if len(loaded.Aliases) != 0 {
|
||||
t.Errorf("aliases should be empty, got %v", loaded.Aliases)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveAliases_NilMap handles nil gracefully
|
||||
func TestSaveAliases_NilMap(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save with aliases first
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Save nil map - should clear aliases
|
||||
if err := saveAliases("claude", nil); err != nil {
|
||||
t.Fatalf("failed to save nil: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if len(loaded.Aliases) > 0 {
|
||||
t.Errorf("aliases should be nil or empty, got %v", loaded.Aliases)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveAliases_EmptyAppName returns error
|
||||
func TestSaveAliases_EmptyAppName(t *testing.T) {
|
||||
err := saveAliases("", map[string]string{"primary": "model"})
|
||||
if err == nil {
|
||||
t.Error("expected error for empty app name")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAliases_CaseInsensitive(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
if err := saveAliases("Claude", map[string]string{"primary": "model1"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Load with different case
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != "model1" {
|
||||
t.Errorf("expected primary=model1, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
|
||||
// Update with different case
|
||||
if err := saveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil {
|
||||
t.Fatalf("failed to update: %v", err)
|
||||
}
|
||||
|
||||
loaded, err = loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load after update: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != "model2" {
|
||||
t.Errorf("expected primary=model2, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveAliases_CreatesIntegration creates integration if it doesn't exist
|
||||
func TestSaveAliases_CreatesIntegration(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save aliases for non-existent integration
|
||||
if err := saveAliases("newintegration", map[string]string{"primary": "model"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("newintegration")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != "model" {
|
||||
t.Errorf("expected primary=model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigureAliases_AliasMap(t *testing.T) {
|
||||
t.Run("cloud model auto-sets fast to primary", func(t *testing.T) {
|
||||
aliases := make(map[string]string)
|
||||
aliases["primary"] = "cloud-model"
|
||||
|
||||
// Simulate cloud model behavior
|
||||
isCloud := true
|
||||
if isCloud {
|
||||
if aliases["fast"] == "" {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
}
|
||||
}
|
||||
|
||||
if aliases["fast"] != "cloud-model" {
|
||||
t.Errorf("expected fast=cloud-model, got %q", aliases["fast"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cloud model preserves custom fast", func(t *testing.T) {
|
||||
aliases := map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "custom-fast-model",
|
||||
}
|
||||
|
||||
// Simulate cloud model behavior - should preserve existing fast
|
||||
isCloud := true
|
||||
if isCloud {
|
||||
if aliases["fast"] == "" {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
}
|
||||
}
|
||||
|
||||
if aliases["fast"] != "custom-fast-model" {
|
||||
t.Errorf("expected fast=custom-fast-model (preserved), got %q", aliases["fast"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("local model clears fast", func(t *testing.T) {
|
||||
aliases := map[string]string{
|
||||
"primary": "local-model",
|
||||
"fast": "should-be-cleared",
|
||||
}
|
||||
|
||||
// Simulate local model behavior
|
||||
isCloud := false
|
||||
if !isCloud {
|
||||
delete(aliases, "fast")
|
||||
}
|
||||
|
||||
if aliases["fast"] != "" {
|
||||
t.Errorf("expected fast to be cleared, got %q", aliases["fast"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("switching cloud to local clears fast", func(t *testing.T) {
|
||||
// Start with cloud config
|
||||
aliases := map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}
|
||||
|
||||
// Switch to local
|
||||
aliases["primary"] = "local-model"
|
||||
isCloud := false
|
||||
if !isCloud {
|
||||
delete(aliases, "fast")
|
||||
}
|
||||
|
||||
if aliases["fast"] != "" {
|
||||
t.Errorf("fast should be cleared when switching to local, got %q", aliases["fast"])
|
||||
}
|
||||
if aliases["primary"] != "local-model" {
|
||||
t.Errorf("primary should be updated, got %q", aliases["primary"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("switching local to cloud sets fast", func(t *testing.T) {
|
||||
// Start with local config (no fast)
|
||||
aliases := map[string]string{
|
||||
"primary": "local-model",
|
||||
}
|
||||
|
||||
// Switch to cloud
|
||||
aliases["primary"] = "cloud-model"
|
||||
isCloud := true
|
||||
if isCloud {
|
||||
if aliases["fast"] == "" {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
}
|
||||
}
|
||||
|
||||
if aliases["fast"] != "cloud-model" {
|
||||
t.Errorf("fast should be set when switching to cloud, got %q", aliases["fast"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetAliases_PrefixMapping(t *testing.T) {
|
||||
// This tests the expected mapping without needing a real client
|
||||
aliases := map[string]string{
|
||||
"primary": "my-cloud-model",
|
||||
"fast": "my-fast-model",
|
||||
}
|
||||
|
||||
expectedMappings := map[string]string{
|
||||
"claude-sonnet-": aliases["primary"],
|
||||
"claude-haiku-": aliases["fast"],
|
||||
}
|
||||
|
||||
if expectedMappings["claude-sonnet-"] != "my-cloud-model" {
|
||||
t.Errorf("claude-sonnet- should map to primary")
|
||||
}
|
||||
if expectedMappings["claude-haiku-"] != "my-fast-model" {
|
||||
t.Errorf("claude-haiku- should map to fast")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetAliases_LocalDeletesPrefixes(t *testing.T) {
|
||||
aliases := map[string]string{
|
||||
"primary": "local-model",
|
||||
// fast is empty/missing - indicates local model
|
||||
}
|
||||
|
||||
prefixesToDelete := []string{"claude-sonnet-", "claude-haiku-"}
|
||||
|
||||
// Verify the logic: when fast is empty, we should delete
|
||||
if aliases["fast"] != "" {
|
||||
t.Error("fast should be empty for local model")
|
||||
}
|
||||
|
||||
// Verify we have the right prefixes to delete
|
||||
if len(prefixesToDelete) != 2 {
|
||||
t.Errorf("expected 2 prefixes to delete, got %d", len(prefixesToDelete))
|
||||
}
|
||||
}
|
||||
|
||||
// TestAtomicUpdate_ServerFailsConfigNotSaved simulates atomic update behavior
|
||||
func TestAtomicUpdate_ServerFailsConfigNotSaved(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Simulate: server fails, config should NOT be saved
|
||||
serverErr := errors.New("server unavailable")
|
||||
|
||||
if serverErr == nil {
|
||||
t.Error("config should NOT be saved when server fails")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAtomicUpdate_ServerSucceedsConfigSaved simulates successful atomic update
|
||||
func TestAtomicUpdate_ServerSucceedsConfigSaved(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Simulate: server succeeds, config should be saved
|
||||
var serverErr error
|
||||
if serverErr != nil {
|
||||
t.Fatal("server should succeed")
|
||||
}
|
||||
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||
t.Fatalf("saveAliases failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify it was actually saved
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != "model" {
|
||||
t.Errorf("expected primary=model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigFile_PreservesUnknownFields(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Write config with extra fields
|
||||
configPath := filepath.Join(tmpDir, ".ollama", "config.json")
|
||||
os.MkdirAll(filepath.Dir(configPath), 0o755)
|
||||
|
||||
// Note: Our config struct only has Integrations, so top-level unknown fields
|
||||
// won't be preserved by our current implementation. This test documents that.
|
||||
initialConfig := `{
|
||||
"integrations": {
|
||||
"claude": {
|
||||
"models": ["model1"],
|
||||
"aliases": {"primary": "model1"},
|
||||
"unknownField": "should be lost"
|
||||
}
|
||||
},
|
||||
"topLevelUnknown": "will be lost"
|
||||
}`
|
||||
os.WriteFile(configPath, []byte(initialConfig), 0o644)
|
||||
|
||||
// Update aliases
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model2"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Read raw file to check
|
||||
data, _ := os.ReadFile(configPath)
|
||||
content := string(data)
|
||||
|
||||
// models should be preserved
|
||||
if !contains(content, "model1") {
|
||||
t.Error("models should be preserved")
|
||||
}
|
||||
|
||||
// primary should be updated
|
||||
if !contains(content, "model2") {
|
||||
t.Error("primary should be updated to model2")
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr))
|
||||
}
|
||||
|
||||
func containsHelper(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func TestClaudeImplementsAliasConfigurer(t *testing.T) {
|
||||
c := &Claude{}
|
||||
var _ AliasConfigurer = c // Compile-time check
|
||||
}
|
||||
|
||||
func TestModelNameEdgeCases(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
model string
|
||||
}{
|
||||
{"simple", "llama3.2"},
|
||||
{"with tag", "llama3.2:latest"},
|
||||
{"with cloud tag", "kimi-k2.5:cloud"},
|
||||
{"with namespace", "library/llama3.2"},
|
||||
{"with dots", "glm-4.7-flash"},
|
||||
{"with numbers", "qwen3:8b"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
aliases := map[string]string{"primary": tc.model}
|
||||
if err := saveAliases("claude", aliases); err != nil {
|
||||
t.Fatalf("failed to save model %q: %v", tc.model, err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != tc.model {
|
||||
t.Errorf("expected primary=%q, got %q", tc.model, loaded.Aliases["primary"])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwitchingScenarios(t *testing.T) {
|
||||
t.Run("cloud to local removes fast", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial cloud config
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Switch to local (no fast)
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "local-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
if loaded.Aliases["fast"] != "" {
|
||||
t.Errorf("fast should be removed, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
if loaded.Aliases["primary"] != "local-model" {
|
||||
t.Errorf("primary should be local-model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("local to cloud adds fast", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial local config
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "local-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Switch to cloud (with fast)
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
if loaded.Aliases["fast"] != "cloud-model" {
|
||||
t.Errorf("fast should be cloud-model, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cloud to different cloud updates both", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial cloud config
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model-1",
|
||||
"fast": "cloud-model-1",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Switch to different cloud
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model-2",
|
||||
"fast": "cloud-model-2",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
if loaded.Aliases["primary"] != "cloud-model-2" {
|
||||
t.Errorf("primary should be cloud-model-2, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
if loaded.Aliases["fast"] != "cloud-model-2" {
|
||||
t.Errorf("fast should be cloud-model-2, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestToolCapabilityFiltering(t *testing.T) {
|
||||
t.Run("all models checked for tool capability", func(t *testing.T) {
|
||||
// Both cloud and local models are checked for tool capability via Show API
|
||||
// Only models with "tools" in capabilities are included
|
||||
m := modelInfo{Name: "tool-model", Remote: false, ToolCapable: true}
|
||||
if !m.ToolCapable {
|
||||
t.Error("tool capable model should be marked as such")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("modelInfo includes ToolCapable field", func(t *testing.T) {
|
||||
m := modelInfo{Name: "test", Remote: true, ToolCapable: true}
|
||||
if !m.ToolCapable {
|
||||
t.Error("ToolCapable field should be accessible")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsCloudModel_RequiresClient(t *testing.T) {
|
||||
t.Run("nil client always returns false", func(t *testing.T) {
|
||||
// isCloudModel now only uses Show API, no suffix detection
|
||||
if isCloudModel(context.Background(), nil, "model:cloud") {
|
||||
t.Error("nil client should return false regardless of suffix")
|
||||
}
|
||||
if isCloudModel(context.Background(), nil, "local-model") {
|
||||
t.Error("nil client should return false")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||
t.Run("saveAliases followed by saveIntegration keeps them in sync", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save aliases with one model
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Save integration with same model (this is the pattern we use)
|
||||
if err := saveIntegration("claude", []string{"model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
if loaded.Aliases["primary"] != loaded.Models[0] {
|
||||
t.Errorf("aliases.primary (%q) != models[0] (%q)", loaded.Aliases["primary"], loaded.Models[0])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("out of sync config is detectable", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Simulate out-of-sync state (like manual edit or bug)
|
||||
if err := saveIntegration("claude", []string{"old-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "new-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
|
||||
// They should be different (this is the bug state)
|
||||
if loaded.Models[0] == loaded.Aliases["primary"] {
|
||||
t.Error("expected out-of-sync state for this test")
|
||||
}
|
||||
|
||||
// The fix: when updating aliases, also update models
|
||||
if err := saveIntegration("claude", []string{loaded.Aliases["primary"]}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ = loadIntegration("claude")
|
||||
if loaded.Models[0] != loaded.Aliases["primary"] {
|
||||
t.Errorf("after fix: models[0] (%q) should equal aliases.primary (%q)",
|
||||
loaded.Models[0], loaded.Aliases["primary"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("updating primary alias updates models too", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial state
|
||||
if err := saveIntegration("claude", []string{"initial-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Update aliases AND models together
|
||||
newAliases := map[string]string{"primary": "updated-model"}
|
||||
if err := saveAliases("claude", newAliases); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveIntegration("claude", []string{newAliases["primary"]}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
if loaded.Models[0] != "updated-model" {
|
||||
t.Errorf("models[0] should be updated-model, got %q", loaded.Models[0])
|
||||
}
|
||||
if loaded.Aliases["primary"] != "updated-model" {
|
||||
t.Errorf("aliases.primary should be updated-model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -46,53 +46,6 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("save and load aliases", func(t *testing.T) {
|
||||
models := []string{"llama3.2"}
|
||||
if err := saveIntegration("claude", models); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
aliases := map[string]string{
|
||||
"primary": "llama3.2:70b",
|
||||
"fast": "llama3.2:8b",
|
||||
}
|
||||
if err := saveAliases("claude", aliases); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if config.Aliases == nil {
|
||||
t.Fatal("expected aliases to be saved")
|
||||
}
|
||||
for k, v := range aliases {
|
||||
if config.Aliases[k] != v {
|
||||
t.Errorf("alias %s: expected %s, got %s", k, v, config.Aliases[k])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("saveIntegration preserves aliases", func(t *testing.T) {
|
||||
if err := saveIntegration("claude", []string{"model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := saveIntegration("claude", []string{"model-b"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
config, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if config.Aliases["primary"] != "model-a" {
|
||||
t.Errorf("expected aliases to be preserved, got %v", config.Aliases)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("defaultModel returns first model", func(t *testing.T) {
|
||||
saveIntegration("codex", []string{"model-a", "model-b"})
|
||||
|
||||
|
||||
@@ -39,15 +39,6 @@ type Editor interface {
|
||||
Models() []string
|
||||
}
|
||||
|
||||
// AliasConfigurer can configure model aliases (e.g., for subagent routing).
|
||||
// Integrations like Claude and Codex use this to route model requests to local models.
|
||||
type AliasConfigurer interface {
|
||||
// ConfigureAliases prompts the user to configure aliases and returns the updated map.
|
||||
ConfigureAliases(ctx context.Context, primaryModel string, existing map[string]string, force bool) (map[string]string, bool, error)
|
||||
// SetAliases syncs the configured aliases to the server
|
||||
SetAliases(ctx context.Context, aliases map[string]string) error
|
||||
}
|
||||
|
||||
// integrations is the registry of available integrations.
|
||||
var integrations = map[string]Runner{
|
||||
"claude": &Claude{},
|
||||
@@ -138,11 +129,7 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
prompt := fmt.Sprintf("Select model for %s:", r)
|
||||
if _, ok := r.(AliasConfigurer); ok {
|
||||
prompt = fmt.Sprintf("Select Primary model for %s:", r)
|
||||
}
|
||||
model, err := selectPrompt(prompt, items)
|
||||
model, err := selectPrompt(fmt.Sprintf("Select model for %s:", r), items)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -170,123 +157,73 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if err := ensureAuth(ctx, client, cloudModels, selected); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
func pullIfNeeded(ctx context.Context, client *api.Client, existingModels map[string]bool, model string) error {
|
||||
if existingModels[model] {
|
||||
return nil
|
||||
}
|
||||
msg := fmt.Sprintf("Download %s?", model)
|
||||
if ok, err := confirmPrompt(msg); err != nil {
|
||||
return err
|
||||
} else if !ok {
|
||||
return errCancelled
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
if err := pullModel(ctx, client, model); err != nil {
|
||||
return fmt.Errorf("failed to pull %s: %w", model, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func listModels(ctx context.Context) ([]selectItem, map[string]bool, map[string]bool, *api.Client, error) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
|
||||
models, err := client.List(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
|
||||
var existing []modelInfo
|
||||
for _, m := range models.Models {
|
||||
existing = append(existing, modelInfo{
|
||||
Name: m.Name,
|
||||
Remote: m.RemoteModel != "",
|
||||
})
|
||||
}
|
||||
|
||||
items, _, existingModels, cloudModels := buildModelList(existing, nil, "")
|
||||
|
||||
if len(items) == 0 {
|
||||
return nil, nil, nil, nil, fmt.Errorf("no models available, run 'ollama pull <model>' first")
|
||||
}
|
||||
|
||||
return items, existingModels, cloudModels, client, nil
|
||||
}
|
||||
|
||||
func ensureAuth(ctx context.Context, client *api.Client, cloudModels map[string]bool, selected []string) error {
|
||||
var selectedCloudModels []string
|
||||
for _, m := range selected {
|
||||
if cloudModels[m] {
|
||||
selectedCloudModels = append(selectedCloudModels, m)
|
||||
}
|
||||
}
|
||||
if len(selectedCloudModels) == 0 {
|
||||
return nil
|
||||
}
|
||||
if len(selectedCloudModels) > 0 {
|
||||
// ensure user is signed in
|
||||
user, err := client.Whoami(ctx)
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
user, err := client.Whoami(ctx)
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return nil
|
||||
}
|
||||
var aErr api.AuthorizationError
|
||||
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var aErr api.AuthorizationError
|
||||
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
|
||||
return err
|
||||
}
|
||||
modelList := strings.Join(selectedCloudModels, ", ")
|
||||
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
|
||||
if err != nil || !yes {
|
||||
return nil, fmt.Errorf("%s requires sign in", modelList)
|
||||
}
|
||||
|
||||
modelList := strings.Join(selectedCloudModels, ", ")
|
||||
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
|
||||
if err != nil || !yes {
|
||||
return fmt.Errorf("%s requires sign in", modelList)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
|
||||
// TODO(parthsareen): extract into auth package for cmd
|
||||
// Auto-open browser (best effort, fail silently)
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
_ = exec.Command("open", aErr.SigninURL).Start()
|
||||
case "linux":
|
||||
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
|
||||
case "windows":
|
||||
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
_ = exec.Command("open", aErr.SigninURL).Start()
|
||||
case "linux":
|
||||
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
|
||||
case "windows":
|
||||
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
|
||||
}
|
||||
spinnerFrames := []string{"|", "/", "-", "\\"}
|
||||
frame := 0
|
||||
|
||||
spinnerFrames := []string{"|", "/", "-", "\\"}
|
||||
frame := 0
|
||||
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
|
||||
ticker := time.NewTicker(200 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
ticker := time.NewTicker(200 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
return nil, ctx.Err()
|
||||
case <-ticker.C:
|
||||
frame++
|
||||
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
frame++
|
||||
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
|
||||
|
||||
// poll every 10th frame (~2 seconds)
|
||||
if frame%10 == 0 {
|
||||
u, err := client.Whoami(ctx)
|
||||
if err == nil && u != nil && u.Name != "" {
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
|
||||
return nil
|
||||
// poll every 10th frame (~2 seconds)
|
||||
if frame%10 == 0 {
|
||||
u, err := client.Whoami(ctx)
|
||||
if err == nil && u != nil && u.Name != "" {
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
|
||||
return selected, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
func runIntegration(name, modelName string, args []string) error {
|
||||
@@ -294,33 +231,10 @@ func runIntegration(name, modelName string, args []string) error {
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", r, modelName)
|
||||
return r.Run(modelName, args)
|
||||
}
|
||||
|
||||
// syncAliases syncs aliases to server and saves locally for an AliasConfigurer.
|
||||
func syncAliases(ctx context.Context, client *api.Client, ac AliasConfigurer, name, model string, existing map[string]string) error {
|
||||
aliases := make(map[string]string)
|
||||
for k, v := range existing {
|
||||
aliases[k] = v
|
||||
}
|
||||
aliases["primary"] = model
|
||||
|
||||
if isCloudModel(ctx, client, model) {
|
||||
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
|
||||
aliases["fast"] = model
|
||||
}
|
||||
} else {
|
||||
delete(aliases, "fast")
|
||||
}
|
||||
|
||||
if err := ac.SetAliases(ctx, aliases); err != nil {
|
||||
return err
|
||||
}
|
||||
return saveAliases(name, aliases)
|
||||
}
|
||||
|
||||
// LaunchCmd returns the cobra command for launching integrations.
|
||||
func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error) *cobra.Command {
|
||||
var modelFlag string
|
||||
@@ -388,87 +302,9 @@ Examples:
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
|
||||
// Handle AliasConfigurer integrations (claude, codex)
|
||||
if ac, ok := r.(AliasConfigurer); ok {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate --model flag if provided
|
||||
if modelFlag != "" {
|
||||
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: modelFlag}); err != nil {
|
||||
return fmt.Errorf("model %q not found", modelFlag)
|
||||
}
|
||||
}
|
||||
|
||||
var model string
|
||||
var existingAliases map[string]string
|
||||
|
||||
// Load saved config
|
||||
if cfg, err := loadIntegration(name); err == nil {
|
||||
existingAliases = cfg.Aliases
|
||||
if len(cfg.Models) > 0 {
|
||||
model = cfg.Models[0]
|
||||
// AliasConfigurer integrations use single model; sanitize if multiple
|
||||
if len(cfg.Models) > 1 {
|
||||
_ = saveIntegration(name, []string{model})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --model flag overrides saved model
|
||||
if modelFlag != "" {
|
||||
model = modelFlag
|
||||
}
|
||||
|
||||
// Validate saved model still exists
|
||||
if model != "" && modelFlag == "" {
|
||||
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: model}); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%sConfigured model %q not found%s\n\n", ansiGray, model, ansiReset)
|
||||
model = ""
|
||||
}
|
||||
}
|
||||
|
||||
// If no valid model or --config flag, show picker
|
||||
if model == "" || configFlag {
|
||||
aliases, _, err := ac.ConfigureAliases(cmd.Context(), model, existingAliases, configFlag)
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
model = aliases["primary"]
|
||||
existingAliases = aliases
|
||||
}
|
||||
|
||||
// Sync aliases and save
|
||||
if err := syncAliases(cmd.Context(), client, ac, name, model, existingAliases); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%sWarning: Could not sync aliases: %v%s\n", ansiGray, err, ansiReset)
|
||||
}
|
||||
if err := saveIntegration(name, []string{model}); err != nil {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
|
||||
// Launch (unless --config without confirmation)
|
||||
if configFlag {
|
||||
if launch, _ := confirmPrompt(fmt.Sprintf("Launch %s now?", r)); launch {
|
||||
return runIntegration(name, model, passArgs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return runIntegration(name, model, passArgs)
|
||||
}
|
||||
|
||||
// Validate --model flag for non-AliasConfigurer integrations
|
||||
if modelFlag != "" {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: modelFlag}); err != nil {
|
||||
return fmt.Errorf("model %q not found", modelFlag)
|
||||
if !configFlag && modelFlag == "" {
|
||||
if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 {
|
||||
return runIntegration(name, config.Models[0], passArgs)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -482,8 +318,6 @@ Examples:
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if saved, err := loadIntegration(name); err == nil && len(saved.Models) > 0 && !configFlag {
|
||||
return runIntegration(name, saved.Models[0], passArgs)
|
||||
} else {
|
||||
var err error
|
||||
models, err = selectModels(cmd.Context(), name, "")
|
||||
@@ -546,9 +380,8 @@ Examples:
|
||||
}
|
||||
|
||||
type modelInfo struct {
|
||||
Name string
|
||||
Remote bool
|
||||
ToolCapable bool
|
||||
Name string
|
||||
Remote bool
|
||||
}
|
||||
|
||||
// buildModelList merges existing models with recommendations, sorts them, and returns
|
||||
@@ -585,7 +418,7 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
|
||||
continue
|
||||
}
|
||||
items = append(items, rec)
|
||||
if strings.HasSuffix(rec.Name, ":cloud") {
|
||||
if isCloudModel(rec.Name) {
|
||||
cloudModels[rec.Name] = true
|
||||
}
|
||||
}
|
||||
@@ -645,16 +478,8 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
|
||||
return items, preChecked, existingModels, cloudModels
|
||||
}
|
||||
|
||||
// isCloudModel checks if a model is a cloud model using the Show API.
|
||||
func isCloudModel(ctx context.Context, client *api.Client, name string) bool {
|
||||
if client == nil {
|
||||
return false
|
||||
}
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Name: name})
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return resp.RemoteModel != ""
|
||||
func isCloudModel(name string) bool {
|
||||
return strings.HasSuffix(name, ":cloud")
|
||||
}
|
||||
|
||||
func pullModel(ctx context.Context, client *api.Client, model string) error {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
@@ -298,15 +297,24 @@ func TestParseArgs(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestIsCloudModel(t *testing.T) {
|
||||
// isCloudModel now only uses Show API, so nil client always returns false
|
||||
t.Run("nil client returns false", func(t *testing.T) {
|
||||
models := []string{"glm-4.7:cloud", "kimi-k2.5:cloud", "local-model"}
|
||||
for _, model := range models {
|
||||
if isCloudModel(context.Background(), nil, model) {
|
||||
t.Errorf("isCloudModel(%q) with nil client should return false", model)
|
||||
tests := []struct {
|
||||
name string
|
||||
want bool
|
||||
}{
|
||||
{"glm-4.7:cloud", true},
|
||||
{"kimi-k2.5:cloud", true},
|
||||
{"glm-4.7-flash", false},
|
||||
{"glm-4.7-flash:latest", false},
|
||||
{"cloud-model", false},
|
||||
{"model:cloudish", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := isCloudModel(tt.name); got != tt.want {
|
||||
t.Errorf("isCloudModel(%q) = %v, want %v", tt.name, got, tt.want)
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func names(items []selectItem) []string {
|
||||
@@ -501,41 +509,3 @@ func TestBuildModelList_ReturnsExistingAndCloudMaps(t *testing.T) {
|
||||
t.Error("llama3.2 should not be in cloudModels")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEditorIntegration_SavedConfigSkipsSelection(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save a config for opencode so it looks like a previous launch
|
||||
if err := saveIntegration("opencode", []string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Verify loadIntegration returns the saved models
|
||||
saved, err := loadIntegration("opencode")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(saved.Models) == 0 {
|
||||
t.Fatal("expected saved models")
|
||||
}
|
||||
if saved.Models[0] != "llama3.2" {
|
||||
t.Errorf("expected llama3.2, got %s", saved.Models[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAliasConfigurerInterface(t *testing.T) {
|
||||
t.Run("claude implements AliasConfigurer", func(t *testing.T) {
|
||||
claude := &Claude{}
|
||||
if _, ok := interface{}(claude).(AliasConfigurer); !ok {
|
||||
t.Error("Claude should implement AliasConfigurer")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("codex does not implement AliasConfigurer", func(t *testing.T) {
|
||||
codex := &Codex{}
|
||||
if _, ok := interface{}(codex).(AliasConfigurer); ok {
|
||||
t.Error("Codex should not implement AliasConfigurer")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -17,6 +17,8 @@ type Openclaw struct{}
|
||||
|
||||
func (c *Openclaw) String() string { return "OpenClaw" }
|
||||
|
||||
const ansiGreen = "\033[32m"
|
||||
|
||||
func (c *Openclaw) Run(model string, args []string) error {
|
||||
bin := "openclaw"
|
||||
if _, err := exec.LookPath(bin); err != nil {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"maps"
|
||||
@@ -11,52 +10,12 @@ import (
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// OpenCode implements Runner and Editor for OpenCode integration
|
||||
type OpenCode struct{}
|
||||
|
||||
// cloudModelLimit holds context and output token limits for a cloud model.
|
||||
type cloudModelLimit struct {
|
||||
Context int
|
||||
Output int
|
||||
}
|
||||
|
||||
// cloudModelLimits maps cloud model base names to their token limits.
|
||||
// TODO(parthsareen): grab context/output limits from model info instead of hardcoding
|
||||
var cloudModelLimits = map[string]cloudModelLimit{
|
||||
"cogito-2.1:671b": {Context: 163_840, Output: 65_536},
|
||||
"deepseek-v3.1:671b": {Context: 163_840, Output: 163_840},
|
||||
"deepseek-v3.2": {Context: 163_840, Output: 65_536},
|
||||
"glm-4.6": {Context: 202_752, Output: 131_072},
|
||||
"glm-4.7": {Context: 202_752, Output: 131_072},
|
||||
"gpt-oss:120b": {Context: 131_072, Output: 131_072},
|
||||
"gpt-oss:20b": {Context: 131_072, Output: 131_072},
|
||||
"kimi-k2:1t": {Context: 262_144, Output: 262_144},
|
||||
"kimi-k2.5": {Context: 262_144, Output: 262_144},
|
||||
"kimi-k2-thinking": {Context: 262_144, Output: 262_144},
|
||||
"nemotron-3-nano:30b": {Context: 1_048_576, Output: 131_072},
|
||||
"qwen3-coder:480b": {Context: 262_144, Output: 65_536},
|
||||
"qwen3-next:80b": {Context: 262_144, Output: 32_768},
|
||||
}
|
||||
|
||||
// lookupCloudModelLimit returns the token limits for a cloud model.
|
||||
// It tries the exact name first, then strips the ":cloud" suffix.
|
||||
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
|
||||
if l, ok := cloudModelLimits[name]; ok {
|
||||
return l, true
|
||||
}
|
||||
base := strings.TrimSuffix(name, ":cloud")
|
||||
if base != name {
|
||||
if l, ok := cloudModelLimits[base]; ok {
|
||||
return l, true
|
||||
}
|
||||
}
|
||||
return cloudModelLimit{}, false
|
||||
}
|
||||
|
||||
func (o *OpenCode) String() string { return "OpenCode" }
|
||||
|
||||
func (o *OpenCode) Run(model string, args []string) error {
|
||||
@@ -154,8 +113,6 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
client, _ := api.ClientFromEnvironment()
|
||||
|
||||
for _, model := range modelList {
|
||||
if existing, ok := models[model].(map[string]any); ok {
|
||||
// migrate existing models without _launch marker
|
||||
@@ -165,29 +122,12 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
|
||||
}
|
||||
}
|
||||
if isCloudModel(context.Background(), client, model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
existing["limit"] = map[string]any{
|
||||
"context": l.Context,
|
||||
"output": l.Output,
|
||||
}
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
entry := map[string]any{
|
||||
models[model] = map[string]any{
|
||||
"name": model,
|
||||
"_launch": true,
|
||||
}
|
||||
if isCloudModel(context.Background(), client, model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
entry["limit"] = map[string]any{
|
||||
"context": l.Context,
|
||||
"output": l.Output,
|
||||
}
|
||||
}
|
||||
}
|
||||
models[model] = entry
|
||||
}
|
||||
|
||||
ollama["models"] = models
|
||||
|
||||
@@ -2,7 +2,6 @@ package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -496,165 +495,6 @@ func TestOpenCodeEdit_SpecialCharsInModelName(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func readOpenCodeModel(t *testing.T, configPath, model string) map[string]any {
|
||||
t.Helper()
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
provider := cfg["provider"].(map[string]any)
|
||||
ollama := provider["ollama"].(map[string]any)
|
||||
models := ollama["models"].(map[string]any)
|
||||
entry, ok := models[model].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("model %s not found in config", model)
|
||||
}
|
||||
return entry
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_LocalModelNoLimit(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configPath := filepath.Join(tmpDir, ".config", "opencode", "opencode.json")
|
||||
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entry := readOpenCodeModel(t, configPath, "llama3.2")
|
||||
if entry["limit"] != nil {
|
||||
t.Errorf("local model should not have limit set, got %v", entry["limit"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_PreservesUserLimit(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
|
||||
// Set up a model with a user-configured limit
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{
|
||||
"provider": {
|
||||
"ollama": {
|
||||
"models": {
|
||||
"llama3.2": {
|
||||
"name": "llama3.2",
|
||||
"_launch": true,
|
||||
"limit": {"context": 8192, "output": 4096}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`), 0o644)
|
||||
|
||||
// Re-edit should preserve the user's limit (not delete it)
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entry := readOpenCodeModel(t, configPath, "llama3.2")
|
||||
limit, ok := entry["limit"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("user-configured limit was removed")
|
||||
}
|
||||
if limit["context"] != float64(8192) {
|
||||
t.Errorf("context limit changed: got %v, want 8192", limit["context"])
|
||||
}
|
||||
if limit["output"] != float64(4096) {
|
||||
t.Errorf("output limit changed: got %v, want 4096", limit["output"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_CloudModelLimitStructure(t *testing.T) {
|
||||
// Verify that when a cloud model entry has limits set (as Edit would do),
|
||||
// the structure matches what opencode expects and re-edit preserves them.
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
|
||||
expected := cloudModelLimits["glm-4.7"]
|
||||
|
||||
// Simulate a cloud model that already has the limit set by a previous Edit
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(fmt.Sprintf(`{
|
||||
"provider": {
|
||||
"ollama": {
|
||||
"models": {
|
||||
"glm-4.7:cloud": {
|
||||
"name": "glm-4.7:cloud",
|
||||
"_launch": true,
|
||||
"limit": {"context": %d, "output": %d}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`, expected.Context, expected.Output)), 0o644)
|
||||
|
||||
// Re-edit should preserve the cloud model limit
|
||||
if err := o.Edit([]string{"glm-4.7:cloud"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entry := readOpenCodeModel(t, configPath, "glm-4.7:cloud")
|
||||
limit, ok := entry["limit"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("cloud model limit was removed on re-edit")
|
||||
}
|
||||
if limit["context"] != float64(expected.Context) {
|
||||
t.Errorf("context = %v, want %d", limit["context"], expected.Context)
|
||||
}
|
||||
if limit["output"] != float64(expected.Output) {
|
||||
t.Errorf("output = %v, want %d", limit["output"], expected.Output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookupCloudModelLimit(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
wantOK bool
|
||||
wantContext int
|
||||
wantOutput int
|
||||
}{
|
||||
{"glm-4.7", true, 202_752, 131_072},
|
||||
{"glm-4.7:cloud", true, 202_752, 131_072},
|
||||
{"kimi-k2.5", true, 262_144, 262_144},
|
||||
{"kimi-k2.5:cloud", true, 262_144, 262_144},
|
||||
{"deepseek-v3.2", true, 163_840, 65_536},
|
||||
{"deepseek-v3.2:cloud", true, 163_840, 65_536},
|
||||
{"qwen3-coder:480b", true, 262_144, 65_536},
|
||||
{"llama3.2", false, 0, 0},
|
||||
{"unknown-model:cloud", false, 0, 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
l, ok := lookupCloudModelLimit(tt.name)
|
||||
if ok != tt.wantOK {
|
||||
t.Errorf("lookupCloudModelLimit(%q) ok = %v, want %v", tt.name, ok, tt.wantOK)
|
||||
}
|
||||
if ok {
|
||||
if l.Context != tt.wantContext {
|
||||
t.Errorf("context = %d, want %d", l.Context, tt.wantContext)
|
||||
}
|
||||
if l.Output != tt.wantOutput {
|
||||
t.Errorf("output = %d, want %d", l.Output, tt.wantOutput)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeModels_NoConfig(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
@@ -17,7 +17,6 @@ const (
|
||||
ansiBold = "\033[1m"
|
||||
ansiReset = "\033[0m"
|
||||
ansiGray = "\033[37m"
|
||||
ansiGreen = "\033[32m"
|
||||
ansiClearDown = "\033[J"
|
||||
)
|
||||
|
||||
|
||||
@@ -96,14 +96,6 @@ func TestSelectState(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_EmptyFilteredList_EmptyFilter_DoesNothing", func(t *testing.T) {
|
||||
s := newSelectState([]selectItem{})
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
if done || result != "" || err != nil {
|
||||
t.Errorf("expected (false, '', nil), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Escape_ReturnsCancelledError", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
done, result, err := s.handleInput(eventEscape, 0)
|
||||
@@ -582,19 +574,8 @@ func TestRenderSelect(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
renderSelect(&buf, "Select:", s)
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "no matches") {
|
||||
t.Errorf("expected 'no matches' message, got: %s", output)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EmptyFilteredList_EmptyFilter_ShowsNoMatches", func(t *testing.T) {
|
||||
s := newSelectState([]selectItem{})
|
||||
var buf bytes.Buffer
|
||||
renderSelect(&buf, "Select:", s)
|
||||
|
||||
if !strings.Contains(buf.String(), "no matches") {
|
||||
t.Error("expected 'no matches' message for empty list with no filter")
|
||||
t.Error("expected 'no matches' message")
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@@ -10,21 +10,19 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
var errNotRunning = errors.New("could not connect to ollama server, run 'ollama serve' to start it")
|
||||
|
||||
func startApp(ctx context.Context, client *api.Client) error {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return errNotRunning
|
||||
return err
|
||||
}
|
||||
link, err := os.Readlink(exe)
|
||||
if err != nil {
|
||||
return errNotRunning
|
||||
return err
|
||||
}
|
||||
r := regexp.MustCompile(`^.*/Ollama\s?\d*.app`)
|
||||
m := r.FindStringSubmatch(link)
|
||||
if len(m) != 1 {
|
||||
return errNotRunning
|
||||
return errors.New("could not find ollama app")
|
||||
}
|
||||
if err := exec.Command("/usr/bin/open", "-j", "-a", m[0], "--args", "--fast-startup").Run(); err != nil {
|
||||
return err
|
||||
|
||||
@@ -188,6 +188,8 @@ func LogLevel() slog.Level {
|
||||
var (
|
||||
// FlashAttention enables the experimental flash attention feature.
|
||||
FlashAttention = BoolWithDefault("OLLAMA_FLASH_ATTENTION")
|
||||
// DebugLogRequests logs inference requests to disk for replay/debugging.
|
||||
DebugLogRequests = Bool("OLLAMA_DEBUG_LOG_REQUESTS")
|
||||
// KvCacheType is the quantization type for the K/V cache.
|
||||
KvCacheType = String("OLLAMA_KV_CACHE_TYPE")
|
||||
// NoHistory disables readline history.
|
||||
@@ -273,26 +275,27 @@ type EnvVar struct {
|
||||
|
||||
func AsMap() map[string]EnvVar {
|
||||
ret := map[string]EnvVar{
|
||||
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", LogLevel(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
|
||||
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(false), "Enabled flash attention"},
|
||||
"OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"},
|
||||
"OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"},
|
||||
"OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"},
|
||||
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"},
|
||||
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary(), "Set LLM library to bypass autodetection"},
|
||||
"OLLAMA_LOAD_TIMEOUT": {"OLLAMA_LOAD_TIMEOUT", LoadTimeout(), "How long to allow model loads to stall before giving up (default \"5m\")"},
|
||||
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners(), "Maximum number of loaded models per GPU"},
|
||||
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueue(), "Maximum number of queued requests"},
|
||||
"OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path to the models directory"},
|
||||
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
|
||||
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
|
||||
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
|
||||
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
|
||||
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
|
||||
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
||||
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4k/32k/256k based on VRAM)"},
|
||||
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
|
||||
"OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},
|
||||
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", LogLevel(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
|
||||
"OLLAMA_DEBUG_LOG_REQUESTS": {"OLLAMA_DEBUG_LOG_REQUESTS", DebugLogRequests(), "Log inference request bodies and replay curl commands to a temp directory"},
|
||||
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(false), "Enabled flash attention"},
|
||||
"OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"},
|
||||
"OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"},
|
||||
"OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"},
|
||||
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"},
|
||||
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary(), "Set LLM library to bypass autodetection"},
|
||||
"OLLAMA_LOAD_TIMEOUT": {"OLLAMA_LOAD_TIMEOUT", LoadTimeout(), "How long to allow model loads to stall before giving up (default \"5m\")"},
|
||||
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners(), "Maximum number of loaded models per GPU"},
|
||||
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueue(), "Maximum number of queued requests"},
|
||||
"OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path to the models directory"},
|
||||
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
|
||||
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
|
||||
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
|
||||
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
|
||||
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
|
||||
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
||||
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4k/32k/256k based on VRAM)"},
|
||||
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
|
||||
"OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},
|
||||
|
||||
// Informational
|
||||
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},
|
||||
|
||||
@@ -34,7 +34,6 @@ import (
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type filteredEnv []string
|
||||
@@ -117,7 +116,7 @@ type llamaServer struct {
|
||||
type ollamaServer struct {
|
||||
llmServer
|
||||
|
||||
tokenizer tokenizer.Tokenizer // tokenizer handles text encoding/decoding
|
||||
textProcessor model.TextProcessor // textProcessor handles text encoding/decoding
|
||||
}
|
||||
|
||||
// LoadModel will load a model from disk. The model must be in the GGML format.
|
||||
@@ -143,11 +142,11 @@ func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) {
|
||||
// NewLlamaServer will run a server for the given GPUs
|
||||
func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
|
||||
var llamaModel *llama.Model
|
||||
var tok tokenizer.Tokenizer
|
||||
var textProcessor model.TextProcessor
|
||||
var err error
|
||||
if envconfig.NewEngine() || f.KV().OllamaEngineRequired() {
|
||||
if len(projectors) == 0 {
|
||||
tok, err = model.NewTextProcessor(modelPath)
|
||||
textProcessor, err = model.NewTextProcessor(modelPath)
|
||||
} else {
|
||||
err = errors.New("split vision models aren't supported")
|
||||
}
|
||||
@@ -156,7 +155,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err)
|
||||
}
|
||||
}
|
||||
if tok == nil {
|
||||
if textProcessor == nil {
|
||||
llamaModel, err = llama.LoadModelFromFile(modelPath, llama.ModelParams{VocabOnly: true})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -212,7 +211,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
|
||||
kvct := strings.ToLower(envconfig.KvCacheType())
|
||||
|
||||
if tok == nil {
|
||||
if textProcessor == nil {
|
||||
flashAttention := ml.FlashAttentionAuto
|
||||
if faUserSet {
|
||||
if fa {
|
||||
@@ -262,7 +261,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
gpuLibs := ml.LibraryPaths(gpus)
|
||||
status := NewStatusWriter(os.Stderr)
|
||||
cmd, port, err := StartRunner(
|
||||
tok != nil,
|
||||
textProcessor != nil,
|
||||
modelPath,
|
||||
gpuLibs,
|
||||
status,
|
||||
@@ -311,8 +310,8 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
}
|
||||
}()
|
||||
|
||||
if tok != nil {
|
||||
return &ollamaServer{llmServer: s, tokenizer: tok}, nil
|
||||
if textProcessor != nil {
|
||||
return &ollamaServer{llmServer: s, textProcessor: textProcessor}, nil
|
||||
} else {
|
||||
return &llamaServer{llmServer: s, ggml: f}, nil
|
||||
}
|
||||
@@ -1775,7 +1774,7 @@ func (s *llamaServer) Tokenize(ctx context.Context, content string) ([]int, erro
|
||||
}
|
||||
|
||||
func (s *ollamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
tokens, err := s.tokenizer.Encode(content, false)
|
||||
tokens, err := s.textProcessor.Encode(content, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1810,7 +1809,7 @@ func (s *ollamaServer) Detokenize(ctx context.Context, tokens []int) (string, er
|
||||
toks[i] = int32(t)
|
||||
}
|
||||
|
||||
content, err := s.tokenizer.Decode(toks)
|
||||
content, err := s.textProcessor.Decode(toks)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -131,15 +131,12 @@ func AnthropicMessagesMiddleware() gin.HandlerFunc {
|
||||
|
||||
messageID := anthropic.GenerateMessageID()
|
||||
|
||||
// Estimate input tokens for streaming (actual count not available until generation completes)
|
||||
estimatedTokens := anthropic.EstimateInputTokens(req)
|
||||
|
||||
w := &AnthropicWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: messageID,
|
||||
model: req.Model,
|
||||
converter: anthropic.NewStreamConverter(messageID, req.Model, estimatedTokens),
|
||||
converter: anthropic.NewStreamConverter(messageID, req.Model),
|
||||
}
|
||||
|
||||
if req.Stream {
|
||||
|
||||
272
model/bytepairencoding.go
Normal file
272
model/bytepairencoding.go
Normal file
@@ -0,0 +1,272 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"iter"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/dlclark/regexp2"
|
||||
heap "github.com/emirpasic/gods/v2/trees/binaryheap"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type BytePairEncoding struct {
|
||||
vocab *Vocabulary
|
||||
regexps []*regexp2.Regexp
|
||||
}
|
||||
|
||||
var _ TextProcessor = (*BytePairEncoding)(nil)
|
||||
|
||||
func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding {
|
||||
if len(pretokenizers) == 0 {
|
||||
// set default byte-level pretokenizer if none provided, e.g.
|
||||
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44
|
||||
pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
|
||||
}
|
||||
|
||||
return BytePairEncoding{
|
||||
vocab: vocab,
|
||||
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
|
||||
for _, p := range pretokenizers {
|
||||
if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Vocabulary() *Vocabulary {
|
||||
return bpe.vocab
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Is(id int32, special Special) bool {
|
||||
return bpe.vocab.Is(id, special)
|
||||
}
|
||||
|
||||
func (bpe *BytePairEncoding) split(s string) iter.Seq[string] {
|
||||
parts := []string{s}
|
||||
for _, re := range bpe.regexps {
|
||||
parts = slices.Collect(func(yield func(string) bool) {
|
||||
for _, part := range parts {
|
||||
r := []rune(part)
|
||||
var offset int
|
||||
for m, _ := re.FindRunesMatch(r); m != nil; m, _ = re.FindNextMatch(m) {
|
||||
if offset-m.Index != 0 {
|
||||
if !yield(string(r[:m.Index])) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !yield(m.String()) {
|
||||
return
|
||||
}
|
||||
|
||||
offset = m.Index + m.Length
|
||||
}
|
||||
|
||||
if offset < len(r) {
|
||||
if !yield(string(r[offset:])) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return slices.Values(parts)
|
||||
}
|
||||
|
||||
// fragment is a string fragment and their corresponding token IDs
|
||||
type fragment struct {
|
||||
value string
|
||||
ids []int32
|
||||
}
|
||||
|
||||
// pair is a pair of runes and its rank
|
||||
type pair struct {
|
||||
a, b int
|
||||
rank int
|
||||
value string
|
||||
}
|
||||
|
||||
type merge struct {
|
||||
p, n int
|
||||
runes []rune
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
fragments := []fragment{{value: s}}
|
||||
for _, special := range bpe.vocab.SpecialVocabulary() {
|
||||
// TODO: process special tokens concurrently
|
||||
id := bpe.vocab.Encode(special)
|
||||
for i := 0; i < len(fragments); i++ {
|
||||
frag := fragments[i]
|
||||
if len(frag.ids) > 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var middle []fragment
|
||||
switch i := strings.Index(frag.value, special); {
|
||||
case i < 0:
|
||||
middle = append(middle, frag)
|
||||
case i > 0:
|
||||
middle = append(middle, fragment{value: frag.value[:i]})
|
||||
fallthrough
|
||||
default:
|
||||
middle = append(middle, fragment{value: special, ids: []int32{id}})
|
||||
if rest := frag.value[i+len(special):]; rest != "" {
|
||||
middle = append(middle, fragment{value: rest})
|
||||
}
|
||||
}
|
||||
|
||||
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
|
||||
}
|
||||
}
|
||||
|
||||
var ids []int32
|
||||
for _, frag := range fragments {
|
||||
if len(frag.ids) > 0 {
|
||||
ids = append(ids, frag.ids...)
|
||||
continue
|
||||
}
|
||||
|
||||
for split := range bpe.split(frag.value) {
|
||||
// TODO: process splits concurrently
|
||||
var sb strings.Builder
|
||||
for _, b := range []byte(split) {
|
||||
r := rune(b)
|
||||
switch {
|
||||
case r == 0x00ad:
|
||||
r = 0x0143
|
||||
case r <= 0x0020:
|
||||
r = r + 0x0100
|
||||
case r >= 0x007f && r <= 0x00a0:
|
||||
r = r + 0x00a2
|
||||
}
|
||||
|
||||
sb.WriteRune(r)
|
||||
}
|
||||
|
||||
// short circuit if the fragment is in the vocabulary
|
||||
if id := bpe.vocab.Encode(sb.String()); id >= 0 {
|
||||
ids = append(ids, id)
|
||||
continue
|
||||
}
|
||||
|
||||
runes := []rune(sb.String())
|
||||
merges := make([]merge, len(runes))
|
||||
for r := range runes {
|
||||
merges[r] = merge{
|
||||
p: r - 1,
|
||||
n: r + 1,
|
||||
runes: []rune{runes[r]},
|
||||
}
|
||||
}
|
||||
|
||||
pairwise := func(a, b int) *pair {
|
||||
if a < 0 || b >= len(runes) {
|
||||
return nil
|
||||
}
|
||||
|
||||
left, right := string(merges[a].runes), string(merges[b].runes)
|
||||
rank := bpe.vocab.Merge(left, right)
|
||||
if rank < 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &pair{
|
||||
a: a,
|
||||
b: b,
|
||||
rank: rank,
|
||||
value: left + right,
|
||||
}
|
||||
}
|
||||
|
||||
pairs := heap.NewWith(func(i, j *pair) int {
|
||||
return cmp.Compare(i.rank, j.rank)
|
||||
})
|
||||
|
||||
for i := range len(runes) - 1 {
|
||||
if pair := pairwise(i, i+1); pair != nil {
|
||||
pairs.Push(pair)
|
||||
}
|
||||
}
|
||||
|
||||
for !pairs.Empty() {
|
||||
pair, _ := pairs.Pop()
|
||||
|
||||
left, right := merges[pair.a], merges[pair.b]
|
||||
if len(left.runes) == 0 || len(right.runes) == 0 ||
|
||||
string(left.runes)+string(right.runes) != pair.value {
|
||||
continue
|
||||
}
|
||||
|
||||
if id := bpe.vocab.Encode(pair.value); id < 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
merges[pair.a].runes = append(left.runes, right.runes...)
|
||||
merges[pair.b].runes = nil
|
||||
|
||||
merges[pair.a].n = right.n
|
||||
if right.n < len(merges) {
|
||||
merges[right.n].p = pair.a
|
||||
}
|
||||
|
||||
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
|
||||
pairs.Push(pair)
|
||||
}
|
||||
|
||||
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
|
||||
pairs.Push(pair)
|
||||
}
|
||||
}
|
||||
|
||||
for _, merge := range merges {
|
||||
if len(merge.runes) > 0 {
|
||||
// TODO: handle the edge case where the rune isn't in the vocabulary
|
||||
if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if addSpecial {
|
||||
ids = bpe.vocab.addSpecials(ids)
|
||||
}
|
||||
|
||||
logutil.Trace("encoded", "string", s, "ids", ids)
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
|
||||
var sb strings.Builder
|
||||
for _, id := range ids {
|
||||
for _, r := range bpe.vocab.Decode(id) {
|
||||
switch {
|
||||
case r == 0x0100:
|
||||
// this produces 0x00 aka NULL
|
||||
continue
|
||||
case r == 0x0143:
|
||||
r = 0x00ad
|
||||
case r > 0x0100 && r <= 0x0120:
|
||||
r = r - 0x0100
|
||||
case r > 0x0120 && r <= 0x0142:
|
||||
r = r - 0x00a2
|
||||
}
|
||||
|
||||
// NOTE: not using WriteRune here because it writes the UTF-8
|
||||
// encoding of the rune which is _not_ what we want
|
||||
if err := sb.WriteByte(byte(r)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logutil.Trace("decoded", "string", sb.String(), "from", ids)
|
||||
return sb.String(), nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package tokenizer
|
||||
package model
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
func llama(t testing.TB) BytePairEncoding {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.Open(filepath.FromSlash("testdata/llama3.2/encoder.json"))
|
||||
f, err := os.Open(filepath.Join("testdata", "llama3.2", "encoder.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -43,7 +43,7 @@ func llama(t testing.TB) BytePairEncoding {
|
||||
}
|
||||
}
|
||||
|
||||
f, err = os.Open(filepath.FromSlash("testdata/llama3.2/vocab.bpe"))
|
||||
f, err = os.Open(filepath.Join("testdata", "llama3.2", "vocab.bpe"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -23,7 +23,6 @@ import (
|
||||
_ "github.com/ollama/ollama/ml/backend"
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -134,7 +133,7 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func NewTextProcessor(s string) (tokenizer.Tokenizer, error) {
|
||||
func NewTextProcessor(s string) (TextProcessor, error) {
|
||||
r, err := os.Open(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -151,7 +150,7 @@ func NewTextProcessor(s string) (tokenizer.Tokenizer, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tp, ok := m.(tokenizer.Tokenizer)
|
||||
tp, ok := m.(TextProcessor)
|
||||
if !ok {
|
||||
return nil, ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
@@ -10,12 +10,11 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.TextProcessor
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
TypeEmbedding *nn.Embedding `gguf:"token_types"`
|
||||
@@ -130,7 +129,7 @@ func (o Options) headDim() int {
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
vocab := &tokenizer.Vocabulary{
|
||||
vocab := &model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -154,17 +153,17 @@ func New(c fs.Config) (model.Model, error) {
|
||||
},
|
||||
}
|
||||
|
||||
var t tokenizer.Tokenizer
|
||||
var processor model.TextProcessor
|
||||
switch c.String("tokenizer.ggml.model", "bert") {
|
||||
case "bert":
|
||||
t = tokenizer.NewWordPiece(vocab, true)
|
||||
processor = model.NewWordPiece(vocab, true)
|
||||
default:
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
return &Model{
|
||||
Tokenizer: t,
|
||||
Layers: make([]EncoderLayer, c.Uint("block_count")),
|
||||
TextProcessor: processor,
|
||||
Layers: make([]EncoderLayer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -223,7 +222,7 @@ func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.BytePairEncoding
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -278,8 +277,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -10,12 +10,11 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.TextProcessor
|
||||
|
||||
Sam *samModel `gguf:"s"`
|
||||
Vision *visionModel `gguf:"v"`
|
||||
@@ -135,8 +134,8 @@ func init() {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
TextProcessor: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -28,7 +27,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.SentencePiece
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -44,8 +43,8 @@ const (
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
Tokenizer: tokenizer.NewSentencePiece(
|
||||
&tokenizer.Vocabulary{
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
|
||||
@@ -7,12 +7,11 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type embedModel struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.SentencePiece
|
||||
|
||||
*TextModel
|
||||
poolingType pooling.Type
|
||||
@@ -32,8 +31,8 @@ func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, erro
|
||||
|
||||
func newEmbedModel(c fs.Config) (model.Model, error) {
|
||||
m := &embedModel{
|
||||
Tokenizer: tokenizer.NewSentencePiece(
|
||||
&tokenizer.Vocabulary{
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
|
||||
@@ -12,12 +12,11 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.TextProcessor
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
*TextModel
|
||||
@@ -55,7 +54,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
vocabulary := tokenizer.Vocabulary{
|
||||
vocabulary := model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -71,19 +70,19 @@ func New(c fs.Config) (model.Model, error) {
|
||||
),
|
||||
}
|
||||
|
||||
var t tokenizer.Tokenizer
|
||||
var processor model.TextProcessor
|
||||
switch c.String("tokenizer.ggml.model") {
|
||||
case "gpt2":
|
||||
t = tokenizer.NewBytePairEncoding(&vocabulary)
|
||||
processor = model.NewBytePairEncoding(&vocabulary)
|
||||
default:
|
||||
// Previous uploads of Gemma 3 on Ollama did not have token 106
|
||||
// (i.e. "<end_of_turn>") so we need to add in case it's not already present
|
||||
vocabulary.EOS = append(vocabulary.EOS, int32(c.Uint("tokenizer.ggml.eot_token_id", 106)))
|
||||
t = tokenizer.NewSentencePiece(&vocabulary)
|
||||
processor = model.NewSentencePiece(&vocabulary)
|
||||
}
|
||||
|
||||
m := Model{
|
||||
Tokenizer: t,
|
||||
TextProcessor: processor,
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
TextModel: newTextModel(c),
|
||||
|
||||
@@ -6,12 +6,11 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.SentencePiece
|
||||
|
||||
*TextModel
|
||||
}
|
||||
@@ -24,8 +23,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
TextModel: newTextModel(c),
|
||||
Tokenizer: tokenizer.NewSentencePiece(
|
||||
&tokenizer.Vocabulary{
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
var ErrOldModelFormat = errors.New("this model uses a weight format that is no longer supported; please re-download it")
|
||||
@@ -199,7 +198,7 @@ func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.BytePairEncoding
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -237,8 +236,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -11,12 +11,11 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.BytePairEncoding
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v"`
|
||||
@@ -38,8 +37,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
allEOS := append([]int32{eosTokenID}, eosTokenIDs...)
|
||||
|
||||
m := &Model{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -12,12 +12,11 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Transformer struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.BytePairEncoding
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
TransformerBlocks []TransformerBlock `gguf:"blk"`
|
||||
@@ -197,8 +196,8 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Optio
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Transformer{
|
||||
TransformerBlocks: make([]TransformerBlock, c.Uint("block_count")),
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -60,7 +59,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.TextProcessor
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -79,7 +78,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
vocabulary := tokenizer.Vocabulary{
|
||||
vocabulary := model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -105,8 +104,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
TextProcessor: model.NewBytePairEncoding(&vocabulary, pretokenizers...),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
headDim: int(c.Uint("attention.key_length")),
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -26,7 +25,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.TextProcessor
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -42,8 +41,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
return nil, model.ErrUnsupportedModel
|
||||
}
|
||||
|
||||
var processor tokenizer.Tokenizer
|
||||
vocabulary := tokenizer.Vocabulary{
|
||||
var processor model.TextProcessor
|
||||
vocabulary := model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -81,16 +80,16 @@ func New(c fs.Config) (model.Model, error) {
|
||||
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
}
|
||||
}
|
||||
processor = tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...)
|
||||
processor = model.NewBytePairEncoding(&vocabulary, pretokenizers...)
|
||||
case "llama":
|
||||
processor = tokenizer.NewSentencePiece(&vocabulary)
|
||||
processor = model.NewSentencePiece(&vocabulary)
|
||||
default:
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
m := Model{
|
||||
Tokenizer: processor,
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
TextProcessor: processor,
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
|
||||
@@ -11,12 +11,11 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.BytePairEncoding
|
||||
ImageProcessor
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
@@ -34,8 +33,8 @@ func (p *Projector) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor {
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -11,12 +11,11 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.BytePairEncoding
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v"`
|
||||
@@ -29,12 +28,12 @@ type Model struct {
|
||||
var _ model.MultimodalProcessor = (*Model)(nil)
|
||||
|
||||
// Implement TextProcessor interface
|
||||
var _ tokenizer.Tokenizer = (*Model)(nil)
|
||||
var _ model.TextProcessor = (*Model)(nil)
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := &Model{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -11,12 +11,11 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.BytePairEncoding
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
*TextModel
|
||||
@@ -33,8 +32,8 @@ const (
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -11,12 +11,11 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.TextProcessor
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
TypeEmbedding *nn.Embedding `gguf:"token_types"`
|
||||
@@ -179,6 +178,29 @@ func New(c fs.Config) (model.Model, error) {
|
||||
numHeads := int(c.Uint("attention.head_count"))
|
||||
headDim := hiddenSize / numHeads
|
||||
|
||||
processor := model.NewWordPiece(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{
|
||||
int32(cmp.Or(
|
||||
c.Uint("tokenizer.ggml.cls_token_id"),
|
||||
c.Uint("tokenizer.ggml.bos_token_id"),
|
||||
)),
|
||||
},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
|
||||
EOS: []int32{
|
||||
int32(cmp.Or(
|
||||
c.Uint("tokenizer.ggml.separator_token_id"),
|
||||
c.Uint("tokenizer.ggml.eos_token_id"),
|
||||
)),
|
||||
},
|
||||
},
|
||||
false,
|
||||
)
|
||||
|
||||
blockCount := int(c.Uint("block_count"))
|
||||
moeEveryNLayers := int(c.Uint("moe_every_n_layers", 0))
|
||||
layers := make([]EncoderLayer, blockCount)
|
||||
@@ -197,29 +219,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
return &Model{
|
||||
Tokenizer: tokenizer.NewWordPiece(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{
|
||||
int32(cmp.Or(
|
||||
c.Uint("tokenizer.ggml.cls_token_id"),
|
||||
c.Uint("tokenizer.ggml.bos_token_id"),
|
||||
)),
|
||||
},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
|
||||
EOS: []int32{
|
||||
int32(cmp.Or(
|
||||
c.Uint("tokenizer.ggml.separator_token_id"),
|
||||
c.Uint("tokenizer.ggml.eos_token_id"),
|
||||
)),
|
||||
},
|
||||
},
|
||||
false,
|
||||
),
|
||||
Layers: layers,
|
||||
TextProcessor: processor,
|
||||
Layers: layers,
|
||||
Options: Options{
|
||||
hiddenSize: hiddenSize,
|
||||
numHeads: numHeads,
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -34,7 +33,7 @@ type Options struct {
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.TextProcessor
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -45,24 +44,28 @@ type Model struct {
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
vocabulary := model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
}
|
||||
|
||||
processor := model.NewBytePairEncoding(
|
||||
&vocabulary,
|
||||
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
)
|
||||
|
||||
m := Model{
|
||||
TextProcessor: processor,
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -93,7 +92,7 @@ func (d DecoderLayer) Forward(ctx ml.Context, hiddenStates, positions, outputs m
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.BytePairEncoding
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []DecoderLayer `gguf:"blk"`
|
||||
@@ -140,8 +139,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
m := Model{
|
||||
Layers: make([]DecoderLayer, c.Uint("block_count")),
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -10,12 +10,11 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.BytePairEncoding
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v"`
|
||||
@@ -28,8 +27,8 @@ var _ model.MultimodalProcessor = (*Model)(nil)
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := &Model{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -7,12 +7,11 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type embedModel struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.BytePairEncoding
|
||||
|
||||
*Model
|
||||
poolingType pooling.Type
|
||||
@@ -35,8 +34,8 @@ func newEmbed(c fs.Config) (model.Model, error) {
|
||||
layers[i].MLP = &dense{}
|
||||
}
|
||||
m := embedModel{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -160,7 +159,7 @@ func (d *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.BytePairEncoding
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
@@ -219,8 +218,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
// Options contains model configuration
|
||||
@@ -208,7 +207,7 @@ func (l *Layer) Forward(ctx ml.Context, layer int, hiddenStates, positions, outp
|
||||
// Model is the main Qwen3-Next model
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.BytePairEncoding
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
@@ -354,8 +353,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -10,12 +10,11 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
model.TextProcessor
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v"`
|
||||
@@ -173,8 +172,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
TextProcessor: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package tokenizer
|
||||
package model
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
@@ -17,7 +17,7 @@ type SentencePiece struct {
|
||||
vocab *Vocabulary
|
||||
}
|
||||
|
||||
var _ Tokenizer = (*SentencePiece)(nil)
|
||||
var _ TextProcessor = (*SentencePiece)(nil)
|
||||
|
||||
func (spm SentencePiece) Vocabulary() *Vocabulary {
|
||||
return spm.vocab
|
||||
@@ -224,7 +224,7 @@ func (spm SentencePiece) Decode(ids []int32) (string, error) {
|
||||
data := spm.vocab.Decode(id)
|
||||
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
|
||||
|
||||
// For tokenizer that use byte tokens like "<0xEA>"
|
||||
// For tokenizers that use byte tokens like "<0xEA>"
|
||||
// convert them to the partial unicode character
|
||||
// so they are buffered correctly by the runner instead
|
||||
// of being sent back to the api as "<0xEA>"
|
||||
@@ -1,4 +1,4 @@
|
||||
package tokenizer
|
||||
package model
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
func loadSentencePieceVocab(t *testing.T) SentencePiece {
|
||||
t.Helper()
|
||||
|
||||
bts, err := os.ReadFile(filepath.FromSlash("testdata/gemma2/tokenizer.model"))
|
||||
bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package tokenizer
|
||||
package model
|
||||
|
||||
const (
|
||||
TOKEN_TYPE_NORMAL = iota + 1
|
||||
@@ -9,7 +9,7 @@ const (
|
||||
TOKEN_TYPE_BYTE
|
||||
)
|
||||
|
||||
type Tokenizer interface {
|
||||
type TextProcessor interface {
|
||||
Encode(s string, addSpecial bool) ([]int32, error)
|
||||
Decode([]int32) (string, error)
|
||||
Is(int32, Special) bool
|
||||
@@ -1,4 +1,4 @@
|
||||
package tokenizer
|
||||
package model
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
@@ -1,4 +1,4 @@
|
||||
package tokenizer
|
||||
package model
|
||||
|
||||
import (
|
||||
"testing"
|
||||
@@ -1,4 +1,4 @@
|
||||
package tokenizer
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -32,7 +32,7 @@ var wordPieceReplacer = strings.NewReplacer(
|
||||
" 're", "'re",
|
||||
)
|
||||
|
||||
// Decode implements Tokenizer.
|
||||
// Decode implements TextProcessor.
|
||||
func (wpm WordPiece) Decode(ids []int32) (string, error) {
|
||||
var sb strings.Builder
|
||||
for i, id := range ids {
|
||||
@@ -96,7 +96,7 @@ func (wpm WordPiece) words(s string) iter.Seq[string] {
|
||||
}
|
||||
}
|
||||
|
||||
// Encode implements Tokenizer.
|
||||
// Encode implements TextProcessor.
|
||||
func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
var ids []int32
|
||||
|
||||
@@ -151,17 +151,17 @@ func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
// Is implements Tokenizer.
|
||||
// Is implements TextProcessor.
|
||||
func (wpm WordPiece) Is(id int32, special Special) bool {
|
||||
return wpm.vocab.Is(id, special)
|
||||
}
|
||||
|
||||
// Vocabulary implements Tokenizer.
|
||||
// Vocabulary implements TextProcessor.
|
||||
func (wpm WordPiece) Vocabulary() *Vocabulary {
|
||||
return wpm.vocab
|
||||
}
|
||||
|
||||
var _ Tokenizer = (*WordPiece)(nil)
|
||||
var _ TextProcessor = (*WordPiece)(nil)
|
||||
|
||||
func NewWordPiece(vocab *Vocabulary, lowercase bool) WordPiece {
|
||||
return WordPiece{
|
||||
@@ -1,4 +1,4 @@
|
||||
package tokenizer
|
||||
package model
|
||||
|
||||
import (
|
||||
"slices"
|
||||
@@ -37,7 +37,6 @@ import (
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/runner/common"
|
||||
"github.com/ollama/ollama/sample"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
|
||||
_ "github.com/ollama/ollama/model/models"
|
||||
)
|
||||
@@ -211,9 +210,9 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
}
|
||||
|
||||
// calculateLogprobs converts raw logits to log probabilities and finds top K tokens
|
||||
func calculateLogprobs(logits []float32, selectedToken int32, topK int, tok tokenizer.Tokenizer) []llm.Logprob {
|
||||
func calculateLogprobs(logits []float32, selectedToken int32, topK int, textProcessor model.TextProcessor) []llm.Logprob {
|
||||
decoder := func(tokenID int) string {
|
||||
text, _ := tok.Decode([]int32{int32(tokenID)})
|
||||
text, _ := textProcessor.Decode([]int32{int32(tokenID)})
|
||||
return text
|
||||
}
|
||||
return common.CalculateLogprobs(logits, int(selectedToken), topK, decoder)
|
||||
@@ -243,7 +242,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input,
|
||||
|
||||
for i, part := range parts {
|
||||
// text - tokenize
|
||||
tokens, err := s.model.(tokenizer.Tokenizer).Encode(part, i == 0)
|
||||
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@@ -765,7 +764,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
nextBatchTokens[i].Token = token
|
||||
|
||||
// if it's an end of sequence token, break
|
||||
if s.model.(tokenizer.Tokenizer).Is(token, tokenizer.SpecialEOS) {
|
||||
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
|
||||
// TODO (jmorganca): we should send this back
|
||||
// as it's important for the /api/generate context
|
||||
// seq.responses <- piece
|
||||
@@ -774,14 +773,14 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
continue
|
||||
}
|
||||
|
||||
piece, err := s.model.(tokenizer.Tokenizer).Decode([]int32{token})
|
||||
piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
|
||||
if err != nil {
|
||||
panic("failed to decode token")
|
||||
}
|
||||
|
||||
// Calculate logprobs if requested (after EOS check to avoid logprobs for EOS tokens)
|
||||
if seq.logprobs {
|
||||
logprobs := calculateLogprobs(logits, token, seq.topLogprobs, s.model.(tokenizer.Tokenizer))
|
||||
logprobs := calculateLogprobs(logits, token, seq.topLogprobs, s.model.(model.TextProcessor))
|
||||
seq.pendingLogprobs = append(seq.pendingLogprobs, logprobs...)
|
||||
}
|
||||
|
||||
@@ -879,7 +878,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
var grammar *sample.GrammarSampler
|
||||
var err error
|
||||
if req.Grammar != "" {
|
||||
grammar, err = sample.NewGrammarSampler(s.model.(tokenizer.Tokenizer), req.Grammar)
|
||||
grammar, err = sample.NewGrammarSampler(s.model.(model.TextProcessor), req.Grammar)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
|
||||
return
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
// token represents information about a single token during sampling
|
||||
@@ -168,15 +168,15 @@ type GrammarSampler struct {
|
||||
grammar *llama.Grammar
|
||||
}
|
||||
|
||||
func NewGrammarSampler(tok tokenizer.Tokenizer, grammarStr string) (*GrammarSampler, error) {
|
||||
vocabIds := make([]uint32, len(tok.Vocabulary().Values))
|
||||
pieces := make([]string, len(tok.Vocabulary().Values))
|
||||
for i := range tok.Vocabulary().Values {
|
||||
pieces[i], _ = tok.Decode([]int32{int32(i)})
|
||||
func NewGrammarSampler(model model.TextProcessor, grammarStr string) (*GrammarSampler, error) {
|
||||
vocabIds := make([]uint32, len(model.Vocabulary().Values))
|
||||
pieces := make([]string, len(model.Vocabulary().Values))
|
||||
for i := range model.Vocabulary().Values {
|
||||
pieces[i], _ = model.Decode([]int32{int32(i)})
|
||||
vocabIds[i] = uint32(i)
|
||||
}
|
||||
|
||||
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, tok.Vocabulary().EOS)
|
||||
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, model.Vocabulary().EOS)
|
||||
if grammar == nil {
|
||||
return nil, errors.New("sample: failed to initialize grammar")
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
func TestWeighted(t *testing.T) {
|
||||
@@ -60,10 +60,10 @@ func TestWeighted(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func modelHelper(t testing.TB) tokenizer.Tokenizer {
|
||||
func modelHelper(t testing.TB) model.BytePairEncoding {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.Open(filepath.FromSlash("../tokenizer/testdata/llama3.2/encoder.json"))
|
||||
f, err := os.Open(filepath.Join("..", "model", "testdata", "llama3.2", "encoder.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -81,8 +81,8 @@ func modelHelper(t testing.TB) tokenizer.Tokenizer {
|
||||
|
||||
merges := make([]string, 0, 1)
|
||||
// Only need vocab for Grammar Test
|
||||
return tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
return model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: tokens,
|
||||
Types: make([]int32, len(vocab)),
|
||||
Merges: merges,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#!/bin/sh
|
||||
# This script installs Ollama on Linux and macOS.
|
||||
# This script installs Ollama on Linux.
|
||||
# It detects the current operating system architecture and installs the appropriate version of Ollama.
|
||||
|
||||
set -eu
|
||||
@@ -27,7 +27,8 @@ require() {
|
||||
echo $MISSING
|
||||
}
|
||||
|
||||
OS="$(uname -s)"
|
||||
[ "$(uname -s)" = "Linux" ] || error 'This script is intended to run on Linux only.'
|
||||
|
||||
ARCH=$(uname -m)
|
||||
case "$ARCH" in
|
||||
x86_64) ARCH="amd64" ;;
|
||||
@@ -35,65 +36,6 @@ case "$ARCH" in
|
||||
*) error "Unsupported architecture: $ARCH" ;;
|
||||
esac
|
||||
|
||||
###########################################
|
||||
# macOS
|
||||
###########################################
|
||||
|
||||
if [ "$OS" = "Darwin" ]; then
|
||||
NEEDS=$(require curl unzip)
|
||||
if [ -n "$NEEDS" ]; then
|
||||
status "ERROR: The following tools are required but missing:"
|
||||
for NEED in $NEEDS; do
|
||||
echo " - $NEED"
|
||||
done
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -n "${OLLAMA_VERSION:-}" ]; then
|
||||
DOWNLOAD_URL="https://github.com/ollama/ollama/releases/download/${OLLAMA_VERSION}/Ollama-darwin.zip"
|
||||
else
|
||||
DOWNLOAD_URL="https://github.com/ollama/ollama/releases/latest/download/Ollama-darwin.zip"
|
||||
fi
|
||||
|
||||
if pgrep -x Ollama >/dev/null 2>&1; then
|
||||
status "Stopping running Ollama instance..."
|
||||
pkill -x Ollama 2>/dev/null || true
|
||||
sleep 2
|
||||
fi
|
||||
|
||||
if [ -d "/Applications/Ollama.app" ]; then
|
||||
status "Removing existing Ollama installation..."
|
||||
rm -rf "/Applications/Ollama.app"
|
||||
fi
|
||||
|
||||
status "Downloading Ollama for macOS..."
|
||||
curl --fail --show-error --location --progress-bar \
|
||||
-o "$TEMP_DIR/Ollama-darwin.zip" "$DOWNLOAD_URL"
|
||||
|
||||
status "Installing Ollama to /Applications..."
|
||||
unzip -q "$TEMP_DIR/Ollama-darwin.zip" -d "$TEMP_DIR"
|
||||
mv "$TEMP_DIR/Ollama.app" "/Applications/"
|
||||
|
||||
status "Adding 'ollama' command to PATH (may require password)..."
|
||||
mkdir -p "/usr/local/bin" 2>/dev/null || sudo mkdir -p "/usr/local/bin"
|
||||
ln -sf "/Applications/Ollama.app/Contents/Resources/ollama" "/usr/local/bin/ollama" 2>/dev/null || \
|
||||
sudo ln -sf "/Applications/Ollama.app/Contents/Resources/ollama" "/usr/local/bin/ollama"
|
||||
|
||||
if [ -z "${OLLAMA_NO_START:-}" ]; then
|
||||
status "Starting Ollama..."
|
||||
open -a Ollama --args hidden
|
||||
fi
|
||||
|
||||
status "Install complete. You can now run 'ollama'."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
###########################################
|
||||
# Linux
|
||||
###########################################
|
||||
|
||||
[ "$OS" = "Linux" ] || error 'This script is intended to run on Linux and macOS only.'
|
||||
|
||||
IS_WSL2=false
|
||||
|
||||
KERN=$(uname -r)
|
||||
|
||||
@@ -1,422 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
const (
|
||||
serverConfigFilename = "server.json"
|
||||
serverConfigVersion = 1
|
||||
)
|
||||
|
||||
var errAliasCycle = errors.New("alias cycle detected")
|
||||
|
||||
type aliasEntry struct {
|
||||
Alias string `json:"alias"`
|
||||
Target string `json:"target"`
|
||||
PrefixMatching bool `json:"prefix_matching,omitempty"`
|
||||
}
|
||||
|
||||
type serverConfig struct {
|
||||
Version int `json:"version"`
|
||||
Aliases []aliasEntry `json:"aliases"`
|
||||
}
|
||||
|
||||
type store struct {
|
||||
mu sync.RWMutex
|
||||
path string
|
||||
entries map[string]aliasEntry // normalized alias -> entry (exact matches)
|
||||
prefixEntries []aliasEntry // prefix matches, sorted longest-first
|
||||
}
|
||||
|
||||
func createStore(path string) (*store, error) {
|
||||
store := &store{
|
||||
path: path,
|
||||
entries: make(map[string]aliasEntry),
|
||||
}
|
||||
if err := store.load(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func (s *store) load() error {
|
||||
data, err := os.ReadFile(s.path)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
var cfg serverConfig
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if cfg.Version != 0 && cfg.Version != serverConfigVersion {
|
||||
return fmt.Errorf("unsupported router config version %d", cfg.Version)
|
||||
}
|
||||
|
||||
for _, entry := range cfg.Aliases {
|
||||
targetName := model.ParseName(entry.Target)
|
||||
if !targetName.IsValid() {
|
||||
slog.Warn("invalid alias target in router config", "target", entry.Target)
|
||||
continue
|
||||
}
|
||||
canonicalTarget := displayAliasName(targetName)
|
||||
|
||||
if entry.PrefixMatching {
|
||||
// Prefix aliases don't need to be valid model names
|
||||
alias := strings.TrimSpace(entry.Alias)
|
||||
if alias == "" {
|
||||
slog.Warn("empty prefix alias in router config")
|
||||
continue
|
||||
}
|
||||
s.prefixEntries = append(s.prefixEntries, aliasEntry{
|
||||
Alias: alias,
|
||||
Target: canonicalTarget,
|
||||
PrefixMatching: true,
|
||||
})
|
||||
} else {
|
||||
aliasName := model.ParseName(entry.Alias)
|
||||
if !aliasName.IsValid() {
|
||||
slog.Warn("invalid alias name in router config", "alias", entry.Alias)
|
||||
continue
|
||||
}
|
||||
canonicalAlias := displayAliasName(aliasName)
|
||||
s.entries[normalizeAliasKey(aliasName)] = aliasEntry{
|
||||
Alias: canonicalAlias,
|
||||
Target: canonicalTarget,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort prefix entries by alias length descending (longest prefix wins)
|
||||
s.sortPrefixEntriesLocked()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *store) saveLocked() error {
|
||||
dir := filepath.Dir(s.path)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Combine exact and prefix entries
|
||||
entries := make([]aliasEntry, 0, len(s.entries)+len(s.prefixEntries))
|
||||
for _, entry := range s.entries {
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
entries = append(entries, s.prefixEntries...)
|
||||
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return strings.Compare(entries[i].Alias, entries[j].Alias) < 0
|
||||
})
|
||||
|
||||
cfg := serverConfig{
|
||||
Version: serverConfigVersion,
|
||||
Aliases: entries,
|
||||
}
|
||||
|
||||
f, err := os.CreateTemp(dir, "router-*.json")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
enc := json.NewEncoder(f)
|
||||
enc.SetIndent("", " ")
|
||||
if err := enc.Encode(cfg); err != nil {
|
||||
_ = f.Close()
|
||||
_ = os.Remove(f.Name())
|
||||
return err
|
||||
}
|
||||
|
||||
if err := f.Close(); err != nil {
|
||||
_ = os.Remove(f.Name())
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.Chmod(f.Name(), 0o644); err != nil {
|
||||
_ = os.Remove(f.Name())
|
||||
return err
|
||||
}
|
||||
|
||||
return os.Rename(f.Name(), s.path)
|
||||
}
|
||||
|
||||
func (s *store) ResolveName(name model.Name) (model.Name, bool, error) {
|
||||
// If a local model exists, do not allow alias shadowing (highest priority).
|
||||
exists, err := localModelExists(name)
|
||||
if err != nil {
|
||||
return name, false, err
|
||||
}
|
||||
if exists {
|
||||
return name, false, nil
|
||||
}
|
||||
|
||||
key := normalizeAliasKey(name)
|
||||
|
||||
s.mu.RLock()
|
||||
entry, exactMatch := s.entries[key]
|
||||
var prefixMatch *aliasEntry
|
||||
if !exactMatch {
|
||||
// Try prefix matching - prefixEntries is sorted longest-first
|
||||
nameStr := strings.ToLower(displayAliasName(name))
|
||||
for i := range s.prefixEntries {
|
||||
prefix := strings.ToLower(s.prefixEntries[i].Alias)
|
||||
if strings.HasPrefix(nameStr, prefix) {
|
||||
prefixMatch = &s.prefixEntries[i]
|
||||
break // First match is longest due to sorting
|
||||
}
|
||||
}
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exactMatch && prefixMatch == nil {
|
||||
return name, false, nil
|
||||
}
|
||||
|
||||
var current string
|
||||
var visited map[string]struct{}
|
||||
|
||||
if exactMatch {
|
||||
visited = map[string]struct{}{key: {}}
|
||||
current = entry.Target
|
||||
} else {
|
||||
// For prefix match, use the target as-is
|
||||
visited = map[string]struct{}{}
|
||||
current = prefixMatch.Target
|
||||
}
|
||||
|
||||
targetKey := normalizeAliasKeyString(current)
|
||||
|
||||
for {
|
||||
targetName := model.ParseName(current)
|
||||
if !targetName.IsValid() {
|
||||
return name, false, fmt.Errorf("alias target %q is invalid", current)
|
||||
}
|
||||
|
||||
if _, seen := visited[targetKey]; seen {
|
||||
return name, false, errAliasCycle
|
||||
}
|
||||
visited[targetKey] = struct{}{}
|
||||
|
||||
s.mu.RLock()
|
||||
next, ok := s.entries[targetKey]
|
||||
s.mu.RUnlock()
|
||||
if !ok {
|
||||
return targetName, true, nil
|
||||
}
|
||||
|
||||
current = next.Target
|
||||
targetKey = normalizeAliasKeyString(current)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *store) Set(alias, target model.Name, prefixMatching bool) error {
|
||||
targetKey := normalizeAliasKey(target)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if prefixMatching {
|
||||
// For prefix aliases, we skip cycle detection since prefix matching
|
||||
// works differently and the target is a specific model
|
||||
aliasStr := displayAliasName(alias)
|
||||
|
||||
// Remove any existing prefix entry with the same alias
|
||||
for i, e := range s.prefixEntries {
|
||||
if strings.EqualFold(e.Alias, aliasStr) {
|
||||
s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
s.prefixEntries = append(s.prefixEntries, aliasEntry{
|
||||
Alias: aliasStr,
|
||||
Target: displayAliasName(target),
|
||||
PrefixMatching: true,
|
||||
})
|
||||
s.sortPrefixEntriesLocked()
|
||||
return s.saveLocked()
|
||||
}
|
||||
|
||||
aliasKey := normalizeAliasKey(alias)
|
||||
|
||||
if aliasKey == targetKey {
|
||||
return fmt.Errorf("alias cannot point to itself")
|
||||
}
|
||||
|
||||
visited := map[string]struct{}{aliasKey: {}}
|
||||
currentKey := targetKey
|
||||
for {
|
||||
if _, seen := visited[currentKey]; seen {
|
||||
return errAliasCycle
|
||||
}
|
||||
visited[currentKey] = struct{}{}
|
||||
|
||||
next, ok := s.entries[currentKey]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
currentKey = normalizeAliasKeyString(next.Target)
|
||||
}
|
||||
|
||||
s.entries[aliasKey] = aliasEntry{
|
||||
Alias: displayAliasName(alias),
|
||||
Target: displayAliasName(target),
|
||||
}
|
||||
|
||||
return s.saveLocked()
|
||||
}
|
||||
|
||||
func (s *store) Delete(alias model.Name) (bool, error) {
|
||||
aliasKey := normalizeAliasKey(alias)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Try exact match first
|
||||
if _, ok := s.entries[aliasKey]; ok {
|
||||
delete(s.entries, aliasKey)
|
||||
return true, s.saveLocked()
|
||||
}
|
||||
|
||||
// Try prefix entries
|
||||
aliasStr := displayAliasName(alias)
|
||||
for i, e := range s.prefixEntries {
|
||||
if strings.EqualFold(e.Alias, aliasStr) {
|
||||
s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...)
|
||||
return true, s.saveLocked()
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// DeleteByString deletes an alias by its raw string value, useful for prefix
|
||||
// aliases that may not be valid model names.
|
||||
func (s *store) DeleteByString(alias string) (bool, error) {
|
||||
alias = strings.TrimSpace(alias)
|
||||
aliasLower := strings.ToLower(alias)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Try prefix entries first (since this is mainly for prefix aliases)
|
||||
for i, e := range s.prefixEntries {
|
||||
if strings.EqualFold(e.Alias, alias) {
|
||||
s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...)
|
||||
return true, s.saveLocked()
|
||||
}
|
||||
}
|
||||
|
||||
// Also check exact entries by normalized key
|
||||
if _, ok := s.entries[aliasLower]; ok {
|
||||
delete(s.entries, aliasLower)
|
||||
return true, s.saveLocked()
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (s *store) List() []aliasEntry {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
entries := make([]aliasEntry, 0, len(s.entries)+len(s.prefixEntries))
|
||||
for _, entry := range s.entries {
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
entries = append(entries, s.prefixEntries...)
|
||||
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return strings.Compare(entries[i].Alias, entries[j].Alias) < 0
|
||||
})
|
||||
return entries
|
||||
}
|
||||
|
||||
func normalizeAliasKey(name model.Name) string {
|
||||
return strings.ToLower(displayAliasName(name))
|
||||
}
|
||||
|
||||
func (s *store) sortPrefixEntriesLocked() {
|
||||
sort.Slice(s.prefixEntries, func(i, j int) bool {
|
||||
// Sort by length descending (longest prefix first)
|
||||
return len(s.prefixEntries[i].Alias) > len(s.prefixEntries[j].Alias)
|
||||
})
|
||||
}
|
||||
|
||||
func normalizeAliasKeyString(value string) string {
|
||||
n := model.ParseName(value)
|
||||
if !n.IsValid() {
|
||||
return strings.ToLower(strings.TrimSpace(value))
|
||||
}
|
||||
return normalizeAliasKey(n)
|
||||
}
|
||||
|
||||
func displayAliasName(n model.Name) string {
|
||||
display := n.DisplayShortest()
|
||||
if strings.EqualFold(n.Tag, "latest") {
|
||||
if idx := strings.LastIndex(display, ":"); idx != -1 {
|
||||
return display[:idx]
|
||||
}
|
||||
}
|
||||
return display
|
||||
}
|
||||
|
||||
func localModelExists(name model.Name) (bool, error) {
|
||||
manifests, err := manifest.Manifests(true)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
needle := name.String()
|
||||
for existing := range manifests {
|
||||
if strings.EqualFold(existing.String(), needle) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func serverConfigPath() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return filepath.Join(".ollama", serverConfigFilename)
|
||||
}
|
||||
return filepath.Join(home, ".ollama", serverConfigFilename)
|
||||
}
|
||||
|
||||
func (s *Server) aliasStore() (*store, error) {
|
||||
s.aliasesOnce.Do(func() {
|
||||
s.aliases, s.aliasesErr = createStore(serverConfigPath())
|
||||
})
|
||||
|
||||
return s.aliases, s.aliasesErr
|
||||
}
|
||||
|
||||
func (s *Server) resolveAlias(name model.Name) (model.Name, bool, error) {
|
||||
store, err := s.aliasStore()
|
||||
if err != nil {
|
||||
return name, false, err
|
||||
}
|
||||
|
||||
if store == nil {
|
||||
return name, false, nil
|
||||
}
|
||||
|
||||
return store.ResolveName(name)
|
||||
}
|
||||
144
server/inference_request_log.go
Normal file
144
server/inference_request_log.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
type inferenceRequestLogger struct {
|
||||
dir string
|
||||
counter uint64
|
||||
}
|
||||
|
||||
func newInferenceRequestLogger() (*inferenceRequestLogger, error) {
|
||||
dir, err := os.MkdirTemp("", "ollama-request-logs-*")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &inferenceRequestLogger{dir: dir}, nil
|
||||
}
|
||||
|
||||
func (s *Server) initRequestLogging() error {
|
||||
if !envconfig.DebugLogRequests() {
|
||||
return nil
|
||||
}
|
||||
|
||||
requestLogger, err := newInferenceRequestLogger()
|
||||
if err != nil {
|
||||
return fmt.Errorf("enable OLLAMA_DEBUG_LOG_REQUESTS: %w", err)
|
||||
}
|
||||
|
||||
s.requestLogger = requestLogger
|
||||
slog.Info(fmt.Sprintf("request debug logging enabled; inference request logs will be stored in %s and include request bodies and replay curl commands", requestLogger.dir))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) withInferenceRequestLogging(route string, handlers ...gin.HandlerFunc) []gin.HandlerFunc {
|
||||
if s.requestLogger == nil {
|
||||
return handlers
|
||||
}
|
||||
|
||||
return append([]gin.HandlerFunc{s.requestLogger.middleware(route)}, handlers...)
|
||||
}
|
||||
|
||||
func (l *inferenceRequestLogger) middleware(route string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if c.Request == nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
method := c.Request.Method
|
||||
host := c.Request.Host
|
||||
scheme := "http"
|
||||
if c.Request.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
contentType := c.GetHeader("Content-Type")
|
||||
|
||||
var body []byte
|
||||
if c.Request.Body != nil {
|
||||
var err error
|
||||
body, err = io.ReadAll(c.Request.Body)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(body))
|
||||
if err != nil {
|
||||
slog.Warn("failed to read request body for debug logging", "route", route, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
c.Next()
|
||||
l.log(route, method, scheme, host, contentType, body)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *inferenceRequestLogger) log(route, method, scheme, host, contentType string, body []byte) {
|
||||
if l == nil || l.dir == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if contentType == "" {
|
||||
contentType = "application/json"
|
||||
}
|
||||
if host == "" || scheme == "" {
|
||||
base := envconfig.Host()
|
||||
if host == "" {
|
||||
host = base.Host
|
||||
}
|
||||
if scheme == "" {
|
||||
scheme = base.Scheme
|
||||
}
|
||||
}
|
||||
|
||||
routeForFilename := sanitizeRouteForFilename(route)
|
||||
timestamp := fmt.Sprintf("%s-%06d", time.Now().UTC().Format("20060102T150405.000000000Z"), atomic.AddUint64(&l.counter, 1))
|
||||
bodyFilename := fmt.Sprintf("%s_%s_body.json", timestamp, routeForFilename)
|
||||
curlFilename := fmt.Sprintf("%s_%s_request.sh", timestamp, routeForFilename)
|
||||
bodyPath := filepath.Join(l.dir, bodyFilename)
|
||||
curlPath := filepath.Join(l.dir, curlFilename)
|
||||
|
||||
if err := os.WriteFile(bodyPath, body, 0o600); err != nil {
|
||||
slog.Warn("failed to write debug request body", "route", route, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s://%s%s", scheme, host, route)
|
||||
curl := fmt.Sprintf("#!/bin/sh\nSCRIPT_DIR=\"$(CDPATH= cd -- \"$(dirname -- \"$0\")\" && pwd)\"\ncurl --request %s --url %q --header %q --data-binary @\"${SCRIPT_DIR}/%s\"\n", method, url, "Content-Type: "+contentType, bodyFilename)
|
||||
if err := os.WriteFile(curlPath, []byte(curl), 0o600); err != nil {
|
||||
slog.Warn("failed to write debug request replay command", "route", route, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info(fmt.Sprintf("logged to %s, replay using curl with `sh %s`", bodyPath, curlPath))
|
||||
}
|
||||
|
||||
func sanitizeRouteForFilename(route string) string {
|
||||
route = strings.TrimPrefix(route, "/")
|
||||
if route == "" {
|
||||
return "root"
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(len(route))
|
||||
for _, r := range route {
|
||||
if ('a' <= r && r <= 'z') || ('A' <= r && r <= 'Z') || ('0' <= r && r <= '9') {
|
||||
b.WriteRune(r)
|
||||
} else {
|
||||
b.WriteByte('_')
|
||||
}
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
@@ -22,7 +22,6 @@ import (
|
||||
"os/signal"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
@@ -82,9 +81,7 @@ type Server struct {
|
||||
addr net.Addr
|
||||
sched *Scheduler
|
||||
defaultNumCtx int
|
||||
aliasesOnce sync.Once
|
||||
aliases *store
|
||||
aliasesErr error
|
||||
requestLogger *inferenceRequestLogger
|
||||
}
|
||||
|
||||
func init() {
|
||||
@@ -195,16 +192,9 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
resolvedName, _, err := s.resolveAlias(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
name = resolvedName
|
||||
|
||||
// We cannot currently consolidate this into GetModel because all we'll
|
||||
// induce infinite recursion given the current code structure.
|
||||
name, err = getExistingName(name)
|
||||
name, err := getExistingName(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
return
|
||||
@@ -1591,30 +1581,27 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
|
||||
r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
|
||||
r.POST("/api/copy", s.CopyHandler)
|
||||
r.GET("/api/experimental/aliases", s.ListAliasesHandler)
|
||||
r.POST("/api/experimental/aliases", s.CreateAliasHandler)
|
||||
r.DELETE("/api/experimental/aliases", s.DeleteAliasHandler)
|
||||
|
||||
// Inference
|
||||
r.GET("/api/ps", s.PsHandler)
|
||||
r.POST("/api/generate", s.GenerateHandler)
|
||||
r.POST("/api/chat", s.ChatHandler)
|
||||
r.POST("/api/generate", s.withInferenceRequestLogging("/api/generate", s.GenerateHandler)...)
|
||||
r.POST("/api/chat", s.withInferenceRequestLogging("/api/chat", s.ChatHandler)...)
|
||||
r.POST("/api/embed", s.EmbedHandler)
|
||||
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
||||
|
||||
// Inference (OpenAI compatibility)
|
||||
r.POST("/v1/chat/completions", middleware.ChatMiddleware(), s.ChatHandler)
|
||||
r.POST("/v1/completions", middleware.CompletionsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/chat/completions", s.withInferenceRequestLogging("/v1/chat/completions", middleware.ChatMiddleware(), s.ChatHandler)...)
|
||||
r.POST("/v1/completions", s.withInferenceRequestLogging("/v1/completions", middleware.CompletionsMiddleware(), s.GenerateHandler)...)
|
||||
r.POST("/v1/embeddings", middleware.EmbeddingsMiddleware(), s.EmbedHandler)
|
||||
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
||||
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||
r.POST("/v1/responses", s.withInferenceRequestLogging("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)...)
|
||||
// OpenAI-compatible image generation endpoints
|
||||
r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/images/edits", middleware.ImageEditsMiddleware(), s.GenerateHandler)
|
||||
|
||||
// Inference (Anthropic compatibility)
|
||||
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
||||
r.POST("/v1/messages", s.withInferenceRequestLogging("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)...)
|
||||
|
||||
if rc != nil {
|
||||
// wrap old with new
|
||||
@@ -1664,6 +1651,9 @@ func Serve(ln net.Listener) error {
|
||||
}
|
||||
|
||||
s := &Server{addr: ln.Addr()}
|
||||
if err := s.initRequestLogging(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var rc *ollama.Registry
|
||||
if useClient2 {
|
||||
@@ -1964,20 +1954,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
resolvedName, _, err := s.resolveAlias(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
name = resolvedName
|
||||
|
||||
name, err = getExistingName(name)
|
||||
name, err := getExistingName(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
return
|
||||
}
|
||||
|
||||
m, err := GetModel(name.String())
|
||||
m, err := GetModel(req.Model)
|
||||
if err != nil {
|
||||
switch {
|
||||
case os.IsNotExist(err):
|
||||
|
||||
@@ -1,159 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
type aliasListResponse struct {
|
||||
Aliases []aliasEntry `json:"aliases"`
|
||||
}
|
||||
|
||||
type aliasDeleteRequest struct {
|
||||
Alias string `json:"alias"`
|
||||
}
|
||||
|
||||
func (s *Server) ListAliasesHandler(c *gin.Context) {
|
||||
store, err := s.aliasStore()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
var aliases []aliasEntry
|
||||
if store != nil {
|
||||
aliases = store.List()
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, aliasListResponse{Aliases: aliases})
|
||||
}
|
||||
|
||||
func (s *Server) CreateAliasHandler(c *gin.Context) {
|
||||
var req aliasEntry
|
||||
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
} else if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
req.Alias = strings.TrimSpace(req.Alias)
|
||||
req.Target = strings.TrimSpace(req.Target)
|
||||
if req.Alias == "" || req.Target == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias and target are required"})
|
||||
return
|
||||
}
|
||||
|
||||
// Target must always be a valid model name
|
||||
targetName := model.ParseName(req.Target)
|
||||
if !targetName.IsValid() {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("target %q is invalid", req.Target)})
|
||||
return
|
||||
}
|
||||
|
||||
var aliasName model.Name
|
||||
if req.PrefixMatching {
|
||||
// For prefix aliases, we still parse the alias to normalize it,
|
||||
// but we allow any non-empty string since prefix patterns may not be valid model names
|
||||
aliasName = model.ParseName(req.Alias)
|
||||
// Even if not valid as a model name, we accept it for prefix matching
|
||||
} else {
|
||||
aliasName = model.ParseName(req.Alias)
|
||||
if !aliasName.IsValid() {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("alias %q is invalid", req.Alias)})
|
||||
return
|
||||
}
|
||||
|
||||
if normalizeAliasKey(aliasName) == normalizeAliasKey(targetName) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias cannot point to itself"})
|
||||
return
|
||||
}
|
||||
|
||||
exists, err := localModelExists(aliasName)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if exists {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("alias %q conflicts with existing model", req.Alias)})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
store, err := s.aliasStore()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := store.Set(aliasName, targetName, req.PrefixMatching); err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
if errors.Is(err, errAliasCycle) {
|
||||
status = http.StatusBadRequest
|
||||
}
|
||||
c.AbortWithStatusJSON(status, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
resp := aliasEntry{
|
||||
Alias: displayAliasName(aliasName),
|
||||
Target: displayAliasName(targetName),
|
||||
PrefixMatching: req.PrefixMatching,
|
||||
}
|
||||
if req.PrefixMatching && !aliasName.IsValid() {
|
||||
// For prefix aliases that aren't valid model names, use the raw alias
|
||||
resp.Alias = req.Alias
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (s *Server) DeleteAliasHandler(c *gin.Context) {
|
||||
var req aliasDeleteRequest
|
||||
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
} else if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
req.Alias = strings.TrimSpace(req.Alias)
|
||||
if req.Alias == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias is required"})
|
||||
return
|
||||
}
|
||||
|
||||
store, err := s.aliasStore()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
aliasName := model.ParseName(req.Alias)
|
||||
var deleted bool
|
||||
if aliasName.IsValid() {
|
||||
deleted, err = store.Delete(aliasName)
|
||||
} else {
|
||||
// For invalid model names (like prefix aliases), try deleting by raw string
|
||||
deleted, err = store.DeleteByString(req.Alias)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if !deleted {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("alias %q not found", req.Alias)})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"deleted": true})
|
||||
}
|
||||
@@ -1,426 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func TestAliasShadowingRejected(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Setenv("HOME", t.TempDir())
|
||||
|
||||
s := Server{}
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "shadowed-model",
|
||||
RemoteHost: "example.com",
|
||||
From: "test",
|
||||
Info: map[string]any{
|
||||
"capabilities": []string{"completion"},
|
||||
},
|
||||
Stream: &stream,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
w = createRequest(t, s.CreateAliasHandler, aliasEntry{Alias: "shadowed-model", Target: "other-model"})
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected status 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAliasResolvesForChatRemote(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Setenv("HOME", t.TempDir())
|
||||
|
||||
var remoteModel string
|
||||
rs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var req api.ChatRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
remoteModel = req.Model
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
resp := api.ChatResponse{
|
||||
Model: req.Model,
|
||||
Done: true,
|
||||
DoneReason: "load",
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}))
|
||||
defer rs.Close()
|
||||
|
||||
p, err := url.Parse(rs.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_REMOTES", p.Hostname())
|
||||
|
||||
s := Server{}
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "target-model",
|
||||
RemoteHost: rs.URL,
|
||||
From: "test",
|
||||
Info: map[string]any{
|
||||
"capabilities": []string{"completion"},
|
||||
},
|
||||
Stream: &stream,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
w = createRequest(t, s.CreateAliasHandler, aliasEntry{Alias: "alias-model", Target: "target-model"})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
w = createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "alias-model",
|
||||
Messages: []api.Message{{Role: "user", Content: "hi"}},
|
||||
Stream: &stream,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp api.ChatResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if resp.Model != "alias-model" {
|
||||
t.Fatalf("expected response model to be alias-model, got %q", resp.Model)
|
||||
}
|
||||
|
||||
if remoteModel != "test" {
|
||||
t.Fatalf("expected remote model to be 'test', got %q", remoteModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrefixAliasBasicMatching(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
store, err := createStore(filepath.Join(tmpDir, "server.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a prefix alias: "myprefix-" -> "targetmodel"
|
||||
targetName := model.ParseName("targetmodel")
|
||||
|
||||
// Set a prefix alias (using "myprefix-" as the pattern)
|
||||
store.mu.Lock()
|
||||
store.prefixEntries = append(store.prefixEntries, aliasEntry{
|
||||
Alias: "myprefix-",
|
||||
Target: "targetmodel",
|
||||
PrefixMatching: true,
|
||||
})
|
||||
store.mu.Unlock()
|
||||
|
||||
// Test that "myprefix-foo" resolves to "targetmodel"
|
||||
testName := model.ParseName("myprefix-foo")
|
||||
resolved, wasResolved, err := store.ResolveName(testName)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected name to be resolved")
|
||||
}
|
||||
if resolved.DisplayShortest() != targetName.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be %q, got %q", targetName.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
|
||||
// Test that "otherprefix-foo" does not resolve
|
||||
otherName := model.ParseName("otherprefix-foo")
|
||||
_, wasResolved, err = store.ResolveName(otherName)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if wasResolved {
|
||||
t.Fatal("expected name not to be resolved")
|
||||
}
|
||||
|
||||
// Test that exact alias takes precedence
|
||||
exactAlias := model.ParseName("myprefix-exact")
|
||||
exactTarget := model.ParseName("exacttarget")
|
||||
if err := store.Set(exactAlias, exactTarget, false); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
resolved, wasResolved, err = store.ResolveName(exactAlias)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected name to be resolved")
|
||||
}
|
||||
if resolved.DisplayShortest() != exactTarget.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be %q (exact match), got %q", exactTarget.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrefixAliasLongestMatchWins(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
store, err := createStore(filepath.Join(tmpDir, "server.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Add two prefix aliases with overlapping patterns
|
||||
store.mu.Lock()
|
||||
store.prefixEntries = []aliasEntry{
|
||||
{Alias: "abc-", Target: "short-target", PrefixMatching: true},
|
||||
{Alias: "abc-def-", Target: "long-target", PrefixMatching: true},
|
||||
}
|
||||
store.sortPrefixEntriesLocked()
|
||||
store.mu.Unlock()
|
||||
|
||||
// "abc-def-ghi" should match the longer prefix "abc-def-"
|
||||
testName := model.ParseName("abc-def-ghi")
|
||||
resolved, wasResolved, err := store.ResolveName(testName)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected name to be resolved")
|
||||
}
|
||||
expectedLongTarget := model.ParseName("long-target")
|
||||
if resolved.DisplayShortest() != expectedLongTarget.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be %q (longest prefix match), got %q", expectedLongTarget.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
|
||||
// "abc-xyz" should match the shorter prefix "abc-"
|
||||
testName2 := model.ParseName("abc-xyz")
|
||||
resolved, wasResolved, err = store.ResolveName(testName2)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected name to be resolved")
|
||||
}
|
||||
expectedShortTarget := model.ParseName("short-target")
|
||||
if resolved.DisplayShortest() != expectedShortTarget.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be %q, got %q", expectedShortTarget.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrefixAliasChain(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
store, err := createStore(filepath.Join(tmpDir, "server.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a chain: prefix "test-" -> "intermediate" -> "final"
|
||||
intermediate := model.ParseName("intermediate")
|
||||
final := model.ParseName("final")
|
||||
|
||||
// Add prefix alias
|
||||
store.mu.Lock()
|
||||
store.prefixEntries = []aliasEntry{
|
||||
{Alias: "test-", Target: "intermediate", PrefixMatching: true},
|
||||
}
|
||||
store.mu.Unlock()
|
||||
|
||||
// Add exact alias for the intermediate step
|
||||
if err := store.Set(intermediate, final, false); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// "test-foo" should resolve through the chain to "final"
|
||||
testName := model.ParseName("test-foo")
|
||||
resolved, wasResolved, err := store.ResolveName(testName)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected name to be resolved")
|
||||
}
|
||||
if resolved.DisplayShortest() != final.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be %q, got %q", final.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrefixAliasCRUD(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Setenv("HOME", t.TempDir())
|
||||
|
||||
s := Server{}
|
||||
|
||||
// Create a prefix alias via API
|
||||
w := createRequest(t, s.CreateAliasHandler, aliasEntry{
|
||||
Alias: "myprefix-",
|
||||
Target: "llama2",
|
||||
PrefixMatching: true,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var createResp aliasEntry
|
||||
if err := json.NewDecoder(w.Body).Decode(&createResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !createResp.PrefixMatching {
|
||||
t.Fatal("expected prefix_matching to be true in response")
|
||||
}
|
||||
|
||||
// List aliases and verify the prefix alias is included
|
||||
w = createRequest(t, s.ListAliasesHandler, nil)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var listResp aliasListResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&listResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, a := range listResp.Aliases {
|
||||
if a.PrefixMatching && a.Target == "llama2" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatal("expected to find prefix alias in list")
|
||||
}
|
||||
|
||||
// Delete the prefix alias
|
||||
w = createRequest(t, s.DeleteAliasHandler, aliasDeleteRequest{Alias: "myprefix-"})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Verify it's deleted
|
||||
w = createRequest(t, s.ListAliasesHandler, nil)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(w.Body).Decode(&listResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, a := range listResp.Aliases {
|
||||
if a.PrefixMatching {
|
||||
t.Fatal("expected prefix alias to be deleted")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrefixAliasCaseInsensitive(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
store, err := createStore(filepath.Join(tmpDir, "server.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Add a prefix alias with mixed case
|
||||
store.mu.Lock()
|
||||
store.prefixEntries = []aliasEntry{
|
||||
{Alias: "MyPrefix-", Target: "targetmodel", PrefixMatching: true},
|
||||
}
|
||||
store.mu.Unlock()
|
||||
|
||||
// Test that matching is case-insensitive
|
||||
testName := model.ParseName("myprefix-foo")
|
||||
resolved, wasResolved, err := store.ResolveName(testName)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected name to be resolved (case-insensitive)")
|
||||
}
|
||||
expectedTarget := model.ParseName("targetmodel")
|
||||
if resolved.DisplayShortest() != expectedTarget.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be %q, got %q", expectedTarget.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
|
||||
// Test uppercase request
|
||||
testName2 := model.ParseName("MYPREFIX-BAR")
|
||||
_, wasResolved, err = store.ResolveName(testName2)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected name to be resolved (uppercase)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrefixAliasLocalModelPrecedence(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Setenv("HOME", t.TempDir())
|
||||
|
||||
s := Server{}
|
||||
|
||||
// Create a local model that would match a prefix alias
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "myprefix-localmodel",
|
||||
RemoteHost: "example.com",
|
||||
From: "test",
|
||||
Info: map[string]any{
|
||||
"capabilities": []string{"completion"},
|
||||
},
|
||||
Stream: &stream,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Create a prefix alias that would match the local model name
|
||||
w = createRequest(t, s.CreateAliasHandler, aliasEntry{
|
||||
Alias: "myprefix-",
|
||||
Target: "someothermodel",
|
||||
PrefixMatching: true,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Verify that resolving "myprefix-localmodel" returns the local model, not the alias target
|
||||
store, err := s.aliasStore()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
localModelName := model.ParseName("myprefix-localmodel")
|
||||
resolved, wasResolved, err := store.ResolveName(localModelName)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if wasResolved {
|
||||
t.Fatalf("expected local model to take precedence (wasResolved should be false), but got resolved to %q", resolved.DisplayShortest())
|
||||
}
|
||||
if resolved.DisplayShortest() != localModelName.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be local model %q, got %q", localModelName.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
|
||||
// Also verify that a non-local model matching the prefix DOES resolve to the alias target
|
||||
nonLocalName := model.ParseName("myprefix-nonexistent")
|
||||
resolved, wasResolved, err = store.ResolveName(nonLocalName)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected non-local model to resolve via prefix alias")
|
||||
}
|
||||
expectedTarget := model.ParseName("someothermodel")
|
||||
if resolved.DisplayShortest() != expectedTarget.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be %q, got %q", expectedTarget.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
}
|
||||
128
server/routes_request_log_test.go
Normal file
128
server/routes_request_log_test.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestInferenceRequestLoggerMiddlewareWritesReplayArtifacts(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
logDir := t.TempDir()
|
||||
requestLogger := &inferenceRequestLogger{dir: logDir}
|
||||
|
||||
const route = "/v1/chat/completions"
|
||||
const requestBody = `{"model":"test-model","messages":[{"role":"user","content":"hello"}]}`
|
||||
|
||||
var bodySeenByHandler string
|
||||
|
||||
r := gin.New()
|
||||
r.POST(route, requestLogger.middleware(route), func(c *gin.Context) {
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read body in handler: %v", err)
|
||||
}
|
||||
|
||||
bodySeenByHandler = string(body)
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, route, strings.NewReader(requestBody))
|
||||
req.Host = "127.0.0.1:11434"
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if bodySeenByHandler != requestBody {
|
||||
t.Fatalf("handler body mismatch:\nexpected: %s\ngot: %s", requestBody, bodySeenByHandler)
|
||||
}
|
||||
|
||||
bodyFiles, err := filepath.Glob(filepath.Join(logDir, "*_v1_chat_completions_body.json"))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to glob body logs: %v", err)
|
||||
}
|
||||
if len(bodyFiles) != 1 {
|
||||
t.Fatalf("expected 1 body log, got %d (%v)", len(bodyFiles), bodyFiles)
|
||||
}
|
||||
|
||||
curlFiles, err := filepath.Glob(filepath.Join(logDir, "*_v1_chat_completions_request.sh"))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to glob curl logs: %v", err)
|
||||
}
|
||||
if len(curlFiles) != 1 {
|
||||
t.Fatalf("expected 1 curl log, got %d (%v)", len(curlFiles), curlFiles)
|
||||
}
|
||||
|
||||
bodyData, err := os.ReadFile(bodyFiles[0])
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read body log: %v", err)
|
||||
}
|
||||
if string(bodyData) != requestBody {
|
||||
t.Fatalf("body log mismatch:\nexpected: %s\ngot: %s", requestBody, string(bodyData))
|
||||
}
|
||||
|
||||
curlData, err := os.ReadFile(curlFiles[0])
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read curl log: %v", err)
|
||||
}
|
||||
|
||||
curlString := string(curlData)
|
||||
if !strings.Contains(curlString, "http://127.0.0.1:11434"+route) {
|
||||
t.Fatalf("curl log does not contain expected route URL: %s", curlString)
|
||||
}
|
||||
|
||||
bodyFileName := filepath.Base(bodyFiles[0])
|
||||
if !strings.Contains(curlString, "@\"${SCRIPT_DIR}/"+bodyFileName+"\"") {
|
||||
t.Fatalf("curl log does not reference sibling body file: %s", curlString)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewInferenceRequestLoggerCreatesDirectory(t *testing.T) {
|
||||
requestLogger, err := newInferenceRequestLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error creating request logger: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = os.RemoveAll(requestLogger.dir)
|
||||
})
|
||||
|
||||
if requestLogger == nil || requestLogger.dir == "" {
|
||||
t.Fatalf("expected request logger directory to be set")
|
||||
}
|
||||
|
||||
info, err := os.Stat(requestLogger.dir)
|
||||
if err != nil {
|
||||
t.Fatalf("expected directory to exist: %v", err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
t.Fatalf("expected %q to be a directory", requestLogger.dir)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeRouteForFilename(t *testing.T) {
|
||||
tests := []struct {
|
||||
route string
|
||||
want string
|
||||
}{
|
||||
{route: "/api/generate", want: "api_generate"},
|
||||
{route: "/v1/chat/completions", want: "v1_chat_completions"},
|
||||
{route: "/v1/messages", want: "v1_messages"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if got := sanitizeRouteForFilename(tt.route); got != tt.want {
|
||||
t.Fatalf("sanitizeRouteForFilename(%q) = %q, want %q", tt.route, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
77
x/kvcache/cache.go
Normal file
77
x/kvcache/cache.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package kvcache
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
"github.com/ollama/ollama/x/model/input"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrKvCacheFull = errors.New("could not find a kv cache slot")
|
||||
ErrNotSupported = errors.New("model does not support operation")
|
||||
)
|
||||
|
||||
type Cache interface {
|
||||
// ** used by model implementations **
|
||||
|
||||
// SetLayer sets the active layer of the cache
|
||||
SetLayer(layer int)
|
||||
|
||||
// Get returns the history of key and value tensors plus a mask
|
||||
//
|
||||
// The shape of the tensors is documented in the specific
|
||||
// cache implementation used.
|
||||
Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)
|
||||
|
||||
// Put stores a batch of key and value in the cache
|
||||
//
|
||||
// The shape of the tensors is documented in the specific
|
||||
// cache implementation used.
|
||||
Put(ctx ml.Context, key, value ml.Tensor)
|
||||
|
||||
// SetConfig controls optimizations (mostly backend-specific) that may transform
|
||||
// the output of the cache to work better with specific kernels. If not called,
|
||||
// the backend settings will be used. This works well when calling Attention.
|
||||
//
|
||||
// The config can be overridden by models, especially if they require vanilla
|
||||
// output when implementing their own version of attention. To do this, pass
|
||||
// an empty ml.CacheConfig.
|
||||
//
|
||||
// Most models will not need to use this.
|
||||
SetConfig(ml.CacheConfig)
|
||||
|
||||
// ** cache management **
|
||||
|
||||
// Init sets up runtime parameters.
|
||||
// backend: Used to allocate cache data storage and execute management operations (such as defrag)
|
||||
// dtype: The data type for storing cache entries
|
||||
// maxSequences: The maximum number of sequences stored in the cache - across all batches
|
||||
// capacity: The number of cache entries to store, per sequence
|
||||
// maxBatch: The maximum number of tokens that can occur in a single batch
|
||||
Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
|
||||
|
||||
// Close closes the cache and frees resources associated with it
|
||||
Close()
|
||||
|
||||
// StartForward is called before the start of the model's forward pass.
|
||||
// For each token in the coming batch, there must be a corresponding
|
||||
// entry in positions and seqs. reserve is to preallocate memory
|
||||
// without actually storing data in the cache.
|
||||
StartForward(ctx ml.Context, batch input.Batch, reserve bool) error
|
||||
|
||||
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
|
||||
CopyPrefix(srcSeq, dstSeq int, len int32)
|
||||
|
||||
// CanResume returns true if the cache can continue with the next token at
|
||||
// the given position and sequence. Assumes that the caller has already
|
||||
// verified the contents of the cache.
|
||||
CanResume(seq int, pos int32) bool
|
||||
|
||||
// Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set
|
||||
// endIndex to math.MaxInt32 to remove everything starting at beginIndex.
|
||||
//
|
||||
// If an error occurs, the entire context for the sequence should be
|
||||
// removed by calling Remove(seq, 0, math.MaxInt32)
|
||||
Remove(seq int, beginIndex, endIndex int32) error
|
||||
}
|
||||
144
x/kvcache/causal.go
Normal file
144
x/kvcache/causal.go
Normal file
@@ -0,0 +1,144 @@
|
||||
//go:build mlx
|
||||
|
||||
package kvcache
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
"github.com/ollama/ollama/x/model/input"
|
||||
)
|
||||
|
||||
// Causal cache stores K and V tensors according to their position in the
|
||||
// sequence. Returns the history and a mask for attending to past tokens
|
||||
type Causal struct {
|
||||
DType ml.DType
|
||||
|
||||
// locations for data storage for this batch
|
||||
curLocPut ml.Tensor
|
||||
|
||||
// locations for data storage for this batch
|
||||
curLocGet ml.Tensor
|
||||
|
||||
// the active layer for Get and Put
|
||||
curLayer int
|
||||
|
||||
capacity int
|
||||
|
||||
offset int
|
||||
|
||||
backend ml.Backend
|
||||
ctxs map[int]ml.Context
|
||||
keys, values map[int]ml.Tensor
|
||||
|
||||
// TODO is this needed per layer, or will it always be consistent?
|
||||
kHeadDims, vHeadDims, numKVHeads map[int]int
|
||||
}
|
||||
|
||||
func NewCausalCache() *Causal {
|
||||
return &Causal{
|
||||
ctxs: make(map[int]ml.Context),
|
||||
keys: make(map[int]ml.Tensor),
|
||||
values: make(map[int]ml.Tensor),
|
||||
kHeadDims: make(map[int]int),
|
||||
vHeadDims: make(map[int]int),
|
||||
numKVHeads: make(map[int]int),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
c.DType = dtype
|
||||
c.capacity = capacity
|
||||
c.backend = backend
|
||||
}
|
||||
|
||||
func (c *Causal) SetConfig(config ml.CacheConfig) {}
|
||||
|
||||
func (c *Causal) SetLayer(layer int) {
|
||||
c.curLayer = layer
|
||||
}
|
||||
|
||||
func (c *Causal) Close() {
|
||||
// slog.Info("XXX Causal.Close called", "number of contexts", len(c.ctxs))
|
||||
for _, ctx := range c.ctxs {
|
||||
ctx.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
locsPut := make([]int32, len(batch.Positions))
|
||||
for i := c.offset; i < len(batch.Positions); i++ {
|
||||
locsPut[i-c.offset] = int32(i)
|
||||
}
|
||||
c.offset += len(batch.Positions)
|
||||
locsGet := make([]int32, c.offset)
|
||||
for i := range c.offset {
|
||||
locsGet[i] = int32(i)
|
||||
}
|
||||
c.curLocGet = ctx.Input().FromInts(locsGet, len(locsGet))
|
||||
c.curLocPut = ctx.Input().FromInts(locsPut, len(locsPut))
|
||||
// slog.Info("XXX Causal.StartForward", "offset", c.offset, "put", locsPut, "get", locsGet)
|
||||
|
||||
return nil
|
||||
}
|
||||
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
kHeadDim := key.Dim(3)
|
||||
vHeadDim := value.Dim(3)
|
||||
numKVHeads := key.Dim(1)
|
||||
batchSize := key.Dim(2)
|
||||
kCellSize := kHeadDim * numKVHeads
|
||||
vCellSize := vHeadDim * numKVHeads
|
||||
// slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize, "kCellSize", kCellSize, "vCellSize", vCellSize)
|
||||
|
||||
if _, ok := c.ctxs[c.curLayer]; !ok {
|
||||
// slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer)
|
||||
c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
|
||||
}
|
||||
|
||||
if _, ok := c.keys[c.curLayer]; !ok {
|
||||
// slog.Info("XXX Causal.Put allocating keys and values", "c.curLayer", c.curLayer, "shape", []int{c.capacity, kCellSize})
|
||||
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, kCellSize)
|
||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, vCellSize)
|
||||
c.kHeadDims[c.curLayer] = kHeadDim
|
||||
c.vHeadDims[c.curLayer] = vHeadDim
|
||||
c.numKVHeads[c.curLayer] = numKVHeads
|
||||
}
|
||||
key = key.Reshape(ctx, batchSize, 1, kCellSize)
|
||||
|
||||
// slog.Info("XXX Causal.Put ", "c.keys[c.curLayer]", c.keys[c.curLayer])
|
||||
// slog.Info("XXX Causal.Put ", "c.curLocPut", c.curLocPut)
|
||||
// slog.Info("XXX Causal.Put ", "key", key)
|
||||
ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, key, []int{0}))
|
||||
value = value.Reshape(ctx, batchSize, 1, vCellSize)
|
||||
ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, value, []int{0}))
|
||||
|
||||
}
|
||||
|
||||
func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
key := c.keys[c.curLayer]
|
||||
value := c.values[c.curLayer]
|
||||
|
||||
kHeadDim := c.kHeadDims[c.curLayer]
|
||||
vHeadDim := c.vHeadDims[c.curLayer]
|
||||
numKVHeads := c.numKVHeads[c.curLayer]
|
||||
// rowSize := numKVHeads * c.curBatchSize
|
||||
// cachedSize := c.curMask.Dim(1)
|
||||
cachedSize := c.curLocGet.Dim(0)
|
||||
// kCellSize := kHeadDim * numKVHeads
|
||||
// vCellSize := vHeadDim * numKVHeads
|
||||
// slog.Info("XXX Causal.Get", "shape", []int{1, numKVHeads, cachedSize, kHeadDim})
|
||||
|
||||
key = key.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim)
|
||||
value = value.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim)
|
||||
return key, value, nil
|
||||
}
|
||||
|
||||
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *Causal) CanResume(seq int, pos int32) bool {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
156
x/kvcache/encoder.go
Normal file
156
x/kvcache/encoder.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package kvcache
|
||||
|
||||
// import (
|
||||
// "fmt"
|
||||
|
||||
// "github.com/ollama/ollama/ml"
|
||||
// "github.com/ollama/ollama/model/input"
|
||||
// )
|
||||
|
||||
// // Encoder cache stores K and V tensors that are position independent
|
||||
// //
|
||||
// // The tensors can be of any shape and will be returned as they were stored
|
||||
// // The mask is currently always nil
|
||||
// //
|
||||
// // Not currently safe for multiple sequences
|
||||
// type EncoderCache struct {
|
||||
// // config controls mostly backend-specific optimizations
|
||||
// config *ml.CacheConfig
|
||||
|
||||
// // ** current forward pass **
|
||||
|
||||
// // the active layer for Get and Put
|
||||
// curLayer int
|
||||
|
||||
// // if something is stored during this pass, this
|
||||
// // will be the position (but there is no guarantee
|
||||
// // anything will be stored)
|
||||
// curPos int32
|
||||
|
||||
// // curReserve indicates that this forward pass is only for
|
||||
// // memory reservation and we should not update our metadata
|
||||
// // based on it.
|
||||
// curReserve bool
|
||||
|
||||
// // ** cache metadata **
|
||||
|
||||
// // was something stored in the cache?
|
||||
// encoderCached bool
|
||||
|
||||
// // position of the cached data
|
||||
// encoderPos int32
|
||||
|
||||
// // ** cache data storage **
|
||||
// backend ml.Backend
|
||||
// ctxs map[int]ml.Context
|
||||
// keys, values map[int]ml.Tensor
|
||||
// }
|
||||
|
||||
// func NewEncoderCache() *EncoderCache {
|
||||
// return &EncoderCache{
|
||||
// ctxs: make(map[int]ml.Context),
|
||||
// keys: make(map[int]ml.Tensor),
|
||||
// values: make(map[int]ml.Tensor),
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
// if c.config == nil {
|
||||
// var config ml.CacheConfig
|
||||
// if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
||||
// config = cc.CacheConfig()
|
||||
// }
|
||||
// c.config = &config
|
||||
// }
|
||||
|
||||
// if maxSequences > 1 {
|
||||
// panic(fmt.Errorf("encoder cache does not support multiple sequences; requested: %v", maxSequences))
|
||||
// }
|
||||
|
||||
// if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
|
||||
// panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
|
||||
// }
|
||||
|
||||
// c.backend = backend
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
|
||||
// if c.config != nil {
|
||||
// panic("config cannot be changed after being previously set, either by the model or backend")
|
||||
// }
|
||||
|
||||
// c.config = &config
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) Close() {
|
||||
// for _, ctx := range c.ctxs {
|
||||
// ctx.Close()
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
// // We work with the most recent image
|
||||
// if len(batch.Multimodal) > 0 {
|
||||
// c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
|
||||
// }
|
||||
|
||||
// c.curReserve = reserve
|
||||
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) SetLayer(layer int) {
|
||||
// c.curLayer = layer
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) EncoderCached() bool {
|
||||
// return c.encoderCached
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
// return c.keys[c.curLayer], c.values[c.curLayer], nil
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
// if !c.curReserve {
|
||||
// c.encoderPos = c.curPos
|
||||
// c.encoderCached = true
|
||||
// }
|
||||
|
||||
// if c.config.PermutedV {
|
||||
// value = value.Transpose(ctx, 1, 2, 0, 3)
|
||||
// }
|
||||
|
||||
// if _, ok := c.ctxs[c.curLayer]; !ok {
|
||||
// c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
|
||||
// }
|
||||
|
||||
// if _, ok := c.keys[c.curLayer]; !ok {
|
||||
// c.keys[c.curLayer] = c.ctxs[c.curLayer].Empty(key.DType(), key.Shape()...)
|
||||
// }
|
||||
|
||||
// if _, ok := c.values[c.curLayer]; !ok {
|
||||
// c.values[c.curLayer] = c.ctxs[c.curLayer].Empty(value.DType(), value.Shape()...)
|
||||
// }
|
||||
|
||||
// ctx.Forward(
|
||||
// key.Copy(ctx, c.keys[c.curLayer]),
|
||||
// value.Copy(ctx, c.values[c.curLayer]),
|
||||
// )
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
// panic("encoder cache does not support multiple sequences")
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) CanResume(seq int, pos int32) bool {
|
||||
// return true
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
// if c.encoderPos >= beginIndex && c.encoderPos < endIndex {
|
||||
// c.encoderCached = false
|
||||
// }
|
||||
|
||||
// return nil
|
||||
// }
|
||||
110
x/kvcache/wrapper.go
Normal file
110
x/kvcache/wrapper.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package kvcache
|
||||
|
||||
// import (
|
||||
// "math"
|
||||
|
||||
// "github.com/ollama/ollama/ml"
|
||||
// "github.com/ollama/ollama/model/input"
|
||||
// )
|
||||
|
||||
// // Wrapper cache is a container for multiple types of caches,
|
||||
// // such as for the encoding and decoding portions of a model.
|
||||
// type WrapperCache struct {
|
||||
// // caches we are wrapping
|
||||
// caches []Cache
|
||||
|
||||
// // cache to be used for this layer
|
||||
// curType int
|
||||
// }
|
||||
|
||||
// func NewWrapperCache(caches ...Cache) *WrapperCache {
|
||||
// return &WrapperCache{
|
||||
// caches: caches,
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
// for _, cache := range c.caches {
|
||||
// cache.Init(backend, dtype, maxSequences, capacity, maxBatch)
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) SetConfig(config ml.CacheConfig) {
|
||||
// for _, cache := range c.caches {
|
||||
// cache.SetConfig(config)
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) Close() {
|
||||
// for _, cache := range c.caches {
|
||||
// cache.Close()
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
// for i, cache := range c.caches {
|
||||
// err := cache.StartForward(ctx, batch, reserve)
|
||||
// if err != nil {
|
||||
// // unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
|
||||
// for j := i - 1; j >= 0; j-- {
|
||||
// for k := range batch.Positions {
|
||||
// _ = c.caches[j].Remove(batch.Sequences[k], batch.Positions[k], math.MaxInt32)
|
||||
// }
|
||||
// }
|
||||
// return err
|
||||
// }
|
||||
// }
|
||||
|
||||
// c.curType = 0
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) SetLayer(layer int) {
|
||||
// for _, cache := range c.caches {
|
||||
// cache.SetLayer(layer)
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) SetLayerType(layerType int) {
|
||||
// c.curType = layerType
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) UnderlyingCache() Cache {
|
||||
// return c.caches[c.curType]
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
// return c.caches[c.curType].Get(ctx)
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
// c.caches[c.curType].Put(ctx, key, value)
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
// for _, cache := range c.caches {
|
||||
// cache.CopyPrefix(srcSeq, dstSeq, len)
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) CanResume(seq int, pos int32) bool {
|
||||
// for _, cache := range c.caches {
|
||||
// if !cache.CanResume(seq, pos) {
|
||||
// return false
|
||||
// }
|
||||
// }
|
||||
|
||||
// return true
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
// // If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail
|
||||
// for _, cache := range c.caches {
|
||||
// err := cache.Remove(seq, beginIndex, endIndex)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// }
|
||||
|
||||
// return nil
|
||||
// }
|
||||
433
x/ml/backend.go
Normal file
433
x/ml/backend.go
Normal file
@@ -0,0 +1,433 @@
|
||||
package ml
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
)
|
||||
|
||||
type Backend interface {
|
||||
// Close frees all memory associated with this backend
|
||||
// Close()
|
||||
|
||||
// Load(ctx context.Context, progress func(float32)) error
|
||||
|
||||
// BackendMemory returns the memory allocations that were made for this model
|
||||
// BackendMemory() BackendMemory
|
||||
|
||||
Config() fs.Config
|
||||
Get(name string) Tensor
|
||||
NewContext() Context
|
||||
// NewContextSize(size int) Context
|
||||
|
||||
// Enumerate the devices available for inference via this backend
|
||||
// BackendDevices() []DeviceInfo
|
||||
}
|
||||
|
||||
// BackendCacheConfig should be implemented by backends that need special output
|
||||
// from the cache to meet specific requirements. It is frequently implemented in
|
||||
// conjunction with ScaledDotProductAttention.
|
||||
type BackendCacheConfig interface {
|
||||
CacheConfig() CacheConfig
|
||||
}
|
||||
|
||||
// CacheConfig controls optimizations (mostly backend-specific) that may transform
|
||||
// the output the cache to work better with specific kernels.
|
||||
type CacheConfig struct {
|
||||
// CachePadding specifies the multiple for the number of tokens of cache history
|
||||
// that will be returned from cache Get for k, v and mask. The capacity of the
|
||||
// cache itself will also be increased to a multiple of this size if needed.
|
||||
CachePadding int
|
||||
|
||||
// PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put
|
||||
// and return the permuted version via Get. This uses the cache copy operation
|
||||
// to avoid a Contiguous call on the permuted tensor.
|
||||
PermutedV bool
|
||||
|
||||
// MaskDType specifies the data type for generating the mask. If unset it will
|
||||
// default to DTypeF32.
|
||||
MaskDType DType
|
||||
|
||||
// MaskBatchPadding specifies the multiple for the batch size dimension in the mask.
|
||||
// Any position that does not correspond to an actual token will be filled with -Inf.
|
||||
MaskBatchPadding int
|
||||
}
|
||||
|
||||
// BackendParams controls how the backend loads and executes models
|
||||
type BackendParams struct {
|
||||
// AllocMemory causes the backend to allocate memory for the model. If
|
||||
// false, this is only being used for discovering the required amount of
|
||||
// memory and cannot load the model for running.
|
||||
AllocMemory bool
|
||||
|
||||
// NumThreads sets the number of threads to use if running on the CPU
|
||||
NumThreads int
|
||||
|
||||
// GPULayers is the set of layers to offload to GPUs
|
||||
GPULayers GPULayersList
|
||||
|
||||
// FlashAttention indicates that we should use a fused flash attention kernel
|
||||
FlashAttention bool
|
||||
}
|
||||
|
||||
var backends = make(map[string]func(string, BackendParams) (Backend, error))
|
||||
|
||||
func RegisterBackend(name string, f func(string, BackendParams) (Backend, error)) {
|
||||
if _, ok := backends[name]; ok {
|
||||
panic("backend: backend already registered")
|
||||
}
|
||||
|
||||
backends[name] = f
|
||||
}
|
||||
|
||||
func NewBackend(modelPath string, params BackendParams) (Backend, error) {
|
||||
be := os.Getenv("OLLAMA_BACKEND")
|
||||
if be == "" {
|
||||
be = "mlx"
|
||||
slog.Info("Defaulting to " + be + ". Set OLLAMA_BACKEND to override")
|
||||
}
|
||||
slog.Info("Loading new engine", "backend", be)
|
||||
if backend, ok := backends[be]; ok {
|
||||
return backend(modelPath, params)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unsupported backend")
|
||||
}
|
||||
|
||||
type Context interface {
|
||||
Empty(dtype DType, shape ...int) Tensor
|
||||
Zeros(dtype DType, shape ...int) Tensor
|
||||
// FromBytes(dtype DType, s []byte, shape ...int) Tensor
|
||||
FromFloats(s []float32, shape ...int) Tensor
|
||||
FromInts(s []int32, shape ...int) Tensor
|
||||
RandomNormal(shape []int, dtype DType, loc, scale float32, key Tensor) Tensor
|
||||
|
||||
// Arange creates a 1D tensor with values within an interval (start, stop] increased by step.
|
||||
Arange(start, stop, step float32, dtype DType) Tensor
|
||||
|
||||
Forward(...Tensor) Context
|
||||
|
||||
// SetBatchSize provides a hint on the batch size to optimize processing
|
||||
// Uses heuristics if not set
|
||||
// SetBatchSize(int)
|
||||
|
||||
Compute(...Tensor)
|
||||
// ComputeWithNotify(func(), ...Tensor) // notify callback once compute has begun
|
||||
|
||||
// Reserve is analogous to Compute but rather than executing a
|
||||
// graph, simply preallocates memory. Typically called with a
|
||||
// worst case graph to ensure all resources are available for
|
||||
// for future inference.
|
||||
// Reserve()
|
||||
|
||||
// MaxGraphNodes() int
|
||||
Close()
|
||||
|
||||
// Input returns a context appropriate for creating tensors that are
|
||||
// inputs to the model (which includes things like output locations)
|
||||
Input() Context
|
||||
|
||||
// Layer returns a context appropriate for creating intermediate tensors
|
||||
Layer(int) Context
|
||||
|
||||
// Load a tensor from "filename" safetensors file, and compare with the input tensor
|
||||
// Returns error if the shape is inconsistent, or similarity measures are below 99%
|
||||
CompareWith(filename string, tensors map[string]Tensor, abortOnError bool) error
|
||||
}
|
||||
|
||||
type RoPEOptions struct {
|
||||
Base *float32
|
||||
Freqs Tensor
|
||||
}
|
||||
|
||||
func WithRoPEBase(base float32) func(*RoPEOptions) {
|
||||
return func(opts *RoPEOptions) {
|
||||
opts.Base = &base
|
||||
}
|
||||
}
|
||||
|
||||
func WithRoPEFreqs(freqs Tensor) func(*RoPEOptions) {
|
||||
return func(opts *RoPEOptions) {
|
||||
opts.Freqs = freqs
|
||||
}
|
||||
}
|
||||
|
||||
type Tensor interface {
|
||||
ToString() string
|
||||
RoPE(ctx Context, dims int, traditional bool, scale float32, offset int, options ...func(*RoPEOptions)) Tensor
|
||||
ScaledDotProductAttention(ctx Context, keys, values Tensor, scale float64, maskMode string, mask Tensor, sinks Tensor) Tensor
|
||||
TakeAxes(ctx Context, indicies Tensor, axes int) Tensor
|
||||
// TakeAxes(ctx Context, axes int, indicies ...int) Tensor
|
||||
|
||||
Dim(n int) int
|
||||
Stride(n int) int
|
||||
|
||||
Shape() []int
|
||||
DType() DType
|
||||
// Cast(ctx Context, dtype DType) Tensor
|
||||
|
||||
// Bytes() []byte
|
||||
Floats() []float32
|
||||
Ints() []int32
|
||||
|
||||
// FromBytes([]byte)
|
||||
// FromFloats([]float32)
|
||||
// FromInts([]int32)
|
||||
|
||||
Add(ctx Context, t2 Tensor) Tensor
|
||||
Sub(ctx Context, t2 Tensor) Tensor
|
||||
// Mul(ctx Context, t2 Tensor) Tensor
|
||||
// Div(ctx Context, t2 Tensor) Tensor
|
||||
|
||||
Max(ctx Context, axes []int, keepDims bool) Tensor
|
||||
Min(ctx Context, axes []int, keepDims bool) Tensor
|
||||
|
||||
Matmul(ctx Context, a2 Tensor) Tensor
|
||||
// Mulmat(ctx Context, t2 Tensor) Tensor
|
||||
// MulmatFullPrec(ctx Context, t2 Tensor) Tensor
|
||||
// MulmatID(ctx Context, t2, ids Tensor) Tensor
|
||||
// AddID(ctx Context, t2, ids Tensor) Tensor
|
||||
|
||||
Softmax(ctx Context) Tensor
|
||||
L2Norm(ctx Context, eps float32) Tensor
|
||||
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
|
||||
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
|
||||
Scale(ctx Context, s float64) Tensor
|
||||
// SumRows(ctx Context) Tensor
|
||||
|
||||
AvgPool2D(ctx Context, k, s int, p float32) Tensor
|
||||
Conv2D(ctx Context, weight Tensor, stride0, stride1, padding0, padding1, dilation0, dilation1, groups int) Tensor
|
||||
Conv3D(ctx Context, weight Tensor, stride0, stride1, stride2, padding0, padding1, padding2, dilation0, dilation1, dilation2, groups int) Tensor
|
||||
|
||||
// IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||
|
||||
// Sin(ctx Context) Tensor
|
||||
// Cos(ctx Context) Tensor
|
||||
// Tanh(ctx Context) Tensor
|
||||
GELU(ctx Context, up ...Tensor) Tensor
|
||||
// QuickGELU(ctx Context, up ...Tensor) Tensor
|
||||
// SILU(ctx Context, up ...Tensor) Tensor
|
||||
// RELU(ctx Context, up ...Tensor) Tensor
|
||||
// Sigmoid(ctx Context) Tensor
|
||||
|
||||
// AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit]
|
||||
// SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor
|
||||
|
||||
Reshape(ctx Context, shape ...int) Tensor
|
||||
AsStrided(ctx Context, shape, strides []int, offset int) Tensor
|
||||
Transpose(ctx Context, shape ...int) Tensor
|
||||
Contiguous(ctx Context, allowColMajor bool) Tensor
|
||||
|
||||
// Pad(ctx Context, shape ...int) Tensor
|
||||
|
||||
// Stack(ctx Context, dim int, s ...Tensor) Tensor
|
||||
|
||||
// Repeat repeats the tensor n times along dimension dim
|
||||
// Repeat(ctx Context, dim, n int) Tensor
|
||||
// Concat(ctx Context, t2 Tensor, dim int) Tensor
|
||||
// Rows(ctx Context, t2 Tensor) Tensor
|
||||
|
||||
// TODO these probably aren't actually needed - false starts on trying to wire up cache
|
||||
// SliceUpdate(ctx Context, update Tensor, start, stop, strides []int) Tensor
|
||||
// SliceUpdateDynamic(ctx Context, update, start Tensor, axes []int) Tensor
|
||||
// PutAlongAxis(ctx Context, indicies, values Tensor, axis int) Tensor
|
||||
|
||||
Scatter(ctx Context, indicies []Tensor, updates Tensor, axes []int) Tensor
|
||||
|
||||
Copy(ctx Context, t2 Tensor) Tensor
|
||||
// Duplicate(ctx Context) Tensor
|
||||
|
||||
// Slice(ctx Context, dim, low, high, step int) Tensor
|
||||
// Chunk(ctx Context, dim int, size int) []Tensor
|
||||
// ChunkSections(ctx Context, dim int, sections ...int) []Tensor
|
||||
|
||||
// TopK(ctx Context, k int) Tensor
|
||||
// Argsort(ctx Context) Tensor
|
||||
// Mean(ctx Context) Tensor
|
||||
// Variance(ctx Context) Tensor
|
||||
// Stddev(ctx Context) Tensor
|
||||
// Sqr(ctx Context) Tensor
|
||||
// Sqrt(ctx Context) Tensor
|
||||
|
||||
// Interpolate(ctx Context, dims [4]int, samplingMode SamplingMode) Tensor
|
||||
}
|
||||
|
||||
// ScaledDotProductAttention implements a fused attention
|
||||
// operation equivalent to following code on a tensor named
|
||||
// query:
|
||||
//
|
||||
// query = query.Permute(ctx, 0, 2, 1, 3)
|
||||
// key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
//
|
||||
// kq := key.MulmatFullPrec(ctx, query)
|
||||
//
|
||||
// kq = kq.Scale(ctx, scale)
|
||||
//
|
||||
// if mask != nil {
|
||||
// kq = kq.Add(ctx, mask)
|
||||
// }
|
||||
//
|
||||
// kq = kq.Softmax(ctx)
|
||||
//
|
||||
// kqv := value.Mulmat(ctx, kq)
|
||||
// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
// type ScaledDotProductAttention interface {
|
||||
// ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64) Tensor
|
||||
// }
|
||||
|
||||
// type number interface {
|
||||
// ~int | ~int8 | ~int16 | ~int32 | ~int64 |
|
||||
// ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
|
||||
// ~float32 | ~float64 |
|
||||
// ~complex64 | ~complex128
|
||||
// }
|
||||
|
||||
// func mul[T number](s ...T) T {
|
||||
// p := T(1)
|
||||
// for _, v := range s {
|
||||
// p *= v
|
||||
// }
|
||||
|
||||
// return p
|
||||
// }
|
||||
|
||||
// type DumpOptions func(*dumpOptions)
|
||||
|
||||
// // DumpWithPrecision sets the number of decimal places to print. Applies to float32 and float64.
|
||||
// func DumpWithPrecision(n int) DumpOptions {
|
||||
// return func(opts *dumpOptions) {
|
||||
// opts.Precision = n
|
||||
// }
|
||||
// }
|
||||
|
||||
// // DumpWithThreshold sets the threshold for printing the entire tensor. If the number of elements
|
||||
// // is less than or equal to this value, the entire tensor will be printed. Otherwise, only the
|
||||
// // beginning and end of each dimension will be printed.
|
||||
// func DumpWithThreshold(n int) DumpOptions {
|
||||
// return func(opts *dumpOptions) {
|
||||
// opts.Threshold = n
|
||||
// }
|
||||
// }
|
||||
|
||||
// // DumpWithEdgeItems sets the number of elements to print at the beginning and end of each dimension.
|
||||
// func DumpWithEdgeItems(n int) DumpOptions {
|
||||
// return func(opts *dumpOptions) {
|
||||
// opts.EdgeItems = n
|
||||
// }
|
||||
// }
|
||||
|
||||
// type dumpOptions struct {
|
||||
// Precision, Threshold, EdgeItems int
|
||||
// }
|
||||
|
||||
// func Dump(ctx Context, t Tensor, optsFuncs ...DumpOptions) string {
|
||||
// opts := dumpOptions{Precision: 4, Threshold: 1000, EdgeItems: 3}
|
||||
// for _, optsFunc := range optsFuncs {
|
||||
// optsFunc(&opts)
|
||||
// }
|
||||
|
||||
// if mul(t.Shape()...) <= opts.Threshold {
|
||||
// opts.EdgeItems = math.MaxInt
|
||||
// }
|
||||
|
||||
// switch t.DType() {
|
||||
// case DTypeFloat32:
|
||||
// return dump[[]float32](ctx, t, opts.EdgeItems, func(f float32) string {
|
||||
// return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
|
||||
// })
|
||||
// case DTypeFloat16: // TODO other types...
|
||||
// f32 := ctx.Input().Empty(DTypeFloat32, t.Shape()...)
|
||||
// f32 = t.Copy(ctx, f32)
|
||||
// return dump[[]float32](ctx, f32, opts.EdgeItems, func(f float32) string {
|
||||
// return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
|
||||
// })
|
||||
// case DTypeInt32:
|
||||
// return dump[[]int32](ctx, t, opts.EdgeItems, func(i int32) string {
|
||||
// return strconv.FormatInt(int64(i), 10)
|
||||
// })
|
||||
// default:
|
||||
// return "<unsupported>"
|
||||
// }
|
||||
// }
|
||||
|
||||
// func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
|
||||
// if t.Bytes() == nil {
|
||||
// ctx.Compute(t)
|
||||
// }
|
||||
|
||||
// s := make(S, mul(t.Shape()...))
|
||||
// if err := binary.Read(bytes.NewBuffer(t.Bytes()), binary.LittleEndian, &s); err != nil {
|
||||
// panic(err)
|
||||
// }
|
||||
|
||||
// shape := t.Shape()
|
||||
// slices.Reverse(shape)
|
||||
|
||||
// var sb strings.Builder
|
||||
// var f func([]int, int)
|
||||
// f = func(dims []int, stride int) {
|
||||
// prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
|
||||
// sb.WriteString("[")
|
||||
// defer func() { sb.WriteString("]") }()
|
||||
// for i := 0; i < dims[0]; i++ {
|
||||
// if i >= items && i < dims[0]-items {
|
||||
// sb.WriteString("..., ")
|
||||
// // skip to next printable element
|
||||
// skip := dims[0] - 2*items
|
||||
// if len(dims) > 1 {
|
||||
// stride += mul(append(dims[1:], skip)...)
|
||||
// fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
|
||||
// }
|
||||
// i += skip - 1
|
||||
// } else if len(dims) > 1 {
|
||||
// f(dims[1:], stride)
|
||||
// stride += mul(dims[1:]...)
|
||||
// if i < dims[0]-1 {
|
||||
// fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
|
||||
// }
|
||||
// } else {
|
||||
// text := fn(s[stride+i])
|
||||
// if len(text) > 0 && text[0] != '-' {
|
||||
// sb.WriteString(" ")
|
||||
// }
|
||||
|
||||
// sb.WriteString(text)
|
||||
// if i < dims[0]-1 {
|
||||
// sb.WriteString(", ")
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// f(shape, 0)
|
||||
|
||||
// return sb.String()
|
||||
// }
|
||||
|
||||
type DType int
|
||||
|
||||
const (
|
||||
DTypeBool DType = iota
|
||||
DTypeUint8
|
||||
DTypeUint16
|
||||
DTypeUint32
|
||||
DTypeUint64
|
||||
DTypeInt8
|
||||
DTypeInt16
|
||||
DTypeInt32
|
||||
DTypeInt64
|
||||
DTypeFloat16
|
||||
DTypeFloat32
|
||||
DTypeFloat64
|
||||
DTypeBfloat16
|
||||
DTypeComplex64
|
||||
)
|
||||
|
||||
type SamplingMode int
|
||||
|
||||
const (
|
||||
SamplingModeNearest SamplingMode = iota
|
||||
SamplingModeBilinear
|
||||
)
|
||||
3
x/ml/backend/backend.go
Normal file
3
x/ml/backend/backend.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package backend
|
||||
|
||||
// _ "github.com/ollama/ollama/x/ml/backend/mlx"
|
||||
61
x/ml/backend/mlx/CMakeLists.txt
Normal file
61
x/ml/backend/mlx/CMakeLists.txt
Normal file
@@ -0,0 +1,61 @@
|
||||
include(FetchContent)
|
||||
|
||||
# Read MLX version from top-level file (shared with Dockerfile)
|
||||
file(READ "${CMAKE_SOURCE_DIR}/MLX_VERSION" MLX_C_GIT_TAG)
|
||||
string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG)
|
||||
|
||||
set(MLX_C_BUILD_EXAMPLES OFF)
|
||||
|
||||
set(MLX_BUILD_GGUF OFF)
|
||||
set(MLX_BUILD_SAFETENSORS ON)
|
||||
|
||||
function(set_target_output_directory _target)
|
||||
if(TARGET ${_target})
|
||||
set_target_properties(${_target} PROPERTIES
|
||||
RUNTIME_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR}
|
||||
LIBRARY_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR}
|
||||
ARCHIVE_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR}
|
||||
)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
# Check for Metal support (macOS only)
|
||||
if(CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
||||
execute_process(
|
||||
COMMAND
|
||||
zsh "-c"
|
||||
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
|
||||
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
|
||||
if(NOT MLX_METAL_VERSION)
|
||||
message(STATUS "`xcrun metal` error. Setting MLX_BUILD_METAL=OFF")
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
endif()
|
||||
else()
|
||||
# On Linux, disable Metal backend
|
||||
message(STATUS "Non-macOS platform detected. Setting MLX_BUILD_METAL=OFF")
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
endif()
|
||||
|
||||
# Map CMAKE_CUDA_ARCHITECTURES to MLX_CUDA_ARCHITECTURES if not explicitly set
|
||||
if(NOT MLX_CUDA_ARCHITECTURES AND CMAKE_CUDA_ARCHITECTURES)
|
||||
set(MLX_CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES})
|
||||
message(STATUS "Using CMAKE_CUDA_ARCHITECTURES for MLX: ${MLX_CUDA_ARCHITECTURES}")
|
||||
endif()
|
||||
|
||||
# Enable CUDA backend if CUDA architectures are specified and CUDA compiler is available
|
||||
if(MLX_CUDA_ARCHITECTURES AND CMAKE_CUDA_COMPILER)
|
||||
set(MLX_BUILD_CUDA ON CACHE BOOL "Build CUDA backend for MLX" FORCE)
|
||||
message(STATUS "Enabling MLX CUDA backend with architectures: ${MLX_CUDA_ARCHITECTURES}")
|
||||
elseif(MLX_CUDA_ARCHITECTURES)
|
||||
message(WARNING "MLX_CUDA_ARCHITECTURES specified but CUDA compiler not found, CUDA backend will be disabled")
|
||||
endif()
|
||||
|
||||
FetchContent_Declare(
|
||||
mlx-c
|
||||
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
|
||||
GIT_TAG ${MLX_C_GIT_TAG})
|
||||
FetchContent_MakeAvailable(mlx-c)
|
||||
|
||||
set_target_output_directory(mlx)
|
||||
set_target_output_directory(mlxc)
|
||||
1278
x/ml/backend/mlx/mlx.go
Normal file
1278
x/ml/backend/mlx/mlx.go
Normal file
File diff suppressed because it is too large
Load Diff
92
x/ml/backend/mlx/mlx_dynamic.c
Normal file
92
x/ml/backend/mlx/mlx_dynamic.c
Normal file
@@ -0,0 +1,92 @@
|
||||
// mlx_dynamic.c - Dynamic loading wrapper for MLX-C library
|
||||
// This file provides runtime dynamic loading of libmlxc instead of link-time binding
|
||||
|
||||
#include "mlx_dynamic.h"
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <windows.h>
|
||||
typedef HMODULE lib_handle_t;
|
||||
#define LOAD_LIB(path) LoadLibraryA(path)
|
||||
#define GET_SYMBOL(handle, name) GetProcAddress(handle, name)
|
||||
#define CLOSE_LIB(handle) FreeLibrary(handle)
|
||||
#define LIB_ERROR() "LoadLibrary failed"
|
||||
static const char* LIB_NAMES[] = {"libmlxc.dll", NULL};
|
||||
#else
|
||||
#include <dlfcn.h>
|
||||
typedef void* lib_handle_t;
|
||||
#define LOAD_LIB(path) dlopen(path, RTLD_LAZY | RTLD_GLOBAL)
|
||||
#define GET_SYMBOL(handle, name) dlsym(handle, name)
|
||||
#define CLOSE_LIB(handle) dlclose(handle)
|
||||
#define LIB_ERROR() dlerror()
|
||||
#ifdef __APPLE__
|
||||
static const char* LIB_NAMES[] = {
|
||||
"libmlxc.dylib",
|
||||
"@loader_path/../build/lib/ollama/libmlxc.dylib",
|
||||
"@executable_path/../build/lib/ollama/libmlxc.dylib",
|
||||
"build/lib/ollama/libmlxc.dylib",
|
||||
"../build/lib/ollama/libmlxc.dylib",
|
||||
NULL
|
||||
};
|
||||
#else
|
||||
static const char* LIB_NAMES[] = {
|
||||
"libmlxc.so",
|
||||
"$ORIGIN/../build/lib/ollama/libmlxc.so",
|
||||
"build/lib/ollama/libmlxc.so",
|
||||
"../build/lib/ollama/libmlxc.so",
|
||||
NULL
|
||||
};
|
||||
#endif
|
||||
#endif
|
||||
|
||||
static lib_handle_t mlx_handle = NULL;
|
||||
static int mlx_initialized = 0;
|
||||
static char mlx_error_buffer[512] = {0};
|
||||
|
||||
// Initialize MLX dynamic library
|
||||
// Returns 0 on success, -1 on failure
|
||||
// On failure, call mlx_dynamic_error() to get error message
|
||||
int mlx_dynamic_init(void) {
|
||||
if (mlx_initialized) {
|
||||
return 0; // Already initialized
|
||||
}
|
||||
|
||||
// Try each possible library path
|
||||
for (int i = 0; LIB_NAMES[i] != NULL; i++) {
|
||||
mlx_handle = LOAD_LIB(LIB_NAMES[i]);
|
||||
if (mlx_handle != NULL) {
|
||||
mlx_initialized = 1;
|
||||
snprintf(mlx_error_buffer, sizeof(mlx_error_buffer),
|
||||
"MLX: Successfully loaded %s", LIB_NAMES[i]);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Failed to load library
|
||||
const char* err = LIB_ERROR();
|
||||
snprintf(mlx_error_buffer, sizeof(mlx_error_buffer),
|
||||
"MLX: Failed to load libmlxc library. %s",
|
||||
err ? err : "Unknown error");
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Get the last error message
|
||||
const char* mlx_dynamic_error(void) {
|
||||
return mlx_error_buffer;
|
||||
}
|
||||
|
||||
// Check if MLX is initialized
|
||||
int mlx_dynamic_is_initialized(void) {
|
||||
return mlx_initialized;
|
||||
}
|
||||
|
||||
// Cleanup (optional, called at program exit)
|
||||
void mlx_dynamic_cleanup(void) {
|
||||
if (mlx_handle != NULL) {
|
||||
CLOSE_LIB(mlx_handle);
|
||||
mlx_handle = NULL;
|
||||
mlx_initialized = 0;
|
||||
}
|
||||
}
|
||||
26
x/ml/backend/mlx/mlx_dynamic.h
Normal file
26
x/ml/backend/mlx/mlx_dynamic.h
Normal file
@@ -0,0 +1,26 @@
|
||||
// mlx_dynamic.h - Dynamic loading interface for MLX-C library
|
||||
#ifndef MLX_DYNAMIC_H
|
||||
#define MLX_DYNAMIC_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// Initialize the MLX dynamic library
|
||||
// Returns 0 on success, -1 on failure
|
||||
int mlx_dynamic_init(void);
|
||||
|
||||
// Get the last error message from dynamic loading
|
||||
const char* mlx_dynamic_error(void);
|
||||
|
||||
// Check if MLX is initialized
|
||||
int mlx_dynamic_is_initialized(void);
|
||||
|
||||
// Cleanup resources (optional, for clean shutdown)
|
||||
void mlx_dynamic_cleanup(void);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MLX_DYNAMIC_H
|
||||
314
x/ml/backend/mlx/mlx_test.go
Normal file
314
x/ml/backend/mlx/mlx_test.go
Normal file
@@ -0,0 +1,314 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlx
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/runner/common"
|
||||
"github.com/ollama/ollama/sample"
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
"github.com/ollama/ollama/x/model"
|
||||
"github.com/ollama/ollama/x/model/input"
|
||||
_ "github.com/ollama/ollama/x/model/models/gemma3"
|
||||
)
|
||||
|
||||
func init() {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
slog.SetDefault(logger)
|
||||
}
|
||||
|
||||
func TestLoadModel(t *testing.T) {
|
||||
dir := "/Users/daniel/Models/gemma-3-4b-it/"
|
||||
b := &Backend{}
|
||||
err := b.LoadSafeTensors(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("load failed: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromInts(t *testing.T) {
|
||||
b := &Backend{}
|
||||
c := b.NewContext()
|
||||
defer c.Close()
|
||||
data := []int32{1, 2, 3, 4, 5, 6}
|
||||
a := c.FromInts(data, 2, 3)
|
||||
slog.Info("", "array", a)
|
||||
t.Log(a.ToString())
|
||||
if !reflect.DeepEqual(a.Shape(), []int{2, 3}) {
|
||||
t.Fatalf("incorrect shape: %v", a.Shape())
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromFloats(t *testing.T) {
|
||||
b := &Backend{}
|
||||
c := b.NewContext()
|
||||
defer c.Close()
|
||||
data := []float32{1, 2, 3, 4, 5, 6}
|
||||
a := c.FromFloats(data, 2, 3)
|
||||
slog.Info("", "array", a)
|
||||
t.Log(a.ToString())
|
||||
if !reflect.DeepEqual(a.Shape(), []int{2, 3}) {
|
||||
t.Fatalf("incorrect shape: %v", a.Shape())
|
||||
}
|
||||
res := a.Floats()
|
||||
if !reflect.DeepEqual(res, data) {
|
||||
t.Fatalf("incorrect results: %v", res)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdd(t *testing.T) {
|
||||
b := &Backend{}
|
||||
c := b.NewContext()
|
||||
defer c.Close()
|
||||
t1 := c.Arange(0, 24, 1, ml.DTypeFloat16)
|
||||
t2 := c.Arange(0, 24, 1, ml.DTypeFloat16)
|
||||
exp := c.Arange(0, 48, 2, ml.DTypeFloat16)
|
||||
t3 := t1.Add(c, t2)
|
||||
c.Compute(t3, exp)
|
||||
t3f := t3.Floats()
|
||||
if !reflect.DeepEqual(t3f, exp.Floats()) {
|
||||
t.Fatalf("incorrect result: %v", t3f)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReshapeTranspose(t *testing.T) {
|
||||
b := &Backend{}
|
||||
c := b.NewContext()
|
||||
defer c.Close()
|
||||
t1 := c.Arange(0, 24, 1, ml.DTypeFloat16).Reshape(c, 2, 3, 4).Transpose(c, 0, 2, 1).Contiguous(c, false)
|
||||
c.Compute(t1)
|
||||
t1f := t1.Floats()
|
||||
exp := []float32{
|
||||
0, 4, 8,
|
||||
1, 5, 9,
|
||||
2, 6, 10,
|
||||
3, 7, 11,
|
||||
12, 16, 20,
|
||||
13, 17, 21,
|
||||
14, 18, 22,
|
||||
15, 19, 23,
|
||||
}
|
||||
if !reflect.DeepEqual(t1f, exp) {
|
||||
t.Fatalf("incorrect results: %v", t1f)
|
||||
}
|
||||
}
|
||||
|
||||
func prod(vals ...int) int {
|
||||
r := 1
|
||||
for _, v := range vals {
|
||||
r *= v
|
||||
}
|
||||
return r
|
||||
}
|
||||
func TestMatmul(t *testing.T) {
|
||||
// TODO create scenarios...
|
||||
b := &Backend{}
|
||||
c := b.NewContext()
|
||||
defer c.Close()
|
||||
s1 := []int{1, 3, 2, 4}
|
||||
t1 := c.Arange(0, float32(prod(s1...)), 1, ml.DTypeFloat16).Reshape(c, s1...)
|
||||
s2 := []int{4, 2}
|
||||
t2 := c.Arange(0, float32(prod(s2...)), 1, ml.DTypeFloat16).Reshape(c, s2...)
|
||||
t3 := t1.Matmul(c, t2)
|
||||
exp := []float32{
|
||||
28, 34,
|
||||
76, 98,
|
||||
|
||||
124, 162,
|
||||
172, 226,
|
||||
|
||||
220, 290,
|
||||
268, 354,
|
||||
}
|
||||
c.Compute(t3)
|
||||
t3f := t3.Floats()
|
||||
if !reflect.DeepEqual(t3f, exp) {
|
||||
t.Fatalf("incorrect result: %v", t3f)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRows(t *testing.T) {
|
||||
b := &Backend{}
|
||||
c := b.NewContext()
|
||||
defer c.Close()
|
||||
t1 := c.Arange(0, 12, 1, ml.DTypeFloat32).Reshape(c, 1, 4, 3)
|
||||
outputs := c.Zeros(ml.DTypeInt32, 1)
|
||||
t2 := t1.TakeAxes(c, outputs, 1)
|
||||
c.Forward(t1, t2).Compute(t1, t2)
|
||||
t.Log(t1.ToString())
|
||||
t.Log(t2.ToString())
|
||||
f := t2.Floats()
|
||||
t.Logf("Result: %v", f)
|
||||
}
|
||||
|
||||
func TestCaching(t *testing.T) {
|
||||
// Validate the caching algorithm
|
||||
b := &Backend{}
|
||||
c := b.NewContext()
|
||||
defer c.Close()
|
||||
batchSize := 3
|
||||
headDim := 4
|
||||
numKVHeads := 2
|
||||
// Make cache twice the size of one test batch
|
||||
cells := batchSize * 2
|
||||
cellSize := numKVHeads * headDim
|
||||
shape := []int{1, numKVHeads, batchSize, headDim}
|
||||
stop := float32(1)
|
||||
for _, x := range shape {
|
||||
stop *= float32(x)
|
||||
}
|
||||
// Create the cache
|
||||
cache := c.Zeros(ml.DTypeFloat16, cells, cellSize)
|
||||
t.Logf("Empty Cache shape%v\n"+cache.ToString(), []int{cells, cellSize})
|
||||
|
||||
// Input tensor
|
||||
t1 := c.Arange(0, stop, 1, ml.DTypeFloat16).Reshape(c, shape...)
|
||||
t.Logf("Initial Data shape%v\n"+t1.ToString(), shape)
|
||||
|
||||
// Reshape to copy into the cache
|
||||
/*
|
||||
From MLX python/src/indexing.cpp mlx_scatter_args_array
|
||||
// The update shape must broadcast with indices.shape + [1] + src.shape[1:]
|
||||
auto up_shape = indices.shape();
|
||||
up_shape.insert(up_shape.end(), src.shape().begin() + 1, src.shape().end());
|
||||
up = broadcast_to(up, up_shape);
|
||||
up_shape.insert(up_shape.begin() + indices.ndim(), 1);
|
||||
up = reshape(up, up_shape);
|
||||
*/
|
||||
numRows := 3
|
||||
up := t1.Reshape(c, numRows, 1, cellSize) // The shape has to look like this for scatter to work properly
|
||||
t.Logf("Data reshaped for cache input shape%v\n"+up.ToString(), []int{batchSize, numKVHeads * headDim})
|
||||
|
||||
// Simulate cells 1,3,5 are available
|
||||
indicies := []ml.Tensor{c.FromInts([]int32{1, 3, 5}, numRows)}
|
||||
t.Logf("Indicies shape%v\n"+indicies[0].ToString(), []int{numRows})
|
||||
axis := []int{0} // The 1,3,5 of the indicies are in reference to axis 0 in the cache shape
|
||||
cache.Scatter(c, indicies, up, axis)
|
||||
|
||||
c.Forward(cache)
|
||||
// Cache should contain the data now
|
||||
t.Log("Cache after put\n" + cache.ToString())
|
||||
|
||||
// Retrieve cache content and verify it matches
|
||||
out := cache.TakeAxes(c, indicies[0], 0).Reshape(c, shape...)
|
||||
t.Logf("Output shape%v\n"+out.ToString(), out.Shape())
|
||||
|
||||
t1f := t1.Floats()
|
||||
outf := out.Floats()
|
||||
if !reflect.DeepEqual(t1f, outf) {
|
||||
t.Fatalf("mismatched in->out\n%v\n ->\n%v", t1f, outf)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGemma3(t *testing.T) {
|
||||
// Why is the sky blue
|
||||
inputs := []int32{2, 105, 2364, 107, 36425, 563, 506, 7217, 3730, 106, 107, 105, 4368}
|
||||
limit := 50
|
||||
|
||||
// TODO generalize this
|
||||
dir := "/Users/daniel/Models/gemma-3-4b-it/"
|
||||
|
||||
m, err := model.New(dir, ml.BackendParams{})
|
||||
if err != nil {
|
||||
t.Fatalf("unable to load model: %s", err)
|
||||
}
|
||||
b := m.Backend()
|
||||
ctx := b.NewContext()
|
||||
defer ctx.Close()
|
||||
|
||||
batch := input.Batch{
|
||||
Inputs: ctx.FromInts(inputs[:], 1, len(inputs)),
|
||||
Positions: make([]int32, len(inputs)),
|
||||
Sequences: make([]int, len(inputs)),
|
||||
Outputs: ctx.FromInts([]int32{int32(len(inputs) - 1)}, 1),
|
||||
Offset: 0,
|
||||
}
|
||||
for i := range len(inputs) {
|
||||
batch.Positions[i] = int32(i)
|
||||
}
|
||||
offset := len(inputs)
|
||||
|
||||
cache := m.Config().Cache
|
||||
if cache != nil {
|
||||
numSlots := 1
|
||||
batchSize := 512
|
||||
numCtx := 4096
|
||||
|
||||
// Note: this is inconsistent with mlx-py, but trying to be consistent with the GGML cache impl to get things working
|
||||
// cache.SetConfig(ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeBfloat16, MaskBatchPadding: 64})
|
||||
cache.SetConfig(ml.CacheConfig{CachePadding: 0, MaskDType: ml.DTypeBfloat16, MaskBatchPadding: 0})
|
||||
|
||||
cache.Init(b, ml.DTypeBfloat16, numSlots, int(numCtx), batchSize)
|
||||
err := cache.StartForward(ctx, batch, false)
|
||||
if err != nil {
|
||||
t.Fatalf("failed cache.StartForward: %s", err)
|
||||
}
|
||||
}
|
||||
opts := api.DefaultOptions()
|
||||
var grammar *sample.GrammarSampler
|
||||
sampler := sample.NewSampler(
|
||||
opts.Temperature,
|
||||
opts.TopK,
|
||||
opts.TopP,
|
||||
opts.MinP,
|
||||
opts.Seed,
|
||||
grammar,
|
||||
)
|
||||
|
||||
t.Log("Starting Forward pass loop")
|
||||
pendingResponses := []string{}
|
||||
for {
|
||||
out, err := m.Forward(ctx, batch)
|
||||
if err != nil {
|
||||
t.Fatalf("failed forward pass: %s", err)
|
||||
}
|
||||
ctx.Forward(out)
|
||||
outputs := out.Floats()
|
||||
t.Logf("finished forward pass! length:%d", len(outputs))
|
||||
// sample a token
|
||||
logits := outputs
|
||||
token, err := sampler.Sample(logits)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to sample token: %s", err)
|
||||
}
|
||||
t.Logf("Sampled token: %v", token)
|
||||
if m.(model.TextProcessor).Is(token, model.SpecialEOS) {
|
||||
t.Log("hit EOS")
|
||||
break
|
||||
}
|
||||
piece, err := m.(model.TextProcessor).Decode([]int32{token})
|
||||
if err != nil {
|
||||
t.Fatalf("unable to decode token: %s", err)
|
||||
}
|
||||
|
||||
pendingResponses = append(pendingResponses, piece)
|
||||
sequence := strings.Join(pendingResponses, "")
|
||||
if ok, stop := common.FindStop(sequence, opts.Stop); ok {
|
||||
t.Logf("hit stop token: %v", stop)
|
||||
break
|
||||
}
|
||||
t.Logf("RESULTS: %s", sequence)
|
||||
batch = input.Batch{
|
||||
Inputs: ctx.FromInts([]int32{token}, 1, 1),
|
||||
Positions: make([]int32, 1),
|
||||
Sequences: make([]int, 1),
|
||||
Outputs: ctx.FromInts([]int32{0}, 1),
|
||||
Offset: offset,
|
||||
}
|
||||
offset++
|
||||
batch.Positions[0] = 0
|
||||
err = cache.StartForward(ctx, batch, false)
|
||||
if err != nil {
|
||||
t.Fatalf("failed cache.StartForward: %s", err)
|
||||
}
|
||||
if offset > limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
335
x/ml/backend/mlx/quant.go
Normal file
335
x/ml/backend/mlx/quant.go
Normal file
@@ -0,0 +1,335 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlx
|
||||
|
||||
/*
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/ops.h"
|
||||
|
||||
// Derived from https://github.com/ml-explore/mlx/blob/main/mlx/io/gguf_quants.cpp
|
||||
|
||||
void unpack_32_4(uint8_t* data, int8_t* dst) {
|
||||
memset(dst, 0, 16);
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
uint8_t x = (data[j + 2] & 0x0F); // j+2 to skip scale bytes.
|
||||
if (j % 2 != 0) {
|
||||
x <<= 4;
|
||||
}
|
||||
dst[j / 2] += x;
|
||||
}
|
||||
// Last 16 weights are in the higher bits
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
uint8_t x = (data[j + 2] >> 4);
|
||||
if (j % 2 != 0) {
|
||||
x <<= 4;
|
||||
}
|
||||
dst[8 + j / 2] += x;
|
||||
}
|
||||
}
|
||||
|
||||
// Extracts (weight, scales, biases) from Q4_0 tensors.
|
||||
// Data layout is: |16 bit scale|32 x 4bit weights|.
|
||||
void extract_q4_0_data(
|
||||
uint8_t* data,
|
||||
mlx_array* weights_arr,
|
||||
mlx_array* scales_arr,
|
||||
mlx_array* biases_arr) {
|
||||
const uint64_t bytes_per_block = 18; // 2 bytes scale, 32x0.5 byte weights
|
||||
uint8_t* weights = mlx_array_data_uint8(*weights_arr);
|
||||
float16_t* scales = mlx_array_data_float16(*scales_arr);
|
||||
float16_t* biases = mlx_array_data_float16(*biases_arr);
|
||||
for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
|
||||
scales[i] = *((float16_t*)data);
|
||||
biases[i] = -8 * scales[i];
|
||||
unpack_32_4(data, weights);
|
||||
weights += 16;
|
||||
data += bytes_per_block;
|
||||
}
|
||||
}
|
||||
|
||||
// Extracts (weight, scales, biases) from Q4_1 tensors.
|
||||
// Data layout is: |16 bit scale|16 bit bias|32 x 4bit weights|.
|
||||
void extract_q4_1_data(
|
||||
uint8_t* data,
|
||||
mlx_array* weights_arr,
|
||||
mlx_array* scales_arr,
|
||||
mlx_array* biases_arr) {
|
||||
const uint64_t bytes_per_block = 20; // 2 bytes scale, 2 bytes bias, 32x0.5 byte weights
|
||||
uint8_t* weights = mlx_array_data_uint8(*weights_arr);
|
||||
float16_t* scales = mlx_array_data_float16(*scales_arr);
|
||||
float16_t* biases = mlx_array_data_float16(*biases_arr);
|
||||
for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
|
||||
scales[i] = *((float16_t*)data);
|
||||
biases[i] = *((float16_t*)(data) + 1);
|
||||
unpack_32_4(data, weights);
|
||||
weights += 16;
|
||||
data += bytes_per_block;
|
||||
}
|
||||
}
|
||||
|
||||
// Extracts (weight, scales, biases) from Q8_0 tensors.
|
||||
// Data layout is: |16 bit scale|32 x 8bit weights|.
|
||||
void extract_q8_0_data(
|
||||
uint8_t* data,
|
||||
mlx_array* weights_arr,
|
||||
mlx_array* scales_arr,
|
||||
mlx_array* biases_arr) {
|
||||
const uint64_t weights_per_block = 32;
|
||||
const uint64_t bytes_per_block = 34; // 2 bytes scale, 32x1 byte weights
|
||||
uint8_t* weights = mlx_array_data_uint8(*weights_arr);
|
||||
float16_t* scales = mlx_array_data_float16(*scales_arr);
|
||||
float16_t* biases = mlx_array_data_float16(*biases_arr);
|
||||
for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
|
||||
uint8_t* block_data = data + i * bytes_per_block;
|
||||
scales[i] = *((float16_t*)block_data);
|
||||
biases[i] = -128 * scales[i];
|
||||
for (int64_t j = 0; j < weights_per_block; ++j) {
|
||||
uint8_t x = block_data[j + 2]; // j+2 to skip the scale bytes.
|
||||
// Original data is in int8_t, so we add a bias of -128 and invert the
|
||||
// first bit.
|
||||
x ^= 1 << 7;
|
||||
weights[i * weights_per_block + j] = x;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Drived from ggml-quants.c
|
||||
|
||||
#define QK_K 256
|
||||
|
||||
// 6-bit quantization
|
||||
// weight is represented as x = a * q
|
||||
// 16 blocks of 16 elements each
|
||||
// Effectively 6.5625 bits per weight
|
||||
typedef struct {
|
||||
uint8_t ql[QK_K/2]; // quants, lower 4 bits
|
||||
uint8_t qh[QK_K/4]; // quants, upper 2 bits
|
||||
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
|
||||
uint16_t d; // super-block scale
|
||||
} block_q6_K;
|
||||
|
||||
void dequant_row_q6_K(const void * restrict vx, void * restrict vy, int k) {
|
||||
const int64_t nb = k / QK_K;
|
||||
block_q6_K *x = (block_q6_K *)vx;
|
||||
float16_t* y = (float16_t *)vy;
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
float16_t d = 0.0;
|
||||
memcpy(&d, &x[i].d, sizeof(d));
|
||||
|
||||
const uint8_t * restrict ql = x[i].ql;
|
||||
const uint8_t * restrict qh = x[i].qh;
|
||||
const int8_t * restrict sc = x[i].scales;
|
||||
|
||||
for (int n = 0; n < QK_K; n += 128) {
|
||||
for (int l = 0; l < 32; ++l) {
|
||||
int is = l/16;
|
||||
const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
|
||||
const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
|
||||
const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
|
||||
const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
|
||||
y[l + 0] = d * sc[is + 0] * q1;
|
||||
y[l + 32] = d * sc[is + 2] * q2;
|
||||
y[l + 64] = d * sc[is + 4] * q3;
|
||||
y[l + 96] = d * sc[is + 6] * q4;
|
||||
}
|
||||
y += 128;
|
||||
ql += 64;
|
||||
qh += 32;
|
||||
sc += 8;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define K_SCALE_SIZE 12
|
||||
#define GGML_COMMON_AGGR_U
|
||||
#define GGML_COMMON_AGGR_S
|
||||
|
||||
// 4-bit quantization
|
||||
// 8 blocks of 32 elements each
|
||||
// weight is represented as x = a * q + b
|
||||
// Effectively 4.5 bits per weight
|
||||
typedef struct {
|
||||
union {
|
||||
struct {
|
||||
uint16_t d; // super-block scale for quantized scales
|
||||
uint16_t dmin; // super-block scale for quantized mins
|
||||
} GGML_COMMON_AGGR_S;
|
||||
uint16_t dm;
|
||||
} GGML_COMMON_AGGR_U;
|
||||
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
|
||||
uint8_t qs[QK_K/2]; // 4--bit quants
|
||||
} block_q4_K;
|
||||
|
||||
static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
|
||||
if (j < 4) {
|
||||
*d = q[j] & 63; *m = q[j + 4] & 63;
|
||||
} else {
|
||||
*d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
|
||||
*m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
|
||||
}
|
||||
}
|
||||
|
||||
void dequant_row_q4_K(const void * restrict vx, void * restrict vy, int k) {
|
||||
block_q4_K *x = (block_q4_K *)vx;
|
||||
float16_t* y = (float16_t *)vy;
|
||||
const int nb = k / QK_K;
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const uint8_t * q = x[i].qs;
|
||||
float16_t d = 0.0;
|
||||
memcpy(&d, &x[i].d, sizeof(d));
|
||||
float16_t min = 0.0;
|
||||
memcpy(&min, &x[i].dmin, sizeof(d));
|
||||
|
||||
int is = 0;
|
||||
uint8_t sc, m;
|
||||
for (int j = 0; j < QK_K; j += 64) {
|
||||
get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
|
||||
const float16_t d1 = d * sc; const float16_t m1 = min * m;
|
||||
get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
|
||||
const float16_t d2 = d * sc; const float16_t m2 = min * m;
|
||||
for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
|
||||
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
|
||||
q += 32; is += 2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
|
||||
"github.com/x448/float16"
|
||||
)
|
||||
|
||||
func gguf_load_quantized(data unsafe.Pointer, name string, final_shape []C.int, dtype uint32, stream C.mlx_stream) (r C.mlx_array, err error) {
|
||||
shape := append([]C.int{}, final_shape...)
|
||||
var weights_per_byte C.int
|
||||
if dtype == 2 || dtype == 3 {
|
||||
weights_per_byte = 2
|
||||
} else if dtype == 8 {
|
||||
weights_per_byte = 1
|
||||
} else {
|
||||
return r, fmt.Errorf("unsupported tensor type %d", dtype)
|
||||
}
|
||||
|
||||
weights_per_block := C.int(32)
|
||||
if shape[len(shape)-1]%weights_per_block != 0 {
|
||||
return r, fmt.Errorf("[load_gguf] tensor has incompatible last dim shape: %d", shape[len(shape)-1])
|
||||
}
|
||||
|
||||
weights_shape := append([]C.int{}, shape...)
|
||||
weights_shape[len(weights_shape)-1] /= (weights_per_byte * 4)
|
||||
w_nbytes := C.int(unsafe.Sizeof(uint32(0)))
|
||||
for i := range weights_shape {
|
||||
w_nbytes *= weights_shape[i]
|
||||
}
|
||||
w_data := make([]byte, w_nbytes)
|
||||
cbytes := C.CBytes(w_data)
|
||||
defer C.free(cbytes)
|
||||
weights := C.mlx_array_new_data(
|
||||
cbytes,
|
||||
&weights_shape[0],
|
||||
C.int(len(weights_shape)),
|
||||
C.MLX_UINT32,
|
||||
)
|
||||
|
||||
// For scales and bias
|
||||
shape[len(shape)-1] = shape[len(shape)-1] / weights_per_block
|
||||
sb_nbytes := C.int(unsafe.Sizeof(float16.Float16(0)))
|
||||
for i := range shape {
|
||||
sb_nbytes *= shape[i]
|
||||
}
|
||||
|
||||
s_data := make([]byte, sb_nbytes)
|
||||
cbytes = C.CBytes(s_data)
|
||||
defer C.free(cbytes)
|
||||
scales := C.mlx_array_new_data(
|
||||
cbytes,
|
||||
&shape[0],
|
||||
C.int(len(shape)),
|
||||
C.MLX_FLOAT16,
|
||||
)
|
||||
b_data := make([]byte, sb_nbytes)
|
||||
cbytes = C.CBytes(b_data)
|
||||
defer C.free(cbytes)
|
||||
biases := C.mlx_array_new_data(
|
||||
cbytes,
|
||||
&shape[0],
|
||||
C.int(len(shape)),
|
||||
C.MLX_FLOAT16,
|
||||
)
|
||||
var bits C.int
|
||||
switch dtype {
|
||||
case 2:
|
||||
C.extract_q4_0_data((*C.uint8_t)(data), &weights, &scales, &biases)
|
||||
bits = 4
|
||||
case 3:
|
||||
C.extract_q4_1_data((*C.uint8_t)(data), &weights, &scales, &biases)
|
||||
bits = 4
|
||||
case 8:
|
||||
C.extract_q8_0_data((*C.uint8_t)(data), &weights, &scales, &biases)
|
||||
bits = 8
|
||||
}
|
||||
groupSize := C.mlx_optional_int{value: 32, has_value: true}
|
||||
bitsOpt := C.mlx_optional_int{value: bits, has_value: true}
|
||||
var dtypeOpt C.mlx_optional_dtype // has_value defaults to false
|
||||
C.mlx_dequantize(
|
||||
&r,
|
||||
weights,
|
||||
scales,
|
||||
biases,
|
||||
groupSize,
|
||||
bitsOpt,
|
||||
nil, // TODO mode
|
||||
dtypeOpt,
|
||||
stream,
|
||||
)
|
||||
C.mlx_array_free(weights)
|
||||
C.mlx_array_free(scales)
|
||||
C.mlx_array_free(biases)
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func load_k_quantized(data unsafe.Pointer, name string, shape []C.int, dtype uint32, stream C.mlx_stream) (r C.mlx_array, err error) {
|
||||
size := 1
|
||||
for _, d := range shape {
|
||||
size *= int(d)
|
||||
}
|
||||
fdata := make([]float16.Float16, size)
|
||||
switch dtype {
|
||||
case 14:
|
||||
C.dequant_row_q6_K(
|
||||
data,
|
||||
unsafe.Pointer(&fdata[0]),
|
||||
C.int(size),
|
||||
)
|
||||
|
||||
case 12:
|
||||
C.dequant_row_q4_K(
|
||||
data,
|
||||
unsafe.Pointer(&fdata[0]),
|
||||
C.int(size),
|
||||
)
|
||||
default:
|
||||
return r, fmt.Errorf("unsupported K quant")
|
||||
}
|
||||
|
||||
r = C.mlx_array_new_data(
|
||||
unsafe.Pointer(&fdata[0]),
|
||||
&shape[0],
|
||||
C.int(len(shape)),
|
||||
C.MLX_FLOAT16,
|
||||
)
|
||||
return r, nil
|
||||
}
|
||||
643
x/ml/device.go
Normal file
643
x/ml/device.go
Normal file
@@ -0,0 +1,643 @@
|
||||
package ml
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash/maphash"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
// GPULayers is a set of layers to be allocated on a single GPU
|
||||
type GPULayers struct {
|
||||
DeviceID
|
||||
|
||||
// Layers is a set of layer indicies to load
|
||||
Layers []int
|
||||
}
|
||||
|
||||
// FirstLayer returns the smallest layer index scheduled on this GPU, or MaxInt when empty.
|
||||
func (g GPULayers) FirstLayer() int {
|
||||
if len(g.Layers) == 0 {
|
||||
return math.MaxInt
|
||||
}
|
||||
|
||||
first := g.Layers[0]
|
||||
for i := 1; i < len(g.Layers); i++ {
|
||||
if g.Layers[i] < first {
|
||||
first = g.Layers[i]
|
||||
}
|
||||
}
|
||||
|
||||
return first
|
||||
}
|
||||
|
||||
func (g GPULayers) String() string {
|
||||
if len(g.Layers) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
slices.Sort(g.Layers)
|
||||
|
||||
contiguous := true
|
||||
base := g.Layers[0]
|
||||
for i := range g.Layers {
|
||||
if g.Layers[i] != base+i {
|
||||
contiguous = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if contiguous {
|
||||
return fmt.Sprintf("ID:%v Layers:%v(%v..%v)", g.ID, len(g.Layers), g.Layers[0], g.Layers[len(g.Layers)-1])
|
||||
} else {
|
||||
return fmt.Sprintf("ID:%v Layers:%v%v", g.ID, len(g.Layers), g.Layers)
|
||||
}
|
||||
}
|
||||
|
||||
// GPULayersList is a set of layer allocations across multiple GPUs
|
||||
type GPULayersList []GPULayers
|
||||
|
||||
func (l GPULayersList) Len() int { return len(l) }
|
||||
func (l GPULayersList) Swap(i, j int) { l[i], l[j] = l[j], l[i] }
|
||||
|
||||
// Sort by the ordering of the layers offloaded
|
||||
func (l GPULayersList) Less(i, j int) bool {
|
||||
li := l[i].FirstLayer()
|
||||
lj := l[j].FirstLayer()
|
||||
|
||||
return li < lj
|
||||
}
|
||||
|
||||
func (l GPULayersList) String() string {
|
||||
if l.Sum() > 0 {
|
||||
return fmt.Sprintf("%v%v", l.Sum(), []GPULayers(l))
|
||||
} else {
|
||||
return fmt.Sprintf("%v", []GPULayers(l))
|
||||
}
|
||||
}
|
||||
|
||||
// Sum is the total number of layers assigned across all GPUs
|
||||
func (l GPULayersList) Sum() int {
|
||||
var sum int
|
||||
|
||||
for _, g := range l {
|
||||
sum += len(g.Layers)
|
||||
}
|
||||
|
||||
return sum
|
||||
}
|
||||
|
||||
var h maphash.Hash
|
||||
|
||||
// Hash is an identifier of this layer assignment
|
||||
func (l GPULayersList) Hash() uint64 {
|
||||
h.Reset()
|
||||
for _, g := range l {
|
||||
if len(g.Layers) > 0 {
|
||||
h.WriteString(g.ID + g.Library)
|
||||
for _, l := range g.Layers {
|
||||
binary.Write(&h, binary.NativeEndian, int64(l))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return h.Sum64()
|
||||
}
|
||||
|
||||
// ErrNoMem is returned when panicing due to insufficient memory. It includes
|
||||
// the attempted memory allocation.
|
||||
type ErrNoMem struct {
|
||||
BackendMemory
|
||||
}
|
||||
|
||||
func (e ErrNoMem) Error() string {
|
||||
return fmt.Sprintf("insufficient memory - required allocations: %+v", e.BackendMemory)
|
||||
}
|
||||
|
||||
// Minimal unique device identification
|
||||
type DeviceID struct {
|
||||
// ID is an identifier for the device for matching with system
|
||||
// management libraries. The ID is only unique for other devices
|
||||
// using the same Library.
|
||||
// This ID represents a "post filtered" view of the enumerated devices
|
||||
// if the ID is numeric
|
||||
ID string `json:"id"`
|
||||
|
||||
// Library identifies which library is used for the device (e.g. CUDA, ROCm, etc.)
|
||||
Library string `json:"backend,omitempty"`
|
||||
}
|
||||
|
||||
// DeviceMemory provides a breakdown of the memory needed
|
||||
// per device, such as a CPU or GPU.
|
||||
type DeviceMemory struct {
|
||||
DeviceID
|
||||
|
||||
// Name is the name of the device as labeled by the backend. It
|
||||
// may not be persistent across instances of the runner.
|
||||
Name string
|
||||
|
||||
// Weights is the per-layer memory needed for the model weights.
|
||||
Weights []uint64
|
||||
|
||||
// Cache is the per-layer memory needed for the KV cache.
|
||||
Cache []uint64
|
||||
|
||||
// Graph is the size of the compute graph. It is not per-layer.
|
||||
Graph uint64
|
||||
}
|
||||
|
||||
func sumMemory(mem []uint64) uint64 {
|
||||
var sum uint64
|
||||
|
||||
for _, m := range mem {
|
||||
sum += m
|
||||
}
|
||||
|
||||
return sum
|
||||
}
|
||||
|
||||
// Size returns the total size of the memory required by this device
|
||||
func (m DeviceMemory) Size() uint64 {
|
||||
return sumMemory(m.Weights) + sumMemory(m.Cache) + m.Graph
|
||||
}
|
||||
|
||||
func memoryPresent(mem []uint64) bool {
|
||||
return slices.ContainsFunc(mem, func(m uint64) bool { return m != 0 })
|
||||
}
|
||||
|
||||
func (m DeviceMemory) LogValue() slog.Value {
|
||||
var attrs []slog.Attr
|
||||
if memoryPresent(m.Weights) {
|
||||
attrs = append(attrs, slog.Any("Weights", m.Weights))
|
||||
}
|
||||
|
||||
if memoryPresent(m.Cache) {
|
||||
attrs = append(attrs, slog.Any("Cache", m.Cache))
|
||||
}
|
||||
|
||||
if m.Graph != 0 {
|
||||
attrs = append(attrs, slog.Any("Graph", m.Graph))
|
||||
}
|
||||
|
||||
if len(attrs) > 0 && m.ID != "" {
|
||||
attrs = append([]slog.Attr{slog.String("ID", m.ID)}, attrs...)
|
||||
}
|
||||
|
||||
return slog.GroupValue(attrs...)
|
||||
}
|
||||
|
||||
// BackendMemory provides the amount of memory required to load the model
|
||||
// per device based on the BackendParams. In some cases, not all required
|
||||
// allocations will be known at this point. However, the size of the most recent
|
||||
// allocation is guaranteed to be provided so that if it failed, the caller can
|
||||
// accommodate that to make forward progress.
|
||||
type BackendMemory struct {
|
||||
// InputWeights are always located on the CPU and cannot be moved
|
||||
InputWeights uint64
|
||||
|
||||
// CPU model components are located in system memory. This does not
|
||||
// include unified memory allocated through the GPU.
|
||||
CPU DeviceMemory
|
||||
|
||||
// GPU model components are located on one or more GPUs.
|
||||
GPUs []DeviceMemory
|
||||
}
|
||||
|
||||
func (m BackendMemory) LogValue() slog.Value {
|
||||
var attrs []slog.Attr
|
||||
if m.InputWeights != 0 {
|
||||
attrs = append(attrs, slog.Any("InputWeights", m.InputWeights))
|
||||
}
|
||||
|
||||
attrs = append(attrs, slog.Any(m.CPU.Name, m.CPU))
|
||||
for _, g := range m.GPUs {
|
||||
attrs = append(attrs, slog.Any(g.Name, g))
|
||||
}
|
||||
|
||||
return slog.GroupValue(attrs...)
|
||||
}
|
||||
|
||||
// Log prints a high level summary of the memory
|
||||
func (m BackendMemory) Log(level slog.Level) {
|
||||
var total uint64
|
||||
|
||||
for _, gpu := range m.GPUs {
|
||||
if sum := sumMemory(gpu.Weights); sum > 0 {
|
||||
slog.Log(context.TODO(), level, "model weights", "device", gpu.Name, "size", format.HumanBytes2(sum))
|
||||
total += sum
|
||||
}
|
||||
}
|
||||
if sum := m.InputWeights + sumMemory(m.CPU.Weights); sum > 0 {
|
||||
slog.Log(context.TODO(), level, "model weights", "device", m.CPU.Name, "size", format.HumanBytes2(sum))
|
||||
total += sum
|
||||
}
|
||||
|
||||
for _, gpu := range m.GPUs {
|
||||
if sum := sumMemory(gpu.Cache); sum > 0 {
|
||||
slog.Log(context.TODO(), level, "kv cache", "device", gpu.Name, "size", format.HumanBytes2(sum))
|
||||
total += sum
|
||||
}
|
||||
}
|
||||
if sum := sumMemory(m.CPU.Cache); sum > 0 {
|
||||
slog.Log(context.TODO(), level, "kv cache", "device", m.CPU.Name, "size", format.HumanBytes2(sum))
|
||||
total += sum
|
||||
}
|
||||
|
||||
for _, gpu := range m.GPUs {
|
||||
if sum := gpu.Graph; sum > 0 {
|
||||
slog.Log(context.TODO(), level, "compute graph", "device", gpu.Name, "size", format.HumanBytes2(sum))
|
||||
total += sum
|
||||
}
|
||||
}
|
||||
if sum := m.CPU.Graph; sum > 0 {
|
||||
slog.Log(context.TODO(), level, "compute graph", "device", m.CPU.Name, "size", format.HumanBytes2(sum))
|
||||
total += sum
|
||||
}
|
||||
|
||||
if total > 0 {
|
||||
slog.Log(context.TODO(), level, "total memory", "size", format.HumanBytes2(total))
|
||||
}
|
||||
}
|
||||
|
||||
type DeviceInfo struct {
|
||||
DeviceID
|
||||
|
||||
// Name is the name of the device as labeled by the backend. It
|
||||
// may not be persistent across instances of the runner.
|
||||
Name string `json:"name"`
|
||||
|
||||
// Description is the longer user-friendly identification of the device
|
||||
Description string `json:"description"`
|
||||
|
||||
// FilterID is populated with the unfiltered device ID if a numeric ID is used
|
||||
// so the device can be included.
|
||||
FilterID string `json:"filter_id,omitempty"`
|
||||
|
||||
// Integrated is set true for integrated GPUs, false for Discrete GPUs
|
||||
Integrated bool `json:"integration,omitempty"`
|
||||
|
||||
// PCIID is the bus, device and domain ID of the device for deduplication
|
||||
// when discovered by multiple backends
|
||||
PCIID string `json:"pci_id,omitempty"`
|
||||
|
||||
// TotalMemory is the total amount of memory the device can use for loading models
|
||||
TotalMemory uint64 `json:"total_memory"`
|
||||
|
||||
// FreeMemory is the amount of memory currently available on the device for loading models
|
||||
FreeMemory uint64 `json:"free_memory,omitempty"`
|
||||
|
||||
// ComputeMajor is the major version of capabilities of the device
|
||||
// if unsupported by the backend, -1 will be returned
|
||||
ComputeMajor int
|
||||
|
||||
// ComputeMinor is the minor version of capabilities of the device
|
||||
// if unsupported by the backend, -1 will be returned
|
||||
ComputeMinor int
|
||||
|
||||
// Driver Information
|
||||
DriverMajor int `json:"driver_major,omitempty"`
|
||||
DriverMinor int `json:"driver_minor,omitempty"`
|
||||
|
||||
// Where backends were loaded from
|
||||
LibraryPath []string
|
||||
}
|
||||
|
||||
type SystemInfo struct {
|
||||
// ThreadCount is the optimal number of threads to use for inference
|
||||
ThreadCount int `json:"threads,omitempty"`
|
||||
|
||||
// TotalMemory is the total amount of system memory
|
||||
TotalMemory uint64 `json:"total_memory,omitempty"`
|
||||
|
||||
// FreeMemory is the amount of memory currently available on the system for loading models
|
||||
FreeMemory uint64 `json:"free_memory,omitempty"`
|
||||
|
||||
// FreeSwap is the amount of system swap space reported as available
|
||||
FreeSwap uint64 `json:"free_swap,omitempty"`
|
||||
}
|
||||
|
||||
func (d DeviceInfo) Compute() string {
|
||||
// AMD gfx is encoded into the major minor in hex form
|
||||
if strings.EqualFold(d.Library, "ROCm") {
|
||||
return fmt.Sprintf("gfx%x%02x", d.ComputeMajor, d.ComputeMinor)
|
||||
}
|
||||
return strconv.Itoa(d.ComputeMajor) + "." + strconv.Itoa(d.ComputeMinor)
|
||||
}
|
||||
|
||||
func (d DeviceInfo) Driver() string {
|
||||
return strconv.Itoa(d.DriverMajor) + "." + strconv.Itoa(d.DriverMinor)
|
||||
}
|
||||
|
||||
// MinimumMemory reports the amount of memory that should be set aside
|
||||
// on the device for overhead (e.g. VRAM consumed by context structures independent
|
||||
// of model allocations)
|
||||
func (d DeviceInfo) MinimumMemory() uint64 {
|
||||
if d.Library == "Metal" {
|
||||
return 512 * format.MebiByte
|
||||
}
|
||||
return 457 * format.MebiByte
|
||||
}
|
||||
|
||||
// Sort by Free Space.
|
||||
// iGPUs are reported first, thus Reverse() yields the largest discrete GPU first
|
||||
type ByFreeMemory []DeviceInfo
|
||||
|
||||
func (a ByFreeMemory) Len() int { return len(a) }
|
||||
func (a ByFreeMemory) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||
func (a ByFreeMemory) Less(i, j int) bool {
|
||||
if a[i].Integrated && !a[j].Integrated {
|
||||
return true
|
||||
} else if !a[i].Integrated && a[j].Integrated {
|
||||
return false
|
||||
}
|
||||
return a[i].FreeMemory < a[j].FreeMemory
|
||||
}
|
||||
|
||||
// ByPerformance groups devices by similar speed
|
||||
func ByPerformance(l []DeviceInfo) [][]DeviceInfo {
|
||||
resp := [][]DeviceInfo{}
|
||||
scores := []bool{}
|
||||
for _, info := range l {
|
||||
found := false
|
||||
requested := info.Integrated
|
||||
for i, score := range scores {
|
||||
if score == requested {
|
||||
resp[i] = append(resp[i], info)
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
scores = append(scores, requested)
|
||||
resp = append(resp, []DeviceInfo{info})
|
||||
}
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func ByLibrary(l []DeviceInfo) [][]DeviceInfo {
|
||||
resp := [][]DeviceInfo{}
|
||||
libs := []string{}
|
||||
for _, info := range l {
|
||||
found := false
|
||||
requested := info.Library
|
||||
for i, lib := range libs {
|
||||
if lib == requested {
|
||||
resp[i] = append(resp[i], info)
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
libs = append(libs, requested)
|
||||
resp = append(resp, []DeviceInfo{info})
|
||||
}
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func LibraryPaths(l []DeviceInfo) []string {
|
||||
gpuLibs := []string{LibOllamaPath}
|
||||
for _, gpu := range l {
|
||||
for _, dir := range gpu.LibraryPath {
|
||||
needed := true
|
||||
for _, existing := range gpuLibs {
|
||||
if dir == existing {
|
||||
needed = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if needed {
|
||||
gpuLibs = append(gpuLibs, dir)
|
||||
}
|
||||
}
|
||||
}
|
||||
return gpuLibs
|
||||
}
|
||||
|
||||
type DeviceComparison int
|
||||
|
||||
const (
|
||||
UniqueDevice DeviceComparison = iota
|
||||
SameBackendDevice // The device is the same, and the library/backend is the same
|
||||
DuplicateDevice // The same physical device but different library/backend (overlapping device)
|
||||
)
|
||||
|
||||
func (a DeviceInfo) Compare(b DeviceInfo) DeviceComparison {
|
||||
if a.PCIID != b.PCIID {
|
||||
return UniqueDevice
|
||||
}
|
||||
// If PCIID is empty, we have to use ID + library for uniqueness
|
||||
if a.PCIID == "" && a.DeviceID != b.DeviceID {
|
||||
return UniqueDevice
|
||||
}
|
||||
if a.Library == b.Library {
|
||||
return SameBackendDevice
|
||||
}
|
||||
return DuplicateDevice
|
||||
}
|
||||
|
||||
// For a SameBackendDevice, return true if b is better than a
|
||||
// e.g. newer GPU library version
|
||||
func (a DeviceInfo) IsBetter(b DeviceInfo) bool {
|
||||
aLib := a.LibraryPath[len(a.LibraryPath)-1]
|
||||
bLib := b.LibraryPath[len(b.LibraryPath)-1]
|
||||
if aLib == bLib {
|
||||
return false
|
||||
}
|
||||
aLibSplit := strings.SplitN(aLib, "_", 2)
|
||||
bLibSplit := strings.SplitN(bLib, "_", 2)
|
||||
if len(aLibSplit) < 2 || len(bLibSplit) < 2 {
|
||||
return false
|
||||
}
|
||||
if aLibSplit[0] != bLibSplit[0] {
|
||||
slog.Debug("unexpected libraries", "a", aLib, "b", bLib)
|
||||
return false
|
||||
}
|
||||
if aLibSplit[1] == bLibSplit[1] {
|
||||
return false
|
||||
}
|
||||
cmp := []string{aLibSplit[1], bLibSplit[1]}
|
||||
sort.Sort(sort.Reverse(sort.StringSlice(cmp)))
|
||||
return cmp[0] == bLibSplit[1]
|
||||
}
|
||||
|
||||
// For each GPU, check if it does NOT support flash attention
|
||||
func FlashAttentionSupported(l []DeviceInfo) bool {
|
||||
for _, gpu := range l {
|
||||
supportsFA := gpu.Library == "cpu" ||
|
||||
gpu.Name == "Metal" || gpu.Library == "Metal" ||
|
||||
(gpu.Library == "CUDA" && gpu.DriverMajor >= 7 && !(gpu.ComputeMajor == 7 && gpu.ComputeMinor == 2)) ||
|
||||
gpu.Library == "ROCm" ||
|
||||
gpu.Library == "Vulkan"
|
||||
|
||||
if !supportsFA {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Given the list of GPUs this instantiation is targeted for,
|
||||
// figure out the visible devices environment variables
|
||||
// Set mustFilter true to enable filtering of CUDA devices
|
||||
func GetVisibleDevicesEnv(l []DeviceInfo, mustFilter bool) map[string]string {
|
||||
if len(l) == 0 {
|
||||
return nil
|
||||
}
|
||||
env := map[string]string{}
|
||||
for _, d := range l {
|
||||
d.updateVisibleDevicesEnv(env, mustFilter)
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
// NeedsInitValidation returns true if the device in question has the potential
|
||||
// to crash at inference time and requires deeper validation before we include
|
||||
// it in the supported devices list.
|
||||
func (d DeviceInfo) NeedsInitValidation() bool {
|
||||
// ROCm: rocblas will crash on unsupported devices.
|
||||
// CUDA: verify CC is supported by the version of the library
|
||||
return d.Library == "ROCm" || d.Library == "CUDA"
|
||||
}
|
||||
|
||||
// Set the init validation environment variable
|
||||
func (d DeviceInfo) AddInitValidation(env map[string]string) {
|
||||
env["GGML_CUDA_INIT"] = "1" // force deep initialization to trigger crash on unsupported GPUs
|
||||
}
|
||||
|
||||
// PreferredLibrary returns true if this library is preferred over the other input
|
||||
// library
|
||||
// Used to filter out Vulkan in favor of CUDA or ROCm
|
||||
func (d DeviceInfo) PreferredLibrary(other DeviceInfo) bool {
|
||||
// TODO in the future if we find Vulkan is better than ROCm on some devices
|
||||
// that implementation can live here.
|
||||
|
||||
if d.Library == "CUDA" || d.Library == "ROCm" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (d DeviceInfo) updateVisibleDevicesEnv(env map[string]string, mustFilter bool) {
|
||||
var envVar string
|
||||
switch d.Library {
|
||||
case "ROCm":
|
||||
// ROCm must be filtered as it can crash the runner on unsupported devices
|
||||
envVar = "ROCR_VISIBLE_DEVICES"
|
||||
if runtime.GOOS != "linux" {
|
||||
envVar = "HIP_VISIBLE_DEVICES"
|
||||
}
|
||||
case "CUDA":
|
||||
if !mustFilter {
|
||||
// By default we try to avoid filtering CUDA devices because ROCm also
|
||||
// looks at the CUDA env var, and gets confused in mixed vendor environments.
|
||||
return
|
||||
}
|
||||
envVar = "CUDA_VISIBLE_DEVICES"
|
||||
default:
|
||||
// Vulkan is not filtered via env var, but via scheduling decisions
|
||||
return
|
||||
}
|
||||
v, existing := env[envVar]
|
||||
if existing {
|
||||
v = v + ","
|
||||
}
|
||||
if d.FilterID != "" {
|
||||
v = v + d.FilterID
|
||||
} else {
|
||||
v = v + d.ID
|
||||
}
|
||||
env[envVar] = v
|
||||
}
|
||||
|
||||
type BaseRunner interface {
|
||||
// GetPort returns the localhost port number the runner is running on
|
||||
GetPort() int
|
||||
|
||||
// HasExited indicates if the runner is no longer running. This can be used during
|
||||
// bootstrap to detect if a given filtered device is incompatible and triggered an assert
|
||||
HasExited() bool
|
||||
}
|
||||
|
||||
type RunnerDiscovery interface {
|
||||
BaseRunner
|
||||
|
||||
// GetDeviceInfos will perform a query of the underlying device libraries
|
||||
// for device identification and free VRAM information
|
||||
// During bootstrap scenarios, this routine may take seconds to complete
|
||||
GetDeviceInfos(ctx context.Context) []DeviceInfo
|
||||
}
|
||||
|
||||
type FilteredRunnerDiscovery interface {
|
||||
RunnerDiscovery
|
||||
|
||||
// GetActiveDeviceIDs returns the filtered set of devices actively in
|
||||
// use by this runner for running models. If the runner is a bootstrap runner, no devices
|
||||
// will be active yet so no device IDs are returned.
|
||||
// This routine will not query the underlying device and will return immediately
|
||||
GetActiveDeviceIDs() []DeviceID
|
||||
}
|
||||
|
||||
func GetDevicesFromRunner(ctx context.Context, runner BaseRunner) ([]DeviceInfo, error) {
|
||||
var moreDevices []DeviceInfo
|
||||
port := runner.GetPort()
|
||||
tick := time.Tick(10 * time.Millisecond)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("failed to finish discovery before timeout")
|
||||
case <-tick:
|
||||
r, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/info", port), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(r)
|
||||
if err != nil {
|
||||
// slog.Warn("failed to send request", "error", err)
|
||||
if runner.HasExited() {
|
||||
return nil, fmt.Errorf("runner crashed")
|
||||
}
|
||||
continue
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
// old runner, fall back to bootstrapping model
|
||||
return nil, fmt.Errorf("llamarunner free vram reporting not supported")
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
slog.Warn("failed to read response", "error", err)
|
||||
continue
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
logutil.Trace("runner failed to discover free VRAM", "status", resp.StatusCode, "response", body)
|
||||
return nil, fmt.Errorf("runner error: %s", string(body))
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &moreDevices); err != nil {
|
||||
slog.Warn("unmarshal encode response", "error", err)
|
||||
continue
|
||||
}
|
||||
return moreDevices, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
103
x/ml/nn/attention.go
Normal file
103
x/ml/nn/attention.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package nn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ollama/ollama/x/kvcache"
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
)
|
||||
|
||||
// Attention implements scaled dot-product attention for transformer models:
|
||||
// Attention(Q, K, V) = softmax(QK^T/√d_k)V
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for tensor operations
|
||||
// - query: Query tensor (Q) with shape [d_k, heads, seq_len_q]
|
||||
// - key: Key tensor (K) with shape [d_k, kv_heads, seq_len_k], can be nil to read from cache only
|
||||
// - value: Value tensor (V) with shape [d_v, kv_heads, seq_len_k], can be nil to read from cache only
|
||||
// - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
|
||||
// - cache: KV cache to store key/value and get past history, can be nil to only use provided key/value
|
||||
//
|
||||
// Returns:
|
||||
//
|
||||
// Attention output with shape [d_v, heads, seq_len_q]
|
||||
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||
return AttentionWithVMLA(ctx, query, key, value, nil, nil, scale, cache)
|
||||
}
|
||||
|
||||
func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||
return AttentionWithVMLA(ctx, query, key, value, sinks, nil, scale, cache)
|
||||
}
|
||||
|
||||
func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||
ctx.Forward(query)
|
||||
|
||||
if key != nil && value != nil {
|
||||
if query.Dim(0) != key.Dim(0) {
|
||||
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
|
||||
}
|
||||
|
||||
if key.Dim(1) != value.Dim(1) {
|
||||
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1)))
|
||||
}
|
||||
|
||||
if key.Dim(2) != value.Dim(2) {
|
||||
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
|
||||
}
|
||||
|
||||
ctx.Forward(key, value)
|
||||
if cache != nil {
|
||||
cache.Put(ctx, key, value)
|
||||
}
|
||||
} else if cache == nil {
|
||||
panic("key & value tensors must be provided if cache is nil")
|
||||
}
|
||||
|
||||
// ctx.CompareWith("/tmp/test", map[string]ml.Tensor{"q": query, "k": key, "v": value}, true)
|
||||
// panic("after cache get") //
|
||||
// 2025/12/10 16:02:33 INFO XXX tensors are similar q=0.9999869465827942 shape="[1 8 13 256]" min_difference=[-0.07926178] max_difference=[0.07012844]
|
||||
// 2025/12/10 16:02:33 INFO XXX tensors are similar k=0.9999891519546509 shape="[1 4 13 256]" min_difference=[-0.21365738] max_difference=[0.19916534]
|
||||
// 2025/12/10 16:02:33 INFO XXX tensors are similar v=0.9999960660934448 shape="[1 4 13 256]" min_difference=[-0.32923126] max_difference=[0.32646942]
|
||||
|
||||
// var mask ml.Tensor
|
||||
if cache != nil {
|
||||
key, value, _ = cache.Get(ctx)
|
||||
}
|
||||
// ctx.CompareWith("/tmp/test", map[string]ml.Tensor{"q": query.Contiguous(ctx, false), "k": key.Contiguous(ctx, false), "v": value.Contiguous(ctx, false)}, true)
|
||||
// panic("after cache get") //
|
||||
// 2025/12/10 15:34:03 INFO XXX tensors are similar q=0.9999869465827942 shape="[1 8 13 256]" min_difference=[-0.07926178] max_difference=[0.07012844]
|
||||
// 2025/12/10 15:34:03 INFO XXX tensors are similar k=0.9999881982803345 shape="[1 4 13 256]" min_difference=[-0.25] max_difference=[0.25]
|
||||
// 2025/12/10 15:34:03 INFO XXX tensors are similar v=0.9999913573265076 shape="[1 4 13 256]" min_difference=[-0.5] max_difference=[0.5]
|
||||
|
||||
// Only use the fast SDPA implementation if we have a cache, since that's what
|
||||
// will do any expected backend-specific transformations for us
|
||||
|
||||
if cache != nil {
|
||||
// TODO what to do with vmla?
|
||||
// return query.Transpose(ctx, 0, 2, 1, 3).ScaledDotProductAttention(ctx, key.Transpose(ctx, 0, 2, 1, 3), value.Transpose(ctx, 0, 2, 1, 3), scale, "array", mask, sinks)
|
||||
return query.ScaledDotProductAttention(ctx, key, value, scale, "causal", nil, sinks)
|
||||
|
||||
// TODO these two produce identical output, but not similar enough - 92.9% - should be 99.999%
|
||||
} else {
|
||||
panic("else case not supported")
|
||||
// TODO transpose shapes are wrong
|
||||
// key = key.Transpose(ctx, 0, 2, 1, 3)
|
||||
// value = value.Transpose(ctx, 1, 2, 0, 3).Contiguous(ctx, false)
|
||||
|
||||
// kq := query.Matmul(ctx, key)
|
||||
|
||||
// kq = kq.Scale(ctx, scale)
|
||||
// if mask != nil {
|
||||
// kq = kq.Add(ctx, mask)
|
||||
// }
|
||||
// kq = kq.Softmax(ctx)
|
||||
|
||||
// kqv := kq.Matmul(ctx, value)
|
||||
|
||||
// if vmla != nil {
|
||||
// kqv = kqv.Matmul(ctx, vmla)
|
||||
// }
|
||||
|
||||
// return kqv.Transpose(ctx, 0, 2, 1, 3).Contiguous(ctx, false)
|
||||
}
|
||||
}
|
||||
30
x/ml/nn/convolution.go
Normal file
30
x/ml/nn/convolution.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package nn
|
||||
|
||||
import "github.com/ollama/ollama/x/ml"
|
||||
|
||||
type Conv2D struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Bias ml.Tensor `gguf:"bias"`
|
||||
}
|
||||
|
||||
func (m *Conv2D) Forward(ctx ml.Context, t ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
|
||||
t = m.Weight.Conv2D(ctx, t, s0, s1, p0, p1, d0, d1, 1)
|
||||
if m.Bias != nil {
|
||||
// Bias shape is (out_channels,) while t shape is (width, height, out_channels, batch)
|
||||
t = t.Add(ctx, m.Bias.Reshape(ctx, 1, 1, -1))
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
type Conv3D struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Bias ml.Tensor `gguf:"bias"`
|
||||
}
|
||||
|
||||
func (m *Conv3D) Forward(ctx ml.Context, t ml.Tensor, s0, s1, s2, p0, p1, p2, d0, d1, d2, g int) ml.Tensor {
|
||||
t = m.Weight.Conv3D(ctx, t, s0, s1, s2, p0, p1, p2, d0, d1, d2, g)
|
||||
if m.Bias != nil {
|
||||
t = t.Add(ctx, m.Bias)
|
||||
}
|
||||
return t
|
||||
}
|
||||
11
x/ml/nn/embedding.go
Normal file
11
x/ml/nn/embedding.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package nn
|
||||
|
||||
import "github.com/ollama/ollama/x/ml"
|
||||
|
||||
type Embedding struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
}
|
||||
|
||||
func (m *Embedding) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor {
|
||||
return m.Weight.TakeAxes(ctx, hiddenState, 0)
|
||||
}
|
||||
32
x/ml/nn/linear.go
Normal file
32
x/ml/nn/linear.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package nn
|
||||
|
||||
import "github.com/ollama/ollama/x/ml"
|
||||
|
||||
type Linear struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Bias ml.Tensor `gguf:"bias"`
|
||||
}
|
||||
|
||||
func (m *Linear) Forward(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
||||
t = t.Matmul(ctx, m.Weight.Transpose(ctx))
|
||||
if m.Bias != nil {
|
||||
t = t.Add(ctx, m.Bias)
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
type LinearBatch struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Bias ml.Tensor `gguf:"bias"`
|
||||
}
|
||||
|
||||
func (m *LinearBatch) Forward(ctx ml.Context, t, indices ml.Tensor) ml.Tensor {
|
||||
panic("not yet ported")
|
||||
// t = m.Weight.MulmatID(ctx, t, indices)
|
||||
// if m.Bias != nil {
|
||||
// t = t.AddID(ctx, m.Bias, indices)
|
||||
// }
|
||||
|
||||
// return t
|
||||
}
|
||||
29
x/ml/nn/normalization.go
Normal file
29
x/ml/nn/normalization.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package nn
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
)
|
||||
|
||||
type LayerNorm struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Bias ml.Tensor `gguf:"bias"`
|
||||
}
|
||||
|
||||
func (m *LayerNorm) Forward(ctx ml.Context, t ml.Tensor, eps float32) ml.Tensor {
|
||||
return t.LayerNorm(ctx, m.Weight, m.Bias, eps)
|
||||
}
|
||||
|
||||
type RMSNorm struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
}
|
||||
|
||||
func (m *RMSNorm) Forward(ctx ml.Context, t ml.Tensor, eps float32) ml.Tensor {
|
||||
// slog.Info("RMSNorm", "eps", eps)
|
||||
// fmt.Fprintln(os.Stderr, t.ToString())
|
||||
// fmt.Fprintln(os.Stderr, m.Weight.ToString())
|
||||
|
||||
// TODO this is probably model specific, not generalized...
|
||||
w := m.Weight.Add(ctx, ctx.FromFloats([]float32{1.0}, 1))
|
||||
|
||||
return t.RMSNorm(ctx, w, eps)
|
||||
}
|
||||
41
x/ml/nn/pooling/pooling.go
Normal file
41
x/ml/nn/pooling/pooling.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package pooling
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
)
|
||||
|
||||
type Type uint32
|
||||
|
||||
const (
|
||||
TypeNone Type = iota
|
||||
TypeMean
|
||||
TypeCLS
|
||||
TypeLast
|
||||
)
|
||||
|
||||
func (t Type) String() string {
|
||||
switch t {
|
||||
case TypeMean:
|
||||
return "Mean"
|
||||
case TypeCLS:
|
||||
return "CLS"
|
||||
case TypeLast:
|
||||
return "Last"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
|
||||
switch t {
|
||||
// case TypeMean:
|
||||
// hiddenStates = hiddenStates.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false).Mean(ctx)
|
||||
// return hiddenStates.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false)
|
||||
// case TypeCLS:
|
||||
// return hiddenStates.Slice(ctx, 1, 0, 1, 1)
|
||||
// case TypeLast:
|
||||
// return hiddenStates.Slice(ctx, 1, hiddenStates.Dim(1)-1, hiddenStates.Dim(1), 1)
|
||||
default:
|
||||
panic("unknown pooling type")
|
||||
}
|
||||
}
|
||||
72
x/ml/nn/rope/rope.go
Normal file
72
x/ml/nn/rope/rope.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package rope
|
||||
|
||||
import "github.com/ollama/ollama/x/ml"
|
||||
|
||||
// Options contains optional parameters for RoPE function
|
||||
type Options struct {
|
||||
Type int
|
||||
Factors ml.Tensor
|
||||
|
||||
// YaRN options
|
||||
YaRN struct {
|
||||
OriginalContextLength int
|
||||
ExtrapolationFactor,
|
||||
AttentionFactor,
|
||||
BetaFast,
|
||||
BetaSlow float32
|
||||
}
|
||||
|
||||
// MRoPE options
|
||||
MRoPE struct {
|
||||
Sections []int
|
||||
}
|
||||
}
|
||||
|
||||
// WithTypeNeoX sets RoPE type to NeoX
|
||||
func WithTypeNeoX() func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.Type = 2
|
||||
}
|
||||
}
|
||||
|
||||
// WithFactors sets custom rope factors
|
||||
func WithFactors(factors ml.Tensor) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
if factors != nil {
|
||||
opts.Factors = factors
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithOriginalContextLength sets a custom context length
|
||||
func WithOriginalContextLength(n int) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.YaRN.OriginalContextLength = n
|
||||
}
|
||||
}
|
||||
|
||||
func WithExtrapolationFactor(extrapolationFactor float32) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.YaRN.ExtrapolationFactor = extrapolationFactor
|
||||
}
|
||||
}
|
||||
|
||||
func WithAttentionFactor(attentionFactor float32) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.YaRN.AttentionFactor = attentionFactor
|
||||
}
|
||||
}
|
||||
|
||||
func WithMRoPE(sections []int) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.Type |= 1 << 3
|
||||
opts.MRoPE.Sections = sections
|
||||
}
|
||||
}
|
||||
|
||||
func WithInterleaveMRoPE(sections []int) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.Type |= 1<<3 | 1<<5
|
||||
opts.MRoPE.Sections = sections
|
||||
}
|
||||
}
|
||||
56
x/ml/path.go
Normal file
56
x/ml/path.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package ml
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// LibPath is a path to lookup dynamic libraries
|
||||
// in development it's usually 'build/lib/ollama'
|
||||
// in distribution builds it's 'lib/ollama' on Windows
|
||||
// '../lib/ollama' on Linux and the executable's directory on macOS
|
||||
// note: distribution builds, additional GPU-specific libraries are
|
||||
// found in subdirectories of the returned path, such as
|
||||
// 'cuda_v12', 'rocm', etc.
|
||||
var LibOllamaPath string = func() string {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if eval, err := filepath.EvalSymlinks(exe); err == nil {
|
||||
exe = eval
|
||||
}
|
||||
|
||||
var libPath string
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
libPath = filepath.Join(filepath.Dir(exe), "lib", "ollama")
|
||||
case "linux":
|
||||
libPath = filepath.Join(filepath.Dir(exe), "..", "lib", "ollama")
|
||||
case "darwin":
|
||||
libPath = filepath.Dir(exe)
|
||||
}
|
||||
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
paths := []string{
|
||||
libPath,
|
||||
|
||||
// build paths for development
|
||||
filepath.Join(filepath.Dir(exe), "build", "lib", "ollama"),
|
||||
filepath.Join(cwd, "build", "lib", "ollama"),
|
||||
}
|
||||
|
||||
for _, p := range paths {
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return p
|
||||
}
|
||||
}
|
||||
|
||||
return filepath.Dir(exe)
|
||||
}()
|
||||
@@ -1,4 +1,4 @@
|
||||
package tokenizer
|
||||
package model
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
@@ -18,19 +18,19 @@ type BytePairEncoding struct {
|
||||
regexps []*regexp2.Regexp
|
||||
}
|
||||
|
||||
var _ Tokenizer = (*BytePairEncoding)(nil)
|
||||
var _ TextProcessor = (*BytePairEncoding)(nil)
|
||||
|
||||
func NewBytePairEncoding(vocab *Vocabulary, pretokenizer ...string) BytePairEncoding {
|
||||
if len(pretokenizer) == 0 {
|
||||
func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding {
|
||||
if len(pretokenizers) == 0 {
|
||||
// set default byte-level pretokenizer if none provided, e.g.
|
||||
// https://github.com/huggingface/tokenizer/blob/main/tokenizer/src/pre_tokenizer/byte_level.rs#L44
|
||||
pretokenizer = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
|
||||
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44
|
||||
pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
|
||||
}
|
||||
|
||||
return BytePairEncoding{
|
||||
vocab: vocab,
|
||||
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
|
||||
for _, p := range pretokenizer {
|
||||
for _, p := range pretokenizers {
|
||||
if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
|
||||
return
|
||||
}
|
||||
322
x/model/bytepairencoding_test.go
Normal file
322
x/model/bytepairencoding_test.go
Normal file
@@ -0,0 +1,322 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func llama(t testing.TB) BytePairEncoding {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.Open(filepath.Join("..", "..", "model", "testdata", "llama3.2", "encoder.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
vocab := make(map[string]int32)
|
||||
if err := json.NewDecoder(f).Decode(&vocab); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
types := make([]int32, len(vocab))
|
||||
tokens := make([]string, len(vocab))
|
||||
for token, id := range vocab {
|
||||
tokens[id] = token
|
||||
types[id] = 1
|
||||
}
|
||||
|
||||
for _, token := range []string{"<|begin_of_text|>", "<|end_of_text|>"} {
|
||||
if _, ok := vocab[token]; !ok {
|
||||
tokens = append(tokens, token) //nolint:makezero
|
||||
types = append(types, 3) //nolint:makezero
|
||||
vocab[token] = int32(len(vocab))
|
||||
}
|
||||
}
|
||||
|
||||
f, err = os.Open(filepath.Join("..", "..", "model", "testdata", "llama3.2", "vocab.bpe"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
merges := make([]string, 0, 50000)
|
||||
|
||||
scanner := bufio.NewScanner(f)
|
||||
for scanner.Scan() {
|
||||
if !strings.HasPrefix(scanner.Text(), "#") {
|
||||
merges = append(merges, scanner.Text())
|
||||
}
|
||||
}
|
||||
|
||||
return NewBytePairEncoding(
|
||||
&Vocabulary{
|
||||
Values: tokens,
|
||||
Types: types,
|
||||
Merges: merges,
|
||||
},
|
||||
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
)
|
||||
}
|
||||
|
||||
func TestLlama(t *testing.T) {
|
||||
tokenizer := llama(t)
|
||||
|
||||
t.Run("simple", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ids, err := tokenizer.Encode("hello world", true)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff([]int32{15339, 1917}, ids); diff != "" {
|
||||
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||||
}
|
||||
|
||||
s, err := tokenizer.Decode([]int32{15339, 1917})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if s != "hello world" {
|
||||
t.Errorf("got %q, want hello world", s)
|
||||
}
|
||||
|
||||
ids, err = tokenizer.Encode("hello <|end_of_text|>", true)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff([]int32{15339, 220, 128001}, ids); diff != "" {
|
||||
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("simple repeated", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := map[string][]int32{
|
||||
strings.Repeat("0", 1): {15},
|
||||
strings.Repeat("0", 2): {410},
|
||||
strings.Repeat("0", 3): {931},
|
||||
strings.Repeat("0", 4): {931, 15},
|
||||
strings.Repeat("0", 5): {931, 410},
|
||||
strings.Repeat("0", 6): {931, 931},
|
||||
strings.Repeat("0", 7): {931, 931, 15},
|
||||
strings.Repeat("0", 8): {931, 931, 410},
|
||||
strings.Repeat("0", 9): {931, 931, 931},
|
||||
strings.Repeat("0", 10): {931, 931, 931, 15},
|
||||
strings.Repeat("0", 11): {931, 931, 931, 410},
|
||||
strings.Repeat("0", 12): {931, 931, 931, 931},
|
||||
strings.Repeat("0", 13): {931, 931, 931, 931, 15},
|
||||
strings.Repeat("0", 14): {931, 931, 931, 931, 410},
|
||||
strings.Repeat("0", 15): {931, 931, 931, 931, 931},
|
||||
strings.Repeat("0", 16): {931, 931, 931, 931, 931, 15},
|
||||
strings.Repeat("0", 17): {931, 931, 931, 931, 931, 410},
|
||||
}
|
||||
|
||||
for s, want := range cases {
|
||||
ids, err := tokenizer.Encode(s, true)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(want, ids); diff != "" {
|
||||
t.Errorf("%q no match (-theirs +ours):\n%s", s, diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("basic roundtrip", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []string{
|
||||
"hello",
|
||||
"hello ",
|
||||
"hello ",
|
||||
" hello",
|
||||
" hello ",
|
||||
" hello ",
|
||||
"hello world",
|
||||
"请考试我的软件!12345",
|
||||
}
|
||||
|
||||
for _, want := range cases {
|
||||
ids, err := tokenizer.Encode(want, true)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if got, err := tokenizer.Decode(ids); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("special", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := map[string][]int32{
|
||||
"<|begin_of_text|>A B!": {128000, 32, 426, 0},
|
||||
"<|begin_of_text|>A<|end_of_text|>B!": {128000, 32, 128001, 33, 0},
|
||||
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!": {128000, 32, 128001, 33, 128000, 0},
|
||||
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!<|end_of_text|>": {128000, 32, 128001, 33, 128000, 0, 128001},
|
||||
}
|
||||
|
||||
for s, want := range cases {
|
||||
ids, err := tokenizer.Encode(s, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(want, ids); diff != "" {
|
||||
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("split", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := map[string][]string{
|
||||
"Hello World!": {"Hello", " World", "!"},
|
||||
"I'm don't won't": {"I", "'m", " don", "'t", " won", "'t"},
|
||||
"In 2024 there are 366 days": {"In", " ", "202", "4", " there", " are", " ", "366", " days"},
|
||||
"Hello!! ...world": {"Hello", "!!", " ...", "world"},
|
||||
"Hello World": {"Hello", " ", " World"},
|
||||
"Hello\nWorld": {"Hello", "\n", "World"},
|
||||
"Hello, WORLD!! How's it going?": {"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?"},
|
||||
}
|
||||
|
||||
for s, want := range cases {
|
||||
got := slices.Collect(tokenizer.split(s))
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("roundtriping 0x00-0xFF", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for b := 0x00; b <= 0xFF; b++ {
|
||||
input := string(rune(b))
|
||||
ids, err := tokenizer.Encode(input, false)
|
||||
if err != nil {
|
||||
t.Errorf("failed to encode rune 0x%02X: %v", b, err)
|
||||
continue
|
||||
}
|
||||
|
||||
decoded, err := tokenizer.Decode(ids)
|
||||
if err != nil {
|
||||
t.Errorf("failed to decode rune 0x%02X: %v", b, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if b == 0x00 {
|
||||
if len(decoded) != 0 {
|
||||
t.Errorf("Decode(Encode(0x00)) should be empty, got %v", ids)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if decoded != input {
|
||||
t.Errorf("rune 0x%02X failed roundtrip: got %q, want %q", b, decoded, input)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkBytePairEncoding(b *testing.B) {
|
||||
tokenizer := llama(b)
|
||||
bts, err := os.ReadFile(filepath.Join("testdata", "war-and-peace.txt"))
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
for i := range 8 {
|
||||
n := min(int(math.Pow10(i)), len(bts))
|
||||
bts := bts[:n]
|
||||
b.Run("encode"+strconv.Itoa(n), func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
_, err := tokenizer.Encode(string(bts), true)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("decode"+strconv.Itoa(n), func(b *testing.B) {
|
||||
ids, err := tokenizer.Encode(string(bts), true)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
_, err := tokenizer.Decode(ids)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("split"+strconv.Itoa(n), func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
slices.Collect(tokenizer.split(string(bts)))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplit(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
patterns,
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
name: "default",
|
||||
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " 123", " 一二三"},
|
||||
},
|
||||
{
|
||||
name: "unicode",
|
||||
patterns: []string{
|
||||
"\\p{N}{1,3}",
|
||||
`[一-龥-ゟ゠-ヿ]+`,
|
||||
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
|
||||
},
|
||||
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "123", " ", "一二三"},
|
||||
},
|
||||
{
|
||||
name: "individual digits",
|
||||
patterns: []string{
|
||||
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
},
|
||||
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "1", "2", "3", " 一二三"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tokenizer := NewBytePairEncoding(nil, tt.patterns...)
|
||||
if diff := cmp.Diff(tt.want, slices.Collect(tokenizer.split("Hello, WORLD!! How's it going? 123 一二三"))); diff != "" {
|
||||
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
76
x/model/input/input.go
Normal file
76
x/model/input/input.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package input
|
||||
|
||||
import "github.com/ollama/ollama/x/ml"
|
||||
|
||||
// Multimodal is a multimodal embedding or a component of one.
|
||||
// For example, it could be a row of an image that can be processed
|
||||
// independently.
|
||||
type Multimodal struct {
|
||||
// Tensor is the embedding data. Implementations may chose what to
|
||||
// store here or it may be nil if not needed. However, any ml.Tensor
|
||||
// objects must be stored here and not in Data.
|
||||
Tensor ml.Tensor
|
||||
|
||||
// Data is implementation-specific opaque data, such as metadata on how
|
||||
// to layout Tensor. It may be nil if not needed. It may also store larger
|
||||
// objects such as complete images if they are to be processed later.
|
||||
Data any
|
||||
}
|
||||
|
||||
// Input represents one token in the input stream
|
||||
type Input struct {
|
||||
// Token is a single element of text.
|
||||
Token int32
|
||||
|
||||
// Multimodal is represents a non-text element such as an
|
||||
// image (or part of one if the image can be processed in pieces).
|
||||
// It may be used either together with Token or on its own.
|
||||
Multimodal []Multimodal
|
||||
|
||||
// MultimodalHash is a unique representation of the data
|
||||
// stored in Multimodal, used for caching and comparing
|
||||
// equality.
|
||||
MultimodalHash uint64
|
||||
|
||||
// SameBatch forces the following number of tokens to be processed
|
||||
// in a single batch, breaking and extending batches as needed.
|
||||
// Useful for things like images that must be processed in one
|
||||
// shot.
|
||||
SameBatch int
|
||||
}
|
||||
|
||||
// MultimodalIndex is a multimodal element (such as an image)
|
||||
// together with an index into the slice of Inputs with the
|
||||
// corresponding token. Note that the index is not the same
|
||||
// as the position - to find that use the index with the
|
||||
// Positions slice.
|
||||
type MultimodalIndex struct {
|
||||
Index int
|
||||
Multimodal []Multimodal
|
||||
}
|
||||
|
||||
// Batch contains the inputs for a model forward pass
|
||||
type Batch struct {
|
||||
// Inputs is the input tokens, including placeholders for multimodal inputs.
|
||||
Inputs ml.Tensor
|
||||
|
||||
// Outputs are the set of indicies into Inputs for which output data should
|
||||
// be returned.
|
||||
Outputs ml.Tensor
|
||||
|
||||
// TODO maybe not the optimal way to handle this
|
||||
// Offset of final tensor in the final batch
|
||||
Offset int
|
||||
|
||||
// Positions is the position for each Input, relative to its sequence. Equal
|
||||
// in length to Inputs.
|
||||
Positions []int32
|
||||
|
||||
// Sequences is the sequence for each Input. Equal in length to Inputs.
|
||||
Sequences []int
|
||||
|
||||
// Multimodal is a set of multimodal embeddings previously created by
|
||||
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
|
||||
// models or for batches without multimodal elements.
|
||||
Multimodal []MultimodalIndex
|
||||
}
|
||||
333
x/model/model.go
Normal file
333
x/model/model.go
Normal file
@@ -0,0 +1,333 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"log/slog"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
_ "golang.org/x/image/bmp"
|
||||
_ "golang.org/x/image/tiff"
|
||||
_ "golang.org/x/image/webp"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
fsggml "github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/x/kvcache"
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
_ "github.com/ollama/ollama/x/ml/backend"
|
||||
"github.com/ollama/ollama/x/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/x/model/input"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNoVisionModel = errors.New("this model is missing data required for image input")
|
||||
ErrUnsupportedModel = errors.New("model not supported")
|
||||
ErrUnsupportedTokenizer = errors.New("tokenizer not supported")
|
||||
)
|
||||
|
||||
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
|
||||
type Model interface {
|
||||
Forward(ml.Context, input.Batch) (ml.Tensor, error)
|
||||
|
||||
Backend() ml.Backend
|
||||
Config() config
|
||||
}
|
||||
|
||||
// MultimodalProcessor must be implemented by multimodal models.
|
||||
type MultimodalProcessor interface {
|
||||
// EncodeMultimodal processes a single input (such as an image) and
|
||||
// generates an output (typically an embedding) that can be used by the model.
|
||||
//
|
||||
// The return value is one or more tensors, each with optional model-specific
|
||||
// opaque metadata. Typically, the tensors might be views into an embedding
|
||||
// with each view representing a chunk of data that can be processed independently
|
||||
// in different batches.
|
||||
//
|
||||
// The result may be cached by the runner.
|
||||
EncodeMultimodal(ml.Context, []byte) ([]input.Multimodal, error)
|
||||
|
||||
// PostTokenize is called after tokenization to allow the model to edit the
|
||||
// input stream to correctly arrange multimodal elements.
|
||||
//
|
||||
// The input is a slice of tokens with the results of EncodeMultimodal interleaved
|
||||
// in the order that the user provided them. Each element of the slice will be
|
||||
// either a single token or single multimodal object.
|
||||
//
|
||||
// The model must ensure that inputs are stored according to how they will be
|
||||
// processed and stored in the cache. For example, Llava-style models should insert
|
||||
// placeholder tokens equal to the feature size of the corresponding image with
|
||||
// the image itself attached to and split across these tokens. When Forward is called
|
||||
// a partial subset of these tokens may be submitted according to the batch size.
|
||||
//
|
||||
// This function is also responsible for updating MultimodalHash for any Multimodal
|
||||
// that is modified to ensure that there is a unique hash value that accurately
|
||||
// represents the contents.
|
||||
PostTokenize([]*input.Input) ([]*input.Input, error)
|
||||
}
|
||||
|
||||
// Base implements the common fields and methods for all models
|
||||
type Base struct {
|
||||
b ml.Backend
|
||||
config
|
||||
}
|
||||
|
||||
type config struct {
|
||||
Cache kvcache.Cache
|
||||
}
|
||||
|
||||
// Backend returns the underlying backend that will run the model
|
||||
func (m *Base) Backend() ml.Backend {
|
||||
return m.b
|
||||
}
|
||||
|
||||
func (m *Base) Config() config {
|
||||
return m.config
|
||||
}
|
||||
|
||||
var models = make(map[string]func(fs.Config) (Model, error))
|
||||
|
||||
// Register registers a model constructor for the given architecture
|
||||
func Register(name string, f func(fs.Config) (Model, error)) {
|
||||
if _, ok := models[name]; ok {
|
||||
panic("model: model already registered")
|
||||
}
|
||||
|
||||
models[name] = f
|
||||
}
|
||||
|
||||
// New initializes a new model instance with the provided configuration based on the metadata in the model file
|
||||
func New(modelPath string, params ml.BackendParams) (Model, error) {
|
||||
b, err := ml.NewBackend(modelPath, params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m, err := modelForArch(b.Config())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
base := Base{b: b, config: m.Config()}
|
||||
v := reflect.ValueOf(m)
|
||||
v.Elem().Set(populateFields(base, v.Elem()))
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func NewTextProcessor(s string) (TextProcessor, error) {
|
||||
r, err := os.Open(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
meta, err := fsggml.Decode(r, -1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m, err := modelForArch(meta.KV())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tp, ok := m.(TextProcessor)
|
||||
if !ok {
|
||||
return nil, ErrUnsupportedTokenizer
|
||||
}
|
||||
return tp, nil
|
||||
}
|
||||
|
||||
func modelForArch(c fs.Config) (Model, error) {
|
||||
arch := c.Architecture()
|
||||
if pooling.Type(c.Uint("pooling_type")) != pooling.TypeNone {
|
||||
arch = arch + "_embed"
|
||||
}
|
||||
|
||||
f, ok := models[arch]
|
||||
if !ok {
|
||||
return nil, ErrUnsupportedModel
|
||||
}
|
||||
|
||||
return f(c)
|
||||
}
|
||||
|
||||
func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
|
||||
t := v.Type()
|
||||
|
||||
if t.Kind() == reflect.Struct {
|
||||
allNil := true
|
||||
for i := range t.NumField() {
|
||||
tt := t.Field(i).Type
|
||||
vv := v.Field(i)
|
||||
if !vv.CanSet() {
|
||||
continue
|
||||
}
|
||||
|
||||
// make a copy
|
||||
tagsCopy := tags
|
||||
if tag := t.Field(i).Tag.Get("gguf"); tag != "" {
|
||||
tagsCopy = append(tagsCopy, parseTag(tag))
|
||||
}
|
||||
|
||||
if tt == reflect.TypeOf((*Base)(nil)).Elem() {
|
||||
vv.Set(reflect.ValueOf(base))
|
||||
} else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() {
|
||||
var fn func([]Tag, string, string) [][]string
|
||||
fn = func(tags []Tag, prefix, suffix string) (fullNames [][]string) {
|
||||
if len(tags) > 0 {
|
||||
var names []string
|
||||
if tags[0].name != "" {
|
||||
for _, n := range append([]string{tags[0].name}, tags[0].alternatives...) {
|
||||
names = append(names, prefix+n+suffix)
|
||||
}
|
||||
}
|
||||
childNames := fn(tags[1:], tags[0].prefix, tags[0].suffix)
|
||||
if len(names) == 0 {
|
||||
// current tag has no name, use child names only
|
||||
fullNames = append(fullNames, childNames...)
|
||||
} else if len(childNames) == 0 {
|
||||
// current tag has names but no children, create branches for each name
|
||||
for _, name := range names {
|
||||
fullNames = append(fullNames, []string{name})
|
||||
}
|
||||
} else {
|
||||
// merge each name with each child
|
||||
for _, name := range names {
|
||||
for _, childName := range childNames {
|
||||
fullNames = append(fullNames, append([]string{name}, childName...))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fullNames
|
||||
}
|
||||
|
||||
names := fn(tagsCopy, "", "")
|
||||
for _, name := range names {
|
||||
if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil {
|
||||
logutil.Trace("found tensor", "", tensor)
|
||||
vv.Set(reflect.ValueOf(tensor))
|
||||
break
|
||||
}
|
||||
}
|
||||
} else if tt.Kind() == reflect.Pointer || tt.Kind() == reflect.Interface {
|
||||
setPointer(base, vv, tagsCopy)
|
||||
} else if tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array {
|
||||
for i := range vv.Len() {
|
||||
vvv := vv.Index(i)
|
||||
if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface {
|
||||
setPointer(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)}))
|
||||
} else {
|
||||
vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)})...))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !canNil(tt) || !vv.IsNil() {
|
||||
allNil = false
|
||||
}
|
||||
}
|
||||
|
||||
if allNil {
|
||||
return reflect.Zero(t)
|
||||
}
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
func setPointer(base Base, v reflect.Value, tags []Tag) {
|
||||
vv := v
|
||||
if v.Kind() == reflect.Interface {
|
||||
if v.IsNil() {
|
||||
return
|
||||
}
|
||||
|
||||
vv = vv.Elem()
|
||||
}
|
||||
|
||||
vv = reflect.Indirect(vv)
|
||||
if v.IsNil() {
|
||||
vv = reflect.New(v.Type().Elem()).Elem()
|
||||
}
|
||||
|
||||
if f := populateFields(base, vv, tags...); f.CanAddr() {
|
||||
v.Set(f.Addr())
|
||||
}
|
||||
}
|
||||
|
||||
type Tag struct {
|
||||
name,
|
||||
// prefix and suffix are applied to child tags
|
||||
prefix,
|
||||
suffix string
|
||||
alternatives []string
|
||||
}
|
||||
|
||||
func parseTag(s string) (tag Tag) {
|
||||
parts := strings.Split(s, ",")
|
||||
if len(parts) > 0 {
|
||||
tag.name = parts[0]
|
||||
|
||||
for _, part := range parts[1:] {
|
||||
if value, ok := strings.CutPrefix(part, "alt:"); ok && tag.name == "" {
|
||||
// elevate alternative to primary if no primary given
|
||||
tag.name = value
|
||||
slog.Warn("gguf tag has alt: but no primary name", "tag", s)
|
||||
} else if ok {
|
||||
tag.alternatives = append(tag.alternatives, value)
|
||||
}
|
||||
if value, ok := strings.CutPrefix(part, "pre:"); ok {
|
||||
tag.prefix = value
|
||||
}
|
||||
if value, ok := strings.CutPrefix(part, "suf:"); ok {
|
||||
tag.suffix = value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func canNil(t reflect.Type) bool {
|
||||
return t.Kind() == reflect.Chan ||
|
||||
t.Kind() == reflect.Func ||
|
||||
t.Kind() == reflect.Interface ||
|
||||
t.Kind() == reflect.Map ||
|
||||
t.Kind() == reflect.Pointer ||
|
||||
t.Kind() == reflect.Slice
|
||||
}
|
||||
|
||||
func Forward(ctx ml.Context, m Model, batch input.Batch) (ml.Tensor, error) {
|
||||
if len(batch.Positions) != len(batch.Sequences) {
|
||||
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
|
||||
}
|
||||
|
||||
if len(batch.Positions) < 1 {
|
||||
return nil, errors.New("batch size cannot be less than 1")
|
||||
}
|
||||
|
||||
cache := m.Config().Cache
|
||||
if cache != nil {
|
||||
err := cache.StartForward(ctx, batch, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
t, err := m.Forward(ctx, batch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx.Forward(t)
|
||||
|
||||
return t, nil
|
||||
}
|
||||
58
x/model/models/gemma3/embed.go
Normal file
58
x/model/models/gemma3/embed.go
Normal file
@@ -0,0 +1,58 @@
|
||||
//go:build mlx
|
||||
|
||||
package gemma3
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
"github.com/ollama/ollama/x/ml/nn"
|
||||
"github.com/ollama/ollama/x/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/x/model"
|
||||
"github.com/ollama/ollama/x/model/input"
|
||||
)
|
||||
|
||||
type embedModel struct {
|
||||
model.Base
|
||||
model.SentencePiece
|
||||
|
||||
*TextModel
|
||||
poolingType pooling.Type
|
||||
|
||||
Dense [2]*nn.Linear `gguf:"dense"`
|
||||
}
|
||||
|
||||
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
|
||||
hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
|
||||
for _, dense := range m.Dense {
|
||||
hiddenStates = dense.Forward(ctx, hiddenStates)
|
||||
}
|
||||
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
|
||||
return hiddenStates, nil
|
||||
}
|
||||
|
||||
func newEmbedModel(c fs.Config) (model.Model, error) {
|
||||
m := &embedModel{
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{
|
||||
int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
||||
int32(c.Uint("tokenizer.ggml.eot_token_id", 106)),
|
||||
},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
),
|
||||
TextModel: newTextModel(c),
|
||||
poolingType: pooling.Type(c.Uint("pooling_type", 0)),
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
157
x/model/models/gemma3/model.go
Normal file
157
x/model/models/gemma3/model.go
Normal file
@@ -0,0 +1,157 @@
|
||||
//go:build mlx
|
||||
|
||||
package gemma3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"image"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/x/kvcache"
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
"github.com/ollama/ollama/x/ml/nn"
|
||||
"github.com/ollama/ollama/x/model"
|
||||
"github.com/ollama/ollama/x/model/input"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.SentencePiece
|
||||
|
||||
*VisionModel `gguf:"vision_tower.vision_model"`
|
||||
*TextModel `gguf:"language_model.model"`
|
||||
|
||||
*MultiModalProjector `gguf:"multi_modal_projector"`
|
||||
|
||||
ImageProcessor
|
||||
}
|
||||
|
||||
var _ model.MultimodalProcessor = (*Model)(nil)
|
||||
|
||||
type MultiModalProjector struct {
|
||||
SoftEmbNorm *nn.RMSNorm `gguf:"mm_soft_emb_norm"`
|
||||
InputProjection *nn.Linear `gguf:"mm_input_projection_weight"` // TODO .weight vs _weight
|
||||
|
||||
tokensPerImage int
|
||||
}
|
||||
|
||||
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, imageSize, patchSize int, eps float32) ml.Tensor {
|
||||
l := visionOutputs.Dim(0)
|
||||
|
||||
visionOutputs = visionOutputs.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false)
|
||||
patchesPerImage := imageSize / patchSize
|
||||
visionOutputs = visionOutputs.Reshape(ctx, patchesPerImage, patchesPerImage, l)
|
||||
|
||||
kernelSize := patchesPerImage / int(math.Sqrt(float64(p.tokensPerImage)))
|
||||
visionOutputs = visionOutputs.AvgPool2D(ctx, kernelSize, kernelSize, 0)
|
||||
visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0)*visionOutputs.Dim(1), l)
|
||||
visionOutputs = visionOutputs.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false)
|
||||
visionOutputs = p.SoftEmbNorm.Forward(ctx, visionOutputs, eps)
|
||||
|
||||
// TODO: inputProjection must be transposed since they're incompatible with visionOutputs
|
||||
visionOutputs = visionOutputs.Matmul(ctx, p.InputProjection.Weight.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false))
|
||||
return visionOutputs
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
// slog.Info("XXX Config", "c", c)
|
||||
m := Model{
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{
|
||||
int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
||||
int32(c.Uint("tokenizer.ggml.eot_token_id", 106)),
|
||||
},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
),
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
TextModel: newTextModel(c),
|
||||
MultiModalProjector: &MultiModalProjector{
|
||||
tokensPerImage: int(c.Uint("mm_tokens_per_image", 256)),
|
||||
},
|
||||
}
|
||||
|
||||
// slidingWindowLen := int32(c.Uint("attention.sliding_window"))
|
||||
// m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
|
||||
|
||||
// TODO need to implement sliding window...
|
||||
m.Cache = kvcache.NewCausalCache()
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
|
||||
if len(m.VisionModel.Layers) == 0 {
|
||||
return nil, model.ErrNoVisionModel
|
||||
}
|
||||
|
||||
image, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f32s, err := m.ImageProcessor.ProcessImage(image)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pixelValues := ctx.Input().FromFloats(f32s,
|
||||
m.ImageProcessor.imageSize,
|
||||
m.ImageProcessor.imageSize,
|
||||
m.ImageProcessor.numChannels,
|
||||
)
|
||||
|
||||
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
|
||||
visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.imageSize, m.patchSize, m.VisionModel.eps)
|
||||
return []input.Multimodal{{Tensor: visionOutputs}}, nil
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
var result []*input.Input
|
||||
|
||||
for _, inp := range inputs {
|
||||
if len(inp.Multimodal) == 0 {
|
||||
result = append(result, inp)
|
||||
} else {
|
||||
inputMultimodal := inp.Multimodal[0].Tensor
|
||||
|
||||
result = append(result,
|
||||
&input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
|
||||
&input.Input{Token: 255999}, // "<start_of_image>""
|
||||
&input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
|
||||
)
|
||||
|
||||
// add image token placeholders
|
||||
result = append(result, slices.Repeat([]*input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
|
||||
|
||||
result = append(result,
|
||||
&input.Input{Token: 256000}, // <end_of_image>
|
||||
&input.Input{Token: 108}, // "\n\n"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("gemma3", New)
|
||||
model.Register("gemma3_embed", newEmbedModel)
|
||||
}
|
||||
211
x/model/models/gemma3/model_text.go
Normal file
211
x/model/models/gemma3/model_text.go
Normal file
@@ -0,0 +1,211 @@
|
||||
//go:build mlx
|
||||
|
||||
package gemma3
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/x/kvcache"
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
"github.com/ollama/ollama/x/ml/nn"
|
||||
"github.com/ollama/ollama/x/model/input"
|
||||
)
|
||||
|
||||
type TextConfig struct {
|
||||
hiddenSize, numHeads, numKVHeads int
|
||||
attnKeyLen int
|
||||
eps, ropeScale float32
|
||||
ropeLocalBase, ropeGlobalBase float32
|
||||
largeModelScaling bool
|
||||
}
|
||||
|
||||
type TextModel struct {
|
||||
TokenEmbedding *nn.Embedding `gguf:"embed_tokens"`
|
||||
Layers []TextLayer `gguf:"layers"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"norm"`
|
||||
Output *nn.Linear `gguf:"embed_tokens"`
|
||||
|
||||
*TextConfig
|
||||
}
|
||||
|
||||
const (
|
||||
gemmaGlobalCacheCount = 6
|
||||
gemma27BLayerCount = 62
|
||||
)
|
||||
|
||||
// const (
|
||||
// cacheTypeSWA = iota
|
||||
// cacheTypeCausal
|
||||
// )
|
||||
|
||||
func newTextModel(c fs.Config) *TextModel {
|
||||
numBlocks := int(c.Uint("block_count"))
|
||||
|
||||
m := TextModel{
|
||||
Layers: make([]TextLayer, numBlocks),
|
||||
TextConfig: &TextConfig{
|
||||
hiddenSize: int(c.Uint("embedding_length")), // 2560 -- config.json: text_config.hidden_size
|
||||
numHeads: int(c.Uint("attention.head_count")), // 8 -- hard coded in python implementation for the model, 4 in some places, then overridden as 8
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")), // 4 -- same as above
|
||||
attnKeyLen: int(c.Uint("attention.key_length", 256)), //256 -- rope settings, hardcoded in model definition python
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), // 1e-06 - hardcoded in model definition python
|
||||
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0), // 10000 - hardcoded in python
|
||||
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0), // 1e+06 - hardcoded in python
|
||||
ropeScale: 1, // 1 - default is 1, implied in python code
|
||||
// vocabSize: vocabSize, // 262144
|
||||
// attnValLen: int(c.Uint("attention.value_length", 256)), //256
|
||||
// NOTE: the rope.scaling.factor is set incorrectly in the official QAT weights
|
||||
// (8 instead of 1)
|
||||
// ropeScale: c.Float("rope.scaling.factor", 1.0),
|
||||
},
|
||||
}
|
||||
if numBlocks == gemma27BLayerCount {
|
||||
m.largeModelScaling = true
|
||||
}
|
||||
|
||||
return &m
|
||||
}
|
||||
|
||||
type TextSelfAttention struct {
|
||||
Query *nn.Linear `gguf:"q_proj"`
|
||||
QueryNorm *nn.RMSNorm `gguf:"q_norm"`
|
||||
Key *nn.Linear `gguf:"k_proj"`
|
||||
KeyNorm *nn.RMSNorm `gguf:"k_norm"`
|
||||
Value *nn.Linear `gguf:"v_proj"`
|
||||
Output *nn.Linear `gguf:"o_proj"`
|
||||
}
|
||||
|
||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState ml.Tensor, offset int, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
|
||||
B := hiddenState.Dim(0)
|
||||
L := hiddenState.Dim(1)
|
||||
ropeBase := opts.ropeLocalBase
|
||||
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
||||
ropeBase = opts.ropeGlobalBase
|
||||
}
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, B, L, opts.numHeads, -1).Transpose(ctx, 0, 2, 1, 3)
|
||||
k = k.Reshape(ctx, B, L, opts.numKVHeads, -1).Transpose(ctx, 0, 2, 1, 3)
|
||||
v = v.Reshape(ctx, B, L, opts.numKVHeads, -1).Transpose(ctx, 0, 2, 1, 3).Contiguous(ctx, false)
|
||||
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
|
||||
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
|
||||
traditional := false
|
||||
q = q.RoPE(ctx, opts.attnKeyLen, traditional, opts.ropeScale, offset, ml.WithRoPEBase(ropeBase))
|
||||
k = k.RoPE(ctx, opts.attnKeyLen, traditional, opts.ropeScale, offset, ml.WithRoPEBase(ropeBase))
|
||||
|
||||
// TODO - this is wrong somehow so commenting out
|
||||
// if opts.largeModelScaling {
|
||||
// q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
||||
// } else {
|
||||
// q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
|
||||
// }
|
||||
|
||||
scaleFactor := math.Pow(256, -0.5)
|
||||
|
||||
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
|
||||
kqv = kqv.Transpose(ctx, 0, 2, 1, 3).Reshape(ctx, B, L, -1)
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
}
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
// ropeBase := m.TextConfig.ropeLocalBase
|
||||
// if (layer+1)%gemmaGlobalCacheCount == 0 {
|
||||
// ropeBase = m.TextConfig.ropeGlobalBase
|
||||
// }
|
||||
// q = q.RoPE(ctx, opts.attnKeyLen, traditional, opts.ropeScale, offset, ml.WithRoPEBase(ropeBase))
|
||||
panic("not yet implemented")
|
||||
// return key.RoPE(ctx, shift, m.TextConfig.attnKeyLen, ropeBase, 1/m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
|
||||
}
|
||||
|
||||
type TextMLP struct {
|
||||
Up *nn.Linear `gguf:"up_proj"`
|
||||
Down *nn.Linear `gguf:"down_proj"`
|
||||
Gate *nn.Linear `gguf:"gate_proj"`
|
||||
}
|
||||
|
||||
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
type TextLayer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"input_layernorm"`
|
||||
SelfAttention *TextSelfAttention `gguf:"self_attn"`
|
||||
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_layernorm"`
|
||||
MLPNorm *nn.RMSNorm `gguf:"pre_feedforward_layernorm"`
|
||||
MLP *TextMLP `gguf:"mlp"`
|
||||
PostMLPNorm *nn.RMSNorm `gguf:"post_feedforward_layernorm"`
|
||||
}
|
||||
|
||||
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, outputs ml.Tensor, offset int, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
|
||||
residual := hiddenState
|
||||
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, offset, cache, opts)
|
||||
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
|
||||
// In the final layer (outputs != nil), optimize by pruning to just the token positions
|
||||
// we need logits for.
|
||||
if outputs != nil {
|
||||
hiddenState = hiddenState.TakeAxes(ctx, outputs, 1)
|
||||
residual = residual.TakeAxes(ctx, outputs, 1)
|
||||
}
|
||||
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
residual = hiddenState
|
||||
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.MLP.Forward(ctx, hiddenState, opts) // TODO this is where it goes bad most likely...
|
||||
hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
|
||||
|
||||
// set image embeddings
|
||||
// var except []int
|
||||
// for _, image := range batch.Multimodal {
|
||||
// visionOutputs := image.Multimodal[0].Tensor
|
||||
// ctx.Forward(visionOutputs.Copy(ctx, hiddenState.AsStrided(ctx,
|
||||
// []int{visionOutputs.Dim(0) * visionOutputs.Dim(1)},
|
||||
// []int{image.Index * hiddenState.Stride(1)}, 0)))
|
||||
|
||||
// for i := range visionOutputs.Dim(1) {
|
||||
// except = append(except, image.Index+i)
|
||||
// }
|
||||
// }
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
// gemma alternates between the sliding window (local) and causal (global)
|
||||
// kv cache every 6 layers
|
||||
if cache != nil {
|
||||
// cacheType := cacheTypeSWA
|
||||
// if (i+1)%gemmaGlobalCacheCount == 0 {
|
||||
// cacheType = cacheTypeCausal
|
||||
// }
|
||||
cache.SetLayer(i)
|
||||
|
||||
// TODO this needs to come back
|
||||
// wc := cache.(*kvcache.WrapperCache)
|
||||
// wc.SetLayerType(cacheType)
|
||||
|
||||
// if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
|
||||
// causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
|
||||
// }
|
||||
}
|
||||
|
||||
var offset int
|
||||
var lastLayerOutputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
offset = batch.Offset
|
||||
lastLayerOutputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, i, hiddenState, lastLayerOutputs, offset, cache, m.TextConfig)
|
||||
}
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
return hiddenState
|
||||
}
|
||||
121
x/model/models/gemma3/model_vision.go
Normal file
121
x/model/models/gemma3/model_vision.go
Normal file
@@ -0,0 +1,121 @@
|
||||
//go:build mlx
|
||||
|
||||
package gemma3
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
"github.com/ollama/ollama/x/ml/nn"
|
||||
)
|
||||
|
||||
var batchSize int = 1
|
||||
|
||||
type VisionSelfAttention struct {
|
||||
Query *nn.Linear `gguf:"self_attn.q_proj"`
|
||||
Key *nn.Linear `gguf:"self_attn.k_proj"`
|
||||
Value *nn.Linear `gguf:"self_attn.v_proj"`
|
||||
Output *nn.Linear `gguf:"self_attn.out_proj"`
|
||||
}
|
||||
|
||||
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
|
||||
query := sa.Query.Forward(ctx, hiddenState)
|
||||
key := sa.Key.Forward(ctx, hiddenState)
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize)
|
||||
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
|
||||
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
||||
|
||||
hiddenState = sa.Output.Forward(ctx, attention)
|
||||
return hiddenState
|
||||
}
|
||||
|
||||
type VisionMLP struct {
|
||||
FC1 *nn.Linear `gguf:"fc1"`
|
||||
FC2 *nn.Linear `gguf:"fc2"`
|
||||
}
|
||||
|
||||
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
hiddenState = mlp.FC1.Forward(ctx, hiddenState).GELU(ctx)
|
||||
hiddenState = mlp.FC2.Forward(ctx, hiddenState)
|
||||
return hiddenState
|
||||
}
|
||||
|
||||
type VisionEncoderLayer struct {
|
||||
LayerNorm1 *nn.LayerNorm `gguf:"layer_norm1"`
|
||||
SelfAttention *VisionSelfAttention
|
||||
|
||||
LayerNorm2 *nn.LayerNorm `gguf:"layer_norm2"`
|
||||
MLP *VisionMLP `gguf:"mlp"`
|
||||
}
|
||||
|
||||
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
// self attention
|
||||
hiddenState = e.LayerNorm1.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = e.SelfAttention.Forward(ctx, hiddenState, opts)
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
residual = hiddenState
|
||||
|
||||
// feed forward
|
||||
hiddenState = e.LayerNorm2.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
|
||||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
type VisionModelOptions struct {
|
||||
hiddenSize, numHeads int
|
||||
imageSize, patchSize int
|
||||
eps float32
|
||||
}
|
||||
|
||||
type VisionModel struct {
|
||||
PatchEmbedding *nn.Conv2D `gguf:"embeddings.patch_embedding"`
|
||||
PositionEmbedding *nn.Embedding `gguf:"embeddings.position_embedding"`
|
||||
PostLayerNorm *nn.LayerNorm `gguf:"post_layernorm"`
|
||||
|
||||
Layers []VisionEncoderLayer `gguf:"encoder.layers"`
|
||||
|
||||
*VisionModelOptions
|
||||
}
|
||||
|
||||
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
|
||||
numPatches := (m.imageSize / m.patchSize) * (m.imageSize / m.patchSize)
|
||||
|
||||
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
|
||||
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
|
||||
hiddenState = hiddenState.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false)
|
||||
|
||||
positionIDs := ctx.Arange(0, float32(numPatches), 1, ml.DTypeInt32)
|
||||
hiddenState = hiddenState.Add(ctx, m.PositionEmbedding.Forward(ctx, positionIDs))
|
||||
|
||||
for _, layer := range m.Layers {
|
||||
hiddenState = layer.Forward(ctx, hiddenState, m.VisionModelOptions)
|
||||
}
|
||||
|
||||
hiddenState = m.PostLayerNorm.Forward(ctx, hiddenState, m.eps)
|
||||
return hiddenState
|
||||
}
|
||||
|
||||
func newVisionModel(c fs.Config) *VisionModel {
|
||||
return &VisionModel{
|
||||
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")),
|
||||
VisionModelOptions: &VisionModelOptions{
|
||||
hiddenSize: int(c.Uint("vision.embedding_length")),
|
||||
numHeads: int(c.Uint("vision.attention.head_count")),
|
||||
|
||||
imageSize: int(c.Uint("vision.image_size")),
|
||||
patchSize: int(c.Uint("vision.patch_size")),
|
||||
|
||||
eps: c.Float("vision.attention.layer_norm_epsilon"),
|
||||
},
|
||||
}
|
||||
}
|
||||
60
x/model/models/gemma3/process_image.go
Normal file
60
x/model/models/gemma3/process_image.go
Normal file
@@ -0,0 +1,60 @@
|
||||
//go:build mlx
|
||||
|
||||
package gemma3
|
||||
|
||||
import (
|
||||
"image"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/model/imageproc"
|
||||
)
|
||||
|
||||
type ImageProcessor struct {
|
||||
imageSize, patchSize, numChannels int
|
||||
}
|
||||
|
||||
func newImageProcessor(c fs.Config) ImageProcessor {
|
||||
return ImageProcessor{
|
||||
imageSize: int(c.Uint("vision.image_size")),
|
||||
patchSize: int(c.Uint("vision.patch_size")),
|
||||
numChannels: int(c.Uint("vision.num_channels")),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ImageProcessor) pack(img image.Image, mean, std [3]float32) []float32 {
|
||||
var pixelVals, rVals, gVals, bVals []float32
|
||||
|
||||
bounds := img.Bounds()
|
||||
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
|
||||
for x := bounds.Min.X; x < bounds.Max.X; x++ {
|
||||
c := img.At(x, y)
|
||||
r, g, b, _ := c.RGBA()
|
||||
rVal := float32(r>>8) / 255.0
|
||||
gVal := float32(g>>8) / 255.0
|
||||
bVal := float32(b>>8) / 255.0
|
||||
|
||||
rVal = (rVal - mean[0]) / std[0]
|
||||
gVal = (gVal - mean[1]) / std[1]
|
||||
bVal = (bVal - mean[2]) / std[2]
|
||||
|
||||
rVals = append(rVals, rVal)
|
||||
gVals = append(gVals, gVal)
|
||||
bVals = append(bVals, bVal)
|
||||
}
|
||||
}
|
||||
|
||||
pixelVals = append(pixelVals, rVals...)
|
||||
pixelVals = append(pixelVals, gVals...)
|
||||
pixelVals = append(pixelVals, bVals...)
|
||||
|
||||
return pixelVals
|
||||
}
|
||||
|
||||
func (p ImageProcessor) ProcessImage(img image.Image) ([]float32, error) {
|
||||
outputSize := image.Point{p.imageSize, p.imageSize}
|
||||
newImage := imageproc.Composite(img)
|
||||
newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear)
|
||||
|
||||
data := p.pack(newImage, imageproc.ImageNetStandardMean, imageproc.ImageNetStandardSTD)
|
||||
return data, nil
|
||||
}
|
||||
3
x/model/models/models.go
Normal file
3
x/model/models/models.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package models
|
||||
|
||||
// _ "github.com/ollama/ollama/x/model/models/gemma3"
|
||||
249
x/model/sentencepiece.go
Normal file
249
x/model/sentencepiece.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
const spmWhitespaceSep = "▁"
|
||||
|
||||
type SentencePiece struct {
|
||||
maxTokenLen int
|
||||
vocab *Vocabulary
|
||||
}
|
||||
|
||||
var _ TextProcessor = (*SentencePiece)(nil)
|
||||
|
||||
func (spm SentencePiece) Vocabulary() *Vocabulary {
|
||||
return spm.vocab
|
||||
}
|
||||
|
||||
func NewSentencePiece(vocab *Vocabulary) SentencePiece {
|
||||
logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
|
||||
|
||||
counter := map[int]int{}
|
||||
var maxTokenLen int
|
||||
for cnt := range vocab.Types {
|
||||
switch vocab.Types[cnt] {
|
||||
case TOKEN_TYPE_NORMAL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_UNUSED:
|
||||
maxTokenLen = max(maxTokenLen, len(vocab.Values[cnt]))
|
||||
fallthrough
|
||||
default:
|
||||
counter[int(vocab.Types[cnt])] += 1
|
||||
}
|
||||
}
|
||||
|
||||
logutil.Trace("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
|
||||
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
|
||||
"max token len", maxTokenLen)
|
||||
|
||||
return SentencePiece{
|
||||
maxTokenLen: maxTokenLen,
|
||||
vocab: vocab,
|
||||
}
|
||||
}
|
||||
|
||||
func (spm SentencePiece) Is(id int32, special Special) bool {
|
||||
return spm.vocab.Is(id, special)
|
||||
}
|
||||
|
||||
func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
fragments := []fragment{{value: s}}
|
||||
for _, special := range spm.vocab.SpecialVocabulary() {
|
||||
id := spm.vocab.Encode(special)
|
||||
for i := 0; i < len(fragments); i++ {
|
||||
frag := fragments[i]
|
||||
if len(frag.ids) > 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var middle []fragment
|
||||
switch i := strings.Index(frag.value, special); {
|
||||
case i < 0:
|
||||
middle = append(middle, frag)
|
||||
case i > 0:
|
||||
middle = append(middle, fragment{value: frag.value[:i]})
|
||||
fallthrough
|
||||
default:
|
||||
middle = append(middle, fragment{value: special, ids: []int32{id}})
|
||||
if rest := frag.value[i+len(special):]; rest != "" {
|
||||
middle = append(middle, fragment{value: rest})
|
||||
}
|
||||
}
|
||||
|
||||
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
|
||||
}
|
||||
}
|
||||
|
||||
var ids []int32
|
||||
for _, frag := range fragments {
|
||||
if len(frag.ids) > 0 {
|
||||
ids = append(ids, frag.ids...)
|
||||
continue
|
||||
}
|
||||
|
||||
text := strings.ReplaceAll(frag.value, " ", spmWhitespaceSep)
|
||||
|
||||
if id := spm.vocab.Encode(text); id >= 0 {
|
||||
ids = append(ids, id)
|
||||
continue
|
||||
}
|
||||
|
||||
q := &queue{}
|
||||
heap.Init(q)
|
||||
|
||||
runes := []rune(text)
|
||||
merges := make([]merge, len(runes))
|
||||
for r := range runes {
|
||||
merges[r] = merge{
|
||||
p: r - 1,
|
||||
n: r + 1,
|
||||
runes: []rune{runes[r]},
|
||||
}
|
||||
}
|
||||
|
||||
pairwise := func(a, b int) *candidate {
|
||||
if a < 0 || b >= len(runes) {
|
||||
return nil
|
||||
}
|
||||
|
||||
left, right := string(merges[a].runes), string(merges[b].runes)
|
||||
if id := spm.vocab.Encode(left + right); id >= 0 {
|
||||
return &candidate{
|
||||
a: a,
|
||||
b: b,
|
||||
score: spm.vocab.Scores[id],
|
||||
size: len(left) + len(right),
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
for i := range len(runes) - 1 {
|
||||
if pair := pairwise(i, i+1); pair != nil {
|
||||
heap.Push(q, pair)
|
||||
}
|
||||
}
|
||||
|
||||
for q.Len() > 0 {
|
||||
pair := heap.Pop(q).(*candidate)
|
||||
left, right := merges[pair.a], merges[pair.b]
|
||||
|
||||
if string(left.runes) == "" || string(right.runes) == "" || len(string(left.runes))+len(string(right.runes)) != pair.size {
|
||||
continue
|
||||
}
|
||||
|
||||
merges[pair.a].runes = append(left.runes, right.runes...)
|
||||
merges[pair.b].runes = nil
|
||||
merges[pair.a].n = right.n
|
||||
if right.n < len(merges) {
|
||||
merges[right.n].p = pair.a
|
||||
}
|
||||
|
||||
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
|
||||
heap.Push(q, pair)
|
||||
}
|
||||
|
||||
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
|
||||
heap.Push(q, pair)
|
||||
}
|
||||
}
|
||||
|
||||
for _, merge := range merges {
|
||||
if token := string(merge.runes); token != "" {
|
||||
id := spm.vocab.Encode(token)
|
||||
|
||||
if id >= 0 {
|
||||
ids = append(ids, id)
|
||||
continue
|
||||
}
|
||||
|
||||
// Fallback to byte tokenization
|
||||
var result []int32
|
||||
for _, b := range []byte(token) {
|
||||
byteToken := fmt.Sprintf("<0x%02X>", b)
|
||||
unknownID := spm.vocab.Encode(byteToken)
|
||||
if unknownID >= 0 {
|
||||
result = append(result, unknownID)
|
||||
} else {
|
||||
slog.Debug("unknown byte token", "byte", b, "token", byteToken)
|
||||
}
|
||||
}
|
||||
|
||||
ids = append(ids, result...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if addSpecial {
|
||||
ids = spm.vocab.addSpecials(ids)
|
||||
}
|
||||
|
||||
logutil.Trace("encoded", "string", s, "ids", ids)
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
type candidate struct {
|
||||
a, b int
|
||||
score float32
|
||||
size int
|
||||
}
|
||||
|
||||
type queue []*candidate
|
||||
|
||||
func (q queue) Len() int { return len(q) }
|
||||
|
||||
func (q queue) Less(i, j int) bool {
|
||||
return (q[i].score > q[j].score) || (q[i].score == q[j].score && q[i].a < q[j].a)
|
||||
}
|
||||
|
||||
func (q queue) Swap(i, j int) { q[i], q[j] = q[j], q[i] }
|
||||
|
||||
func (q *queue) Push(x interface{}) {
|
||||
item := x.(*candidate)
|
||||
*q = append(*q, item)
|
||||
}
|
||||
|
||||
func (q *queue) Pop() interface{} {
|
||||
old := *q
|
||||
n := len(old)
|
||||
item := old[n-1]
|
||||
*q = old[0 : n-1]
|
||||
return item
|
||||
}
|
||||
|
||||
func (spm SentencePiece) Decode(ids []int32) (string, error) {
|
||||
var sb strings.Builder
|
||||
for _, id := range ids {
|
||||
data := spm.vocab.Decode(id)
|
||||
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
|
||||
|
||||
// For tokenizers that use byte tokens like "<0xEA>"
|
||||
// convert them to the partial unicode character
|
||||
// so they are buffered correctly by the runner instead
|
||||
// of being sent back to the api as "<0xEA>"
|
||||
if len(data) == 6 && strings.HasPrefix(data, "<0x") && strings.HasSuffix(data, ">") {
|
||||
byteVal, err := strconv.ParseUint(data[1:5], 0, 8)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse hex byte: %v", err)
|
||||
}
|
||||
|
||||
if err := sb.WriteByte(byte(byteVal)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
} else {
|
||||
if _, err := sb.WriteString(data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logutil.Trace("decoded", "ids", ids, "string", sb.String())
|
||||
return sb.String(), nil
|
||||
}
|
||||
172
x/model/sentencepiece_test.go
Normal file
172
x/model/sentencepiece_test.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/ollama/ollama/convert/sentencepiece"
|
||||
)
|
||||
|
||||
func loadSentencePieceVocab(t *testing.T) SentencePiece {
|
||||
t.Helper()
|
||||
|
||||
bts, err := os.ReadFile(filepath.Join("..", "..", "model", "testdata", "gemma2", "tokenizer.model"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var spm sentencepiece.ModelProto
|
||||
if err := proto.Unmarshal(bts, &spm); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var v Vocabulary
|
||||
|
||||
for _, piece := range spm.GetPieces() {
|
||||
v.Values = append(v.Values, piece.GetPiece())
|
||||
v.Scores = append(v.Scores, piece.GetScore())
|
||||
switch t := piece.GetType(); t {
|
||||
case sentencepiece.ModelProto_SentencePiece_UNKNOWN,
|
||||
sentencepiece.ModelProto_SentencePiece_CONTROL,
|
||||
sentencepiece.ModelProto_SentencePiece_UNUSED,
|
||||
sentencepiece.ModelProto_SentencePiece_BYTE:
|
||||
v.Types = append(v.Types, int32(t))
|
||||
default:
|
||||
tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
|
||||
// todo parse the special tokens file
|
||||
// - this will roundtrip correctly but the <start_of_turn> and
|
||||
// <end_of_turn> tokens aren't processed
|
||||
v.Types = append(v.Types, tt)
|
||||
}
|
||||
}
|
||||
|
||||
return NewSentencePiece(&v)
|
||||
}
|
||||
|
||||
func TestSentencePieceEncode(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
slog.SetDefault(logger)
|
||||
|
||||
tokenizer := loadSentencePieceVocab(t)
|
||||
|
||||
t.Run("basic roundtrip", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []string{
|
||||
"hello",
|
||||
"hello ",
|
||||
"hello ",
|
||||
" hello",
|
||||
" hello ",
|
||||
" hello ",
|
||||
"hello world",
|
||||
"请考试我的软件!12345",
|
||||
"你好",
|
||||
"Hello 你好 world!",
|
||||
"Special characters: !@#$%^&*()_+-=[]{}|;':\",./<>?",
|
||||
"Multilingual: 你好 こんにちは Привет Hola مرحبا",
|
||||
"Numbers and symbols: 123456789 +- */",
|
||||
"Special tokens: <bos> text <eos>",
|
||||
"Code snippets: func main() { fmt.Println(\"Hello World\") }",
|
||||
"Long text: " + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " +
|
||||
"Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " +
|
||||
"Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris.",
|
||||
}
|
||||
|
||||
for _, want := range cases {
|
||||
ids, err := tokenizer.Encode(want, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if got, err := tokenizer.Decode(ids); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if got != want {
|
||||
t.Errorf("got %q, want %q [%#v]", got, want, ids)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("special tokens", func(t *testing.T) {
|
||||
type candidate struct {
|
||||
token string
|
||||
ids []int32
|
||||
}
|
||||
|
||||
cases := []candidate{
|
||||
{"<bos>", []int32{2}},
|
||||
{"<eos>", []int32{1}},
|
||||
}
|
||||
|
||||
for _, want := range cases {
|
||||
ids, err := tokenizer.Encode(want.token, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !slices.Equal(ids, want.ids) {
|
||||
t.Errorf("got %#v, want %#v", ids, want.ids)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSentencePieceDecodeByteTokens(t *testing.T) {
|
||||
vocab := &Vocabulary{
|
||||
Values: []string{
|
||||
"normal",
|
||||
"<0xEA>",
|
||||
"<0x41>",
|
||||
"<0xC3>",
|
||||
"<0xA3>",
|
||||
},
|
||||
Types: []int32{
|
||||
TOKEN_TYPE_NORMAL,
|
||||
TOKEN_TYPE_BYTE,
|
||||
TOKEN_TYPE_BYTE,
|
||||
TOKEN_TYPE_BYTE,
|
||||
TOKEN_TYPE_BYTE,
|
||||
},
|
||||
Scores: []float32{0, 0, 0, 0, 0},
|
||||
}
|
||||
|
||||
spm := NewSentencePiece(vocab)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ids []int32
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "single byte token",
|
||||
ids: []int32{1},
|
||||
expected: "\xea",
|
||||
},
|
||||
{
|
||||
name: "ASCII byte token",
|
||||
ids: []int32{2},
|
||||
expected: "A",
|
||||
},
|
||||
{
|
||||
name: "multiple byte tokens forming UTF-8 character",
|
||||
ids: []int32{3, 4},
|
||||
expected: "ã",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := spm.Decode(tt.ids)
|
||||
if err != nil {
|
||||
t.Errorf("failed to decode token IDs %v: %v", tt.ids, err)
|
||||
}
|
||||
if result != tt.expected {
|
||||
t.Errorf("got %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user