mirror of
https://github.com/ollama/ollama.git
synced 2026-02-06 13:43:39 -05:00
Compare commits
6 Commits
v0.15.5
...
brucemacd/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d5a2849d1f | ||
|
|
42e1d49fbe | ||
|
|
814630ca60 | ||
|
|
87cf187774 | ||
|
|
6ddd8862cd | ||
|
|
f1373193dc |
@@ -182,7 +182,7 @@ option(MLX_ENGINE "Enable MLX backend" OFF)
|
||||
|
||||
if(MLX_ENGINE)
|
||||
message(STATUS "Setting up MLX (this takes a while...)")
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/ml/backend/mlx)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/imagegen/mlx)
|
||||
|
||||
# Find CUDA toolkit if MLX is built with CUDA support
|
||||
find_package(CUDAToolkit)
|
||||
@@ -216,4 +216,4 @@ if(MLX_ENGINE)
|
||||
COMPONENT MLX)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -58,14 +58,39 @@ func (c *Claude) Run(model string, args []string) error {
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Env = append(os.Environ(),
|
||||
|
||||
env := append(os.Environ(),
|
||||
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
|
||||
"ANTHROPIC_API_KEY=",
|
||||
"ANTHROPIC_AUTH_TOKEN=ollama",
|
||||
)
|
||||
|
||||
env = append(env, c.modelEnvVars(model)...)
|
||||
|
||||
cmd.Env = env
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// modelEnvVars returns Claude Code env vars that route all model tiers through Ollama.
|
||||
func (c *Claude) modelEnvVars(model string) []string {
|
||||
primary := model
|
||||
fast := model
|
||||
if cfg, err := loadIntegration("claude"); err == nil && cfg.Aliases != nil {
|
||||
if p := cfg.Aliases["primary"]; p != "" {
|
||||
primary = p
|
||||
}
|
||||
if f := cfg.Aliases["fast"]; f != "" {
|
||||
fast = f
|
||||
}
|
||||
}
|
||||
return []string{
|
||||
"ANTHROPIC_DEFAULT_OPUS_MODEL=" + primary,
|
||||
"ANTHROPIC_DEFAULT_SONNET_MODEL=" + primary,
|
||||
"ANTHROPIC_DEFAULT_HAIKU_MODEL=" + fast,
|
||||
"CLAUDE_CODE_SUBAGENT_MODEL=" + primary,
|
||||
}
|
||||
}
|
||||
|
||||
// ConfigureAliases sets up model aliases for Claude Code.
|
||||
// model: the model to use (if empty, user will be prompted to select)
|
||||
// aliases: existing alias configuration to preserve/update
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -103,3 +104,95 @@ func TestClaudeArgs(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeModelEnvVars(t *testing.T) {
|
||||
c := &Claude{}
|
||||
|
||||
envMap := func(envs []string) map[string]string {
|
||||
m := make(map[string]string)
|
||||
for _, e := range envs {
|
||||
k, v, _ := strings.Cut(e, "=")
|
||||
m[k] = v
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
t.Run("falls back to model param when no aliases saved", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
got := envMap(c.modelEnvVars("llama3.2"))
|
||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2" {
|
||||
t.Errorf("OPUS = %q, want llama3.2", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "llama3.2" {
|
||||
t.Errorf("SONNET = %q, want llama3.2", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "llama3.2" {
|
||||
t.Errorf("HAIKU = %q, want llama3.2", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
|
||||
}
|
||||
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2" {
|
||||
t.Errorf("SUBAGENT = %q, want llama3.2", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses primary alias for opus sonnet and subagent", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
saveIntegration("claude", []string{"qwen3:8b"})
|
||||
saveAliases("claude", map[string]string{"primary": "qwen3:8b"})
|
||||
|
||||
got := envMap(c.modelEnvVars("qwen3:8b"))
|
||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "qwen3:8b" {
|
||||
t.Errorf("OPUS = %q, want qwen3:8b", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "qwen3:8b" {
|
||||
t.Errorf("SONNET = %q, want qwen3:8b", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "qwen3:8b" {
|
||||
t.Errorf("HAIKU = %q, want qwen3:8b (no fast alias)", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
|
||||
}
|
||||
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "qwen3:8b" {
|
||||
t.Errorf("SUBAGENT = %q, want qwen3:8b", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses fast alias for haiku", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
saveIntegration("claude", []string{"llama3.2:70b"})
|
||||
saveAliases("claude", map[string]string{
|
||||
"primary": "llama3.2:70b",
|
||||
"fast": "llama3.2:8b",
|
||||
})
|
||||
|
||||
got := envMap(c.modelEnvVars("llama3.2:70b"))
|
||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2:70b" {
|
||||
t.Errorf("OPUS = %q, want llama3.2:70b", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "llama3.2:70b" {
|
||||
t.Errorf("SONNET = %q, want llama3.2:70b", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "llama3.2:8b" {
|
||||
t.Errorf("HAIKU = %q, want llama3.2:8b", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
|
||||
}
|
||||
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2:70b" {
|
||||
t.Errorf("SUBAGENT = %q, want llama3.2:70b", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("alias primary overrides model param", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
saveIntegration("claude", []string{"saved-model"})
|
||||
saveAliases("claude", map[string]string{"primary": "saved-model"})
|
||||
|
||||
got := envMap(c.modelEnvVars("different-model"))
|
||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "saved-model" {
|
||||
t.Errorf("OPUS = %q, want saved-model", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"golang.org/x/mod/semver"
|
||||
)
|
||||
|
||||
@@ -32,6 +33,10 @@ func (c *Codex) Run(model string, args []string) error {
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Env = append(os.Environ(),
|
||||
"OPENAI_BASE_URL="+envconfig.Host().String()+"/v1/",
|
||||
"OPENAI_API_KEY=ollama",
|
||||
)
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -54,7 +53,6 @@ func migrateConfig() (bool, error) {
|
||||
|
||||
var js json.RawMessage
|
||||
if err := json.Unmarshal(oldData, &js); err != nil {
|
||||
slog.Warn("legacy config has invalid JSON, skipping migration", "path", oldPath, "error", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@@ -73,7 +71,6 @@ func migrateConfig() (bool, error) {
|
||||
_ = os.Remove(oldPath)
|
||||
_ = os.Remove(filepath.Dir(oldPath)) // clean up empty directory
|
||||
|
||||
slog.Info("migrated config", "from", oldPath, "to", newPath)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
@@ -8,6 +9,7 @@ import (
|
||||
"path/filepath"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
@@ -112,9 +114,17 @@ func (d *Droid) Edit(models []string) error {
|
||||
}
|
||||
|
||||
// Build new Ollama model entries with sequential indices (0, 1, 2, ...)
|
||||
client, _ := api.ClientFromEnvironment()
|
||||
|
||||
var newModels []any
|
||||
var defaultModelID string
|
||||
for i, model := range models {
|
||||
maxOutput := 64000
|
||||
if isCloudModel(context.Background(), client, model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
maxOutput = l.Output
|
||||
}
|
||||
}
|
||||
modelID := fmt.Sprintf("custom:%s-%d", model, i)
|
||||
newModels = append(newModels, modelEntry{
|
||||
Model: model,
|
||||
@@ -122,7 +132,7 @@ func (d *Droid) Edit(models []string) error {
|
||||
BaseURL: envconfig.Host().String() + "/v1",
|
||||
APIKey: "ollama",
|
||||
Provider: "generic-chat-completion-api",
|
||||
MaxOutputTokens: 64000,
|
||||
MaxOutputTokens: maxOutput,
|
||||
SupportsImages: false,
|
||||
ID: modelID,
|
||||
Index: i,
|
||||
|
||||
@@ -1251,6 +1251,55 @@ func TestDroidEdit_LargeNumberOfModels(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDroidEdit_LocalModelDefaultMaxOutput(t *testing.T) {
|
||||
d := &Droid{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
settingsDir := filepath.Join(tmpDir, ".factory")
|
||||
settingsPath := filepath.Join(settingsDir, "settings.json")
|
||||
|
||||
if err := d.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(settingsPath)
|
||||
var settings map[string]any
|
||||
json.Unmarshal(data, &settings)
|
||||
|
||||
models := settings["customModels"].([]any)
|
||||
entry := models[0].(map[string]any)
|
||||
if entry["maxOutputTokens"] != float64(64000) {
|
||||
t.Errorf("local model maxOutputTokens = %v, want 64000", entry["maxOutputTokens"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestDroidEdit_CloudModelLimitsUsed(t *testing.T) {
|
||||
// Verify that every cloud model in cloudModelLimits has a valid output
|
||||
// value that would be used for maxOutputTokens when isCloudModel returns true.
|
||||
// :cloud suffix stripping must also work since that's how users specify them.
|
||||
for name, expected := range cloudModelLimits {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
l, ok := lookupCloudModelLimit(name)
|
||||
if !ok {
|
||||
t.Fatalf("lookupCloudModelLimit(%q) returned false", name)
|
||||
}
|
||||
if l.Output != expected.Output {
|
||||
t.Errorf("output = %d, want %d", l.Output, expected.Output)
|
||||
}
|
||||
// Also verify :cloud suffix lookup
|
||||
cloudName := name + ":cloud"
|
||||
l2, ok := lookupCloudModelLimit(cloudName)
|
||||
if !ok {
|
||||
t.Fatalf("lookupCloudModelLimit(%q) returned false", cloudName)
|
||||
}
|
||||
if l2.Output != expected.Output {
|
||||
t.Errorf(":cloud output = %d, want %d", l2.Output, expected.Output)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDroidEdit_ArraysWithMixedTypes(t *testing.T) {
|
||||
d := &Droid{}
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
@@ -39,6 +39,7 @@ var cloudModelLimits = map[string]cloudModelLimit{
|
||||
"kimi-k2-thinking": {Context: 262_144, Output: 262_144},
|
||||
"nemotron-3-nano:30b": {Context: 1_048_576, Output: 131_072},
|
||||
"qwen3-coder:480b": {Context: 262_144, Output: 65_536},
|
||||
"qwen3-coder-next": {Context: 262_144, Output: 32_768},
|
||||
"qwen3-next:80b": {Context: 262_144, Output: 32_768},
|
||||
}
|
||||
|
||||
|
||||
@@ -633,6 +633,7 @@ func TestLookupCloudModelLimit(t *testing.T) {
|
||||
{"deepseek-v3.2", true, 163_840, 65_536},
|
||||
{"deepseek-v3.2:cloud", true, 163_840, 65_536},
|
||||
{"qwen3-coder:480b", true, 262_144, 65_536},
|
||||
{"qwen3-coder-next:cloud", true, 262_144, 32_768},
|
||||
{"llama3.2", false, 0, 0},
|
||||
{"unknown-model:cloud", false, 0, 0},
|
||||
}
|
||||
|
||||
1
go.mod
1
go.mod
@@ -24,6 +24,7 @@ require (
|
||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
|
||||
github.com/dlclark/regexp2 v1.11.4
|
||||
github.com/emirpasic/gods/v2 v2.0.0-alpha
|
||||
github.com/klauspost/compress v1.18.3
|
||||
github.com/mattn/go-runewidth v0.0.14
|
||||
github.com/nlpodyssey/gopickle v0.3.0
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||
|
||||
4
go.sum
4
go.sum
@@ -106,7 +106,6 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS
|
||||
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA=
|
||||
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/google/flatbuffers v2.0.0+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
|
||||
github.com/google/flatbuffers v24.3.25+incompatible h1:CX395cjN9Kke9mmalRoL3d81AtFUxJM+yDthflgJGkI=
|
||||
@@ -134,8 +133,9 @@ github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+
|
||||
github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
|
||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/klauspost/compress v1.13.1 h1:wXr2uRxZTJXHLly6qhJabee5JqIhTRoLBhDOA74hDEQ=
|
||||
github.com/klauspost/compress v1.13.1/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg=
|
||||
github.com/klauspost/compress v1.18.3 h1:9PJRvfbmTabkOX8moIpXPbMMbYN60bWImDDU7L+/6zw=
|
||||
github.com/klauspost/compress v1.18.3/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
|
||||
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
|
||||
|
||||
@@ -34,6 +34,7 @@ import (
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type filteredEnv []string
|
||||
@@ -116,7 +117,7 @@ type llamaServer struct {
|
||||
type ollamaServer struct {
|
||||
llmServer
|
||||
|
||||
textProcessor model.TextProcessor // textProcessor handles text encoding/decoding
|
||||
tokenizer tokenizer.Tokenizer // tokenizer handles text encoding/decoding
|
||||
}
|
||||
|
||||
// LoadModel will load a model from disk. The model must be in the GGML format.
|
||||
@@ -142,11 +143,11 @@ func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) {
|
||||
// NewLlamaServer will run a server for the given GPUs
|
||||
func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
|
||||
var llamaModel *llama.Model
|
||||
var textProcessor model.TextProcessor
|
||||
var tok tokenizer.Tokenizer
|
||||
var err error
|
||||
if envconfig.NewEngine() || f.KV().OllamaEngineRequired() {
|
||||
if len(projectors) == 0 {
|
||||
textProcessor, err = model.NewTextProcessor(modelPath)
|
||||
tok, err = model.NewTextProcessor(modelPath)
|
||||
} else {
|
||||
err = errors.New("split vision models aren't supported")
|
||||
}
|
||||
@@ -155,7 +156,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err)
|
||||
}
|
||||
}
|
||||
if textProcessor == nil {
|
||||
if tok == nil {
|
||||
llamaModel, err = llama.LoadModelFromFile(modelPath, llama.ModelParams{VocabOnly: true})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -211,7 +212,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
|
||||
kvct := strings.ToLower(envconfig.KvCacheType())
|
||||
|
||||
if textProcessor == nil {
|
||||
if tok == nil {
|
||||
flashAttention := ml.FlashAttentionAuto
|
||||
if faUserSet {
|
||||
if fa {
|
||||
@@ -261,7 +262,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
gpuLibs := ml.LibraryPaths(gpus)
|
||||
status := NewStatusWriter(os.Stderr)
|
||||
cmd, port, err := StartRunner(
|
||||
textProcessor != nil,
|
||||
tok != nil,
|
||||
modelPath,
|
||||
gpuLibs,
|
||||
status,
|
||||
@@ -310,8 +311,8 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
}
|
||||
}()
|
||||
|
||||
if textProcessor != nil {
|
||||
return &ollamaServer{llmServer: s, textProcessor: textProcessor}, nil
|
||||
if tok != nil {
|
||||
return &ollamaServer{llmServer: s, tokenizer: tok}, nil
|
||||
} else {
|
||||
return &llamaServer{llmServer: s, ggml: f}, nil
|
||||
}
|
||||
@@ -1774,7 +1775,7 @@ func (s *llamaServer) Tokenize(ctx context.Context, content string) ([]int, erro
|
||||
}
|
||||
|
||||
func (s *ollamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
tokens, err := s.textProcessor.Encode(content, false)
|
||||
tokens, err := s.tokenizer.Encode(content, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1809,7 +1810,7 @@ func (s *ollamaServer) Detokenize(ctx context.Context, tokens []int) (string, er
|
||||
toks[i] = int32(t)
|
||||
}
|
||||
|
||||
content, err := s.textProcessor.Decode(toks)
|
||||
content, err := s.tokenizer.Decode(toks)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/openai"
|
||||
@@ -496,6 +497,17 @@ func (w *ResponsesWriter) Write(data []byte) (int, error) {
|
||||
|
||||
func ResponsesMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if c.GetHeader("Content-Encoding") == "zstd" {
|
||||
reader, err := zstd.NewReader(c.Request.Body, zstd.WithDecoderMaxMemory(8<<20))
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "failed to decompress zstd body"))
|
||||
return
|
||||
}
|
||||
defer reader.Close()
|
||||
c.Request.Body = io.NopCloser(reader)
|
||||
c.Request.Header.Del("Content-Encoding")
|
||||
}
|
||||
|
||||
var req openai.ResponsesRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/openai"
|
||||
@@ -1238,3 +1239,102 @@ func TestImageEditsMiddleware(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func zstdCompress(t *testing.T, data []byte) []byte {
|
||||
t.Helper()
|
||||
var buf bytes.Buffer
|
||||
w, err := zstd.NewWriter(&buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := w.Write(data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func TestResponsesMiddlewareZstd(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
useZstd bool
|
||||
oversized bool
|
||||
wantCode int
|
||||
wantModel string
|
||||
wantMessage string
|
||||
}{
|
||||
{
|
||||
name: "plain JSON",
|
||||
body: `{"model": "test-model", "input": "Hello"}`,
|
||||
wantCode: http.StatusOK,
|
||||
wantModel: "test-model",
|
||||
wantMessage: "Hello",
|
||||
},
|
||||
{
|
||||
name: "zstd compressed",
|
||||
body: `{"model": "test-model", "input": "Hello"}`,
|
||||
useZstd: true,
|
||||
wantCode: http.StatusOK,
|
||||
wantModel: "test-model",
|
||||
wantMessage: "Hello",
|
||||
},
|
||||
{
|
||||
name: "zstd over max decompressed size",
|
||||
oversized: true,
|
||||
useZstd: true,
|
||||
wantCode: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var capturedRequest *api.ChatRequest
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(ResponsesMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||
router.Handle(http.MethodPost, "/v1/responses", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
var bodyReader io.Reader
|
||||
if tt.oversized {
|
||||
bodyReader = bytes.NewReader(zstdCompress(t, bytes.Repeat([]byte("A"), 9<<20)))
|
||||
} else if tt.useZstd {
|
||||
bodyReader = bytes.NewReader(zstdCompress(t, []byte(tt.body)))
|
||||
} else {
|
||||
bodyReader = strings.NewReader(tt.body)
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/responses", bodyReader)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if tt.useZstd || tt.oversized {
|
||||
req.Header.Set("Content-Encoding", "zstd")
|
||||
}
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != tt.wantCode {
|
||||
t.Fatalf("expected status %d, got %d: %s", tt.wantCode, resp.Code, resp.Body.String())
|
||||
}
|
||||
|
||||
if tt.wantCode != http.StatusOK {
|
||||
return
|
||||
}
|
||||
|
||||
if capturedRequest == nil {
|
||||
t.Fatal("expected captured request, got nil")
|
||||
}
|
||||
if capturedRequest.Model != tt.wantModel {
|
||||
t.Fatalf("expected model %q, got %q", tt.wantModel, capturedRequest.Model)
|
||||
}
|
||||
if len(capturedRequest.Messages) != 1 || capturedRequest.Messages[0].Content != tt.wantMessage {
|
||||
t.Fatalf("expected single user message %q, got %+v", tt.wantMessage, capturedRequest.Messages)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,272 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"iter"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/dlclark/regexp2"
|
||||
heap "github.com/emirpasic/gods/v2/trees/binaryheap"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type BytePairEncoding struct {
|
||||
vocab *Vocabulary
|
||||
regexps []*regexp2.Regexp
|
||||
}
|
||||
|
||||
var _ TextProcessor = (*BytePairEncoding)(nil)
|
||||
|
||||
func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding {
|
||||
if len(pretokenizers) == 0 {
|
||||
// set default byte-level pretokenizer if none provided, e.g.
|
||||
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44
|
||||
pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
|
||||
}
|
||||
|
||||
return BytePairEncoding{
|
||||
vocab: vocab,
|
||||
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
|
||||
for _, p := range pretokenizers {
|
||||
if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Vocabulary() *Vocabulary {
|
||||
return bpe.vocab
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Is(id int32, special Special) bool {
|
||||
return bpe.vocab.Is(id, special)
|
||||
}
|
||||
|
||||
func (bpe *BytePairEncoding) split(s string) iter.Seq[string] {
|
||||
parts := []string{s}
|
||||
for _, re := range bpe.regexps {
|
||||
parts = slices.Collect(func(yield func(string) bool) {
|
||||
for _, part := range parts {
|
||||
r := []rune(part)
|
||||
var offset int
|
||||
for m, _ := re.FindRunesMatch(r); m != nil; m, _ = re.FindNextMatch(m) {
|
||||
if offset-m.Index != 0 {
|
||||
if !yield(string(r[:m.Index])) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !yield(m.String()) {
|
||||
return
|
||||
}
|
||||
|
||||
offset = m.Index + m.Length
|
||||
}
|
||||
|
||||
if offset < len(r) {
|
||||
if !yield(string(r[offset:])) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return slices.Values(parts)
|
||||
}
|
||||
|
||||
// fragment is a string fragment and their corresponding token IDs
|
||||
type fragment struct {
|
||||
value string
|
||||
ids []int32
|
||||
}
|
||||
|
||||
// pair is a pair of runes and its rank
|
||||
type pair struct {
|
||||
a, b int
|
||||
rank int
|
||||
value string
|
||||
}
|
||||
|
||||
type merge struct {
|
||||
p, n int
|
||||
runes []rune
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
fragments := []fragment{{value: s}}
|
||||
for _, special := range bpe.vocab.SpecialVocabulary() {
|
||||
// TODO: process special tokens concurrently
|
||||
id := bpe.vocab.Encode(special)
|
||||
for i := 0; i < len(fragments); i++ {
|
||||
frag := fragments[i]
|
||||
if len(frag.ids) > 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var middle []fragment
|
||||
switch i := strings.Index(frag.value, special); {
|
||||
case i < 0:
|
||||
middle = append(middle, frag)
|
||||
case i > 0:
|
||||
middle = append(middle, fragment{value: frag.value[:i]})
|
||||
fallthrough
|
||||
default:
|
||||
middle = append(middle, fragment{value: special, ids: []int32{id}})
|
||||
if rest := frag.value[i+len(special):]; rest != "" {
|
||||
middle = append(middle, fragment{value: rest})
|
||||
}
|
||||
}
|
||||
|
||||
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
|
||||
}
|
||||
}
|
||||
|
||||
var ids []int32
|
||||
for _, frag := range fragments {
|
||||
if len(frag.ids) > 0 {
|
||||
ids = append(ids, frag.ids...)
|
||||
continue
|
||||
}
|
||||
|
||||
for split := range bpe.split(frag.value) {
|
||||
// TODO: process splits concurrently
|
||||
var sb strings.Builder
|
||||
for _, b := range []byte(split) {
|
||||
r := rune(b)
|
||||
switch {
|
||||
case r == 0x00ad:
|
||||
r = 0x0143
|
||||
case r <= 0x0020:
|
||||
r = r + 0x0100
|
||||
case r >= 0x007f && r <= 0x00a0:
|
||||
r = r + 0x00a2
|
||||
}
|
||||
|
||||
sb.WriteRune(r)
|
||||
}
|
||||
|
||||
// short circuit if the fragment is in the vocabulary
|
||||
if id := bpe.vocab.Encode(sb.String()); id >= 0 {
|
||||
ids = append(ids, id)
|
||||
continue
|
||||
}
|
||||
|
||||
runes := []rune(sb.String())
|
||||
merges := make([]merge, len(runes))
|
||||
for r := range runes {
|
||||
merges[r] = merge{
|
||||
p: r - 1,
|
||||
n: r + 1,
|
||||
runes: []rune{runes[r]},
|
||||
}
|
||||
}
|
||||
|
||||
pairwise := func(a, b int) *pair {
|
||||
if a < 0 || b >= len(runes) {
|
||||
return nil
|
||||
}
|
||||
|
||||
left, right := string(merges[a].runes), string(merges[b].runes)
|
||||
rank := bpe.vocab.Merge(left, right)
|
||||
if rank < 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &pair{
|
||||
a: a,
|
||||
b: b,
|
||||
rank: rank,
|
||||
value: left + right,
|
||||
}
|
||||
}
|
||||
|
||||
pairs := heap.NewWith(func(i, j *pair) int {
|
||||
return cmp.Compare(i.rank, j.rank)
|
||||
})
|
||||
|
||||
for i := range len(runes) - 1 {
|
||||
if pair := pairwise(i, i+1); pair != nil {
|
||||
pairs.Push(pair)
|
||||
}
|
||||
}
|
||||
|
||||
for !pairs.Empty() {
|
||||
pair, _ := pairs.Pop()
|
||||
|
||||
left, right := merges[pair.a], merges[pair.b]
|
||||
if len(left.runes) == 0 || len(right.runes) == 0 ||
|
||||
string(left.runes)+string(right.runes) != pair.value {
|
||||
continue
|
||||
}
|
||||
|
||||
if id := bpe.vocab.Encode(pair.value); id < 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
merges[pair.a].runes = append(left.runes, right.runes...)
|
||||
merges[pair.b].runes = nil
|
||||
|
||||
merges[pair.a].n = right.n
|
||||
if right.n < len(merges) {
|
||||
merges[right.n].p = pair.a
|
||||
}
|
||||
|
||||
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
|
||||
pairs.Push(pair)
|
||||
}
|
||||
|
||||
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
|
||||
pairs.Push(pair)
|
||||
}
|
||||
}
|
||||
|
||||
for _, merge := range merges {
|
||||
if len(merge.runes) > 0 {
|
||||
// TODO: handle the edge case where the rune isn't in the vocabulary
|
||||
if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if addSpecial {
|
||||
ids = bpe.vocab.addSpecials(ids)
|
||||
}
|
||||
|
||||
logutil.Trace("encoded", "string", s, "ids", ids)
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
|
||||
var sb strings.Builder
|
||||
for _, id := range ids {
|
||||
for _, r := range bpe.vocab.Decode(id) {
|
||||
switch {
|
||||
case r == 0x0100:
|
||||
// this produces 0x00 aka NULL
|
||||
continue
|
||||
case r == 0x0143:
|
||||
r = 0x00ad
|
||||
case r > 0x0100 && r <= 0x0120:
|
||||
r = r - 0x0100
|
||||
case r > 0x0120 && r <= 0x0142:
|
||||
r = r - 0x00a2
|
||||
}
|
||||
|
||||
// NOTE: not using WriteRune here because it writes the UTF-8
|
||||
// encoding of the rune which is _not_ what we want
|
||||
if err := sb.WriteByte(byte(r)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logutil.Trace("decoded", "string", sb.String(), "from", ids)
|
||||
return sb.String(), nil
|
||||
}
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
_ "github.com/ollama/ollama/ml/backend"
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -133,7 +134,7 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func NewTextProcessor(s string) (TextProcessor, error) {
|
||||
func NewTextProcessor(s string) (tokenizer.Tokenizer, error) {
|
||||
r, err := os.Open(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -150,7 +151,7 @@ func NewTextProcessor(s string) (TextProcessor, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tp, ok := m.(TextProcessor)
|
||||
tp, ok := m.(tokenizer.Tokenizer)
|
||||
if !ok {
|
||||
return nil, ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
@@ -10,11 +10,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
TypeEmbedding *nn.Embedding `gguf:"token_types"`
|
||||
@@ -129,7 +130,7 @@ func (o Options) headDim() int {
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
vocab := &model.Vocabulary{
|
||||
vocab := &tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -153,17 +154,17 @@ func New(c fs.Config) (model.Model, error) {
|
||||
},
|
||||
}
|
||||
|
||||
var processor model.TextProcessor
|
||||
var t tokenizer.Tokenizer
|
||||
switch c.String("tokenizer.ggml.model", "bert") {
|
||||
case "bert":
|
||||
processor = model.NewWordPiece(vocab, true)
|
||||
t = tokenizer.NewWordPiece(vocab, true)
|
||||
default:
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
return &Model{
|
||||
TextProcessor: processor,
|
||||
Layers: make([]EncoderLayer, c.Uint("block_count")),
|
||||
Tokenizer: t,
|
||||
Layers: make([]EncoderLayer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -222,7 +223,7 @@ func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -277,8 +278,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -10,11 +10,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
Sam *samModel `gguf:"s"`
|
||||
Vision *visionModel `gguf:"v"`
|
||||
@@ -134,8 +135,8 @@ func init() {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
TextProcessor: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -27,7 +28,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.SentencePiece
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -43,8 +44,8 @@ const (
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewSentencePiece(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
|
||||
@@ -7,11 +7,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type embedModel struct {
|
||||
model.Base
|
||||
model.SentencePiece
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*TextModel
|
||||
poolingType pooling.Type
|
||||
@@ -31,8 +32,8 @@ func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, erro
|
||||
|
||||
func newEmbedModel(c fs.Config) (model.Model, error) {
|
||||
m := &embedModel{
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewSentencePiece(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
|
||||
@@ -12,11 +12,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
*TextModel
|
||||
@@ -54,7 +55,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
vocabulary := model.Vocabulary{
|
||||
vocabulary := tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -70,19 +71,19 @@ func New(c fs.Config) (model.Model, error) {
|
||||
),
|
||||
}
|
||||
|
||||
var processor model.TextProcessor
|
||||
var t tokenizer.Tokenizer
|
||||
switch c.String("tokenizer.ggml.model") {
|
||||
case "gpt2":
|
||||
processor = model.NewBytePairEncoding(&vocabulary)
|
||||
t = tokenizer.NewBytePairEncoding(&vocabulary)
|
||||
default:
|
||||
// Previous uploads of Gemma 3 on Ollama did not have token 106
|
||||
// (i.e. "<end_of_turn>") so we need to add in case it's not already present
|
||||
vocabulary.EOS = append(vocabulary.EOS, int32(c.Uint("tokenizer.ggml.eot_token_id", 106)))
|
||||
processor = model.NewSentencePiece(&vocabulary)
|
||||
t = tokenizer.NewSentencePiece(&vocabulary)
|
||||
}
|
||||
|
||||
m := Model{
|
||||
TextProcessor: processor,
|
||||
Tokenizer: t,
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
TextModel: newTextModel(c),
|
||||
|
||||
@@ -6,11 +6,12 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.SentencePiece
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*TextModel
|
||||
}
|
||||
@@ -23,8 +24,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
TextModel: newTextModel(c),
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewSentencePiece(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
var ErrOldModelFormat = errors.New("this model uses a weight format that is no longer supported; please re-download it")
|
||||
@@ -198,7 +199,7 @@ func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -236,8 +237,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -11,11 +11,12 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v"`
|
||||
@@ -37,8 +38,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
allEOS := append([]int32{eosTokenID}, eosTokenIDs...)
|
||||
|
||||
m := &Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -12,11 +12,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Transformer struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
TransformerBlocks []TransformerBlock `gguf:"blk"`
|
||||
@@ -196,8 +197,8 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Optio
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Transformer{
|
||||
TransformerBlocks: make([]TransformerBlock, c.Uint("block_count")),
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -59,7 +60,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -78,7 +79,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
vocabulary := model.Vocabulary{
|
||||
vocabulary := tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -104,8 +105,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
TextProcessor: model.NewBytePairEncoding(&vocabulary, pretokenizers...),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
headDim: int(c.Uint("attention.key_length")),
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -25,7 +26,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -41,8 +42,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
return nil, model.ErrUnsupportedModel
|
||||
}
|
||||
|
||||
var processor model.TextProcessor
|
||||
vocabulary := model.Vocabulary{
|
||||
var processor tokenizer.Tokenizer
|
||||
vocabulary := tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -80,16 +81,16 @@ func New(c fs.Config) (model.Model, error) {
|
||||
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
}
|
||||
}
|
||||
processor = model.NewBytePairEncoding(&vocabulary, pretokenizers...)
|
||||
processor = tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...)
|
||||
case "llama":
|
||||
processor = model.NewSentencePiece(&vocabulary)
|
||||
processor = tokenizer.NewSentencePiece(&vocabulary)
|
||||
default:
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
m := Model{
|
||||
TextProcessor: processor,
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Tokenizer: processor,
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
|
||||
@@ -11,11 +11,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
ImageProcessor
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
@@ -33,8 +34,8 @@ func (p *Projector) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor {
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -11,11 +11,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v"`
|
||||
@@ -28,12 +29,12 @@ type Model struct {
|
||||
var _ model.MultimodalProcessor = (*Model)(nil)
|
||||
|
||||
// Implement TextProcessor interface
|
||||
var _ model.TextProcessor = (*Model)(nil)
|
||||
var _ tokenizer.Tokenizer = (*Model)(nil)
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := &Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -11,11 +11,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
*TextModel
|
||||
@@ -32,8 +33,8 @@ const (
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -11,11 +11,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
TypeEmbedding *nn.Embedding `gguf:"token_types"`
|
||||
@@ -178,29 +179,6 @@ func New(c fs.Config) (model.Model, error) {
|
||||
numHeads := int(c.Uint("attention.head_count"))
|
||||
headDim := hiddenSize / numHeads
|
||||
|
||||
processor := model.NewWordPiece(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{
|
||||
int32(cmp.Or(
|
||||
c.Uint("tokenizer.ggml.cls_token_id"),
|
||||
c.Uint("tokenizer.ggml.bos_token_id"),
|
||||
)),
|
||||
},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
|
||||
EOS: []int32{
|
||||
int32(cmp.Or(
|
||||
c.Uint("tokenizer.ggml.separator_token_id"),
|
||||
c.Uint("tokenizer.ggml.eos_token_id"),
|
||||
)),
|
||||
},
|
||||
},
|
||||
false,
|
||||
)
|
||||
|
||||
blockCount := int(c.Uint("block_count"))
|
||||
moeEveryNLayers := int(c.Uint("moe_every_n_layers", 0))
|
||||
layers := make([]EncoderLayer, blockCount)
|
||||
@@ -219,8 +197,29 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
return &Model{
|
||||
TextProcessor: processor,
|
||||
Layers: layers,
|
||||
Tokenizer: tokenizer.NewWordPiece(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{
|
||||
int32(cmp.Or(
|
||||
c.Uint("tokenizer.ggml.cls_token_id"),
|
||||
c.Uint("tokenizer.ggml.bos_token_id"),
|
||||
)),
|
||||
},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
|
||||
EOS: []int32{
|
||||
int32(cmp.Or(
|
||||
c.Uint("tokenizer.ggml.separator_token_id"),
|
||||
c.Uint("tokenizer.ggml.eos_token_id"),
|
||||
)),
|
||||
},
|
||||
},
|
||||
false,
|
||||
),
|
||||
Layers: layers,
|
||||
Options: Options{
|
||||
hiddenSize: hiddenSize,
|
||||
numHeads: numHeads,
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -33,7 +34,7 @@ type Options struct {
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -44,28 +45,24 @@ type Model struct {
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
vocabulary := model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
}
|
||||
|
||||
processor := model.NewBytePairEncoding(
|
||||
&vocabulary,
|
||||
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
)
|
||||
|
||||
m := Model{
|
||||
TextProcessor: processor,
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -92,7 +93,7 @@ func (d DecoderLayer) Forward(ctx ml.Context, hiddenStates, positions, outputs m
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []DecoderLayer `gguf:"blk"`
|
||||
@@ -139,8 +140,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
m := Model{
|
||||
Layers: make([]DecoderLayer, c.Uint("block_count")),
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -10,11 +10,12 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v"`
|
||||
@@ -27,8 +28,8 @@ var _ model.MultimodalProcessor = (*Model)(nil)
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := &Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -7,11 +7,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type embedModel struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*Model
|
||||
poolingType pooling.Type
|
||||
@@ -34,8 +35,8 @@ func newEmbed(c fs.Config) (model.Model, error) {
|
||||
layers[i].MLP = &dense{}
|
||||
}
|
||||
m := embedModel{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -159,7 +160,7 @@ func (d *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
@@ -218,8 +219,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
// Options contains model configuration
|
||||
@@ -207,7 +208,7 @@ func (l *Layer) Forward(ctx ml.Context, layer int, hiddenStates, positions, outp
|
||||
// Model is the main Qwen3-Next model
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
@@ -353,8 +354,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -10,11 +10,12 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v"`
|
||||
@@ -172,8 +173,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
TextProcessor: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestWordPiece(t *testing.T) {
|
||||
wpm := NewWordPiece(
|
||||
&Vocabulary{
|
||||
Values: []string{"[UNK]", "[CLS]", "[SEP]", "▁hello", "▁world", "s", "▁!", "▁@", "▁#"},
|
||||
AddBOS: true,
|
||||
AddEOS: true,
|
||||
BOS: []int32{1},
|
||||
EOS: []int32{2},
|
||||
},
|
||||
true, // lowercase
|
||||
)
|
||||
|
||||
ids, err := wpm.Encode("Hello world!", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff([]int32{1, 3, 4, 6, 2}, ids); diff != "" {
|
||||
t.Errorf("unexpected ids (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
words, err := wpm.Decode(ids)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff("[CLS] hello world! [SEP]", words); diff != "" {
|
||||
t.Errorf("unexpected words (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWordPieceWords(t *testing.T) {
|
||||
var wpm WordPiece
|
||||
|
||||
basic := slices.Collect(wpm.words("Hey friend! How are you?!?"))
|
||||
if diff := cmp.Diff([]string{"Hey", "friend", "!", "How", "are", "you", "?", "!", "?"}, basic); diff != "" {
|
||||
t.Errorf("unexpected words (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
chinese := slices.Collect(wpm.words("野口里佳 Noguchi Rika"))
|
||||
if diff := cmp.Diff([]string{"野", "口", "里", "佳", "Noguchi", "Rika"}, chinese); diff != "" {
|
||||
t.Errorf("unexpected words (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
@@ -37,6 +37,7 @@ import (
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/runner/common"
|
||||
"github.com/ollama/ollama/sample"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
|
||||
_ "github.com/ollama/ollama/model/models"
|
||||
)
|
||||
@@ -210,9 +211,9 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
}
|
||||
|
||||
// calculateLogprobs converts raw logits to log probabilities and finds top K tokens
|
||||
func calculateLogprobs(logits []float32, selectedToken int32, topK int, textProcessor model.TextProcessor) []llm.Logprob {
|
||||
func calculateLogprobs(logits []float32, selectedToken int32, topK int, tok tokenizer.Tokenizer) []llm.Logprob {
|
||||
decoder := func(tokenID int) string {
|
||||
text, _ := textProcessor.Decode([]int32{int32(tokenID)})
|
||||
text, _ := tok.Decode([]int32{int32(tokenID)})
|
||||
return text
|
||||
}
|
||||
return common.CalculateLogprobs(logits, int(selectedToken), topK, decoder)
|
||||
@@ -242,7 +243,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input,
|
||||
|
||||
for i, part := range parts {
|
||||
// text - tokenize
|
||||
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
|
||||
tokens, err := s.model.(tokenizer.Tokenizer).Encode(part, i == 0)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@@ -764,7 +765,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
nextBatchTokens[i].Token = token
|
||||
|
||||
// if it's an end of sequence token, break
|
||||
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
|
||||
if s.model.(tokenizer.Tokenizer).Is(token, tokenizer.SpecialEOS) {
|
||||
// TODO (jmorganca): we should send this back
|
||||
// as it's important for the /api/generate context
|
||||
// seq.responses <- piece
|
||||
@@ -773,14 +774,14 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
continue
|
||||
}
|
||||
|
||||
piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
|
||||
piece, err := s.model.(tokenizer.Tokenizer).Decode([]int32{token})
|
||||
if err != nil {
|
||||
panic("failed to decode token")
|
||||
}
|
||||
|
||||
// Calculate logprobs if requested (after EOS check to avoid logprobs for EOS tokens)
|
||||
if seq.logprobs {
|
||||
logprobs := calculateLogprobs(logits, token, seq.topLogprobs, s.model.(model.TextProcessor))
|
||||
logprobs := calculateLogprobs(logits, token, seq.topLogprobs, s.model.(tokenizer.Tokenizer))
|
||||
seq.pendingLogprobs = append(seq.pendingLogprobs, logprobs...)
|
||||
}
|
||||
|
||||
@@ -878,7 +879,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
var grammar *sample.GrammarSampler
|
||||
var err error
|
||||
if req.Grammar != "" {
|
||||
grammar, err = sample.NewGrammarSampler(s.model.(model.TextProcessor), req.Grammar)
|
||||
grammar, err = sample.NewGrammarSampler(s.model.(tokenizer.Tokenizer), req.Grammar)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
|
||||
return
|
||||
|
||||
@@ -3,7 +3,7 @@ package runner
|
||||
import (
|
||||
"github.com/ollama/ollama/runner/llamarunner"
|
||||
"github.com/ollama/ollama/runner/ollamarunner"
|
||||
"github.com/ollama/ollama/x/mlxrunner"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
func Execute(args []string) error {
|
||||
@@ -11,22 +11,13 @@ func Execute(args []string) error {
|
||||
args = args[1:]
|
||||
}
|
||||
|
||||
var newRunner bool
|
||||
var mlxRunner bool
|
||||
if len(args) > 0 && args[0] == "--ollama-engine" {
|
||||
args = args[1:]
|
||||
newRunner = true
|
||||
}
|
||||
if len(args) > 0 && args[0] == "--mlx-engine" {
|
||||
args = args[1:]
|
||||
mlxRunner = true
|
||||
}
|
||||
|
||||
if mlxRunner {
|
||||
return mlxrunner.Execute(args)
|
||||
} else if newRunner {
|
||||
return ollamarunner.Execute(args)
|
||||
} else {
|
||||
return llamarunner.Execute(args)
|
||||
if len(args) > 0 {
|
||||
switch args[0] {
|
||||
case "--ollama-engine":
|
||||
return ollamarunner.Execute(args[1:])
|
||||
case "--imagegen-engine":
|
||||
return imagegen.Execute(args[1:])
|
||||
}
|
||||
}
|
||||
return llamarunner.Execute(args)
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
// token represents information about a single token during sampling
|
||||
@@ -168,15 +168,15 @@ type GrammarSampler struct {
|
||||
grammar *llama.Grammar
|
||||
}
|
||||
|
||||
func NewGrammarSampler(model model.TextProcessor, grammarStr string) (*GrammarSampler, error) {
|
||||
vocabIds := make([]uint32, len(model.Vocabulary().Values))
|
||||
pieces := make([]string, len(model.Vocabulary().Values))
|
||||
for i := range model.Vocabulary().Values {
|
||||
pieces[i], _ = model.Decode([]int32{int32(i)})
|
||||
func NewGrammarSampler(tok tokenizer.Tokenizer, grammarStr string) (*GrammarSampler, error) {
|
||||
vocabIds := make([]uint32, len(tok.Vocabulary().Values))
|
||||
pieces := make([]string, len(tok.Vocabulary().Values))
|
||||
for i := range tok.Vocabulary().Values {
|
||||
pieces[i], _ = tok.Decode([]int32{int32(i)})
|
||||
vocabIds[i] = uint32(i)
|
||||
}
|
||||
|
||||
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, model.Vocabulary().EOS)
|
||||
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, tok.Vocabulary().EOS)
|
||||
if grammar == nil {
|
||||
return nil, errors.New("sample: failed to initialize grammar")
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
func TestWeighted(t *testing.T) {
|
||||
@@ -60,10 +60,10 @@ func TestWeighted(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func modelHelper(t testing.TB) model.BytePairEncoding {
|
||||
func modelHelper(t testing.TB) tokenizer.Tokenizer {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.Open(filepath.Join("..", "model", "testdata", "llama3.2", "encoder.json"))
|
||||
f, err := os.Open(filepath.FromSlash("../tokenizer/testdata/llama3.2/encoder.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -81,8 +81,8 @@ func modelHelper(t testing.TB) model.BytePairEncoding {
|
||||
|
||||
merges := make([]string, 0, 1)
|
||||
// Only need vocab for Grammar Test
|
||||
return model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
return tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: tokens,
|
||||
Types: make([]int32, len(vocab)),
|
||||
Merges: merges,
|
||||
|
||||
@@ -52,7 +52,7 @@ import (
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
imagegenmanifest "github.com/ollama/ollama/x/imagegen/manifest"
|
||||
xserver "github.com/ollama/ollama/x/server"
|
||||
)
|
||||
|
||||
@@ -1106,7 +1106,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
|
||||
// For image generation models, populate details from imagegen package
|
||||
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
|
||||
if info, err := imagegen.GetModelInfo(name.String()); err == nil {
|
||||
if info, err := imagegenmanifest.GetModelInfo(name.String()); err == nil {
|
||||
modelDetails.Family = info.Architecture
|
||||
modelDetails.ParameterSize = format.HumanNumber(uint64(info.ParameterCount))
|
||||
modelDetails.QuantizationLevel = info.Quantization
|
||||
|
||||
@@ -21,7 +21,7 @@ import (
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
type LlmRequest struct {
|
||||
@@ -567,16 +567,16 @@ iGPUScan:
|
||||
// This supports both LLM (completion) and image generation models.
|
||||
func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
||||
// Determine mode based on capabilities
|
||||
var mode mlxrunner.ModelMode
|
||||
var mode imagegen.ModelMode
|
||||
if slices.Contains(req.model.Config.Capabilities, "image") {
|
||||
mode = mlxrunner.ModeImageGen
|
||||
mode = imagegen.ModeImageGen
|
||||
} else {
|
||||
mode = mlxrunner.ModeLLM
|
||||
mode = imagegen.ModeLLM
|
||||
}
|
||||
|
||||
// Use model name for MLX (it resolves manifests by name, not file path)
|
||||
modelName := req.model.ShortName
|
||||
server, err := mlxrunner.NewServer(modelName, mode)
|
||||
server, err := imagegen.NewServer(modelName, mode)
|
||||
if err != nil {
|
||||
req.errCh <- err
|
||||
return true
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package model
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
@@ -18,19 +18,19 @@ type BytePairEncoding struct {
|
||||
regexps []*regexp2.Regexp
|
||||
}
|
||||
|
||||
var _ TextProcessor = (*BytePairEncoding)(nil)
|
||||
var _ Tokenizer = (*BytePairEncoding)(nil)
|
||||
|
||||
func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding {
|
||||
if len(pretokenizers) == 0 {
|
||||
func NewBytePairEncoding(vocab *Vocabulary, pretokenizer ...string) BytePairEncoding {
|
||||
if len(pretokenizer) == 0 {
|
||||
// set default byte-level pretokenizer if none provided, e.g.
|
||||
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44
|
||||
pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
|
||||
// https://github.com/huggingface/tokenizer/blob/main/tokenizer/src/pre_tokenizer/byte_level.rs#L44
|
||||
pretokenizer = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
|
||||
}
|
||||
|
||||
return BytePairEncoding{
|
||||
vocab: vocab,
|
||||
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
|
||||
for _, p := range pretokenizers {
|
||||
for _, p := range pretokenizer {
|
||||
if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
|
||||
return
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package model
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
func llama(t testing.TB) BytePairEncoding {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.Open(filepath.Join("testdata", "llama3.2", "encoder.json"))
|
||||
f, err := os.Open(filepath.FromSlash("testdata/llama3.2/encoder.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -43,7 +43,7 @@ func llama(t testing.TB) BytePairEncoding {
|
||||
}
|
||||
}
|
||||
|
||||
f, err = os.Open(filepath.Join("testdata", "llama3.2", "vocab.bpe"))
|
||||
f, err = os.Open(filepath.FromSlash("testdata/llama3.2/vocab.bpe"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package model
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
@@ -17,7 +17,7 @@ type SentencePiece struct {
|
||||
vocab *Vocabulary
|
||||
}
|
||||
|
||||
var _ TextProcessor = (*SentencePiece)(nil)
|
||||
var _ Tokenizer = (*SentencePiece)(nil)
|
||||
|
||||
func (spm SentencePiece) Vocabulary() *Vocabulary {
|
||||
return spm.vocab
|
||||
@@ -224,7 +224,7 @@ func (spm SentencePiece) Decode(ids []int32) (string, error) {
|
||||
data := spm.vocab.Decode(id)
|
||||
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
|
||||
|
||||
// For tokenizers that use byte tokens like "<0xEA>"
|
||||
// For tokenizer that use byte tokens like "<0xEA>"
|
||||
// convert them to the partial unicode character
|
||||
// so they are buffered correctly by the runner instead
|
||||
// of being sent back to the api as "<0xEA>"
|
||||
@@ -1,4 +1,4 @@
|
||||
package model
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
func loadSentencePieceVocab(t *testing.T) SentencePiece {
|
||||
t.Helper()
|
||||
|
||||
bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
|
||||
bts, err := os.ReadFile(filepath.FromSlash("testdata/gemma2/tokenizer.model"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package model
|
||||
package tokenizer
|
||||
|
||||
const (
|
||||
TOKEN_TYPE_NORMAL = iota + 1
|
||||
@@ -9,7 +9,7 @@ const (
|
||||
TOKEN_TYPE_BYTE
|
||||
)
|
||||
|
||||
type TextProcessor interface {
|
||||
type Tokenizer interface {
|
||||
Encode(s string, addSpecial bool) ([]int32, error)
|
||||
Decode([]int32) (string, error)
|
||||
Is(int32, Special) bool
|
||||
@@ -1,4 +1,4 @@
|
||||
package model
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
@@ -1,4 +1,4 @@
|
||||
package model
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"testing"
|
||||
@@ -1,4 +1,4 @@
|
||||
package model
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -32,7 +32,7 @@ var wordPieceReplacer = strings.NewReplacer(
|
||||
" 're", "'re",
|
||||
)
|
||||
|
||||
// Decode implements TextProcessor.
|
||||
// Decode implements Tokenizer.
|
||||
func (wpm WordPiece) Decode(ids []int32) (string, error) {
|
||||
var sb strings.Builder
|
||||
for i, id := range ids {
|
||||
@@ -96,7 +96,7 @@ func (wpm WordPiece) words(s string) iter.Seq[string] {
|
||||
}
|
||||
}
|
||||
|
||||
// Encode implements TextProcessor.
|
||||
// Encode implements Tokenizer.
|
||||
func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
var ids []int32
|
||||
|
||||
@@ -151,17 +151,17 @@ func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
// Is implements TextProcessor.
|
||||
// Is implements Tokenizer.
|
||||
func (wpm WordPiece) Is(id int32, special Special) bool {
|
||||
return wpm.vocab.Is(id, special)
|
||||
}
|
||||
|
||||
// Vocabulary implements TextProcessor.
|
||||
// Vocabulary implements Tokenizer.
|
||||
func (wpm WordPiece) Vocabulary() *Vocabulary {
|
||||
return wpm.vocab
|
||||
}
|
||||
|
||||
var _ TextProcessor = (*WordPiece)(nil)
|
||||
var _ Tokenizer = (*WordPiece)(nil)
|
||||
|
||||
func NewWordPiece(vocab *Vocabulary, lowercase bool) WordPiece {
|
||||
return WordPiece{
|
||||
@@ -1,4 +1,4 @@
|
||||
package model
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"slices"
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlxrunner
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/flux2"
|
||||
"github.com/ollama/ollama/x/imagegen/models/zimage"
|
||||
@@ -28,8 +28,8 @@ var imageGenMu sync.Mutex
|
||||
func (s *server) loadImageModel() error {
|
||||
// Check memory requirements before loading
|
||||
var requiredMemory uint64
|
||||
if manifest, err := imagegen.LoadManifest(s.modelName); err == nil {
|
||||
requiredMemory = uint64(manifest.TotalTensorSize())
|
||||
if modelManifest, err := manifest.LoadManifest(s.modelName); err == nil {
|
||||
requiredMemory = uint64(modelManifest.TotalTensorSize())
|
||||
}
|
||||
availableMemory := mlx.GetMemoryLimit()
|
||||
if availableMemory > 0 && requiredMemory > 0 && availableMemory < requiredMemory {
|
||||
@@ -38,7 +38,7 @@ func (s *server) loadImageModel() error {
|
||||
}
|
||||
|
||||
// Detect model type and load appropriate model
|
||||
modelType := imagegen.DetectModelType(s.modelName)
|
||||
modelType := DetectModelType(s.modelName)
|
||||
slog.Info("detected image model type", "type", modelType)
|
||||
|
||||
var model ImageModel
|
||||
@@ -108,7 +108,7 @@ func (s *server) handleImageCompletion(w http.ResponseWriter, r *http.Request, r
|
||||
}
|
||||
|
||||
// Encode image as base64 PNG
|
||||
imageData, err := imagegen.EncodeImageBase64(img)
|
||||
imageData, err := EncodeImageBase64(img)
|
||||
if err != nil {
|
||||
resp := Response{Content: fmt.Sprintf("error encoding: %v", err), Done: true}
|
||||
data, _ := json.Marshal(resp)
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlxrunner
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -12,8 +12,8 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/glm4_moe_lite"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
@@ -197,13 +197,13 @@ func sample(logits *mlx.Array, temp float32, vocabSize int32) *mlx.Array {
|
||||
// loadLLMModel loads a safetensors LLM model and its tokenizer from manifest storage.
|
||||
func (s *server) loadLLMModel() error {
|
||||
// Load the manifest to get model information
|
||||
manifest, err := imagegen.LoadManifest(s.modelName)
|
||||
modelManifest, err := manifest.LoadManifest(s.modelName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
// Detect model architecture from config.json
|
||||
configData, err := manifest.ReadConfig("config.json")
|
||||
configData, err := modelManifest.ReadConfig("config.json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read config.json: %w", err)
|
||||
}
|
||||
@@ -232,7 +232,7 @@ func (s *server) loadLLMModel() error {
|
||||
|
||||
switch {
|
||||
case strings.Contains(archLower, "glm4moelite"):
|
||||
m, err := glm4_moe_lite.LoadFromManifest(manifest)
|
||||
m, err := glm4_moe_lite.LoadFromManifest(modelManifest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load glm4-moe-lite model: %w", err)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package imagegen
|
||||
package manifest
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -1,4 +1,4 @@
|
||||
package imagegen
|
||||
package manifest
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build mlx
|
||||
|
||||
package imagegen
|
||||
package manifest
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -15,9 +15,9 @@ import (
|
||||
type ManifestWeights struct {
|
||||
manifest *ModelManifest
|
||||
component string
|
||||
tensors map[string]ManifestLayer // name -> layer
|
||||
cache map[string]*mlx.Array // name -> loaded array
|
||||
nativeCache []*mlx.SafetensorsFile // keep native handles alive
|
||||
tensors map[string]ManifestLayer // name -> layer
|
||||
cache map[string]*mlx.Array // name -> loaded array
|
||||
nativeCache []*mlx.SafetensorsFile // keep native handles alive
|
||||
}
|
||||
|
||||
// LoadWeightsFromManifest creates a weight loader from manifest storage.
|
||||
@@ -14,6 +14,8 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"runtime"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
)
|
||||
|
||||
// SupportedBackends lists the backends that support image generation.
|
||||
@@ -41,8 +43,8 @@ func CheckPlatformSupport() error {
|
||||
// ResolveModelName checks if a model name is a known image generation model.
|
||||
// Returns the normalized model name if found, empty string otherwise.
|
||||
func ResolveModelName(modelName string) string {
|
||||
manifest, err := LoadManifest(modelName)
|
||||
if err == nil && manifest.HasTensorLayers() {
|
||||
modelManifest, err := manifest.LoadManifest(modelName)
|
||||
if err == nil && modelManifest.HasTensorLayers() {
|
||||
return modelName
|
||||
}
|
||||
return ""
|
||||
@@ -52,12 +54,12 @@ func ResolveModelName(modelName string) string {
|
||||
// Checks both "architecture" (Ollama format) and "_class_name" (diffusers format).
|
||||
// Returns empty string if detection fails.
|
||||
func DetectModelType(modelName string) string {
|
||||
manifest, err := LoadManifest(modelName)
|
||||
modelManifest, err := manifest.LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
data, err := manifest.ReadConfig("model_index.json")
|
||||
data, err := modelManifest.ReadConfig("model_index.json")
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen3"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
@@ -61,7 +61,7 @@ func (m *Model) Load(modelName string) error {
|
||||
m.ModelName = modelName
|
||||
|
||||
// Load manifest
|
||||
manifest, err := imagegen.LoadManifest(modelName)
|
||||
manifest, err := manifest.LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load manifest: %w", err)
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
@@ -14,19 +14,19 @@ import (
|
||||
|
||||
// TransformerConfig holds Flux2 transformer configuration
|
||||
type TransformerConfig struct {
|
||||
AttentionHeadDim int32 `json:"attention_head_dim"` // 128
|
||||
AxesDimsRoPE []int32 `json:"axes_dims_rope"` // [32, 32, 32, 32]
|
||||
Eps float32 `json:"eps"` // 1e-6
|
||||
GuidanceEmbeds bool `json:"guidance_embeds"` // false for Klein
|
||||
InChannels int32 `json:"in_channels"` // 128
|
||||
JointAttentionDim int32 `json:"joint_attention_dim"` // 7680
|
||||
MLPRatio float32 `json:"mlp_ratio"` // 3.0
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"` // 24
|
||||
NumLayers int32 `json:"num_layers"` // 5
|
||||
NumSingleLayers int32 `json:"num_single_layers"` // 20
|
||||
PatchSize int32 `json:"patch_size"` // 1
|
||||
RopeTheta int32 `json:"rope_theta"` // 2000
|
||||
TimestepGuidanceChannels int32 `json:"timestep_guidance_channels"` // 256
|
||||
AttentionHeadDim int32 `json:"attention_head_dim"` // 128
|
||||
AxesDimsRoPE []int32 `json:"axes_dims_rope"` // [32, 32, 32, 32]
|
||||
Eps float32 `json:"eps"` // 1e-6
|
||||
GuidanceEmbeds bool `json:"guidance_embeds"` // false for Klein
|
||||
InChannels int32 `json:"in_channels"` // 128
|
||||
JointAttentionDim int32 `json:"joint_attention_dim"` // 7680
|
||||
MLPRatio float32 `json:"mlp_ratio"` // 3.0
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"` // 24
|
||||
NumLayers int32 `json:"num_layers"` // 5
|
||||
NumSingleLayers int32 `json:"num_single_layers"` // 20
|
||||
PatchSize int32 `json:"patch_size"` // 1
|
||||
RopeTheta int32 `json:"rope_theta"` // 2000
|
||||
TimestepGuidanceChannels int32 `json:"timestep_guidance_channels"` // 256
|
||||
}
|
||||
|
||||
// Computed dimensions
|
||||
@@ -392,12 +392,12 @@ type Flux2Transformer2DModel struct {
|
||||
}
|
||||
|
||||
// Load loads the Flux2 transformer from ollama blob storage.
|
||||
func (m *Flux2Transformer2DModel) Load(manifest *imagegen.ModelManifest) error {
|
||||
func (m *Flux2Transformer2DModel) Load(modelManifest *manifest.ModelManifest) error {
|
||||
fmt.Print(" Loading transformer... ")
|
||||
|
||||
// Load config from blob
|
||||
var cfg TransformerConfig
|
||||
if err := manifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
|
||||
if err := modelManifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.TransformerConfig = &cfg
|
||||
@@ -412,7 +412,7 @@ func (m *Flux2Transformer2DModel) Load(manifest *imagegen.ModelManifest) error {
|
||||
}
|
||||
|
||||
// Load weights from tensor blobs
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer")
|
||||
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "transformer")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
@@ -15,21 +15,21 @@ import (
|
||||
|
||||
// VAEConfig holds AutoencoderKLFlux2 configuration
|
||||
type VAEConfig struct {
|
||||
ActFn string `json:"act_fn"` // "silu"
|
||||
BatchNormEps float32 `json:"batch_norm_eps"` // 0.0001
|
||||
BatchNormMomentum float32 `json:"batch_norm_momentum"` // 0.1
|
||||
BlockOutChannels []int32 `json:"block_out_channels"` // [128, 256, 512, 512]
|
||||
ForceUpcast bool `json:"force_upcast"` // true
|
||||
InChannels int32 `json:"in_channels"` // 3
|
||||
LatentChannels int32 `json:"latent_channels"` // 32
|
||||
LayersPerBlock int32 `json:"layers_per_block"` // 2
|
||||
ActFn string `json:"act_fn"` // "silu"
|
||||
BatchNormEps float32 `json:"batch_norm_eps"` // 0.0001
|
||||
BatchNormMomentum float32 `json:"batch_norm_momentum"` // 0.1
|
||||
BlockOutChannels []int32 `json:"block_out_channels"` // [128, 256, 512, 512]
|
||||
ForceUpcast bool `json:"force_upcast"` // true
|
||||
InChannels int32 `json:"in_channels"` // 3
|
||||
LatentChannels int32 `json:"latent_channels"` // 32
|
||||
LayersPerBlock int32 `json:"layers_per_block"` // 2
|
||||
MidBlockAddAttn bool `json:"mid_block_add_attention"` // true
|
||||
NormNumGroups int32 `json:"norm_num_groups"` // 32
|
||||
OutChannels int32 `json:"out_channels"` // 3
|
||||
PatchSize []int32 `json:"patch_size"` // [2, 2]
|
||||
SampleSize int32 `json:"sample_size"` // 1024
|
||||
UsePostQuantConv bool `json:"use_post_quant_conv"` // true
|
||||
UseQuantConv bool `json:"use_quant_conv"` // true
|
||||
NormNumGroups int32 `json:"norm_num_groups"` // 32
|
||||
OutChannels int32 `json:"out_channels"` // 3
|
||||
PatchSize []int32 `json:"patch_size"` // [2, 2]
|
||||
SampleSize int32 `json:"sample_size"` // 1024
|
||||
UsePostQuantConv bool `json:"use_post_quant_conv"` // true
|
||||
UseQuantConv bool `json:"use_quant_conv"` // true
|
||||
}
|
||||
|
||||
// BatchNorm2D implements 2D batch normalization with running statistics
|
||||
@@ -356,18 +356,18 @@ func (db *DownEncoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
}
|
||||
|
||||
// Load loads the Flux2 VAE from ollama blob storage.
|
||||
func (m *AutoencoderKLFlux2) Load(manifest *imagegen.ModelManifest) error {
|
||||
func (m *AutoencoderKLFlux2) Load(modelManifest *manifest.ModelManifest) error {
|
||||
fmt.Print(" Loading VAE... ")
|
||||
|
||||
// Load config from blob
|
||||
var cfg VAEConfig
|
||||
if err := manifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
|
||||
if err := modelManifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.Config = &cfg
|
||||
|
||||
// Load weights from tensor blobs
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vae")
|
||||
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "vae")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
@@ -38,11 +38,11 @@ type Config struct {
|
||||
AttentionBias bool `json:"attention_bias"`
|
||||
|
||||
// MLA (Multi-head Latent Attention) parameters
|
||||
QLoraRank int32 `json:"q_lora_rank"`
|
||||
KVLoraRank int32 `json:"kv_lora_rank"`
|
||||
QKRopeHeadDim int32 `json:"qk_rope_head_dim"`
|
||||
QKNopeHeadDim int32 `json:"qk_nope_head_dim"`
|
||||
VHeadDim int32 `json:"v_head_dim"`
|
||||
QLoraRank int32 `json:"q_lora_rank"`
|
||||
KVLoraRank int32 `json:"kv_lora_rank"`
|
||||
QKRopeHeadDim int32 `json:"qk_rope_head_dim"`
|
||||
QKNopeHeadDim int32 `json:"qk_nope_head_dim"`
|
||||
VHeadDim int32 `json:"v_head_dim"`
|
||||
|
||||
// MoE parameters
|
||||
NRoutedExperts int32 `json:"n_routed_experts"`
|
||||
@@ -82,7 +82,7 @@ type MLAAttention struct {
|
||||
// Absorbed MLA projections (derived from kv_b_proj)
|
||||
// EmbedQ: projects q_nope to latent space [num_heads, kv_lora_rank, qk_nope_head_dim]
|
||||
// UnembedOut: projects attention output from latent space [num_heads, v_head_dim, kv_lora_rank]
|
||||
EmbedQ *nn.MultiLinear `weight:"-"`
|
||||
EmbedQ *nn.MultiLinear `weight:"-"`
|
||||
UnembedOut *nn.MultiLinear `weight:"-"`
|
||||
|
||||
// Output projection
|
||||
@@ -194,8 +194,8 @@ func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
|
||||
// MoEGate implements the expert gating mechanism
|
||||
type MoEGate struct {
|
||||
Gate nn.LinearLayer `weight:"mlp.gate"`
|
||||
EScoreCorrectionBias *mlx.Array `weight:"mlp.gate.e_score_correction_bias,optional"`
|
||||
Gate nn.LinearLayer `weight:"mlp.gate"`
|
||||
EScoreCorrectionBias *mlx.Array `weight:"mlp.gate.e_score_correction_bias,optional"`
|
||||
}
|
||||
|
||||
// Forward computes expert selection indices and scores
|
||||
@@ -617,9 +617,9 @@ func sanitizeExpertWeights(weights safetensors.WeightSource, prefix string, numE
|
||||
}
|
||||
|
||||
// LoadFromManifest loads a GLM4-MoE-Lite model from a manifest (Ollama blob storage).
|
||||
func LoadFromManifest(manifest *imagegen.ModelManifest) (*Model, error) {
|
||||
func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) {
|
||||
// Read config from manifest
|
||||
configData, err := manifest.ReadConfig("config.json")
|
||||
configData, err := modelManifest.ReadConfig("config.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
@@ -634,7 +634,7 @@ func LoadFromManifest(manifest *imagegen.ModelManifest) (*Model, error) {
|
||||
cfg.Scale = computeScale(&cfg)
|
||||
|
||||
// Load weights from manifest blobs
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "")
|
||||
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
@@ -653,7 +653,7 @@ func LoadFromManifest(manifest *imagegen.ModelManifest) (*Model, error) {
|
||||
}
|
||||
|
||||
// Load tokenizer from manifest with config files for EOS token detection
|
||||
tokData, err := manifest.ReadConfig("tokenizer.json")
|
||||
tokData, err := modelManifest.ReadConfig("tokenizer.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load tokenizer config: %w", err)
|
||||
}
|
||||
@@ -664,12 +664,12 @@ func LoadFromManifest(manifest *imagegen.ModelManifest) (*Model, error) {
|
||||
}
|
||||
|
||||
// Try to load generation_config.json if available (preferred source for EOS)
|
||||
if genConfigData, err := manifest.ReadConfig("generation_config.json"); err == nil {
|
||||
if genConfigData, err := modelManifest.ReadConfig("generation_config.json"); err == nil {
|
||||
tokConfig.GenerationConfigJSON = genConfigData
|
||||
}
|
||||
|
||||
// Try to load tokenizer_config.json if available
|
||||
if tokConfigData, err := manifest.ReadConfig("tokenizer_config.json"); err == nil {
|
||||
if tokConfigData, err := modelManifest.ReadConfig("tokenizer_config.json"); err == nil {
|
||||
tokConfig.TokenizerConfigJSON = tokConfigData
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
@@ -181,19 +181,19 @@ type TextEncoder struct {
|
||||
}
|
||||
|
||||
// Load loads the Qwen3 text encoder from ollama blob storage.
|
||||
func (m *TextEncoder) Load(manifest *imagegen.ModelManifest, configPath string) error {
|
||||
func (m *TextEncoder) Load(modelManifest *manifest.ModelManifest, configPath string) error {
|
||||
fmt.Print(" Loading text encoder... ")
|
||||
|
||||
// Load config from blob
|
||||
var cfg Config
|
||||
if err := manifest.ReadConfigJSON(configPath, &cfg); err != nil {
|
||||
if err := modelManifest.ReadConfigJSON(configPath, &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.Config = &cfg
|
||||
m.Layers = make([]*Block, cfg.NumHiddenLayers)
|
||||
|
||||
// Load weights from tensor blobs
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "text_encoder")
|
||||
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "text_encoder")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
@@ -7,8 +7,8 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
@@ -38,7 +38,7 @@ type TransformerConfig struct {
|
||||
type TimestepEmbedder struct {
|
||||
Linear1 nn.LinearLayer `weight:"mlp.0"`
|
||||
Linear2 nn.LinearLayer `weight:"mlp.2"`
|
||||
FreqEmbedSize int32 // 256 (computed)
|
||||
FreqEmbedSize int32 // 256 (computed)
|
||||
}
|
||||
|
||||
// Forward computes timestep embeddings -> [B, 256]
|
||||
@@ -85,9 +85,9 @@ func (xe *XEmbedder) Forward(x *mlx.Array) *mlx.Array {
|
||||
|
||||
// CapEmbedder projects caption features to model dimension
|
||||
type CapEmbedder struct {
|
||||
Norm *nn.RMSNorm `weight:"0"`
|
||||
Linear nn.LinearLayer `weight:"1"`
|
||||
PadToken *mlx.Array // loaded separately at root level
|
||||
Norm *nn.RMSNorm `weight:"0"`
|
||||
Linear nn.LinearLayer `weight:"1"`
|
||||
PadToken *mlx.Array // loaded separately at root level
|
||||
}
|
||||
|
||||
// Forward projects caption embeddings: [B, L, cap_feat_dim] -> [B, L, dim]
|
||||
@@ -103,10 +103,9 @@ type FeedForward struct {
|
||||
W1 nn.LinearLayer `weight:"w1"` // gate projection
|
||||
W2 nn.LinearLayer `weight:"w2"` // down projection
|
||||
W3 nn.LinearLayer `weight:"w3"` // up projection
|
||||
OutDim int32 // computed from W2
|
||||
OutDim int32 // computed from W2
|
||||
}
|
||||
|
||||
|
||||
// Forward applies SwiGLU: silu(W1(x)) * W3(x), then W2
|
||||
func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
@@ -132,11 +131,11 @@ type Attention struct {
|
||||
ToK nn.LinearLayer `weight:"to_k"`
|
||||
ToV nn.LinearLayer `weight:"to_v"`
|
||||
ToOut nn.LinearLayer `weight:"to_out.0"`
|
||||
NormQ *mlx.Array `weight:"norm_q.weight"` // [head_dim] for per-head RMSNorm
|
||||
NormK *mlx.Array `weight:"norm_k.weight"`
|
||||
NormQ *mlx.Array `weight:"norm_q.weight"` // [head_dim] for per-head RMSNorm
|
||||
NormK *mlx.Array `weight:"norm_k.weight"`
|
||||
// Fused QKV (computed at init time for efficiency, not loaded from weights)
|
||||
ToQKV nn.LinearLayer `weight:"-"` // Fused Q+K+V projection (created by FuseQKV)
|
||||
Fused bool `weight:"-"` // Whether to use fused QKV path
|
||||
Fused bool `weight:"-"` // Whether to use fused QKV path
|
||||
// Computed fields (not loaded from weights)
|
||||
NHeads int32 `weight:"-"`
|
||||
HeadDim int32 `weight:"-"`
|
||||
@@ -288,13 +287,13 @@ func applyRoPE3D(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
|
||||
|
||||
// TransformerBlock is a single transformer block with optional AdaLN modulation
|
||||
type TransformerBlock struct {
|
||||
Attention *Attention `weight:"attention"`
|
||||
FeedForward *FeedForward `weight:"feed_forward"`
|
||||
AttentionNorm1 *nn.RMSNorm `weight:"attention_norm1"`
|
||||
AttentionNorm2 *nn.RMSNorm `weight:"attention_norm2"`
|
||||
FFNNorm1 *nn.RMSNorm `weight:"ffn_norm1"`
|
||||
FFNNorm2 *nn.RMSNorm `weight:"ffn_norm2"`
|
||||
AdaLN nn.LinearLayer `weight:"adaLN_modulation.0,optional"` // only if modulation
|
||||
Attention *Attention `weight:"attention"`
|
||||
FeedForward *FeedForward `weight:"feed_forward"`
|
||||
AttentionNorm1 *nn.RMSNorm `weight:"attention_norm1"`
|
||||
AttentionNorm2 *nn.RMSNorm `weight:"attention_norm2"`
|
||||
FFNNorm1 *nn.RMSNorm `weight:"ffn_norm1"`
|
||||
FFNNorm2 *nn.RMSNorm `weight:"ffn_norm2"`
|
||||
AdaLN nn.LinearLayer `weight:"adaLN_modulation.0,optional"` // only if modulation
|
||||
// Computed fields
|
||||
HasModulation bool
|
||||
Dim int32
|
||||
@@ -350,7 +349,7 @@ func (tb *TransformerBlock) Forward(x *mlx.Array, adaln *mlx.Array, cos, sin *ml
|
||||
type FinalLayer struct {
|
||||
AdaLN nn.LinearLayer `weight:"adaLN_modulation.1"` // [256] -> [dim]
|
||||
Output nn.LinearLayer `weight:"linear"` // [dim] -> [out_channels]
|
||||
OutDim int32 // computed from Output
|
||||
OutDim int32 // computed from Output
|
||||
}
|
||||
|
||||
// Forward computes final output
|
||||
@@ -401,12 +400,12 @@ type Transformer struct {
|
||||
}
|
||||
|
||||
// Load loads the Z-Image transformer from ollama blob storage.
|
||||
func (m *Transformer) Load(manifest *imagegen.ModelManifest) error {
|
||||
func (m *Transformer) Load(modelManifest *manifest.ModelManifest) error {
|
||||
fmt.Print(" Loading transformer... ")
|
||||
|
||||
// Load config from blob
|
||||
var cfg TransformerConfig
|
||||
if err := manifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
|
||||
if err := modelManifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
if len(cfg.AllPatchSize) > 0 {
|
||||
@@ -417,7 +416,7 @@ func (m *Transformer) Load(manifest *imagegen.ModelManifest) error {
|
||||
m.ContextRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
|
||||
m.Layers = make([]*TransformerBlock, cfg.NLayers)
|
||||
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer")
|
||||
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "transformer")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/imagegen/vae"
|
||||
@@ -562,7 +562,7 @@ func (ub *UpDecoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
if ub.Upsample != nil {
|
||||
// Stage 1: Upsample2x (nearest neighbor)
|
||||
{
|
||||
prev := x
|
||||
prev := x
|
||||
x = Upsample2x(x)
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
@@ -570,7 +570,7 @@ func (ub *UpDecoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
|
||||
// Stage 2: Upsample conv
|
||||
{
|
||||
prev := x
|
||||
prev := x
|
||||
x = ub.Upsample.Forward(x)
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
@@ -643,16 +643,16 @@ type VAEDecoder struct {
|
||||
}
|
||||
|
||||
// Load loads the VAE decoder from ollama blob storage.
|
||||
func (m *VAEDecoder) Load(manifest *imagegen.ModelManifest) error {
|
||||
func (m *VAEDecoder) Load(modelManifest *manifest.ModelManifest) error {
|
||||
// Load config from blob
|
||||
var cfg VAEConfig
|
||||
if err := manifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
|
||||
if err := modelManifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.Config = &cfg
|
||||
|
||||
// Load weights from tensor blobs
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vae")
|
||||
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "vae")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
"github.com/ollama/ollama/x/imagegen/vae"
|
||||
@@ -18,14 +18,14 @@ import (
|
||||
// GenerateConfig holds all options for image generation.
|
||||
type GenerateConfig struct {
|
||||
Prompt string
|
||||
NegativePrompt string // Empty = no CFG
|
||||
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
|
||||
Width int32 // Image width (default: 1024)
|
||||
Height int32 // Image height (default: 1024)
|
||||
Steps int // Denoising steps (default: 9 for turbo)
|
||||
Seed int64 // Random seed
|
||||
NegativePrompt string // Empty = no CFG
|
||||
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
|
||||
Width int32 // Image width (default: 1024)
|
||||
Height int32 // Image height (default: 1024)
|
||||
Steps int // Denoising steps (default: 9 for turbo)
|
||||
Seed int64 // Random seed
|
||||
Progress func(step, totalSteps int) // Optional progress callback
|
||||
CapturePath string // GPU capture path (debug)
|
||||
CapturePath string // GPU capture path (debug)
|
||||
|
||||
// TeaCache options (timestep embedding aware caching)
|
||||
TeaCache bool // TeaCache is always enabled for faster inference
|
||||
@@ -58,7 +58,7 @@ func (m *Model) Load(modelName string) error {
|
||||
m.ModelName = modelName
|
||||
|
||||
// Load manifest
|
||||
manifest, err := imagegen.LoadManifest(modelName)
|
||||
manifest, err := manifest.LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load manifest: %w", err)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package mlxrunner provides a unified MLX runner for both LLM and image generation models.
|
||||
package mlxrunner
|
||||
// Package imagegen provides a unified MLX runner for both LLM and image generation models.
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
@@ -98,7 +97,7 @@ func Execute(args []string) error {
|
||||
// detectModelMode determines whether a model is an LLM or image generation model.
|
||||
func detectModelMode(modelName string) ModelMode {
|
||||
// Check for image generation model by looking at model_index.json
|
||||
modelType := imagegen.DetectModelType(modelName)
|
||||
modelType := DetectModelType(modelName)
|
||||
if modelType != "" {
|
||||
// Known image generation model types
|
||||
switch modelType {
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build !mlx
|
||||
|
||||
package mlxrunner
|
||||
package imagegen
|
||||
|
||||
import "errors"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package mlxrunner
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@@ -23,7 +23,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
)
|
||||
|
||||
// Server wraps an MLX runner subprocess to implement llm.LlamaServer.
|
||||
@@ -46,7 +46,7 @@ type Server struct {
|
||||
// NewServer spawns a new MLX runner subprocess and waits until it's ready.
|
||||
func NewServer(modelName string, mode ModelMode) (*Server, error) {
|
||||
// Validate platform support before attempting to start
|
||||
if err := imagegen.CheckPlatformSupport(); err != nil {
|
||||
if err := CheckPlatformSupport(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -71,8 +71,8 @@ func NewServer(modelName string, mode ModelMode) (*Server, error) {
|
||||
exe = eval
|
||||
}
|
||||
|
||||
// Spawn subprocess: ollama runner --mlx-engine --model <path> --port <port>
|
||||
cmd := exec.Command(exe, "runner", "--mlx-engine", "--model", modelName, "--port", strconv.Itoa(port))
|
||||
// Spawn subprocess: ollama runner --imagegen-engine --model <path> --port <port>
|
||||
cmd := exec.Command(exe, "runner", "--imagegen-engine", "--model", modelName, "--port", strconv.Itoa(port))
|
||||
cmd.Env = os.Environ()
|
||||
|
||||
// On Linux, set LD_LIBRARY_PATH to include MLX library directories
|
||||
@@ -107,8 +107,8 @@ func NewServer(modelName string, mode ModelMode) (*Server, error) {
|
||||
|
||||
// Estimate VRAM based on tensor size from manifest
|
||||
var vramSize uint64
|
||||
if manifest, err := imagegen.LoadManifest(modelName); err == nil {
|
||||
vramSize = uint64(manifest.TotalTensorSize())
|
||||
if modelManifest, err := manifest.LoadManifest(modelName); err == nil {
|
||||
vramSize = uint64(modelManifest.TotalTensorSize())
|
||||
} else {
|
||||
// Fallback: default to 8GB if manifest can't be loaded
|
||||
vramSize = 8 * 1024 * 1024 * 1024
|
||||
@@ -1,9 +1,9 @@
|
||||
// Package mlxrunner provides a unified MLX runner for both LLM and image generation models.
|
||||
// Package imagegen provides a unified MLX runner for both LLM and image generation models.
|
||||
//
|
||||
// This package handles safetensors models created with `ollama create --experimental`,
|
||||
// supporting both text generation (LLM) and image generation (diffusion) models
|
||||
// through a single unified interface.
|
||||
package mlxrunner
|
||||
package imagegen
|
||||
|
||||
// Request is the request format for completion requests.
|
||||
type Request struct {
|
||||
@@ -1,77 +0,0 @@
|
||||
package kvcache
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
"github.com/ollama/ollama/x/model/input"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrKvCacheFull = errors.New("could not find a kv cache slot")
|
||||
ErrNotSupported = errors.New("model does not support operation")
|
||||
)
|
||||
|
||||
type Cache interface {
|
||||
// ** used by model implementations **
|
||||
|
||||
// SetLayer sets the active layer of the cache
|
||||
SetLayer(layer int)
|
||||
|
||||
// Get returns the history of key and value tensors plus a mask
|
||||
//
|
||||
// The shape of the tensors is documented in the specific
|
||||
// cache implementation used.
|
||||
Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)
|
||||
|
||||
// Put stores a batch of key and value in the cache
|
||||
//
|
||||
// The shape of the tensors is documented in the specific
|
||||
// cache implementation used.
|
||||
Put(ctx ml.Context, key, value ml.Tensor)
|
||||
|
||||
// SetConfig controls optimizations (mostly backend-specific) that may transform
|
||||
// the output of the cache to work better with specific kernels. If not called,
|
||||
// the backend settings will be used. This works well when calling Attention.
|
||||
//
|
||||
// The config can be overridden by models, especially if they require vanilla
|
||||
// output when implementing their own version of attention. To do this, pass
|
||||
// an empty ml.CacheConfig.
|
||||
//
|
||||
// Most models will not need to use this.
|
||||
SetConfig(ml.CacheConfig)
|
||||
|
||||
// ** cache management **
|
||||
|
||||
// Init sets up runtime parameters.
|
||||
// backend: Used to allocate cache data storage and execute management operations (such as defrag)
|
||||
// dtype: The data type for storing cache entries
|
||||
// maxSequences: The maximum number of sequences stored in the cache - across all batches
|
||||
// capacity: The number of cache entries to store, per sequence
|
||||
// maxBatch: The maximum number of tokens that can occur in a single batch
|
||||
Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
|
||||
|
||||
// Close closes the cache and frees resources associated with it
|
||||
Close()
|
||||
|
||||
// StartForward is called before the start of the model's forward pass.
|
||||
// For each token in the coming batch, there must be a corresponding
|
||||
// entry in positions and seqs. reserve is to preallocate memory
|
||||
// without actually storing data in the cache.
|
||||
StartForward(ctx ml.Context, batch input.Batch, reserve bool) error
|
||||
|
||||
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
|
||||
CopyPrefix(srcSeq, dstSeq int, len int32)
|
||||
|
||||
// CanResume returns true if the cache can continue with the next token at
|
||||
// the given position and sequence. Assumes that the caller has already
|
||||
// verified the contents of the cache.
|
||||
CanResume(seq int, pos int32) bool
|
||||
|
||||
// Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set
|
||||
// endIndex to math.MaxInt32 to remove everything starting at beginIndex.
|
||||
//
|
||||
// If an error occurs, the entire context for the sequence should be
|
||||
// removed by calling Remove(seq, 0, math.MaxInt32)
|
||||
Remove(seq int, beginIndex, endIndex int32) error
|
||||
}
|
||||
@@ -1,144 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package kvcache
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
"github.com/ollama/ollama/x/model/input"
|
||||
)
|
||||
|
||||
// Causal cache stores K and V tensors according to their position in the
|
||||
// sequence. Returns the history and a mask for attending to past tokens
|
||||
type Causal struct {
|
||||
DType ml.DType
|
||||
|
||||
// locations for data storage for this batch
|
||||
curLocPut ml.Tensor
|
||||
|
||||
// locations for data storage for this batch
|
||||
curLocGet ml.Tensor
|
||||
|
||||
// the active layer for Get and Put
|
||||
curLayer int
|
||||
|
||||
capacity int
|
||||
|
||||
offset int
|
||||
|
||||
backend ml.Backend
|
||||
ctxs map[int]ml.Context
|
||||
keys, values map[int]ml.Tensor
|
||||
|
||||
// TODO is this needed per layer, or will it always be consistent?
|
||||
kHeadDims, vHeadDims, numKVHeads map[int]int
|
||||
}
|
||||
|
||||
func NewCausalCache() *Causal {
|
||||
return &Causal{
|
||||
ctxs: make(map[int]ml.Context),
|
||||
keys: make(map[int]ml.Tensor),
|
||||
values: make(map[int]ml.Tensor),
|
||||
kHeadDims: make(map[int]int),
|
||||
vHeadDims: make(map[int]int),
|
||||
numKVHeads: make(map[int]int),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
c.DType = dtype
|
||||
c.capacity = capacity
|
||||
c.backend = backend
|
||||
}
|
||||
|
||||
func (c *Causal) SetConfig(config ml.CacheConfig) {}
|
||||
|
||||
func (c *Causal) SetLayer(layer int) {
|
||||
c.curLayer = layer
|
||||
}
|
||||
|
||||
func (c *Causal) Close() {
|
||||
// slog.Info("XXX Causal.Close called", "number of contexts", len(c.ctxs))
|
||||
for _, ctx := range c.ctxs {
|
||||
ctx.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
locsPut := make([]int32, len(batch.Positions))
|
||||
for i := c.offset; i < len(batch.Positions); i++ {
|
||||
locsPut[i-c.offset] = int32(i)
|
||||
}
|
||||
c.offset += len(batch.Positions)
|
||||
locsGet := make([]int32, c.offset)
|
||||
for i := range c.offset {
|
||||
locsGet[i] = int32(i)
|
||||
}
|
||||
c.curLocGet = ctx.Input().FromInts(locsGet, len(locsGet))
|
||||
c.curLocPut = ctx.Input().FromInts(locsPut, len(locsPut))
|
||||
// slog.Info("XXX Causal.StartForward", "offset", c.offset, "put", locsPut, "get", locsGet)
|
||||
|
||||
return nil
|
||||
}
|
||||
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
kHeadDim := key.Dim(3)
|
||||
vHeadDim := value.Dim(3)
|
||||
numKVHeads := key.Dim(1)
|
||||
batchSize := key.Dim(2)
|
||||
kCellSize := kHeadDim * numKVHeads
|
||||
vCellSize := vHeadDim * numKVHeads
|
||||
// slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize, "kCellSize", kCellSize, "vCellSize", vCellSize)
|
||||
|
||||
if _, ok := c.ctxs[c.curLayer]; !ok {
|
||||
// slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer)
|
||||
c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
|
||||
}
|
||||
|
||||
if _, ok := c.keys[c.curLayer]; !ok {
|
||||
// slog.Info("XXX Causal.Put allocating keys and values", "c.curLayer", c.curLayer, "shape", []int{c.capacity, kCellSize})
|
||||
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, kCellSize)
|
||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, vCellSize)
|
||||
c.kHeadDims[c.curLayer] = kHeadDim
|
||||
c.vHeadDims[c.curLayer] = vHeadDim
|
||||
c.numKVHeads[c.curLayer] = numKVHeads
|
||||
}
|
||||
key = key.Reshape(ctx, batchSize, 1, kCellSize)
|
||||
|
||||
// slog.Info("XXX Causal.Put ", "c.keys[c.curLayer]", c.keys[c.curLayer])
|
||||
// slog.Info("XXX Causal.Put ", "c.curLocPut", c.curLocPut)
|
||||
// slog.Info("XXX Causal.Put ", "key", key)
|
||||
ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, key, []int{0}))
|
||||
value = value.Reshape(ctx, batchSize, 1, vCellSize)
|
||||
ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, value, []int{0}))
|
||||
|
||||
}
|
||||
|
||||
func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
key := c.keys[c.curLayer]
|
||||
value := c.values[c.curLayer]
|
||||
|
||||
kHeadDim := c.kHeadDims[c.curLayer]
|
||||
vHeadDim := c.vHeadDims[c.curLayer]
|
||||
numKVHeads := c.numKVHeads[c.curLayer]
|
||||
// rowSize := numKVHeads * c.curBatchSize
|
||||
// cachedSize := c.curMask.Dim(1)
|
||||
cachedSize := c.curLocGet.Dim(0)
|
||||
// kCellSize := kHeadDim * numKVHeads
|
||||
// vCellSize := vHeadDim * numKVHeads
|
||||
// slog.Info("XXX Causal.Get", "shape", []int{1, numKVHeads, cachedSize, kHeadDim})
|
||||
|
||||
key = key.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim)
|
||||
value = value.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim)
|
||||
return key, value, nil
|
||||
}
|
||||
|
||||
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *Causal) CanResume(seq int, pos int32) bool {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
@@ -1,156 +0,0 @@
|
||||
package kvcache
|
||||
|
||||
// import (
|
||||
// "fmt"
|
||||
|
||||
// "github.com/ollama/ollama/ml"
|
||||
// "github.com/ollama/ollama/model/input"
|
||||
// )
|
||||
|
||||
// // Encoder cache stores K and V tensors that are position independent
|
||||
// //
|
||||
// // The tensors can be of any shape and will be returned as they were stored
|
||||
// // The mask is currently always nil
|
||||
// //
|
||||
// // Not currently safe for multiple sequences
|
||||
// type EncoderCache struct {
|
||||
// // config controls mostly backend-specific optimizations
|
||||
// config *ml.CacheConfig
|
||||
|
||||
// // ** current forward pass **
|
||||
|
||||
// // the active layer for Get and Put
|
||||
// curLayer int
|
||||
|
||||
// // if something is stored during this pass, this
|
||||
// // will be the position (but there is no guarantee
|
||||
// // anything will be stored)
|
||||
// curPos int32
|
||||
|
||||
// // curReserve indicates that this forward pass is only for
|
||||
// // memory reservation and we should not update our metadata
|
||||
// // based on it.
|
||||
// curReserve bool
|
||||
|
||||
// // ** cache metadata **
|
||||
|
||||
// // was something stored in the cache?
|
||||
// encoderCached bool
|
||||
|
||||
// // position of the cached data
|
||||
// encoderPos int32
|
||||
|
||||
// // ** cache data storage **
|
||||
// backend ml.Backend
|
||||
// ctxs map[int]ml.Context
|
||||
// keys, values map[int]ml.Tensor
|
||||
// }
|
||||
|
||||
// func NewEncoderCache() *EncoderCache {
|
||||
// return &EncoderCache{
|
||||
// ctxs: make(map[int]ml.Context),
|
||||
// keys: make(map[int]ml.Tensor),
|
||||
// values: make(map[int]ml.Tensor),
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
// if c.config == nil {
|
||||
// var config ml.CacheConfig
|
||||
// if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
||||
// config = cc.CacheConfig()
|
||||
// }
|
||||
// c.config = &config
|
||||
// }
|
||||
|
||||
// if maxSequences > 1 {
|
||||
// panic(fmt.Errorf("encoder cache does not support multiple sequences; requested: %v", maxSequences))
|
||||
// }
|
||||
|
||||
// if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
|
||||
// panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
|
||||
// }
|
||||
|
||||
// c.backend = backend
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
|
||||
// if c.config != nil {
|
||||
// panic("config cannot be changed after being previously set, either by the model or backend")
|
||||
// }
|
||||
|
||||
// c.config = &config
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) Close() {
|
||||
// for _, ctx := range c.ctxs {
|
||||
// ctx.Close()
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
// // We work with the most recent image
|
||||
// if len(batch.Multimodal) > 0 {
|
||||
// c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
|
||||
// }
|
||||
|
||||
// c.curReserve = reserve
|
||||
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) SetLayer(layer int) {
|
||||
// c.curLayer = layer
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) EncoderCached() bool {
|
||||
// return c.encoderCached
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
// return c.keys[c.curLayer], c.values[c.curLayer], nil
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
// if !c.curReserve {
|
||||
// c.encoderPos = c.curPos
|
||||
// c.encoderCached = true
|
||||
// }
|
||||
|
||||
// if c.config.PermutedV {
|
||||
// value = value.Transpose(ctx, 1, 2, 0, 3)
|
||||
// }
|
||||
|
||||
// if _, ok := c.ctxs[c.curLayer]; !ok {
|
||||
// c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
|
||||
// }
|
||||
|
||||
// if _, ok := c.keys[c.curLayer]; !ok {
|
||||
// c.keys[c.curLayer] = c.ctxs[c.curLayer].Empty(key.DType(), key.Shape()...)
|
||||
// }
|
||||
|
||||
// if _, ok := c.values[c.curLayer]; !ok {
|
||||
// c.values[c.curLayer] = c.ctxs[c.curLayer].Empty(value.DType(), value.Shape()...)
|
||||
// }
|
||||
|
||||
// ctx.Forward(
|
||||
// key.Copy(ctx, c.keys[c.curLayer]),
|
||||
// value.Copy(ctx, c.values[c.curLayer]),
|
||||
// )
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
// panic("encoder cache does not support multiple sequences")
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) CanResume(seq int, pos int32) bool {
|
||||
// return true
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
// if c.encoderPos >= beginIndex && c.encoderPos < endIndex {
|
||||
// c.encoderCached = false
|
||||
// }
|
||||
|
||||
// return nil
|
||||
// }
|
||||
@@ -1,110 +0,0 @@
|
||||
package kvcache
|
||||
|
||||
// import (
|
||||
// "math"
|
||||
|
||||
// "github.com/ollama/ollama/ml"
|
||||
// "github.com/ollama/ollama/model/input"
|
||||
// )
|
||||
|
||||
// // Wrapper cache is a container for multiple types of caches,
|
||||
// // such as for the encoding and decoding portions of a model.
|
||||
// type WrapperCache struct {
|
||||
// // caches we are wrapping
|
||||
// caches []Cache
|
||||
|
||||
// // cache to be used for this layer
|
||||
// curType int
|
||||
// }
|
||||
|
||||
// func NewWrapperCache(caches ...Cache) *WrapperCache {
|
||||
// return &WrapperCache{
|
||||
// caches: caches,
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
// for _, cache := range c.caches {
|
||||
// cache.Init(backend, dtype, maxSequences, capacity, maxBatch)
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) SetConfig(config ml.CacheConfig) {
|
||||
// for _, cache := range c.caches {
|
||||
// cache.SetConfig(config)
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) Close() {
|
||||
// for _, cache := range c.caches {
|
||||
// cache.Close()
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
// for i, cache := range c.caches {
|
||||
// err := cache.StartForward(ctx, batch, reserve)
|
||||
// if err != nil {
|
||||
// // unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
|
||||
// for j := i - 1; j >= 0; j-- {
|
||||
// for k := range batch.Positions {
|
||||
// _ = c.caches[j].Remove(batch.Sequences[k], batch.Positions[k], math.MaxInt32)
|
||||
// }
|
||||
// }
|
||||
// return err
|
||||
// }
|
||||
// }
|
||||
|
||||
// c.curType = 0
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) SetLayer(layer int) {
|
||||
// for _, cache := range c.caches {
|
||||
// cache.SetLayer(layer)
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) SetLayerType(layerType int) {
|
||||
// c.curType = layerType
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) UnderlyingCache() Cache {
|
||||
// return c.caches[c.curType]
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
// return c.caches[c.curType].Get(ctx)
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
// c.caches[c.curType].Put(ctx, key, value)
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
// for _, cache := range c.caches {
|
||||
// cache.CopyPrefix(srcSeq, dstSeq, len)
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) CanResume(seq int, pos int32) bool {
|
||||
// for _, cache := range c.caches {
|
||||
// if !cache.CanResume(seq, pos) {
|
||||
// return false
|
||||
// }
|
||||
// }
|
||||
|
||||
// return true
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
// // If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail
|
||||
// for _, cache := range c.caches {
|
||||
// err := cache.Remove(seq, beginIndex, endIndex)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// }
|
||||
|
||||
// return nil
|
||||
// }
|
||||
433
x/ml/backend.go
433
x/ml/backend.go
@@ -1,433 +0,0 @@
|
||||
package ml
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
)
|
||||
|
||||
type Backend interface {
|
||||
// Close frees all memory associated with this backend
|
||||
// Close()
|
||||
|
||||
// Load(ctx context.Context, progress func(float32)) error
|
||||
|
||||
// BackendMemory returns the memory allocations that were made for this model
|
||||
// BackendMemory() BackendMemory
|
||||
|
||||
Config() fs.Config
|
||||
Get(name string) Tensor
|
||||
NewContext() Context
|
||||
// NewContextSize(size int) Context
|
||||
|
||||
// Enumerate the devices available for inference via this backend
|
||||
// BackendDevices() []DeviceInfo
|
||||
}
|
||||
|
||||
// BackendCacheConfig should be implemented by backends that need special output
|
||||
// from the cache to meet specific requirements. It is frequently implemented in
|
||||
// conjunction with ScaledDotProductAttention.
|
||||
type BackendCacheConfig interface {
|
||||
CacheConfig() CacheConfig
|
||||
}
|
||||
|
||||
// CacheConfig controls optimizations (mostly backend-specific) that may transform
|
||||
// the output the cache to work better with specific kernels.
|
||||
type CacheConfig struct {
|
||||
// CachePadding specifies the multiple for the number of tokens of cache history
|
||||
// that will be returned from cache Get for k, v and mask. The capacity of the
|
||||
// cache itself will also be increased to a multiple of this size if needed.
|
||||
CachePadding int
|
||||
|
||||
// PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put
|
||||
// and return the permuted version via Get. This uses the cache copy operation
|
||||
// to avoid a Contiguous call on the permuted tensor.
|
||||
PermutedV bool
|
||||
|
||||
// MaskDType specifies the data type for generating the mask. If unset it will
|
||||
// default to DTypeF32.
|
||||
MaskDType DType
|
||||
|
||||
// MaskBatchPadding specifies the multiple for the batch size dimension in the mask.
|
||||
// Any position that does not correspond to an actual token will be filled with -Inf.
|
||||
MaskBatchPadding int
|
||||
}
|
||||
|
||||
// BackendParams controls how the backend loads and executes models
|
||||
type BackendParams struct {
|
||||
// AllocMemory causes the backend to allocate memory for the model. If
|
||||
// false, this is only being used for discovering the required amount of
|
||||
// memory and cannot load the model for running.
|
||||
AllocMemory bool
|
||||
|
||||
// NumThreads sets the number of threads to use if running on the CPU
|
||||
NumThreads int
|
||||
|
||||
// GPULayers is the set of layers to offload to GPUs
|
||||
GPULayers GPULayersList
|
||||
|
||||
// FlashAttention indicates that we should use a fused flash attention kernel
|
||||
FlashAttention bool
|
||||
}
|
||||
|
||||
var backends = make(map[string]func(string, BackendParams) (Backend, error))
|
||||
|
||||
func RegisterBackend(name string, f func(string, BackendParams) (Backend, error)) {
|
||||
if _, ok := backends[name]; ok {
|
||||
panic("backend: backend already registered")
|
||||
}
|
||||
|
||||
backends[name] = f
|
||||
}
|
||||
|
||||
func NewBackend(modelPath string, params BackendParams) (Backend, error) {
|
||||
be := os.Getenv("OLLAMA_BACKEND")
|
||||
if be == "" {
|
||||
be = "mlx"
|
||||
slog.Info("Defaulting to " + be + ". Set OLLAMA_BACKEND to override")
|
||||
}
|
||||
slog.Info("Loading new engine", "backend", be)
|
||||
if backend, ok := backends[be]; ok {
|
||||
return backend(modelPath, params)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unsupported backend")
|
||||
}
|
||||
|
||||
type Context interface {
|
||||
Empty(dtype DType, shape ...int) Tensor
|
||||
Zeros(dtype DType, shape ...int) Tensor
|
||||
// FromBytes(dtype DType, s []byte, shape ...int) Tensor
|
||||
FromFloats(s []float32, shape ...int) Tensor
|
||||
FromInts(s []int32, shape ...int) Tensor
|
||||
RandomNormal(shape []int, dtype DType, loc, scale float32, key Tensor) Tensor
|
||||
|
||||
// Arange creates a 1D tensor with values within an interval (start, stop] increased by step.
|
||||
Arange(start, stop, step float32, dtype DType) Tensor
|
||||
|
||||
Forward(...Tensor) Context
|
||||
|
||||
// SetBatchSize provides a hint on the batch size to optimize processing
|
||||
// Uses heuristics if not set
|
||||
// SetBatchSize(int)
|
||||
|
||||
Compute(...Tensor)
|
||||
// ComputeWithNotify(func(), ...Tensor) // notify callback once compute has begun
|
||||
|
||||
// Reserve is analogous to Compute but rather than executing a
|
||||
// graph, simply preallocates memory. Typically called with a
|
||||
// worst case graph to ensure all resources are available for
|
||||
// for future inference.
|
||||
// Reserve()
|
||||
|
||||
// MaxGraphNodes() int
|
||||
Close()
|
||||
|
||||
// Input returns a context appropriate for creating tensors that are
|
||||
// inputs to the model (which includes things like output locations)
|
||||
Input() Context
|
||||
|
||||
// Layer returns a context appropriate for creating intermediate tensors
|
||||
Layer(int) Context
|
||||
|
||||
// Load a tensor from "filename" safetensors file, and compare with the input tensor
|
||||
// Returns error if the shape is inconsistent, or similarity measures are below 99%
|
||||
CompareWith(filename string, tensors map[string]Tensor, abortOnError bool) error
|
||||
}
|
||||
|
||||
type RoPEOptions struct {
|
||||
Base *float32
|
||||
Freqs Tensor
|
||||
}
|
||||
|
||||
func WithRoPEBase(base float32) func(*RoPEOptions) {
|
||||
return func(opts *RoPEOptions) {
|
||||
opts.Base = &base
|
||||
}
|
||||
}
|
||||
|
||||
func WithRoPEFreqs(freqs Tensor) func(*RoPEOptions) {
|
||||
return func(opts *RoPEOptions) {
|
||||
opts.Freqs = freqs
|
||||
}
|
||||
}
|
||||
|
||||
type Tensor interface {
|
||||
ToString() string
|
||||
RoPE(ctx Context, dims int, traditional bool, scale float32, offset int, options ...func(*RoPEOptions)) Tensor
|
||||
ScaledDotProductAttention(ctx Context, keys, values Tensor, scale float64, maskMode string, mask Tensor, sinks Tensor) Tensor
|
||||
TakeAxes(ctx Context, indicies Tensor, axes int) Tensor
|
||||
// TakeAxes(ctx Context, axes int, indicies ...int) Tensor
|
||||
|
||||
Dim(n int) int
|
||||
Stride(n int) int
|
||||
|
||||
Shape() []int
|
||||
DType() DType
|
||||
// Cast(ctx Context, dtype DType) Tensor
|
||||
|
||||
// Bytes() []byte
|
||||
Floats() []float32
|
||||
Ints() []int32
|
||||
|
||||
// FromBytes([]byte)
|
||||
// FromFloats([]float32)
|
||||
// FromInts([]int32)
|
||||
|
||||
Add(ctx Context, t2 Tensor) Tensor
|
||||
Sub(ctx Context, t2 Tensor) Tensor
|
||||
// Mul(ctx Context, t2 Tensor) Tensor
|
||||
// Div(ctx Context, t2 Tensor) Tensor
|
||||
|
||||
Max(ctx Context, axes []int, keepDims bool) Tensor
|
||||
Min(ctx Context, axes []int, keepDims bool) Tensor
|
||||
|
||||
Matmul(ctx Context, a2 Tensor) Tensor
|
||||
// Mulmat(ctx Context, t2 Tensor) Tensor
|
||||
// MulmatFullPrec(ctx Context, t2 Tensor) Tensor
|
||||
// MulmatID(ctx Context, t2, ids Tensor) Tensor
|
||||
// AddID(ctx Context, t2, ids Tensor) Tensor
|
||||
|
||||
Softmax(ctx Context) Tensor
|
||||
L2Norm(ctx Context, eps float32) Tensor
|
||||
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
|
||||
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
|
||||
Scale(ctx Context, s float64) Tensor
|
||||
// SumRows(ctx Context) Tensor
|
||||
|
||||
AvgPool2D(ctx Context, k, s int, p float32) Tensor
|
||||
Conv2D(ctx Context, weight Tensor, stride0, stride1, padding0, padding1, dilation0, dilation1, groups int) Tensor
|
||||
Conv3D(ctx Context, weight Tensor, stride0, stride1, stride2, padding0, padding1, padding2, dilation0, dilation1, dilation2, groups int) Tensor
|
||||
|
||||
// IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||
|
||||
// Sin(ctx Context) Tensor
|
||||
// Cos(ctx Context) Tensor
|
||||
// Tanh(ctx Context) Tensor
|
||||
GELU(ctx Context, up ...Tensor) Tensor
|
||||
// QuickGELU(ctx Context, up ...Tensor) Tensor
|
||||
// SILU(ctx Context, up ...Tensor) Tensor
|
||||
// RELU(ctx Context, up ...Tensor) Tensor
|
||||
// Sigmoid(ctx Context) Tensor
|
||||
|
||||
// AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit]
|
||||
// SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor
|
||||
|
||||
Reshape(ctx Context, shape ...int) Tensor
|
||||
AsStrided(ctx Context, shape, strides []int, offset int) Tensor
|
||||
Transpose(ctx Context, shape ...int) Tensor
|
||||
Contiguous(ctx Context, allowColMajor bool) Tensor
|
||||
|
||||
// Pad(ctx Context, shape ...int) Tensor
|
||||
|
||||
// Stack(ctx Context, dim int, s ...Tensor) Tensor
|
||||
|
||||
// Repeat repeats the tensor n times along dimension dim
|
||||
// Repeat(ctx Context, dim, n int) Tensor
|
||||
// Concat(ctx Context, t2 Tensor, dim int) Tensor
|
||||
// Rows(ctx Context, t2 Tensor) Tensor
|
||||
|
||||
// TODO these probably aren't actually needed - false starts on trying to wire up cache
|
||||
// SliceUpdate(ctx Context, update Tensor, start, stop, strides []int) Tensor
|
||||
// SliceUpdateDynamic(ctx Context, update, start Tensor, axes []int) Tensor
|
||||
// PutAlongAxis(ctx Context, indicies, values Tensor, axis int) Tensor
|
||||
|
||||
Scatter(ctx Context, indicies []Tensor, updates Tensor, axes []int) Tensor
|
||||
|
||||
Copy(ctx Context, t2 Tensor) Tensor
|
||||
// Duplicate(ctx Context) Tensor
|
||||
|
||||
// Slice(ctx Context, dim, low, high, step int) Tensor
|
||||
// Chunk(ctx Context, dim int, size int) []Tensor
|
||||
// ChunkSections(ctx Context, dim int, sections ...int) []Tensor
|
||||
|
||||
// TopK(ctx Context, k int) Tensor
|
||||
// Argsort(ctx Context) Tensor
|
||||
// Mean(ctx Context) Tensor
|
||||
// Variance(ctx Context) Tensor
|
||||
// Stddev(ctx Context) Tensor
|
||||
// Sqr(ctx Context) Tensor
|
||||
// Sqrt(ctx Context) Tensor
|
||||
|
||||
// Interpolate(ctx Context, dims [4]int, samplingMode SamplingMode) Tensor
|
||||
}
|
||||
|
||||
// ScaledDotProductAttention implements a fused attention
|
||||
// operation equivalent to following code on a tensor named
|
||||
// query:
|
||||
//
|
||||
// query = query.Permute(ctx, 0, 2, 1, 3)
|
||||
// key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
//
|
||||
// kq := key.MulmatFullPrec(ctx, query)
|
||||
//
|
||||
// kq = kq.Scale(ctx, scale)
|
||||
//
|
||||
// if mask != nil {
|
||||
// kq = kq.Add(ctx, mask)
|
||||
// }
|
||||
//
|
||||
// kq = kq.Softmax(ctx)
|
||||
//
|
||||
// kqv := value.Mulmat(ctx, kq)
|
||||
// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
// type ScaledDotProductAttention interface {
|
||||
// ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64) Tensor
|
||||
// }
|
||||
|
||||
// type number interface {
|
||||
// ~int | ~int8 | ~int16 | ~int32 | ~int64 |
|
||||
// ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
|
||||
// ~float32 | ~float64 |
|
||||
// ~complex64 | ~complex128
|
||||
// }
|
||||
|
||||
// func mul[T number](s ...T) T {
|
||||
// p := T(1)
|
||||
// for _, v := range s {
|
||||
// p *= v
|
||||
// }
|
||||
|
||||
// return p
|
||||
// }
|
||||
|
||||
// type DumpOptions func(*dumpOptions)
|
||||
|
||||
// // DumpWithPrecision sets the number of decimal places to print. Applies to float32 and float64.
|
||||
// func DumpWithPrecision(n int) DumpOptions {
|
||||
// return func(opts *dumpOptions) {
|
||||
// opts.Precision = n
|
||||
// }
|
||||
// }
|
||||
|
||||
// // DumpWithThreshold sets the threshold for printing the entire tensor. If the number of elements
|
||||
// // is less than or equal to this value, the entire tensor will be printed. Otherwise, only the
|
||||
// // beginning and end of each dimension will be printed.
|
||||
// func DumpWithThreshold(n int) DumpOptions {
|
||||
// return func(opts *dumpOptions) {
|
||||
// opts.Threshold = n
|
||||
// }
|
||||
// }
|
||||
|
||||
// // DumpWithEdgeItems sets the number of elements to print at the beginning and end of each dimension.
|
||||
// func DumpWithEdgeItems(n int) DumpOptions {
|
||||
// return func(opts *dumpOptions) {
|
||||
// opts.EdgeItems = n
|
||||
// }
|
||||
// }
|
||||
|
||||
// type dumpOptions struct {
|
||||
// Precision, Threshold, EdgeItems int
|
||||
// }
|
||||
|
||||
// func Dump(ctx Context, t Tensor, optsFuncs ...DumpOptions) string {
|
||||
// opts := dumpOptions{Precision: 4, Threshold: 1000, EdgeItems: 3}
|
||||
// for _, optsFunc := range optsFuncs {
|
||||
// optsFunc(&opts)
|
||||
// }
|
||||
|
||||
// if mul(t.Shape()...) <= opts.Threshold {
|
||||
// opts.EdgeItems = math.MaxInt
|
||||
// }
|
||||
|
||||
// switch t.DType() {
|
||||
// case DTypeFloat32:
|
||||
// return dump[[]float32](ctx, t, opts.EdgeItems, func(f float32) string {
|
||||
// return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
|
||||
// })
|
||||
// case DTypeFloat16: // TODO other types...
|
||||
// f32 := ctx.Input().Empty(DTypeFloat32, t.Shape()...)
|
||||
// f32 = t.Copy(ctx, f32)
|
||||
// return dump[[]float32](ctx, f32, opts.EdgeItems, func(f float32) string {
|
||||
// return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
|
||||
// })
|
||||
// case DTypeInt32:
|
||||
// return dump[[]int32](ctx, t, opts.EdgeItems, func(i int32) string {
|
||||
// return strconv.FormatInt(int64(i), 10)
|
||||
// })
|
||||
// default:
|
||||
// return "<unsupported>"
|
||||
// }
|
||||
// }
|
||||
|
||||
// func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
|
||||
// if t.Bytes() == nil {
|
||||
// ctx.Compute(t)
|
||||
// }
|
||||
|
||||
// s := make(S, mul(t.Shape()...))
|
||||
// if err := binary.Read(bytes.NewBuffer(t.Bytes()), binary.LittleEndian, &s); err != nil {
|
||||
// panic(err)
|
||||
// }
|
||||
|
||||
// shape := t.Shape()
|
||||
// slices.Reverse(shape)
|
||||
|
||||
// var sb strings.Builder
|
||||
// var f func([]int, int)
|
||||
// f = func(dims []int, stride int) {
|
||||
// prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
|
||||
// sb.WriteString("[")
|
||||
// defer func() { sb.WriteString("]") }()
|
||||
// for i := 0; i < dims[0]; i++ {
|
||||
// if i >= items && i < dims[0]-items {
|
||||
// sb.WriteString("..., ")
|
||||
// // skip to next printable element
|
||||
// skip := dims[0] - 2*items
|
||||
// if len(dims) > 1 {
|
||||
// stride += mul(append(dims[1:], skip)...)
|
||||
// fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
|
||||
// }
|
||||
// i += skip - 1
|
||||
// } else if len(dims) > 1 {
|
||||
// f(dims[1:], stride)
|
||||
// stride += mul(dims[1:]...)
|
||||
// if i < dims[0]-1 {
|
||||
// fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
|
||||
// }
|
||||
// } else {
|
||||
// text := fn(s[stride+i])
|
||||
// if len(text) > 0 && text[0] != '-' {
|
||||
// sb.WriteString(" ")
|
||||
// }
|
||||
|
||||
// sb.WriteString(text)
|
||||
// if i < dims[0]-1 {
|
||||
// sb.WriteString(", ")
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// f(shape, 0)
|
||||
|
||||
// return sb.String()
|
||||
// }
|
||||
|
||||
type DType int
|
||||
|
||||
const (
|
||||
DTypeBool DType = iota
|
||||
DTypeUint8
|
||||
DTypeUint16
|
||||
DTypeUint32
|
||||
DTypeUint64
|
||||
DTypeInt8
|
||||
DTypeInt16
|
||||
DTypeInt32
|
||||
DTypeInt64
|
||||
DTypeFloat16
|
||||
DTypeFloat32
|
||||
DTypeFloat64
|
||||
DTypeBfloat16
|
||||
DTypeComplex64
|
||||
)
|
||||
|
||||
type SamplingMode int
|
||||
|
||||
const (
|
||||
SamplingModeNearest SamplingMode = iota
|
||||
SamplingModeBilinear
|
||||
)
|
||||
@@ -1,3 +0,0 @@
|
||||
package backend
|
||||
|
||||
// _ "github.com/ollama/ollama/x/ml/backend/mlx"
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,92 +0,0 @@
|
||||
// mlx_dynamic.c - Dynamic loading wrapper for MLX-C library
|
||||
// This file provides runtime dynamic loading of libmlxc instead of link-time binding
|
||||
|
||||
#include "mlx_dynamic.h"
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <windows.h>
|
||||
typedef HMODULE lib_handle_t;
|
||||
#define LOAD_LIB(path) LoadLibraryA(path)
|
||||
#define GET_SYMBOL(handle, name) GetProcAddress(handle, name)
|
||||
#define CLOSE_LIB(handle) FreeLibrary(handle)
|
||||
#define LIB_ERROR() "LoadLibrary failed"
|
||||
static const char* LIB_NAMES[] = {"libmlxc.dll", NULL};
|
||||
#else
|
||||
#include <dlfcn.h>
|
||||
typedef void* lib_handle_t;
|
||||
#define LOAD_LIB(path) dlopen(path, RTLD_LAZY | RTLD_GLOBAL)
|
||||
#define GET_SYMBOL(handle, name) dlsym(handle, name)
|
||||
#define CLOSE_LIB(handle) dlclose(handle)
|
||||
#define LIB_ERROR() dlerror()
|
||||
#ifdef __APPLE__
|
||||
static const char* LIB_NAMES[] = {
|
||||
"libmlxc.dylib",
|
||||
"@loader_path/../build/lib/ollama/libmlxc.dylib",
|
||||
"@executable_path/../build/lib/ollama/libmlxc.dylib",
|
||||
"build/lib/ollama/libmlxc.dylib",
|
||||
"../build/lib/ollama/libmlxc.dylib",
|
||||
NULL
|
||||
};
|
||||
#else
|
||||
static const char* LIB_NAMES[] = {
|
||||
"libmlxc.so",
|
||||
"$ORIGIN/../build/lib/ollama/libmlxc.so",
|
||||
"build/lib/ollama/libmlxc.so",
|
||||
"../build/lib/ollama/libmlxc.so",
|
||||
NULL
|
||||
};
|
||||
#endif
|
||||
#endif
|
||||
|
||||
static lib_handle_t mlx_handle = NULL;
|
||||
static int mlx_initialized = 0;
|
||||
static char mlx_error_buffer[512] = {0};
|
||||
|
||||
// Initialize MLX dynamic library
|
||||
// Returns 0 on success, -1 on failure
|
||||
// On failure, call mlx_dynamic_error() to get error message
|
||||
int mlx_dynamic_init(void) {
|
||||
if (mlx_initialized) {
|
||||
return 0; // Already initialized
|
||||
}
|
||||
|
||||
// Try each possible library path
|
||||
for (int i = 0; LIB_NAMES[i] != NULL; i++) {
|
||||
mlx_handle = LOAD_LIB(LIB_NAMES[i]);
|
||||
if (mlx_handle != NULL) {
|
||||
mlx_initialized = 1;
|
||||
snprintf(mlx_error_buffer, sizeof(mlx_error_buffer),
|
||||
"MLX: Successfully loaded %s", LIB_NAMES[i]);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Failed to load library
|
||||
const char* err = LIB_ERROR();
|
||||
snprintf(mlx_error_buffer, sizeof(mlx_error_buffer),
|
||||
"MLX: Failed to load libmlxc library. %s",
|
||||
err ? err : "Unknown error");
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Get the last error message
|
||||
const char* mlx_dynamic_error(void) {
|
||||
return mlx_error_buffer;
|
||||
}
|
||||
|
||||
// Check if MLX is initialized
|
||||
int mlx_dynamic_is_initialized(void) {
|
||||
return mlx_initialized;
|
||||
}
|
||||
|
||||
// Cleanup (optional, called at program exit)
|
||||
void mlx_dynamic_cleanup(void) {
|
||||
if (mlx_handle != NULL) {
|
||||
CLOSE_LIB(mlx_handle);
|
||||
mlx_handle = NULL;
|
||||
mlx_initialized = 0;
|
||||
}
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
// mlx_dynamic.h - Dynamic loading interface for MLX-C library
|
||||
#ifndef MLX_DYNAMIC_H
|
||||
#define MLX_DYNAMIC_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// Initialize the MLX dynamic library
|
||||
// Returns 0 on success, -1 on failure
|
||||
int mlx_dynamic_init(void);
|
||||
|
||||
// Get the last error message from dynamic loading
|
||||
const char* mlx_dynamic_error(void);
|
||||
|
||||
// Check if MLX is initialized
|
||||
int mlx_dynamic_is_initialized(void);
|
||||
|
||||
// Cleanup resources (optional, for clean shutdown)
|
||||
void mlx_dynamic_cleanup(void);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MLX_DYNAMIC_H
|
||||
@@ -1,314 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlx
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/runner/common"
|
||||
"github.com/ollama/ollama/sample"
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
"github.com/ollama/ollama/x/model"
|
||||
"github.com/ollama/ollama/x/model/input"
|
||||
_ "github.com/ollama/ollama/x/model/models/gemma3"
|
||||
)
|
||||
|
||||
func init() {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
slog.SetDefault(logger)
|
||||
}
|
||||
|
||||
func TestLoadModel(t *testing.T) {
|
||||
dir := "/Users/daniel/Models/gemma-3-4b-it/"
|
||||
b := &Backend{}
|
||||
err := b.LoadSafeTensors(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("load failed: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromInts(t *testing.T) {
|
||||
b := &Backend{}
|
||||
c := b.NewContext()
|
||||
defer c.Close()
|
||||
data := []int32{1, 2, 3, 4, 5, 6}
|
||||
a := c.FromInts(data, 2, 3)
|
||||
slog.Info("", "array", a)
|
||||
t.Log(a.ToString())
|
||||
if !reflect.DeepEqual(a.Shape(), []int{2, 3}) {
|
||||
t.Fatalf("incorrect shape: %v", a.Shape())
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromFloats(t *testing.T) {
|
||||
b := &Backend{}
|
||||
c := b.NewContext()
|
||||
defer c.Close()
|
||||
data := []float32{1, 2, 3, 4, 5, 6}
|
||||
a := c.FromFloats(data, 2, 3)
|
||||
slog.Info("", "array", a)
|
||||
t.Log(a.ToString())
|
||||
if !reflect.DeepEqual(a.Shape(), []int{2, 3}) {
|
||||
t.Fatalf("incorrect shape: %v", a.Shape())
|
||||
}
|
||||
res := a.Floats()
|
||||
if !reflect.DeepEqual(res, data) {
|
||||
t.Fatalf("incorrect results: %v", res)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdd(t *testing.T) {
|
||||
b := &Backend{}
|
||||
c := b.NewContext()
|
||||
defer c.Close()
|
||||
t1 := c.Arange(0, 24, 1, ml.DTypeFloat16)
|
||||
t2 := c.Arange(0, 24, 1, ml.DTypeFloat16)
|
||||
exp := c.Arange(0, 48, 2, ml.DTypeFloat16)
|
||||
t3 := t1.Add(c, t2)
|
||||
c.Compute(t3, exp)
|
||||
t3f := t3.Floats()
|
||||
if !reflect.DeepEqual(t3f, exp.Floats()) {
|
||||
t.Fatalf("incorrect result: %v", t3f)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReshapeTranspose(t *testing.T) {
|
||||
b := &Backend{}
|
||||
c := b.NewContext()
|
||||
defer c.Close()
|
||||
t1 := c.Arange(0, 24, 1, ml.DTypeFloat16).Reshape(c, 2, 3, 4).Transpose(c, 0, 2, 1).Contiguous(c, false)
|
||||
c.Compute(t1)
|
||||
t1f := t1.Floats()
|
||||
exp := []float32{
|
||||
0, 4, 8,
|
||||
1, 5, 9,
|
||||
2, 6, 10,
|
||||
3, 7, 11,
|
||||
12, 16, 20,
|
||||
13, 17, 21,
|
||||
14, 18, 22,
|
||||
15, 19, 23,
|
||||
}
|
||||
if !reflect.DeepEqual(t1f, exp) {
|
||||
t.Fatalf("incorrect results: %v", t1f)
|
||||
}
|
||||
}
|
||||
|
||||
func prod(vals ...int) int {
|
||||
r := 1
|
||||
for _, v := range vals {
|
||||
r *= v
|
||||
}
|
||||
return r
|
||||
}
|
||||
func TestMatmul(t *testing.T) {
|
||||
// TODO create scenarios...
|
||||
b := &Backend{}
|
||||
c := b.NewContext()
|
||||
defer c.Close()
|
||||
s1 := []int{1, 3, 2, 4}
|
||||
t1 := c.Arange(0, float32(prod(s1...)), 1, ml.DTypeFloat16).Reshape(c, s1...)
|
||||
s2 := []int{4, 2}
|
||||
t2 := c.Arange(0, float32(prod(s2...)), 1, ml.DTypeFloat16).Reshape(c, s2...)
|
||||
t3 := t1.Matmul(c, t2)
|
||||
exp := []float32{
|
||||
28, 34,
|
||||
76, 98,
|
||||
|
||||
124, 162,
|
||||
172, 226,
|
||||
|
||||
220, 290,
|
||||
268, 354,
|
||||
}
|
||||
c.Compute(t3)
|
||||
t3f := t3.Floats()
|
||||
if !reflect.DeepEqual(t3f, exp) {
|
||||
t.Fatalf("incorrect result: %v", t3f)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRows(t *testing.T) {
|
||||
b := &Backend{}
|
||||
c := b.NewContext()
|
||||
defer c.Close()
|
||||
t1 := c.Arange(0, 12, 1, ml.DTypeFloat32).Reshape(c, 1, 4, 3)
|
||||
outputs := c.Zeros(ml.DTypeInt32, 1)
|
||||
t2 := t1.TakeAxes(c, outputs, 1)
|
||||
c.Forward(t1, t2).Compute(t1, t2)
|
||||
t.Log(t1.ToString())
|
||||
t.Log(t2.ToString())
|
||||
f := t2.Floats()
|
||||
t.Logf("Result: %v", f)
|
||||
}
|
||||
|
||||
func TestCaching(t *testing.T) {
|
||||
// Validate the caching algorithm
|
||||
b := &Backend{}
|
||||
c := b.NewContext()
|
||||
defer c.Close()
|
||||
batchSize := 3
|
||||
headDim := 4
|
||||
numKVHeads := 2
|
||||
// Make cache twice the size of one test batch
|
||||
cells := batchSize * 2
|
||||
cellSize := numKVHeads * headDim
|
||||
shape := []int{1, numKVHeads, batchSize, headDim}
|
||||
stop := float32(1)
|
||||
for _, x := range shape {
|
||||
stop *= float32(x)
|
||||
}
|
||||
// Create the cache
|
||||
cache := c.Zeros(ml.DTypeFloat16, cells, cellSize)
|
||||
t.Logf("Empty Cache shape%v\n"+cache.ToString(), []int{cells, cellSize})
|
||||
|
||||
// Input tensor
|
||||
t1 := c.Arange(0, stop, 1, ml.DTypeFloat16).Reshape(c, shape...)
|
||||
t.Logf("Initial Data shape%v\n"+t1.ToString(), shape)
|
||||
|
||||
// Reshape to copy into the cache
|
||||
/*
|
||||
From MLX python/src/indexing.cpp mlx_scatter_args_array
|
||||
// The update shape must broadcast with indices.shape + [1] + src.shape[1:]
|
||||
auto up_shape = indices.shape();
|
||||
up_shape.insert(up_shape.end(), src.shape().begin() + 1, src.shape().end());
|
||||
up = broadcast_to(up, up_shape);
|
||||
up_shape.insert(up_shape.begin() + indices.ndim(), 1);
|
||||
up = reshape(up, up_shape);
|
||||
*/
|
||||
numRows := 3
|
||||
up := t1.Reshape(c, numRows, 1, cellSize) // The shape has to look like this for scatter to work properly
|
||||
t.Logf("Data reshaped for cache input shape%v\n"+up.ToString(), []int{batchSize, numKVHeads * headDim})
|
||||
|
||||
// Simulate cells 1,3,5 are available
|
||||
indicies := []ml.Tensor{c.FromInts([]int32{1, 3, 5}, numRows)}
|
||||
t.Logf("Indicies shape%v\n"+indicies[0].ToString(), []int{numRows})
|
||||
axis := []int{0} // The 1,3,5 of the indicies are in reference to axis 0 in the cache shape
|
||||
cache.Scatter(c, indicies, up, axis)
|
||||
|
||||
c.Forward(cache)
|
||||
// Cache should contain the data now
|
||||
t.Log("Cache after put\n" + cache.ToString())
|
||||
|
||||
// Retrieve cache content and verify it matches
|
||||
out := cache.TakeAxes(c, indicies[0], 0).Reshape(c, shape...)
|
||||
t.Logf("Output shape%v\n"+out.ToString(), out.Shape())
|
||||
|
||||
t1f := t1.Floats()
|
||||
outf := out.Floats()
|
||||
if !reflect.DeepEqual(t1f, outf) {
|
||||
t.Fatalf("mismatched in->out\n%v\n ->\n%v", t1f, outf)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGemma3(t *testing.T) {
|
||||
// Why is the sky blue
|
||||
inputs := []int32{2, 105, 2364, 107, 36425, 563, 506, 7217, 3730, 106, 107, 105, 4368}
|
||||
limit := 50
|
||||
|
||||
// TODO generalize this
|
||||
dir := "/Users/daniel/Models/gemma-3-4b-it/"
|
||||
|
||||
m, err := model.New(dir, ml.BackendParams{})
|
||||
if err != nil {
|
||||
t.Fatalf("unable to load model: %s", err)
|
||||
}
|
||||
b := m.Backend()
|
||||
ctx := b.NewContext()
|
||||
defer ctx.Close()
|
||||
|
||||
batch := input.Batch{
|
||||
Inputs: ctx.FromInts(inputs[:], 1, len(inputs)),
|
||||
Positions: make([]int32, len(inputs)),
|
||||
Sequences: make([]int, len(inputs)),
|
||||
Outputs: ctx.FromInts([]int32{int32(len(inputs) - 1)}, 1),
|
||||
Offset: 0,
|
||||
}
|
||||
for i := range len(inputs) {
|
||||
batch.Positions[i] = int32(i)
|
||||
}
|
||||
offset := len(inputs)
|
||||
|
||||
cache := m.Config().Cache
|
||||
if cache != nil {
|
||||
numSlots := 1
|
||||
batchSize := 512
|
||||
numCtx := 4096
|
||||
|
||||
// Note: this is inconsistent with mlx-py, but trying to be consistent with the GGML cache impl to get things working
|
||||
// cache.SetConfig(ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeBfloat16, MaskBatchPadding: 64})
|
||||
cache.SetConfig(ml.CacheConfig{CachePadding: 0, MaskDType: ml.DTypeBfloat16, MaskBatchPadding: 0})
|
||||
|
||||
cache.Init(b, ml.DTypeBfloat16, numSlots, int(numCtx), batchSize)
|
||||
err := cache.StartForward(ctx, batch, false)
|
||||
if err != nil {
|
||||
t.Fatalf("failed cache.StartForward: %s", err)
|
||||
}
|
||||
}
|
||||
opts := api.DefaultOptions()
|
||||
var grammar *sample.GrammarSampler
|
||||
sampler := sample.NewSampler(
|
||||
opts.Temperature,
|
||||
opts.TopK,
|
||||
opts.TopP,
|
||||
opts.MinP,
|
||||
opts.Seed,
|
||||
grammar,
|
||||
)
|
||||
|
||||
t.Log("Starting Forward pass loop")
|
||||
pendingResponses := []string{}
|
||||
for {
|
||||
out, err := m.Forward(ctx, batch)
|
||||
if err != nil {
|
||||
t.Fatalf("failed forward pass: %s", err)
|
||||
}
|
||||
ctx.Forward(out)
|
||||
outputs := out.Floats()
|
||||
t.Logf("finished forward pass! length:%d", len(outputs))
|
||||
// sample a token
|
||||
logits := outputs
|
||||
token, err := sampler.Sample(logits)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to sample token: %s", err)
|
||||
}
|
||||
t.Logf("Sampled token: %v", token)
|
||||
if m.(model.TextProcessor).Is(token, model.SpecialEOS) {
|
||||
t.Log("hit EOS")
|
||||
break
|
||||
}
|
||||
piece, err := m.(model.TextProcessor).Decode([]int32{token})
|
||||
if err != nil {
|
||||
t.Fatalf("unable to decode token: %s", err)
|
||||
}
|
||||
|
||||
pendingResponses = append(pendingResponses, piece)
|
||||
sequence := strings.Join(pendingResponses, "")
|
||||
if ok, stop := common.FindStop(sequence, opts.Stop); ok {
|
||||
t.Logf("hit stop token: %v", stop)
|
||||
break
|
||||
}
|
||||
t.Logf("RESULTS: %s", sequence)
|
||||
batch = input.Batch{
|
||||
Inputs: ctx.FromInts([]int32{token}, 1, 1),
|
||||
Positions: make([]int32, 1),
|
||||
Sequences: make([]int, 1),
|
||||
Outputs: ctx.FromInts([]int32{0}, 1),
|
||||
Offset: offset,
|
||||
}
|
||||
offset++
|
||||
batch.Positions[0] = 0
|
||||
err = cache.StartForward(ctx, batch, false)
|
||||
if err != nil {
|
||||
t.Fatalf("failed cache.StartForward: %s", err)
|
||||
}
|
||||
if offset > limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,335 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlx
|
||||
|
||||
/*
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/ops.h"
|
||||
|
||||
// Derived from https://github.com/ml-explore/mlx/blob/main/mlx/io/gguf_quants.cpp
|
||||
|
||||
void unpack_32_4(uint8_t* data, int8_t* dst) {
|
||||
memset(dst, 0, 16);
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
uint8_t x = (data[j + 2] & 0x0F); // j+2 to skip scale bytes.
|
||||
if (j % 2 != 0) {
|
||||
x <<= 4;
|
||||
}
|
||||
dst[j / 2] += x;
|
||||
}
|
||||
// Last 16 weights are in the higher bits
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
uint8_t x = (data[j + 2] >> 4);
|
||||
if (j % 2 != 0) {
|
||||
x <<= 4;
|
||||
}
|
||||
dst[8 + j / 2] += x;
|
||||
}
|
||||
}
|
||||
|
||||
// Extracts (weight, scales, biases) from Q4_0 tensors.
|
||||
// Data layout is: |16 bit scale|32 x 4bit weights|.
|
||||
void extract_q4_0_data(
|
||||
uint8_t* data,
|
||||
mlx_array* weights_arr,
|
||||
mlx_array* scales_arr,
|
||||
mlx_array* biases_arr) {
|
||||
const uint64_t bytes_per_block = 18; // 2 bytes scale, 32x0.5 byte weights
|
||||
uint8_t* weights = mlx_array_data_uint8(*weights_arr);
|
||||
float16_t* scales = mlx_array_data_float16(*scales_arr);
|
||||
float16_t* biases = mlx_array_data_float16(*biases_arr);
|
||||
for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
|
||||
scales[i] = *((float16_t*)data);
|
||||
biases[i] = -8 * scales[i];
|
||||
unpack_32_4(data, weights);
|
||||
weights += 16;
|
||||
data += bytes_per_block;
|
||||
}
|
||||
}
|
||||
|
||||
// Extracts (weight, scales, biases) from Q4_1 tensors.
|
||||
// Data layout is: |16 bit scale|16 bit bias|32 x 4bit weights|.
|
||||
void extract_q4_1_data(
|
||||
uint8_t* data,
|
||||
mlx_array* weights_arr,
|
||||
mlx_array* scales_arr,
|
||||
mlx_array* biases_arr) {
|
||||
const uint64_t bytes_per_block = 20; // 2 bytes scale, 2 bytes bias, 32x0.5 byte weights
|
||||
uint8_t* weights = mlx_array_data_uint8(*weights_arr);
|
||||
float16_t* scales = mlx_array_data_float16(*scales_arr);
|
||||
float16_t* biases = mlx_array_data_float16(*biases_arr);
|
||||
for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
|
||||
scales[i] = *((float16_t*)data);
|
||||
biases[i] = *((float16_t*)(data) + 1);
|
||||
unpack_32_4(data, weights);
|
||||
weights += 16;
|
||||
data += bytes_per_block;
|
||||
}
|
||||
}
|
||||
|
||||
// Extracts (weight, scales, biases) from Q8_0 tensors.
|
||||
// Data layout is: |16 bit scale|32 x 8bit weights|.
|
||||
void extract_q8_0_data(
|
||||
uint8_t* data,
|
||||
mlx_array* weights_arr,
|
||||
mlx_array* scales_arr,
|
||||
mlx_array* biases_arr) {
|
||||
const uint64_t weights_per_block = 32;
|
||||
const uint64_t bytes_per_block = 34; // 2 bytes scale, 32x1 byte weights
|
||||
uint8_t* weights = mlx_array_data_uint8(*weights_arr);
|
||||
float16_t* scales = mlx_array_data_float16(*scales_arr);
|
||||
float16_t* biases = mlx_array_data_float16(*biases_arr);
|
||||
for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
|
||||
uint8_t* block_data = data + i * bytes_per_block;
|
||||
scales[i] = *((float16_t*)block_data);
|
||||
biases[i] = -128 * scales[i];
|
||||
for (int64_t j = 0; j < weights_per_block; ++j) {
|
||||
uint8_t x = block_data[j + 2]; // j+2 to skip the scale bytes.
|
||||
// Original data is in int8_t, so we add a bias of -128 and invert the
|
||||
// first bit.
|
||||
x ^= 1 << 7;
|
||||
weights[i * weights_per_block + j] = x;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Drived from ggml-quants.c
|
||||
|
||||
#define QK_K 256
|
||||
|
||||
// 6-bit quantization
|
||||
// weight is represented as x = a * q
|
||||
// 16 blocks of 16 elements each
|
||||
// Effectively 6.5625 bits per weight
|
||||
typedef struct {
|
||||
uint8_t ql[QK_K/2]; // quants, lower 4 bits
|
||||
uint8_t qh[QK_K/4]; // quants, upper 2 bits
|
||||
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
|
||||
uint16_t d; // super-block scale
|
||||
} block_q6_K;
|
||||
|
||||
void dequant_row_q6_K(const void * restrict vx, void * restrict vy, int k) {
|
||||
const int64_t nb = k / QK_K;
|
||||
block_q6_K *x = (block_q6_K *)vx;
|
||||
float16_t* y = (float16_t *)vy;
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
float16_t d = 0.0;
|
||||
memcpy(&d, &x[i].d, sizeof(d));
|
||||
|
||||
const uint8_t * restrict ql = x[i].ql;
|
||||
const uint8_t * restrict qh = x[i].qh;
|
||||
const int8_t * restrict sc = x[i].scales;
|
||||
|
||||
for (int n = 0; n < QK_K; n += 128) {
|
||||
for (int l = 0; l < 32; ++l) {
|
||||
int is = l/16;
|
||||
const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
|
||||
const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
|
||||
const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
|
||||
const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
|
||||
y[l + 0] = d * sc[is + 0] * q1;
|
||||
y[l + 32] = d * sc[is + 2] * q2;
|
||||
y[l + 64] = d * sc[is + 4] * q3;
|
||||
y[l + 96] = d * sc[is + 6] * q4;
|
||||
}
|
||||
y += 128;
|
||||
ql += 64;
|
||||
qh += 32;
|
||||
sc += 8;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define K_SCALE_SIZE 12
|
||||
#define GGML_COMMON_AGGR_U
|
||||
#define GGML_COMMON_AGGR_S
|
||||
|
||||
// 4-bit quantization
|
||||
// 8 blocks of 32 elements each
|
||||
// weight is represented as x = a * q + b
|
||||
// Effectively 4.5 bits per weight
|
||||
typedef struct {
|
||||
union {
|
||||
struct {
|
||||
uint16_t d; // super-block scale for quantized scales
|
||||
uint16_t dmin; // super-block scale for quantized mins
|
||||
} GGML_COMMON_AGGR_S;
|
||||
uint16_t dm;
|
||||
} GGML_COMMON_AGGR_U;
|
||||
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
|
||||
uint8_t qs[QK_K/2]; // 4--bit quants
|
||||
} block_q4_K;
|
||||
|
||||
static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
|
||||
if (j < 4) {
|
||||
*d = q[j] & 63; *m = q[j + 4] & 63;
|
||||
} else {
|
||||
*d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
|
||||
*m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
|
||||
}
|
||||
}
|
||||
|
||||
void dequant_row_q4_K(const void * restrict vx, void * restrict vy, int k) {
|
||||
block_q4_K *x = (block_q4_K *)vx;
|
||||
float16_t* y = (float16_t *)vy;
|
||||
const int nb = k / QK_K;
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const uint8_t * q = x[i].qs;
|
||||
float16_t d = 0.0;
|
||||
memcpy(&d, &x[i].d, sizeof(d));
|
||||
float16_t min = 0.0;
|
||||
memcpy(&min, &x[i].dmin, sizeof(d));
|
||||
|
||||
int is = 0;
|
||||
uint8_t sc, m;
|
||||
for (int j = 0; j < QK_K; j += 64) {
|
||||
get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
|
||||
const float16_t d1 = d * sc; const float16_t m1 = min * m;
|
||||
get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
|
||||
const float16_t d2 = d * sc; const float16_t m2 = min * m;
|
||||
for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
|
||||
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
|
||||
q += 32; is += 2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
|
||||
"github.com/x448/float16"
|
||||
)
|
||||
|
||||
func gguf_load_quantized(data unsafe.Pointer, name string, final_shape []C.int, dtype uint32, stream C.mlx_stream) (r C.mlx_array, err error) {
|
||||
shape := append([]C.int{}, final_shape...)
|
||||
var weights_per_byte C.int
|
||||
if dtype == 2 || dtype == 3 {
|
||||
weights_per_byte = 2
|
||||
} else if dtype == 8 {
|
||||
weights_per_byte = 1
|
||||
} else {
|
||||
return r, fmt.Errorf("unsupported tensor type %d", dtype)
|
||||
}
|
||||
|
||||
weights_per_block := C.int(32)
|
||||
if shape[len(shape)-1]%weights_per_block != 0 {
|
||||
return r, fmt.Errorf("[load_gguf] tensor has incompatible last dim shape: %d", shape[len(shape)-1])
|
||||
}
|
||||
|
||||
weights_shape := append([]C.int{}, shape...)
|
||||
weights_shape[len(weights_shape)-1] /= (weights_per_byte * 4)
|
||||
w_nbytes := C.int(unsafe.Sizeof(uint32(0)))
|
||||
for i := range weights_shape {
|
||||
w_nbytes *= weights_shape[i]
|
||||
}
|
||||
w_data := make([]byte, w_nbytes)
|
||||
cbytes := C.CBytes(w_data)
|
||||
defer C.free(cbytes)
|
||||
weights := C.mlx_array_new_data(
|
||||
cbytes,
|
||||
&weights_shape[0],
|
||||
C.int(len(weights_shape)),
|
||||
C.MLX_UINT32,
|
||||
)
|
||||
|
||||
// For scales and bias
|
||||
shape[len(shape)-1] = shape[len(shape)-1] / weights_per_block
|
||||
sb_nbytes := C.int(unsafe.Sizeof(float16.Float16(0)))
|
||||
for i := range shape {
|
||||
sb_nbytes *= shape[i]
|
||||
}
|
||||
|
||||
s_data := make([]byte, sb_nbytes)
|
||||
cbytes = C.CBytes(s_data)
|
||||
defer C.free(cbytes)
|
||||
scales := C.mlx_array_new_data(
|
||||
cbytes,
|
||||
&shape[0],
|
||||
C.int(len(shape)),
|
||||
C.MLX_FLOAT16,
|
||||
)
|
||||
b_data := make([]byte, sb_nbytes)
|
||||
cbytes = C.CBytes(b_data)
|
||||
defer C.free(cbytes)
|
||||
biases := C.mlx_array_new_data(
|
||||
cbytes,
|
||||
&shape[0],
|
||||
C.int(len(shape)),
|
||||
C.MLX_FLOAT16,
|
||||
)
|
||||
var bits C.int
|
||||
switch dtype {
|
||||
case 2:
|
||||
C.extract_q4_0_data((*C.uint8_t)(data), &weights, &scales, &biases)
|
||||
bits = 4
|
||||
case 3:
|
||||
C.extract_q4_1_data((*C.uint8_t)(data), &weights, &scales, &biases)
|
||||
bits = 4
|
||||
case 8:
|
||||
C.extract_q8_0_data((*C.uint8_t)(data), &weights, &scales, &biases)
|
||||
bits = 8
|
||||
}
|
||||
groupSize := C.mlx_optional_int{value: 32, has_value: true}
|
||||
bitsOpt := C.mlx_optional_int{value: bits, has_value: true}
|
||||
var dtypeOpt C.mlx_optional_dtype // has_value defaults to false
|
||||
C.mlx_dequantize(
|
||||
&r,
|
||||
weights,
|
||||
scales,
|
||||
biases,
|
||||
groupSize,
|
||||
bitsOpt,
|
||||
nil, // TODO mode
|
||||
dtypeOpt,
|
||||
stream,
|
||||
)
|
||||
C.mlx_array_free(weights)
|
||||
C.mlx_array_free(scales)
|
||||
C.mlx_array_free(biases)
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func load_k_quantized(data unsafe.Pointer, name string, shape []C.int, dtype uint32, stream C.mlx_stream) (r C.mlx_array, err error) {
|
||||
size := 1
|
||||
for _, d := range shape {
|
||||
size *= int(d)
|
||||
}
|
||||
fdata := make([]float16.Float16, size)
|
||||
switch dtype {
|
||||
case 14:
|
||||
C.dequant_row_q6_K(
|
||||
data,
|
||||
unsafe.Pointer(&fdata[0]),
|
||||
C.int(size),
|
||||
)
|
||||
|
||||
case 12:
|
||||
C.dequant_row_q4_K(
|
||||
data,
|
||||
unsafe.Pointer(&fdata[0]),
|
||||
C.int(size),
|
||||
)
|
||||
default:
|
||||
return r, fmt.Errorf("unsupported K quant")
|
||||
}
|
||||
|
||||
r = C.mlx_array_new_data(
|
||||
unsafe.Pointer(&fdata[0]),
|
||||
&shape[0],
|
||||
C.int(len(shape)),
|
||||
C.MLX_FLOAT16,
|
||||
)
|
||||
return r, nil
|
||||
}
|
||||
643
x/ml/device.go
643
x/ml/device.go
@@ -1,643 +0,0 @@
|
||||
package ml
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash/maphash"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
// GPULayers is a set of layers to be allocated on a single GPU
|
||||
type GPULayers struct {
|
||||
DeviceID
|
||||
|
||||
// Layers is a set of layer indicies to load
|
||||
Layers []int
|
||||
}
|
||||
|
||||
// FirstLayer returns the smallest layer index scheduled on this GPU, or MaxInt when empty.
|
||||
func (g GPULayers) FirstLayer() int {
|
||||
if len(g.Layers) == 0 {
|
||||
return math.MaxInt
|
||||
}
|
||||
|
||||
first := g.Layers[0]
|
||||
for i := 1; i < len(g.Layers); i++ {
|
||||
if g.Layers[i] < first {
|
||||
first = g.Layers[i]
|
||||
}
|
||||
}
|
||||
|
||||
return first
|
||||
}
|
||||
|
||||
func (g GPULayers) String() string {
|
||||
if len(g.Layers) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
slices.Sort(g.Layers)
|
||||
|
||||
contiguous := true
|
||||
base := g.Layers[0]
|
||||
for i := range g.Layers {
|
||||
if g.Layers[i] != base+i {
|
||||
contiguous = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if contiguous {
|
||||
return fmt.Sprintf("ID:%v Layers:%v(%v..%v)", g.ID, len(g.Layers), g.Layers[0], g.Layers[len(g.Layers)-1])
|
||||
} else {
|
||||
return fmt.Sprintf("ID:%v Layers:%v%v", g.ID, len(g.Layers), g.Layers)
|
||||
}
|
||||
}
|
||||
|
||||
// GPULayersList is a set of layer allocations across multiple GPUs
|
||||
type GPULayersList []GPULayers
|
||||
|
||||
func (l GPULayersList) Len() int { return len(l) }
|
||||
func (l GPULayersList) Swap(i, j int) { l[i], l[j] = l[j], l[i] }
|
||||
|
||||
// Sort by the ordering of the layers offloaded
|
||||
func (l GPULayersList) Less(i, j int) bool {
|
||||
li := l[i].FirstLayer()
|
||||
lj := l[j].FirstLayer()
|
||||
|
||||
return li < lj
|
||||
}
|
||||
|
||||
func (l GPULayersList) String() string {
|
||||
if l.Sum() > 0 {
|
||||
return fmt.Sprintf("%v%v", l.Sum(), []GPULayers(l))
|
||||
} else {
|
||||
return fmt.Sprintf("%v", []GPULayers(l))
|
||||
}
|
||||
}
|
||||
|
||||
// Sum is the total number of layers assigned across all GPUs
|
||||
func (l GPULayersList) Sum() int {
|
||||
var sum int
|
||||
|
||||
for _, g := range l {
|
||||
sum += len(g.Layers)
|
||||
}
|
||||
|
||||
return sum
|
||||
}
|
||||
|
||||
var h maphash.Hash
|
||||
|
||||
// Hash is an identifier of this layer assignment
|
||||
func (l GPULayersList) Hash() uint64 {
|
||||
h.Reset()
|
||||
for _, g := range l {
|
||||
if len(g.Layers) > 0 {
|
||||
h.WriteString(g.ID + g.Library)
|
||||
for _, l := range g.Layers {
|
||||
binary.Write(&h, binary.NativeEndian, int64(l))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return h.Sum64()
|
||||
}
|
||||
|
||||
// ErrNoMem is returned when panicing due to insufficient memory. It includes
|
||||
// the attempted memory allocation.
|
||||
type ErrNoMem struct {
|
||||
BackendMemory
|
||||
}
|
||||
|
||||
func (e ErrNoMem) Error() string {
|
||||
return fmt.Sprintf("insufficient memory - required allocations: %+v", e.BackendMemory)
|
||||
}
|
||||
|
||||
// Minimal unique device identification
|
||||
type DeviceID struct {
|
||||
// ID is an identifier for the device for matching with system
|
||||
// management libraries. The ID is only unique for other devices
|
||||
// using the same Library.
|
||||
// This ID represents a "post filtered" view of the enumerated devices
|
||||
// if the ID is numeric
|
||||
ID string `json:"id"`
|
||||
|
||||
// Library identifies which library is used for the device (e.g. CUDA, ROCm, etc.)
|
||||
Library string `json:"backend,omitempty"`
|
||||
}
|
||||
|
||||
// DeviceMemory provides a breakdown of the memory needed
|
||||
// per device, such as a CPU or GPU.
|
||||
type DeviceMemory struct {
|
||||
DeviceID
|
||||
|
||||
// Name is the name of the device as labeled by the backend. It
|
||||
// may not be persistent across instances of the runner.
|
||||
Name string
|
||||
|
||||
// Weights is the per-layer memory needed for the model weights.
|
||||
Weights []uint64
|
||||
|
||||
// Cache is the per-layer memory needed for the KV cache.
|
||||
Cache []uint64
|
||||
|
||||
// Graph is the size of the compute graph. It is not per-layer.
|
||||
Graph uint64
|
||||
}
|
||||
|
||||
func sumMemory(mem []uint64) uint64 {
|
||||
var sum uint64
|
||||
|
||||
for _, m := range mem {
|
||||
sum += m
|
||||
}
|
||||
|
||||
return sum
|
||||
}
|
||||
|
||||
// Size returns the total size of the memory required by this device
|
||||
func (m DeviceMemory) Size() uint64 {
|
||||
return sumMemory(m.Weights) + sumMemory(m.Cache) + m.Graph
|
||||
}
|
||||
|
||||
func memoryPresent(mem []uint64) bool {
|
||||
return slices.ContainsFunc(mem, func(m uint64) bool { return m != 0 })
|
||||
}
|
||||
|
||||
func (m DeviceMemory) LogValue() slog.Value {
|
||||
var attrs []slog.Attr
|
||||
if memoryPresent(m.Weights) {
|
||||
attrs = append(attrs, slog.Any("Weights", m.Weights))
|
||||
}
|
||||
|
||||
if memoryPresent(m.Cache) {
|
||||
attrs = append(attrs, slog.Any("Cache", m.Cache))
|
||||
}
|
||||
|
||||
if m.Graph != 0 {
|
||||
attrs = append(attrs, slog.Any("Graph", m.Graph))
|
||||
}
|
||||
|
||||
if len(attrs) > 0 && m.ID != "" {
|
||||
attrs = append([]slog.Attr{slog.String("ID", m.ID)}, attrs...)
|
||||
}
|
||||
|
||||
return slog.GroupValue(attrs...)
|
||||
}
|
||||
|
||||
// BackendMemory provides the amount of memory required to load the model
|
||||
// per device based on the BackendParams. In some cases, not all required
|
||||
// allocations will be known at this point. However, the size of the most recent
|
||||
// allocation is guaranteed to be provided so that if it failed, the caller can
|
||||
// accommodate that to make forward progress.
|
||||
type BackendMemory struct {
|
||||
// InputWeights are always located on the CPU and cannot be moved
|
||||
InputWeights uint64
|
||||
|
||||
// CPU model components are located in system memory. This does not
|
||||
// include unified memory allocated through the GPU.
|
||||
CPU DeviceMemory
|
||||
|
||||
// GPU model components are located on one or more GPUs.
|
||||
GPUs []DeviceMemory
|
||||
}
|
||||
|
||||
func (m BackendMemory) LogValue() slog.Value {
|
||||
var attrs []slog.Attr
|
||||
if m.InputWeights != 0 {
|
||||
attrs = append(attrs, slog.Any("InputWeights", m.InputWeights))
|
||||
}
|
||||
|
||||
attrs = append(attrs, slog.Any(m.CPU.Name, m.CPU))
|
||||
for _, g := range m.GPUs {
|
||||
attrs = append(attrs, slog.Any(g.Name, g))
|
||||
}
|
||||
|
||||
return slog.GroupValue(attrs...)
|
||||
}
|
||||
|
||||
// Log prints a high level summary of the memory
|
||||
func (m BackendMemory) Log(level slog.Level) {
|
||||
var total uint64
|
||||
|
||||
for _, gpu := range m.GPUs {
|
||||
if sum := sumMemory(gpu.Weights); sum > 0 {
|
||||
slog.Log(context.TODO(), level, "model weights", "device", gpu.Name, "size", format.HumanBytes2(sum))
|
||||
total += sum
|
||||
}
|
||||
}
|
||||
if sum := m.InputWeights + sumMemory(m.CPU.Weights); sum > 0 {
|
||||
slog.Log(context.TODO(), level, "model weights", "device", m.CPU.Name, "size", format.HumanBytes2(sum))
|
||||
total += sum
|
||||
}
|
||||
|
||||
for _, gpu := range m.GPUs {
|
||||
if sum := sumMemory(gpu.Cache); sum > 0 {
|
||||
slog.Log(context.TODO(), level, "kv cache", "device", gpu.Name, "size", format.HumanBytes2(sum))
|
||||
total += sum
|
||||
}
|
||||
}
|
||||
if sum := sumMemory(m.CPU.Cache); sum > 0 {
|
||||
slog.Log(context.TODO(), level, "kv cache", "device", m.CPU.Name, "size", format.HumanBytes2(sum))
|
||||
total += sum
|
||||
}
|
||||
|
||||
for _, gpu := range m.GPUs {
|
||||
if sum := gpu.Graph; sum > 0 {
|
||||
slog.Log(context.TODO(), level, "compute graph", "device", gpu.Name, "size", format.HumanBytes2(sum))
|
||||
total += sum
|
||||
}
|
||||
}
|
||||
if sum := m.CPU.Graph; sum > 0 {
|
||||
slog.Log(context.TODO(), level, "compute graph", "device", m.CPU.Name, "size", format.HumanBytes2(sum))
|
||||
total += sum
|
||||
}
|
||||
|
||||
if total > 0 {
|
||||
slog.Log(context.TODO(), level, "total memory", "size", format.HumanBytes2(total))
|
||||
}
|
||||
}
|
||||
|
||||
type DeviceInfo struct {
|
||||
DeviceID
|
||||
|
||||
// Name is the name of the device as labeled by the backend. It
|
||||
// may not be persistent across instances of the runner.
|
||||
Name string `json:"name"`
|
||||
|
||||
// Description is the longer user-friendly identification of the device
|
||||
Description string `json:"description"`
|
||||
|
||||
// FilterID is populated with the unfiltered device ID if a numeric ID is used
|
||||
// so the device can be included.
|
||||
FilterID string `json:"filter_id,omitempty"`
|
||||
|
||||
// Integrated is set true for integrated GPUs, false for Discrete GPUs
|
||||
Integrated bool `json:"integration,omitempty"`
|
||||
|
||||
// PCIID is the bus, device and domain ID of the device for deduplication
|
||||
// when discovered by multiple backends
|
||||
PCIID string `json:"pci_id,omitempty"`
|
||||
|
||||
// TotalMemory is the total amount of memory the device can use for loading models
|
||||
TotalMemory uint64 `json:"total_memory"`
|
||||
|
||||
// FreeMemory is the amount of memory currently available on the device for loading models
|
||||
FreeMemory uint64 `json:"free_memory,omitempty"`
|
||||
|
||||
// ComputeMajor is the major version of capabilities of the device
|
||||
// if unsupported by the backend, -1 will be returned
|
||||
ComputeMajor int
|
||||
|
||||
// ComputeMinor is the minor version of capabilities of the device
|
||||
// if unsupported by the backend, -1 will be returned
|
||||
ComputeMinor int
|
||||
|
||||
// Driver Information
|
||||
DriverMajor int `json:"driver_major,omitempty"`
|
||||
DriverMinor int `json:"driver_minor,omitempty"`
|
||||
|
||||
// Where backends were loaded from
|
||||
LibraryPath []string
|
||||
}
|
||||
|
||||
type SystemInfo struct {
|
||||
// ThreadCount is the optimal number of threads to use for inference
|
||||
ThreadCount int `json:"threads,omitempty"`
|
||||
|
||||
// TotalMemory is the total amount of system memory
|
||||
TotalMemory uint64 `json:"total_memory,omitempty"`
|
||||
|
||||
// FreeMemory is the amount of memory currently available on the system for loading models
|
||||
FreeMemory uint64 `json:"free_memory,omitempty"`
|
||||
|
||||
// FreeSwap is the amount of system swap space reported as available
|
||||
FreeSwap uint64 `json:"free_swap,omitempty"`
|
||||
}
|
||||
|
||||
func (d DeviceInfo) Compute() string {
|
||||
// AMD gfx is encoded into the major minor in hex form
|
||||
if strings.EqualFold(d.Library, "ROCm") {
|
||||
return fmt.Sprintf("gfx%x%02x", d.ComputeMajor, d.ComputeMinor)
|
||||
}
|
||||
return strconv.Itoa(d.ComputeMajor) + "." + strconv.Itoa(d.ComputeMinor)
|
||||
}
|
||||
|
||||
func (d DeviceInfo) Driver() string {
|
||||
return strconv.Itoa(d.DriverMajor) + "." + strconv.Itoa(d.DriverMinor)
|
||||
}
|
||||
|
||||
// MinimumMemory reports the amount of memory that should be set aside
|
||||
// on the device for overhead (e.g. VRAM consumed by context structures independent
|
||||
// of model allocations)
|
||||
func (d DeviceInfo) MinimumMemory() uint64 {
|
||||
if d.Library == "Metal" {
|
||||
return 512 * format.MebiByte
|
||||
}
|
||||
return 457 * format.MebiByte
|
||||
}
|
||||
|
||||
// Sort by Free Space.
|
||||
// iGPUs are reported first, thus Reverse() yields the largest discrete GPU first
|
||||
type ByFreeMemory []DeviceInfo
|
||||
|
||||
func (a ByFreeMemory) Len() int { return len(a) }
|
||||
func (a ByFreeMemory) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||
func (a ByFreeMemory) Less(i, j int) bool {
|
||||
if a[i].Integrated && !a[j].Integrated {
|
||||
return true
|
||||
} else if !a[i].Integrated && a[j].Integrated {
|
||||
return false
|
||||
}
|
||||
return a[i].FreeMemory < a[j].FreeMemory
|
||||
}
|
||||
|
||||
// ByPerformance groups devices by similar speed
|
||||
func ByPerformance(l []DeviceInfo) [][]DeviceInfo {
|
||||
resp := [][]DeviceInfo{}
|
||||
scores := []bool{}
|
||||
for _, info := range l {
|
||||
found := false
|
||||
requested := info.Integrated
|
||||
for i, score := range scores {
|
||||
if score == requested {
|
||||
resp[i] = append(resp[i], info)
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
scores = append(scores, requested)
|
||||
resp = append(resp, []DeviceInfo{info})
|
||||
}
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func ByLibrary(l []DeviceInfo) [][]DeviceInfo {
|
||||
resp := [][]DeviceInfo{}
|
||||
libs := []string{}
|
||||
for _, info := range l {
|
||||
found := false
|
||||
requested := info.Library
|
||||
for i, lib := range libs {
|
||||
if lib == requested {
|
||||
resp[i] = append(resp[i], info)
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
libs = append(libs, requested)
|
||||
resp = append(resp, []DeviceInfo{info})
|
||||
}
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func LibraryPaths(l []DeviceInfo) []string {
|
||||
gpuLibs := []string{LibOllamaPath}
|
||||
for _, gpu := range l {
|
||||
for _, dir := range gpu.LibraryPath {
|
||||
needed := true
|
||||
for _, existing := range gpuLibs {
|
||||
if dir == existing {
|
||||
needed = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if needed {
|
||||
gpuLibs = append(gpuLibs, dir)
|
||||
}
|
||||
}
|
||||
}
|
||||
return gpuLibs
|
||||
}
|
||||
|
||||
type DeviceComparison int
|
||||
|
||||
const (
|
||||
UniqueDevice DeviceComparison = iota
|
||||
SameBackendDevice // The device is the same, and the library/backend is the same
|
||||
DuplicateDevice // The same physical device but different library/backend (overlapping device)
|
||||
)
|
||||
|
||||
func (a DeviceInfo) Compare(b DeviceInfo) DeviceComparison {
|
||||
if a.PCIID != b.PCIID {
|
||||
return UniqueDevice
|
||||
}
|
||||
// If PCIID is empty, we have to use ID + library for uniqueness
|
||||
if a.PCIID == "" && a.DeviceID != b.DeviceID {
|
||||
return UniqueDevice
|
||||
}
|
||||
if a.Library == b.Library {
|
||||
return SameBackendDevice
|
||||
}
|
||||
return DuplicateDevice
|
||||
}
|
||||
|
||||
// For a SameBackendDevice, return true if b is better than a
|
||||
// e.g. newer GPU library version
|
||||
func (a DeviceInfo) IsBetter(b DeviceInfo) bool {
|
||||
aLib := a.LibraryPath[len(a.LibraryPath)-1]
|
||||
bLib := b.LibraryPath[len(b.LibraryPath)-1]
|
||||
if aLib == bLib {
|
||||
return false
|
||||
}
|
||||
aLibSplit := strings.SplitN(aLib, "_", 2)
|
||||
bLibSplit := strings.SplitN(bLib, "_", 2)
|
||||
if len(aLibSplit) < 2 || len(bLibSplit) < 2 {
|
||||
return false
|
||||
}
|
||||
if aLibSplit[0] != bLibSplit[0] {
|
||||
slog.Debug("unexpected libraries", "a", aLib, "b", bLib)
|
||||
return false
|
||||
}
|
||||
if aLibSplit[1] == bLibSplit[1] {
|
||||
return false
|
||||
}
|
||||
cmp := []string{aLibSplit[1], bLibSplit[1]}
|
||||
sort.Sort(sort.Reverse(sort.StringSlice(cmp)))
|
||||
return cmp[0] == bLibSplit[1]
|
||||
}
|
||||
|
||||
// For each GPU, check if it does NOT support flash attention
|
||||
func FlashAttentionSupported(l []DeviceInfo) bool {
|
||||
for _, gpu := range l {
|
||||
supportsFA := gpu.Library == "cpu" ||
|
||||
gpu.Name == "Metal" || gpu.Library == "Metal" ||
|
||||
(gpu.Library == "CUDA" && gpu.DriverMajor >= 7 && !(gpu.ComputeMajor == 7 && gpu.ComputeMinor == 2)) ||
|
||||
gpu.Library == "ROCm" ||
|
||||
gpu.Library == "Vulkan"
|
||||
|
||||
if !supportsFA {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Given the list of GPUs this instantiation is targeted for,
|
||||
// figure out the visible devices environment variables
|
||||
// Set mustFilter true to enable filtering of CUDA devices
|
||||
func GetVisibleDevicesEnv(l []DeviceInfo, mustFilter bool) map[string]string {
|
||||
if len(l) == 0 {
|
||||
return nil
|
||||
}
|
||||
env := map[string]string{}
|
||||
for _, d := range l {
|
||||
d.updateVisibleDevicesEnv(env, mustFilter)
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
// NeedsInitValidation returns true if the device in question has the potential
|
||||
// to crash at inference time and requires deeper validation before we include
|
||||
// it in the supported devices list.
|
||||
func (d DeviceInfo) NeedsInitValidation() bool {
|
||||
// ROCm: rocblas will crash on unsupported devices.
|
||||
// CUDA: verify CC is supported by the version of the library
|
||||
return d.Library == "ROCm" || d.Library == "CUDA"
|
||||
}
|
||||
|
||||
// Set the init validation environment variable
|
||||
func (d DeviceInfo) AddInitValidation(env map[string]string) {
|
||||
env["GGML_CUDA_INIT"] = "1" // force deep initialization to trigger crash on unsupported GPUs
|
||||
}
|
||||
|
||||
// PreferredLibrary returns true if this library is preferred over the other input
|
||||
// library
|
||||
// Used to filter out Vulkan in favor of CUDA or ROCm
|
||||
func (d DeviceInfo) PreferredLibrary(other DeviceInfo) bool {
|
||||
// TODO in the future if we find Vulkan is better than ROCm on some devices
|
||||
// that implementation can live here.
|
||||
|
||||
if d.Library == "CUDA" || d.Library == "ROCm" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (d DeviceInfo) updateVisibleDevicesEnv(env map[string]string, mustFilter bool) {
|
||||
var envVar string
|
||||
switch d.Library {
|
||||
case "ROCm":
|
||||
// ROCm must be filtered as it can crash the runner on unsupported devices
|
||||
envVar = "ROCR_VISIBLE_DEVICES"
|
||||
if runtime.GOOS != "linux" {
|
||||
envVar = "HIP_VISIBLE_DEVICES"
|
||||
}
|
||||
case "CUDA":
|
||||
if !mustFilter {
|
||||
// By default we try to avoid filtering CUDA devices because ROCm also
|
||||
// looks at the CUDA env var, and gets confused in mixed vendor environments.
|
||||
return
|
||||
}
|
||||
envVar = "CUDA_VISIBLE_DEVICES"
|
||||
default:
|
||||
// Vulkan is not filtered via env var, but via scheduling decisions
|
||||
return
|
||||
}
|
||||
v, existing := env[envVar]
|
||||
if existing {
|
||||
v = v + ","
|
||||
}
|
||||
if d.FilterID != "" {
|
||||
v = v + d.FilterID
|
||||
} else {
|
||||
v = v + d.ID
|
||||
}
|
||||
env[envVar] = v
|
||||
}
|
||||
|
||||
type BaseRunner interface {
|
||||
// GetPort returns the localhost port number the runner is running on
|
||||
GetPort() int
|
||||
|
||||
// HasExited indicates if the runner is no longer running. This can be used during
|
||||
// bootstrap to detect if a given filtered device is incompatible and triggered an assert
|
||||
HasExited() bool
|
||||
}
|
||||
|
||||
type RunnerDiscovery interface {
|
||||
BaseRunner
|
||||
|
||||
// GetDeviceInfos will perform a query of the underlying device libraries
|
||||
// for device identification and free VRAM information
|
||||
// During bootstrap scenarios, this routine may take seconds to complete
|
||||
GetDeviceInfos(ctx context.Context) []DeviceInfo
|
||||
}
|
||||
|
||||
type FilteredRunnerDiscovery interface {
|
||||
RunnerDiscovery
|
||||
|
||||
// GetActiveDeviceIDs returns the filtered set of devices actively in
|
||||
// use by this runner for running models. If the runner is a bootstrap runner, no devices
|
||||
// will be active yet so no device IDs are returned.
|
||||
// This routine will not query the underlying device and will return immediately
|
||||
GetActiveDeviceIDs() []DeviceID
|
||||
}
|
||||
|
||||
func GetDevicesFromRunner(ctx context.Context, runner BaseRunner) ([]DeviceInfo, error) {
|
||||
var moreDevices []DeviceInfo
|
||||
port := runner.GetPort()
|
||||
tick := time.Tick(10 * time.Millisecond)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("failed to finish discovery before timeout")
|
||||
case <-tick:
|
||||
r, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/info", port), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(r)
|
||||
if err != nil {
|
||||
// slog.Warn("failed to send request", "error", err)
|
||||
if runner.HasExited() {
|
||||
return nil, fmt.Errorf("runner crashed")
|
||||
}
|
||||
continue
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
// old runner, fall back to bootstrapping model
|
||||
return nil, fmt.Errorf("llamarunner free vram reporting not supported")
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
slog.Warn("failed to read response", "error", err)
|
||||
continue
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
logutil.Trace("runner failed to discover free VRAM", "status", resp.StatusCode, "response", body)
|
||||
return nil, fmt.Errorf("runner error: %s", string(body))
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &moreDevices); err != nil {
|
||||
slog.Warn("unmarshal encode response", "error", err)
|
||||
continue
|
||||
}
|
||||
return moreDevices, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,103 +0,0 @@
|
||||
package nn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ollama/ollama/x/kvcache"
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
)
|
||||
|
||||
// Attention implements scaled dot-product attention for transformer models:
|
||||
// Attention(Q, K, V) = softmax(QK^T/√d_k)V
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for tensor operations
|
||||
// - query: Query tensor (Q) with shape [d_k, heads, seq_len_q]
|
||||
// - key: Key tensor (K) with shape [d_k, kv_heads, seq_len_k], can be nil to read from cache only
|
||||
// - value: Value tensor (V) with shape [d_v, kv_heads, seq_len_k], can be nil to read from cache only
|
||||
// - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
|
||||
// - cache: KV cache to store key/value and get past history, can be nil to only use provided key/value
|
||||
//
|
||||
// Returns:
|
||||
//
|
||||
// Attention output with shape [d_v, heads, seq_len_q]
|
||||
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||
return AttentionWithVMLA(ctx, query, key, value, nil, nil, scale, cache)
|
||||
}
|
||||
|
||||
func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||
return AttentionWithVMLA(ctx, query, key, value, sinks, nil, scale, cache)
|
||||
}
|
||||
|
||||
func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||
ctx.Forward(query)
|
||||
|
||||
if key != nil && value != nil {
|
||||
if query.Dim(0) != key.Dim(0) {
|
||||
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
|
||||
}
|
||||
|
||||
if key.Dim(1) != value.Dim(1) {
|
||||
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1)))
|
||||
}
|
||||
|
||||
if key.Dim(2) != value.Dim(2) {
|
||||
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
|
||||
}
|
||||
|
||||
ctx.Forward(key, value)
|
||||
if cache != nil {
|
||||
cache.Put(ctx, key, value)
|
||||
}
|
||||
} else if cache == nil {
|
||||
panic("key & value tensors must be provided if cache is nil")
|
||||
}
|
||||
|
||||
// ctx.CompareWith("/tmp/test", map[string]ml.Tensor{"q": query, "k": key, "v": value}, true)
|
||||
// panic("after cache get") //
|
||||
// 2025/12/10 16:02:33 INFO XXX tensors are similar q=0.9999869465827942 shape="[1 8 13 256]" min_difference=[-0.07926178] max_difference=[0.07012844]
|
||||
// 2025/12/10 16:02:33 INFO XXX tensors are similar k=0.9999891519546509 shape="[1 4 13 256]" min_difference=[-0.21365738] max_difference=[0.19916534]
|
||||
// 2025/12/10 16:02:33 INFO XXX tensors are similar v=0.9999960660934448 shape="[1 4 13 256]" min_difference=[-0.32923126] max_difference=[0.32646942]
|
||||
|
||||
// var mask ml.Tensor
|
||||
if cache != nil {
|
||||
key, value, _ = cache.Get(ctx)
|
||||
}
|
||||
// ctx.CompareWith("/tmp/test", map[string]ml.Tensor{"q": query.Contiguous(ctx, false), "k": key.Contiguous(ctx, false), "v": value.Contiguous(ctx, false)}, true)
|
||||
// panic("after cache get") //
|
||||
// 2025/12/10 15:34:03 INFO XXX tensors are similar q=0.9999869465827942 shape="[1 8 13 256]" min_difference=[-0.07926178] max_difference=[0.07012844]
|
||||
// 2025/12/10 15:34:03 INFO XXX tensors are similar k=0.9999881982803345 shape="[1 4 13 256]" min_difference=[-0.25] max_difference=[0.25]
|
||||
// 2025/12/10 15:34:03 INFO XXX tensors are similar v=0.9999913573265076 shape="[1 4 13 256]" min_difference=[-0.5] max_difference=[0.5]
|
||||
|
||||
// Only use the fast SDPA implementation if we have a cache, since that's what
|
||||
// will do any expected backend-specific transformations for us
|
||||
|
||||
if cache != nil {
|
||||
// TODO what to do with vmla?
|
||||
// return query.Transpose(ctx, 0, 2, 1, 3).ScaledDotProductAttention(ctx, key.Transpose(ctx, 0, 2, 1, 3), value.Transpose(ctx, 0, 2, 1, 3), scale, "array", mask, sinks)
|
||||
return query.ScaledDotProductAttention(ctx, key, value, scale, "causal", nil, sinks)
|
||||
|
||||
// TODO these two produce identical output, but not similar enough - 92.9% - should be 99.999%
|
||||
} else {
|
||||
panic("else case not supported")
|
||||
// TODO transpose shapes are wrong
|
||||
// key = key.Transpose(ctx, 0, 2, 1, 3)
|
||||
// value = value.Transpose(ctx, 1, 2, 0, 3).Contiguous(ctx, false)
|
||||
|
||||
// kq := query.Matmul(ctx, key)
|
||||
|
||||
// kq = kq.Scale(ctx, scale)
|
||||
// if mask != nil {
|
||||
// kq = kq.Add(ctx, mask)
|
||||
// }
|
||||
// kq = kq.Softmax(ctx)
|
||||
|
||||
// kqv := kq.Matmul(ctx, value)
|
||||
|
||||
// if vmla != nil {
|
||||
// kqv = kqv.Matmul(ctx, vmla)
|
||||
// }
|
||||
|
||||
// return kqv.Transpose(ctx, 0, 2, 1, 3).Contiguous(ctx, false)
|
||||
}
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
package nn
|
||||
|
||||
import "github.com/ollama/ollama/x/ml"
|
||||
|
||||
type Conv2D struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Bias ml.Tensor `gguf:"bias"`
|
||||
}
|
||||
|
||||
func (m *Conv2D) Forward(ctx ml.Context, t ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
|
||||
t = m.Weight.Conv2D(ctx, t, s0, s1, p0, p1, d0, d1, 1)
|
||||
if m.Bias != nil {
|
||||
// Bias shape is (out_channels,) while t shape is (width, height, out_channels, batch)
|
||||
t = t.Add(ctx, m.Bias.Reshape(ctx, 1, 1, -1))
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
type Conv3D struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Bias ml.Tensor `gguf:"bias"`
|
||||
}
|
||||
|
||||
func (m *Conv3D) Forward(ctx ml.Context, t ml.Tensor, s0, s1, s2, p0, p1, p2, d0, d1, d2, g int) ml.Tensor {
|
||||
t = m.Weight.Conv3D(ctx, t, s0, s1, s2, p0, p1, p2, d0, d1, d2, g)
|
||||
if m.Bias != nil {
|
||||
t = t.Add(ctx, m.Bias)
|
||||
}
|
||||
return t
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
package nn
|
||||
|
||||
import "github.com/ollama/ollama/x/ml"
|
||||
|
||||
type Embedding struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
}
|
||||
|
||||
func (m *Embedding) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor {
|
||||
return m.Weight.TakeAxes(ctx, hiddenState, 0)
|
||||
}
|
||||
@@ -1,32 +0,0 @@
|
||||
package nn
|
||||
|
||||
import "github.com/ollama/ollama/x/ml"
|
||||
|
||||
type Linear struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Bias ml.Tensor `gguf:"bias"`
|
||||
}
|
||||
|
||||
func (m *Linear) Forward(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
||||
t = t.Matmul(ctx, m.Weight.Transpose(ctx))
|
||||
if m.Bias != nil {
|
||||
t = t.Add(ctx, m.Bias)
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
type LinearBatch struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Bias ml.Tensor `gguf:"bias"`
|
||||
}
|
||||
|
||||
func (m *LinearBatch) Forward(ctx ml.Context, t, indices ml.Tensor) ml.Tensor {
|
||||
panic("not yet ported")
|
||||
// t = m.Weight.MulmatID(ctx, t, indices)
|
||||
// if m.Bias != nil {
|
||||
// t = t.AddID(ctx, m.Bias, indices)
|
||||
// }
|
||||
|
||||
// return t
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
package nn
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
)
|
||||
|
||||
type LayerNorm struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Bias ml.Tensor `gguf:"bias"`
|
||||
}
|
||||
|
||||
func (m *LayerNorm) Forward(ctx ml.Context, t ml.Tensor, eps float32) ml.Tensor {
|
||||
return t.LayerNorm(ctx, m.Weight, m.Bias, eps)
|
||||
}
|
||||
|
||||
type RMSNorm struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
}
|
||||
|
||||
func (m *RMSNorm) Forward(ctx ml.Context, t ml.Tensor, eps float32) ml.Tensor {
|
||||
// slog.Info("RMSNorm", "eps", eps)
|
||||
// fmt.Fprintln(os.Stderr, t.ToString())
|
||||
// fmt.Fprintln(os.Stderr, m.Weight.ToString())
|
||||
|
||||
// TODO this is probably model specific, not generalized...
|
||||
w := m.Weight.Add(ctx, ctx.FromFloats([]float32{1.0}, 1))
|
||||
|
||||
return t.RMSNorm(ctx, w, eps)
|
||||
}
|
||||
@@ -1,41 +0,0 @@
|
||||
package pooling
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
)
|
||||
|
||||
type Type uint32
|
||||
|
||||
const (
|
||||
TypeNone Type = iota
|
||||
TypeMean
|
||||
TypeCLS
|
||||
TypeLast
|
||||
)
|
||||
|
||||
func (t Type) String() string {
|
||||
switch t {
|
||||
case TypeMean:
|
||||
return "Mean"
|
||||
case TypeCLS:
|
||||
return "CLS"
|
||||
case TypeLast:
|
||||
return "Last"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
|
||||
switch t {
|
||||
// case TypeMean:
|
||||
// hiddenStates = hiddenStates.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false).Mean(ctx)
|
||||
// return hiddenStates.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false)
|
||||
// case TypeCLS:
|
||||
// return hiddenStates.Slice(ctx, 1, 0, 1, 1)
|
||||
// case TypeLast:
|
||||
// return hiddenStates.Slice(ctx, 1, hiddenStates.Dim(1)-1, hiddenStates.Dim(1), 1)
|
||||
default:
|
||||
panic("unknown pooling type")
|
||||
}
|
||||
}
|
||||
@@ -1,72 +0,0 @@
|
||||
package rope
|
||||
|
||||
import "github.com/ollama/ollama/x/ml"
|
||||
|
||||
// Options contains optional parameters for RoPE function
|
||||
type Options struct {
|
||||
Type int
|
||||
Factors ml.Tensor
|
||||
|
||||
// YaRN options
|
||||
YaRN struct {
|
||||
OriginalContextLength int
|
||||
ExtrapolationFactor,
|
||||
AttentionFactor,
|
||||
BetaFast,
|
||||
BetaSlow float32
|
||||
}
|
||||
|
||||
// MRoPE options
|
||||
MRoPE struct {
|
||||
Sections []int
|
||||
}
|
||||
}
|
||||
|
||||
// WithTypeNeoX sets RoPE type to NeoX
|
||||
func WithTypeNeoX() func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.Type = 2
|
||||
}
|
||||
}
|
||||
|
||||
// WithFactors sets custom rope factors
|
||||
func WithFactors(factors ml.Tensor) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
if factors != nil {
|
||||
opts.Factors = factors
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithOriginalContextLength sets a custom context length
|
||||
func WithOriginalContextLength(n int) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.YaRN.OriginalContextLength = n
|
||||
}
|
||||
}
|
||||
|
||||
func WithExtrapolationFactor(extrapolationFactor float32) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.YaRN.ExtrapolationFactor = extrapolationFactor
|
||||
}
|
||||
}
|
||||
|
||||
func WithAttentionFactor(attentionFactor float32) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.YaRN.AttentionFactor = attentionFactor
|
||||
}
|
||||
}
|
||||
|
||||
func WithMRoPE(sections []int) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.Type |= 1 << 3
|
||||
opts.MRoPE.Sections = sections
|
||||
}
|
||||
}
|
||||
|
||||
func WithInterleaveMRoPE(sections []int) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.Type |= 1<<3 | 1<<5
|
||||
opts.MRoPE.Sections = sections
|
||||
}
|
||||
}
|
||||
56
x/ml/path.go
56
x/ml/path.go
@@ -1,56 +0,0 @@
|
||||
package ml
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// LibPath is a path to lookup dynamic libraries
|
||||
// in development it's usually 'build/lib/ollama'
|
||||
// in distribution builds it's 'lib/ollama' on Windows
|
||||
// '../lib/ollama' on Linux and the executable's directory on macOS
|
||||
// note: distribution builds, additional GPU-specific libraries are
|
||||
// found in subdirectories of the returned path, such as
|
||||
// 'cuda_v12', 'rocm', etc.
|
||||
var LibOllamaPath string = func() string {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if eval, err := filepath.EvalSymlinks(exe); err == nil {
|
||||
exe = eval
|
||||
}
|
||||
|
||||
var libPath string
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
libPath = filepath.Join(filepath.Dir(exe), "lib", "ollama")
|
||||
case "linux":
|
||||
libPath = filepath.Join(filepath.Dir(exe), "..", "lib", "ollama")
|
||||
case "darwin":
|
||||
libPath = filepath.Dir(exe)
|
||||
}
|
||||
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
paths := []string{
|
||||
libPath,
|
||||
|
||||
// build paths for development
|
||||
filepath.Join(filepath.Dir(exe), "build", "lib", "ollama"),
|
||||
filepath.Join(cwd, "build", "lib", "ollama"),
|
||||
}
|
||||
|
||||
for _, p := range paths {
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return p
|
||||
}
|
||||
}
|
||||
|
||||
return filepath.Dir(exe)
|
||||
}()
|
||||
@@ -1,322 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func llama(t testing.TB) BytePairEncoding {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.Open(filepath.Join("..", "..", "model", "testdata", "llama3.2", "encoder.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
vocab := make(map[string]int32)
|
||||
if err := json.NewDecoder(f).Decode(&vocab); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
types := make([]int32, len(vocab))
|
||||
tokens := make([]string, len(vocab))
|
||||
for token, id := range vocab {
|
||||
tokens[id] = token
|
||||
types[id] = 1
|
||||
}
|
||||
|
||||
for _, token := range []string{"<|begin_of_text|>", "<|end_of_text|>"} {
|
||||
if _, ok := vocab[token]; !ok {
|
||||
tokens = append(tokens, token) //nolint:makezero
|
||||
types = append(types, 3) //nolint:makezero
|
||||
vocab[token] = int32(len(vocab))
|
||||
}
|
||||
}
|
||||
|
||||
f, err = os.Open(filepath.Join("..", "..", "model", "testdata", "llama3.2", "vocab.bpe"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
merges := make([]string, 0, 50000)
|
||||
|
||||
scanner := bufio.NewScanner(f)
|
||||
for scanner.Scan() {
|
||||
if !strings.HasPrefix(scanner.Text(), "#") {
|
||||
merges = append(merges, scanner.Text())
|
||||
}
|
||||
}
|
||||
|
||||
return NewBytePairEncoding(
|
||||
&Vocabulary{
|
||||
Values: tokens,
|
||||
Types: types,
|
||||
Merges: merges,
|
||||
},
|
||||
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
)
|
||||
}
|
||||
|
||||
func TestLlama(t *testing.T) {
|
||||
tokenizer := llama(t)
|
||||
|
||||
t.Run("simple", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ids, err := tokenizer.Encode("hello world", true)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff([]int32{15339, 1917}, ids); diff != "" {
|
||||
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||||
}
|
||||
|
||||
s, err := tokenizer.Decode([]int32{15339, 1917})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if s != "hello world" {
|
||||
t.Errorf("got %q, want hello world", s)
|
||||
}
|
||||
|
||||
ids, err = tokenizer.Encode("hello <|end_of_text|>", true)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff([]int32{15339, 220, 128001}, ids); diff != "" {
|
||||
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("simple repeated", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := map[string][]int32{
|
||||
strings.Repeat("0", 1): {15},
|
||||
strings.Repeat("0", 2): {410},
|
||||
strings.Repeat("0", 3): {931},
|
||||
strings.Repeat("0", 4): {931, 15},
|
||||
strings.Repeat("0", 5): {931, 410},
|
||||
strings.Repeat("0", 6): {931, 931},
|
||||
strings.Repeat("0", 7): {931, 931, 15},
|
||||
strings.Repeat("0", 8): {931, 931, 410},
|
||||
strings.Repeat("0", 9): {931, 931, 931},
|
||||
strings.Repeat("0", 10): {931, 931, 931, 15},
|
||||
strings.Repeat("0", 11): {931, 931, 931, 410},
|
||||
strings.Repeat("0", 12): {931, 931, 931, 931},
|
||||
strings.Repeat("0", 13): {931, 931, 931, 931, 15},
|
||||
strings.Repeat("0", 14): {931, 931, 931, 931, 410},
|
||||
strings.Repeat("0", 15): {931, 931, 931, 931, 931},
|
||||
strings.Repeat("0", 16): {931, 931, 931, 931, 931, 15},
|
||||
strings.Repeat("0", 17): {931, 931, 931, 931, 931, 410},
|
||||
}
|
||||
|
||||
for s, want := range cases {
|
||||
ids, err := tokenizer.Encode(s, true)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(want, ids); diff != "" {
|
||||
t.Errorf("%q no match (-theirs +ours):\n%s", s, diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("basic roundtrip", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []string{
|
||||
"hello",
|
||||
"hello ",
|
||||
"hello ",
|
||||
" hello",
|
||||
" hello ",
|
||||
" hello ",
|
||||
"hello world",
|
||||
"请考试我的软件!12345",
|
||||
}
|
||||
|
||||
for _, want := range cases {
|
||||
ids, err := tokenizer.Encode(want, true)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if got, err := tokenizer.Decode(ids); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("special", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := map[string][]int32{
|
||||
"<|begin_of_text|>A B!": {128000, 32, 426, 0},
|
||||
"<|begin_of_text|>A<|end_of_text|>B!": {128000, 32, 128001, 33, 0},
|
||||
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!": {128000, 32, 128001, 33, 128000, 0},
|
||||
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!<|end_of_text|>": {128000, 32, 128001, 33, 128000, 0, 128001},
|
||||
}
|
||||
|
||||
for s, want := range cases {
|
||||
ids, err := tokenizer.Encode(s, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(want, ids); diff != "" {
|
||||
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("split", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := map[string][]string{
|
||||
"Hello World!": {"Hello", " World", "!"},
|
||||
"I'm don't won't": {"I", "'m", " don", "'t", " won", "'t"},
|
||||
"In 2024 there are 366 days": {"In", " ", "202", "4", " there", " are", " ", "366", " days"},
|
||||
"Hello!! ...world": {"Hello", "!!", " ...", "world"},
|
||||
"Hello World": {"Hello", " ", " World"},
|
||||
"Hello\nWorld": {"Hello", "\n", "World"},
|
||||
"Hello, WORLD!! How's it going?": {"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?"},
|
||||
}
|
||||
|
||||
for s, want := range cases {
|
||||
got := slices.Collect(tokenizer.split(s))
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("roundtriping 0x00-0xFF", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for b := 0x00; b <= 0xFF; b++ {
|
||||
input := string(rune(b))
|
||||
ids, err := tokenizer.Encode(input, false)
|
||||
if err != nil {
|
||||
t.Errorf("failed to encode rune 0x%02X: %v", b, err)
|
||||
continue
|
||||
}
|
||||
|
||||
decoded, err := tokenizer.Decode(ids)
|
||||
if err != nil {
|
||||
t.Errorf("failed to decode rune 0x%02X: %v", b, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if b == 0x00 {
|
||||
if len(decoded) != 0 {
|
||||
t.Errorf("Decode(Encode(0x00)) should be empty, got %v", ids)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if decoded != input {
|
||||
t.Errorf("rune 0x%02X failed roundtrip: got %q, want %q", b, decoded, input)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkBytePairEncoding(b *testing.B) {
|
||||
tokenizer := llama(b)
|
||||
bts, err := os.ReadFile(filepath.Join("testdata", "war-and-peace.txt"))
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
for i := range 8 {
|
||||
n := min(int(math.Pow10(i)), len(bts))
|
||||
bts := bts[:n]
|
||||
b.Run("encode"+strconv.Itoa(n), func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
_, err := tokenizer.Encode(string(bts), true)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("decode"+strconv.Itoa(n), func(b *testing.B) {
|
||||
ids, err := tokenizer.Encode(string(bts), true)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
_, err := tokenizer.Decode(ids)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("split"+strconv.Itoa(n), func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
slices.Collect(tokenizer.split(string(bts)))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplit(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
patterns,
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
name: "default",
|
||||
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " 123", " 一二三"},
|
||||
},
|
||||
{
|
||||
name: "unicode",
|
||||
patterns: []string{
|
||||
"\\p{N}{1,3}",
|
||||
`[一-龥-ゟ゠-ヿ]+`,
|
||||
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
|
||||
},
|
||||
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "123", " ", "一二三"},
|
||||
},
|
||||
{
|
||||
name: "individual digits",
|
||||
patterns: []string{
|
||||
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
},
|
||||
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "1", "2", "3", " 一二三"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tokenizer := NewBytePairEncoding(nil, tt.patterns...)
|
||||
if diff := cmp.Diff(tt.want, slices.Collect(tokenizer.split("Hello, WORLD!! How's it going? 123 一二三"))); diff != "" {
|
||||
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,76 +0,0 @@
|
||||
package input
|
||||
|
||||
import "github.com/ollama/ollama/x/ml"
|
||||
|
||||
// Multimodal is a multimodal embedding or a component of one.
|
||||
// For example, it could be a row of an image that can be processed
|
||||
// independently.
|
||||
type Multimodal struct {
|
||||
// Tensor is the embedding data. Implementations may chose what to
|
||||
// store here or it may be nil if not needed. However, any ml.Tensor
|
||||
// objects must be stored here and not in Data.
|
||||
Tensor ml.Tensor
|
||||
|
||||
// Data is implementation-specific opaque data, such as metadata on how
|
||||
// to layout Tensor. It may be nil if not needed. It may also store larger
|
||||
// objects such as complete images if they are to be processed later.
|
||||
Data any
|
||||
}
|
||||
|
||||
// Input represents one token in the input stream
|
||||
type Input struct {
|
||||
// Token is a single element of text.
|
||||
Token int32
|
||||
|
||||
// Multimodal is represents a non-text element such as an
|
||||
// image (or part of one if the image can be processed in pieces).
|
||||
// It may be used either together with Token or on its own.
|
||||
Multimodal []Multimodal
|
||||
|
||||
// MultimodalHash is a unique representation of the data
|
||||
// stored in Multimodal, used for caching and comparing
|
||||
// equality.
|
||||
MultimodalHash uint64
|
||||
|
||||
// SameBatch forces the following number of tokens to be processed
|
||||
// in a single batch, breaking and extending batches as needed.
|
||||
// Useful for things like images that must be processed in one
|
||||
// shot.
|
||||
SameBatch int
|
||||
}
|
||||
|
||||
// MultimodalIndex is a multimodal element (such as an image)
|
||||
// together with an index into the slice of Inputs with the
|
||||
// corresponding token. Note that the index is not the same
|
||||
// as the position - to find that use the index with the
|
||||
// Positions slice.
|
||||
type MultimodalIndex struct {
|
||||
Index int
|
||||
Multimodal []Multimodal
|
||||
}
|
||||
|
||||
// Batch contains the inputs for a model forward pass
|
||||
type Batch struct {
|
||||
// Inputs is the input tokens, including placeholders for multimodal inputs.
|
||||
Inputs ml.Tensor
|
||||
|
||||
// Outputs are the set of indicies into Inputs for which output data should
|
||||
// be returned.
|
||||
Outputs ml.Tensor
|
||||
|
||||
// TODO maybe not the optimal way to handle this
|
||||
// Offset of final tensor in the final batch
|
||||
Offset int
|
||||
|
||||
// Positions is the position for each Input, relative to its sequence. Equal
|
||||
// in length to Inputs.
|
||||
Positions []int32
|
||||
|
||||
// Sequences is the sequence for each Input. Equal in length to Inputs.
|
||||
Sequences []int
|
||||
|
||||
// Multimodal is a set of multimodal embeddings previously created by
|
||||
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
|
||||
// models or for batches without multimodal elements.
|
||||
Multimodal []MultimodalIndex
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user