Compare commits

..

10 Commits

Author SHA1 Message Date
jmorganca
03f56d3c30 fix: allow app to build without UI dist directory
Add a placeholder README.md in app/ui/app/dist so the go:embed directive
doesn't fail when the UI hasn't been built. Also improve the error message
when index.html is not found to help developers understand they need to
run 'npm run build'.
2026-02-07 15:15:53 -08:00
Jeffrey Morgan
099a0f18ef build: fix Dockerfile mlx directory (#14131) 2026-02-06 17:08:53 -08:00
Richard Lyons
fff696ee31 docs: increased RAM requirement for parallelism 2026-02-06 15:49:39 -08:00
Jeffrey Morgan
2e3ce6eab3 anthropic: do not count image tokens for now (#14127) 2026-02-06 15:33:18 -08:00
Parth Sareen
9e2003f88a cmd/config: offer to pull missing models instead of erroring (#14113) 2026-02-06 10:19:47 -08:00
Parth Sareen
42e1d49fbe cmd: fix context limits for droid and add qwen3-coder-next ctx (#14112) 2026-02-05 22:29:53 -08:00
Michael Yang
814630ca60 Revert "move tokenizers to separate package (#13825)" (#14111) 2026-02-05 20:49:08 -08:00
Parth Sareen
87cf187774 cmd: set claude code env vars on launch (#14109)
Set ANTHROPIC_DEFAULT_OPUS_MODEL, ANTHROPIC_DEFAULT_SONNET_MODEL,
ANTHROPIC_DEFAULT_HAIKU_MODEL, and CLAUDE_CODE_SUBAGENT_MODEL when
launching Claude Code so all model tiers route through Ollama.
2026-02-05 19:04:53 -08:00
Michael Yang
6ddd8862cd chore: move x/mlxrunner into x/imagegen (#14100) 2026-02-05 18:25:56 -08:00
Michael Yang
f1373193dc move tokenizers to separate package (#13825) 2026-02-05 17:44:11 -08:00
115 changed files with 707 additions and 6834 deletions

View File

@@ -182,7 +182,7 @@ option(MLX_ENGINE "Enable MLX backend" OFF)
if(MLX_ENGINE)
message(STATUS "Setting up MLX (this takes a while...)")
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/ml/backend/mlx)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/imagegen/mlx)
# Find CUDA toolkit if MLX is built with CUDA support
find_package(CUDAToolkit)
@@ -216,4 +216,4 @@ if(MLX_ENGINE)
COMPONENT MLX)
endif()
endif()
endif()
endif()

View File

@@ -147,7 +147,7 @@ ARG PARALLEL
WORKDIR /go/src/github.com/ollama/ollama
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
COPY x/ml/backend/mlx x/ml/backend/mlx
COPY x/imagegen/mlx x/imagegen/mlx
COPY go.mod go.sum .
COPY MLX_VERSION .
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local

View File

@@ -897,11 +897,5 @@ func countContentBlock(block any) int {
}
}
if source, ok := blockMap["source"].(map[string]any); ok {
if data, ok := source["data"].(string); ok {
total += len(data)
}
}
return total
}

View File

@@ -33,7 +33,10 @@ func (s *Server) appHandler() http.Handler {
data, err := fs.ReadFile(fsys, "index.html")
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
http.NotFound(w, r)
// Development mode: UI not built
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusNotFound)
w.Write([]byte("UI not built. Run 'npm run build' in app/ui/app directory."))
} else {
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}

1
app/ui/app/dist/README.md vendored Normal file
View File

@@ -0,0 +1 @@
This directory contains the built React app. Run `npm run build` in the app directory to generate the build.

View File

@@ -58,14 +58,39 @@ func (c *Claude) Run(model string, args []string) error {
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = append(os.Environ(),
env := append(os.Environ(),
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
"ANTHROPIC_API_KEY=",
"ANTHROPIC_AUTH_TOKEN=ollama",
)
env = append(env, c.modelEnvVars(model)...)
cmd.Env = env
return cmd.Run()
}
// modelEnvVars returns Claude Code env vars that route all model tiers through Ollama.
func (c *Claude) modelEnvVars(model string) []string {
primary := model
fast := model
if cfg, err := loadIntegration("claude"); err == nil && cfg.Aliases != nil {
if p := cfg.Aliases["primary"]; p != "" {
primary = p
}
if f := cfg.Aliases["fast"]; f != "" {
fast = f
}
}
return []string{
"ANTHROPIC_DEFAULT_OPUS_MODEL=" + primary,
"ANTHROPIC_DEFAULT_SONNET_MODEL=" + primary,
"ANTHROPIC_DEFAULT_HAIKU_MODEL=" + fast,
"CLAUDE_CODE_SUBAGENT_MODEL=" + primary,
}
}
// ConfigureAliases sets up model aliases for Claude Code.
// model: the model to use (if empty, user will be prompted to select)
// aliases: existing alias configuration to preserve/update

View File

@@ -5,6 +5,7 @@ import (
"path/filepath"
"runtime"
"slices"
"strings"
"testing"
)
@@ -103,3 +104,95 @@ func TestClaudeArgs(t *testing.T) {
})
}
}
func TestClaudeModelEnvVars(t *testing.T) {
c := &Claude{}
envMap := func(envs []string) map[string]string {
m := make(map[string]string)
for _, e := range envs {
k, v, _ := strings.Cut(e, "=")
m[k] = v
}
return m
}
t.Run("falls back to model param when no aliases saved", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
got := envMap(c.modelEnvVars("llama3.2"))
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2" {
t.Errorf("OPUS = %q, want llama3.2", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
}
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "llama3.2" {
t.Errorf("SONNET = %q, want llama3.2", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
}
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "llama3.2" {
t.Errorf("HAIKU = %q, want llama3.2", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
}
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2" {
t.Errorf("SUBAGENT = %q, want llama3.2", got["CLAUDE_CODE_SUBAGENT_MODEL"])
}
})
t.Run("uses primary alias for opus sonnet and subagent", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
saveIntegration("claude", []string{"qwen3:8b"})
saveAliases("claude", map[string]string{"primary": "qwen3:8b"})
got := envMap(c.modelEnvVars("qwen3:8b"))
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "qwen3:8b" {
t.Errorf("OPUS = %q, want qwen3:8b", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
}
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "qwen3:8b" {
t.Errorf("SONNET = %q, want qwen3:8b", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
}
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "qwen3:8b" {
t.Errorf("HAIKU = %q, want qwen3:8b (no fast alias)", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
}
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "qwen3:8b" {
t.Errorf("SUBAGENT = %q, want qwen3:8b", got["CLAUDE_CODE_SUBAGENT_MODEL"])
}
})
t.Run("uses fast alias for haiku", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
saveIntegration("claude", []string{"llama3.2:70b"})
saveAliases("claude", map[string]string{
"primary": "llama3.2:70b",
"fast": "llama3.2:8b",
})
got := envMap(c.modelEnvVars("llama3.2:70b"))
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2:70b" {
t.Errorf("OPUS = %q, want llama3.2:70b", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
}
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "llama3.2:70b" {
t.Errorf("SONNET = %q, want llama3.2:70b", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
}
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "llama3.2:8b" {
t.Errorf("HAIKU = %q, want llama3.2:8b", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
}
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2:70b" {
t.Errorf("SUBAGENT = %q, want llama3.2:70b", got["CLAUDE_CODE_SUBAGENT_MODEL"])
}
})
t.Run("alias primary overrides model param", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
saveIntegration("claude", []string{"saved-model"})
saveAliases("claude", map[string]string{"primary": "saved-model"})
got := envMap(c.modelEnvVars("different-model"))
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "saved-model" {
t.Errorf("OPUS = %q, want saved-model", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
}
})
}

View File

@@ -6,7 +6,6 @@ import (
"encoding/json"
"errors"
"fmt"
"log/slog"
"os"
"path/filepath"
"strings"
@@ -54,7 +53,6 @@ func migrateConfig() (bool, error) {
var js json.RawMessage
if err := json.Unmarshal(oldData, &js); err != nil {
slog.Warn("legacy config has invalid JSON, skipping migration", "path", oldPath, "error", err)
return false, nil
}
@@ -73,7 +71,6 @@ func migrateConfig() (bool, error) {
_ = os.Remove(oldPath)
_ = os.Remove(filepath.Dir(oldPath)) // clean up empty directory
slog.Info("migrated config", "from", oldPath, "to", newPath)
return true, nil
}

View File

@@ -1,6 +1,7 @@
package config
import (
"context"
"encoding/json"
"fmt"
"os"
@@ -8,6 +9,7 @@ import (
"path/filepath"
"slices"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
)
@@ -112,9 +114,17 @@ func (d *Droid) Edit(models []string) error {
}
// Build new Ollama model entries with sequential indices (0, 1, 2, ...)
client, _ := api.ClientFromEnvironment()
var newModels []any
var defaultModelID string
for i, model := range models {
maxOutput := 64000
if isCloudModel(context.Background(), client, model) {
if l, ok := lookupCloudModelLimit(model); ok {
maxOutput = l.Output
}
}
modelID := fmt.Sprintf("custom:%s-%d", model, i)
newModels = append(newModels, modelEntry{
Model: model,
@@ -122,7 +132,7 @@ func (d *Droid) Edit(models []string) error {
BaseURL: envconfig.Host().String() + "/v1",
APIKey: "ollama",
Provider: "generic-chat-completion-api",
MaxOutputTokens: 64000,
MaxOutputTokens: maxOutput,
SupportsImages: false,
ID: modelID,
Index: i,

View File

@@ -1251,6 +1251,55 @@ func TestDroidEdit_LargeNumberOfModels(t *testing.T) {
}
}
func TestDroidEdit_LocalModelDefaultMaxOutput(t *testing.T) {
d := &Droid{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
settingsDir := filepath.Join(tmpDir, ".factory")
settingsPath := filepath.Join(settingsDir, "settings.json")
if err := d.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(settingsPath)
var settings map[string]any
json.Unmarshal(data, &settings)
models := settings["customModels"].([]any)
entry := models[0].(map[string]any)
if entry["maxOutputTokens"] != float64(64000) {
t.Errorf("local model maxOutputTokens = %v, want 64000", entry["maxOutputTokens"])
}
}
func TestDroidEdit_CloudModelLimitsUsed(t *testing.T) {
// Verify that every cloud model in cloudModelLimits has a valid output
// value that would be used for maxOutputTokens when isCloudModel returns true.
// :cloud suffix stripping must also work since that's how users specify them.
for name, expected := range cloudModelLimits {
t.Run(name, func(t *testing.T) {
l, ok := lookupCloudModelLimit(name)
if !ok {
t.Fatalf("lookupCloudModelLimit(%q) returned false", name)
}
if l.Output != expected.Output {
t.Errorf("output = %d, want %d", l.Output, expected.Output)
}
// Also verify :cloud suffix lookup
cloudName := name + ":cloud"
l2, ok := lookupCloudModelLimit(cloudName)
if !ok {
t.Fatalf("lookupCloudModelLimit(%q) returned false", cloudName)
}
if l2.Output != expected.Output {
t.Errorf(":cloud output = %d, want %d", l2.Output, expected.Output)
}
})
}
}
func TestDroidEdit_ArraysWithMixedTypes(t *testing.T) {
d := &Droid{}
tmpDir := t.TempDir()

View File

@@ -194,6 +194,20 @@ func pullIfNeeded(ctx context.Context, client *api.Client, existingModels map[st
return nil
}
// showOrPull checks if a model exists via client.Show and offers to pull it if not found.
func showOrPull(ctx context.Context, client *api.Client, model string) error {
if _, err := client.Show(ctx, &api.ShowRequest{Model: model}); err == nil {
return nil
}
if ok, err := confirmPrompt(fmt.Sprintf("Download %s?", model)); err != nil {
return err
} else if !ok {
return errCancelled
}
fmt.Fprintf(os.Stderr, "\n")
return pullModel(ctx, client, model)
}
func listModels(ctx context.Context) ([]selectItem, map[string]bool, map[string]bool, *api.Client, error) {
client, err := api.ClientFromEnvironment()
if err != nil {
@@ -397,8 +411,11 @@ Examples:
// Validate --model flag if provided
if modelFlag != "" {
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: modelFlag}); err != nil {
return fmt.Errorf("model %q not found", modelFlag)
if err := showOrPull(cmd.Context(), client, modelFlag); err != nil {
if errors.Is(err, errCancelled) {
return nil
}
return err
}
}
@@ -424,9 +441,11 @@ Examples:
// Validate saved model still exists
if model != "" && modelFlag == "" {
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: model}); err != nil {
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Model: model}); err != nil {
fmt.Fprintf(os.Stderr, "%sConfigured model %q not found%s\n\n", ansiGray, model, ansiReset)
model = ""
if err := showOrPull(cmd.Context(), client, model); err != nil {
model = ""
}
}
}
@@ -443,6 +462,13 @@ Examples:
existingAliases = aliases
}
// Ensure cloud models are authenticated
if isCloudModel(cmd.Context(), client, model) {
if err := ensureAuth(cmd.Context(), client, map[string]bool{model: true}, []string{model}); err != nil {
return err
}
}
// Sync aliases and save
if err := syncAliases(cmd.Context(), client, ac, name, model, existingAliases); err != nil {
fmt.Fprintf(os.Stderr, "%sWarning: Could not sync aliases: %v%s\n", ansiGray, err, ansiReset)
@@ -467,8 +493,11 @@ Examples:
if err != nil {
return err
}
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: modelFlag}); err != nil {
return fmt.Errorf("model %q not found", modelFlag)
if err := showOrPull(cmd.Context(), client, modelFlag); err != nil {
if errors.Is(err, errCancelled) {
return nil
}
return err
}
}
@@ -650,7 +679,7 @@ func isCloudModel(ctx context.Context, client *api.Client, name string) bool {
if client == nil {
return false
}
resp, err := client.Show(ctx, &api.ShowRequest{Name: name})
resp, err := client.Show(ctx, &api.ShowRequest{Model: name})
if err != nil {
return false
}

View File

@@ -2,12 +2,17 @@ package config
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"slices"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/spf13/cobra"
)
@@ -539,3 +544,136 @@ func TestAliasConfigurerInterface(t *testing.T) {
}
})
}
func TestShowOrPull_ModelExists(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{"model":"test-model"}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
err := showOrPull(context.Background(), client, "test-model")
if err != nil {
t.Errorf("showOrPull should return nil when model exists, got: %v", err)
}
}
func TestShowOrPull_ModelNotFound_NoTerminal(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, `{"error":"model not found"}`)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
// confirmPrompt will fail in test (no terminal), so showOrPull should return an error
err := showOrPull(context.Background(), client, "missing-model")
if err == nil {
t.Error("showOrPull should return error when model not found and no terminal available")
}
}
func TestShowOrPull_ShowCalledWithCorrectModel(t *testing.T) {
var receivedModel string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
var req api.ShowRequest
if err := json.NewDecoder(r.Body).Decode(&req); err == nil {
receivedModel = req.Model
}
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{"model":"%s"}`, receivedModel)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
_ = showOrPull(context.Background(), client, "qwen3:8b")
if receivedModel != "qwen3:8b" {
t.Errorf("expected Show to be called with %q, got %q", "qwen3:8b", receivedModel)
}
}
func TestEnsureAuth_NoCloudModels(t *testing.T) {
// ensureAuth should be a no-op when no cloud models are selected
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Error("no API calls expected when no cloud models selected")
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
err := ensureAuth(context.Background(), client, map[string]bool{}, []string{"local-model"})
if err != nil {
t.Errorf("ensureAuth should return nil for non-cloud models, got: %v", err)
}
}
func TestEnsureAuth_CloudModelFilteredCorrectly(t *testing.T) {
// ensureAuth should only care about models in cloudModels map
var whoamiCalled bool
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/me" {
whoamiCalled = true
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{"name":"testuser"}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cloudModels := map[string]bool{"cloud-model:cloud": true}
selected := []string{"cloud-model:cloud", "local-model"}
err := ensureAuth(context.Background(), client, cloudModels, selected)
if err != nil {
t.Errorf("ensureAuth should succeed when user is authenticated, got: %v", err)
}
if !whoamiCalled {
t.Error("expected whoami to be called for cloud model")
}
}
func TestEnsureAuth_SkipsWhenNoCloudSelected(t *testing.T) {
var whoamiCalled bool
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/me" {
whoamiCalled = true
}
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
// cloudModels has entries but none are in selected
cloudModels := map[string]bool{"cloud-model:cloud": true}
selected := []string{"local-model"}
err := ensureAuth(context.Background(), client, cloudModels, selected)
if err != nil {
t.Errorf("expected nil error, got: %v", err)
}
if whoamiCalled {
t.Error("whoami should not be called when no cloud models are selected")
}
}

View File

@@ -39,6 +39,7 @@ var cloudModelLimits = map[string]cloudModelLimit{
"kimi-k2-thinking": {Context: 262_144, Output: 262_144},
"nemotron-3-nano:30b": {Context: 1_048_576, Output: 131_072},
"qwen3-coder:480b": {Context: 262_144, Output: 65_536},
"qwen3-coder-next": {Context: 262_144, Output: 32_768},
"qwen3-next:80b": {Context: 262_144, Output: 32_768},
}

View File

@@ -633,6 +633,7 @@ func TestLookupCloudModelLimit(t *testing.T) {
{"deepseek-v3.2", true, 163_840, 65_536},
{"deepseek-v3.2:cloud", true, 163_840, 65_536},
{"qwen3-coder:480b", true, 262_144, 65_536},
{"qwen3-coder-next:cloud", true, 262_144, 32_768},
{"llama3.2", false, 0, 0},
{"unknown-model:cloud", false, 0, 0},
}

View File

@@ -312,7 +312,7 @@ Parallel request processing for a given model results in increasing the context
The following server settings may be used to adjust how Ollama handles concurrent requests on most platforms:
- `OLLAMA_MAX_LOADED_MODELS` - The maximum number of models that can be loaded concurrently provided they fit in available memory. The default is 3 \* the number of GPUs or 3 for CPU inference.
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default will auto-select either 4 or 1 based on available memory.
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time, default 1. Required RAM will scale by `OLLAMA_NUM_PARALLEL` * `OLLAMA_CONTEXT_LENGTH`.
- `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6.2 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.

View File

@@ -34,6 +34,7 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/tokenizer"
)
type filteredEnv []string
@@ -116,7 +117,7 @@ type llamaServer struct {
type ollamaServer struct {
llmServer
textProcessor model.TextProcessor // textProcessor handles text encoding/decoding
tokenizer tokenizer.Tokenizer // tokenizer handles text encoding/decoding
}
// LoadModel will load a model from disk. The model must be in the GGML format.
@@ -142,11 +143,11 @@ func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) {
// NewLlamaServer will run a server for the given GPUs
func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
var llamaModel *llama.Model
var textProcessor model.TextProcessor
var tok tokenizer.Tokenizer
var err error
if envconfig.NewEngine() || f.KV().OllamaEngineRequired() {
if len(projectors) == 0 {
textProcessor, err = model.NewTextProcessor(modelPath)
tok, err = model.NewTextProcessor(modelPath)
} else {
err = errors.New("split vision models aren't supported")
}
@@ -155,7 +156,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err)
}
}
if textProcessor == nil {
if tok == nil {
llamaModel, err = llama.LoadModelFromFile(modelPath, llama.ModelParams{VocabOnly: true})
if err != nil {
return nil, err
@@ -211,7 +212,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
kvct := strings.ToLower(envconfig.KvCacheType())
if textProcessor == nil {
if tok == nil {
flashAttention := ml.FlashAttentionAuto
if faUserSet {
if fa {
@@ -261,7 +262,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
gpuLibs := ml.LibraryPaths(gpus)
status := NewStatusWriter(os.Stderr)
cmd, port, err := StartRunner(
textProcessor != nil,
tok != nil,
modelPath,
gpuLibs,
status,
@@ -310,8 +311,8 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
}
}()
if textProcessor != nil {
return &ollamaServer{llmServer: s, textProcessor: textProcessor}, nil
if tok != nil {
return &ollamaServer{llmServer: s, tokenizer: tok}, nil
} else {
return &llamaServer{llmServer: s, ggml: f}, nil
}
@@ -1774,7 +1775,7 @@ func (s *llamaServer) Tokenize(ctx context.Context, content string) ([]int, erro
}
func (s *ollamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
tokens, err := s.textProcessor.Encode(content, false)
tokens, err := s.tokenizer.Encode(content, false)
if err != nil {
return nil, err
}
@@ -1809,7 +1810,7 @@ func (s *ollamaServer) Detokenize(ctx context.Context, tokens []int) (string, er
toks[i] = int32(t)
}
content, err := s.textProcessor.Decode(toks)
content, err := s.tokenizer.Decode(toks)
if err != nil {
return "", err
}

View File

@@ -1,272 +0,0 @@
package model
import (
"cmp"
"iter"
"slices"
"strings"
"github.com/dlclark/regexp2"
heap "github.com/emirpasic/gods/v2/trees/binaryheap"
"github.com/ollama/ollama/logutil"
)
type BytePairEncoding struct {
vocab *Vocabulary
regexps []*regexp2.Regexp
}
var _ TextProcessor = (*BytePairEncoding)(nil)
func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding {
if len(pretokenizers) == 0 {
// set default byte-level pretokenizer if none provided, e.g.
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44
pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
}
return BytePairEncoding{
vocab: vocab,
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
for _, p := range pretokenizers {
if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
return
}
}
}),
}
}
func (bpe BytePairEncoding) Vocabulary() *Vocabulary {
return bpe.vocab
}
func (bpe BytePairEncoding) Is(id int32, special Special) bool {
return bpe.vocab.Is(id, special)
}
func (bpe *BytePairEncoding) split(s string) iter.Seq[string] {
parts := []string{s}
for _, re := range bpe.regexps {
parts = slices.Collect(func(yield func(string) bool) {
for _, part := range parts {
r := []rune(part)
var offset int
for m, _ := re.FindRunesMatch(r); m != nil; m, _ = re.FindNextMatch(m) {
if offset-m.Index != 0 {
if !yield(string(r[:m.Index])) {
return
}
}
if !yield(m.String()) {
return
}
offset = m.Index + m.Length
}
if offset < len(r) {
if !yield(string(r[offset:])) {
return
}
}
}
})
}
return slices.Values(parts)
}
// fragment is a string fragment and their corresponding token IDs
type fragment struct {
value string
ids []int32
}
// pair is a pair of runes and its rank
type pair struct {
a, b int
rank int
value string
}
type merge struct {
p, n int
runes []rune
}
func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
fragments := []fragment{{value: s}}
for _, special := range bpe.vocab.SpecialVocabulary() {
// TODO: process special tokens concurrently
id := bpe.vocab.Encode(special)
for i := 0; i < len(fragments); i++ {
frag := fragments[i]
if len(frag.ids) > 0 {
continue
}
var middle []fragment
switch i := strings.Index(frag.value, special); {
case i < 0:
middle = append(middle, frag)
case i > 0:
middle = append(middle, fragment{value: frag.value[:i]})
fallthrough
default:
middle = append(middle, fragment{value: special, ids: []int32{id}})
if rest := frag.value[i+len(special):]; rest != "" {
middle = append(middle, fragment{value: rest})
}
}
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
}
}
var ids []int32
for _, frag := range fragments {
if len(frag.ids) > 0 {
ids = append(ids, frag.ids...)
continue
}
for split := range bpe.split(frag.value) {
// TODO: process splits concurrently
var sb strings.Builder
for _, b := range []byte(split) {
r := rune(b)
switch {
case r == 0x00ad:
r = 0x0143
case r <= 0x0020:
r = r + 0x0100
case r >= 0x007f && r <= 0x00a0:
r = r + 0x00a2
}
sb.WriteRune(r)
}
// short circuit if the fragment is in the vocabulary
if id := bpe.vocab.Encode(sb.String()); id >= 0 {
ids = append(ids, id)
continue
}
runes := []rune(sb.String())
merges := make([]merge, len(runes))
for r := range runes {
merges[r] = merge{
p: r - 1,
n: r + 1,
runes: []rune{runes[r]},
}
}
pairwise := func(a, b int) *pair {
if a < 0 || b >= len(runes) {
return nil
}
left, right := string(merges[a].runes), string(merges[b].runes)
rank := bpe.vocab.Merge(left, right)
if rank < 0 {
return nil
}
return &pair{
a: a,
b: b,
rank: rank,
value: left + right,
}
}
pairs := heap.NewWith(func(i, j *pair) int {
return cmp.Compare(i.rank, j.rank)
})
for i := range len(runes) - 1 {
if pair := pairwise(i, i+1); pair != nil {
pairs.Push(pair)
}
}
for !pairs.Empty() {
pair, _ := pairs.Pop()
left, right := merges[pair.a], merges[pair.b]
if len(left.runes) == 0 || len(right.runes) == 0 ||
string(left.runes)+string(right.runes) != pair.value {
continue
}
if id := bpe.vocab.Encode(pair.value); id < 0 {
continue
}
merges[pair.a].runes = append(left.runes, right.runes...)
merges[pair.b].runes = nil
merges[pair.a].n = right.n
if right.n < len(merges) {
merges[right.n].p = pair.a
}
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
pairs.Push(pair)
}
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
pairs.Push(pair)
}
}
for _, merge := range merges {
if len(merge.runes) > 0 {
// TODO: handle the edge case where the rune isn't in the vocabulary
if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 {
ids = append(ids, id)
}
}
}
}
}
if addSpecial {
ids = bpe.vocab.addSpecials(ids)
}
logutil.Trace("encoded", "string", s, "ids", ids)
return ids, nil
}
func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
var sb strings.Builder
for _, id := range ids {
for _, r := range bpe.vocab.Decode(id) {
switch {
case r == 0x0100:
// this produces 0x00 aka NULL
continue
case r == 0x0143:
r = 0x00ad
case r > 0x0100 && r <= 0x0120:
r = r - 0x0100
case r > 0x0120 && r <= 0x0142:
r = r - 0x00a2
}
// NOTE: not using WriteRune here because it writes the UTF-8
// encoding of the rune which is _not_ what we want
if err := sb.WriteByte(byte(r)); err != nil {
return "", err
}
}
}
logutil.Trace("decoded", "string", sb.String(), "from", ids)
return sb.String(), nil
}

View File

@@ -23,6 +23,7 @@ import (
_ "github.com/ollama/ollama/ml/backend"
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
var (
@@ -133,7 +134,7 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
return m, nil
}
func NewTextProcessor(s string) (TextProcessor, error) {
func NewTextProcessor(s string) (tokenizer.Tokenizer, error) {
r, err := os.Open(s)
if err != nil {
return nil, err
@@ -150,7 +151,7 @@ func NewTextProcessor(s string) (TextProcessor, error) {
return nil, err
}
tp, ok := m.(TextProcessor)
tp, ok := m.(tokenizer.Tokenizer)
if !ok {
return nil, ErrUnsupportedTokenizer
}

View File

@@ -10,11 +10,12 @@ import (
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TypeEmbedding *nn.Embedding `gguf:"token_types"`
@@ -129,7 +130,7 @@ func (o Options) headDim() int {
}
func New(c fs.Config) (model.Model, error) {
vocab := &model.Vocabulary{
vocab := &tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -153,17 +154,17 @@ func New(c fs.Config) (model.Model, error) {
},
}
var processor model.TextProcessor
var t tokenizer.Tokenizer
switch c.String("tokenizer.ggml.model", "bert") {
case "bert":
processor = model.NewWordPiece(vocab, true)
t = tokenizer.NewWordPiece(vocab, true)
default:
return nil, model.ErrUnsupportedTokenizer
}
return &Model{
TextProcessor: processor,
Layers: make([]EncoderLayer, c.Uint("block_count")),
Tokenizer: t,
Layers: make([]EncoderLayer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),

View File

@@ -13,6 +13,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -222,7 +223,7 @@ func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -277,8 +278,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,11 +10,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
Sam *samModel `gguf:"s"`
Vision *visionModel `gguf:"v"`
@@ -134,8 +135,8 @@ func init() {
}
m := Model{
TextProcessor: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,6 +10,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -27,7 +28,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
type Model struct {
model.Base
model.SentencePiece
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -43,8 +44,8 @@ const (
func New(c fs.Config) (model.Model, error) {
m := Model{
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Tokenizer: tokenizer.NewSentencePiece(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),

View File

@@ -7,11 +7,12 @@ import (
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type embedModel struct {
model.Base
model.SentencePiece
tokenizer.Tokenizer
*TextModel
poolingType pooling.Type
@@ -31,8 +32,8 @@ func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, erro
func newEmbedModel(c fs.Config) (model.Model, error) {
m := &embedModel{
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Tokenizer: tokenizer.NewSentencePiece(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),

View File

@@ -12,11 +12,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
*VisionModel `gguf:"v"`
*TextModel
@@ -54,7 +55,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
}
func New(c fs.Config) (model.Model, error) {
vocabulary := model.Vocabulary{
vocabulary := tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -70,19 +71,19 @@ func New(c fs.Config) (model.Model, error) {
),
}
var processor model.TextProcessor
var t tokenizer.Tokenizer
switch c.String("tokenizer.ggml.model") {
case "gpt2":
processor = model.NewBytePairEncoding(&vocabulary)
t = tokenizer.NewBytePairEncoding(&vocabulary)
default:
// Previous uploads of Gemma 3 on Ollama did not have token 106
// (i.e. "<end_of_turn>") so we need to add in case it's not already present
vocabulary.EOS = append(vocabulary.EOS, int32(c.Uint("tokenizer.ggml.eot_token_id", 106)))
processor = model.NewSentencePiece(&vocabulary)
t = tokenizer.NewSentencePiece(&vocabulary)
}
m := Model{
TextProcessor: processor,
Tokenizer: t,
ImageProcessor: newImageProcessor(c),
VisionModel: newVisionModel(c),
TextModel: newTextModel(c),

View File

@@ -6,11 +6,12 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.SentencePiece
tokenizer.Tokenizer
*TextModel
}
@@ -23,8 +24,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func New(c fs.Config) (model.Model, error) {
m := Model{
TextModel: newTextModel(c),
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Tokenizer: tokenizer.NewSentencePiece(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),

View File

@@ -10,6 +10,7 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
var ErrOldModelFormat = errors.New("this model uses a weight format that is no longer supported; please re-download it")
@@ -198,7 +199,7 @@ func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -236,8 +237,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
*TextModel
*VisionModel `gguf:"v"`
@@ -37,8 +38,8 @@ func New(c fs.Config) (model.Model, error) {
allEOS := append([]int32{eosTokenID}, eosTokenIDs...)
m := &Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -12,11 +12,12 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Transformer struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TransformerBlocks []TransformerBlock `gguf:"blk"`
@@ -196,8 +197,8 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Optio
func New(c fs.Config) (model.Model, error) {
m := Transformer{
TransformerBlocks: make([]TransformerBlock, c.Uint("block_count")),
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,6 +10,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -59,7 +60,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -78,7 +79,7 @@ func New(c fs.Config) (model.Model, error) {
return nil, model.ErrUnsupportedTokenizer
}
vocabulary := model.Vocabulary{
vocabulary := tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -104,8 +105,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
TextProcessor: model.NewBytePairEncoding(&vocabulary, pretokenizers...),
Layers: make([]Layer, c.Uint("block_count")),
Tokenizer: tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...),
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
headDim: int(c.Uint("attention.key_length")),

View File

@@ -11,6 +11,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -25,7 +26,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -41,8 +42,8 @@ func New(c fs.Config) (model.Model, error) {
return nil, model.ErrUnsupportedModel
}
var processor model.TextProcessor
vocabulary := model.Vocabulary{
var processor tokenizer.Tokenizer
vocabulary := tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -80,16 +81,16 @@ func New(c fs.Config) (model.Model, error) {
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}
}
processor = model.NewBytePairEncoding(&vocabulary, pretokenizers...)
processor = tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...)
case "llama":
processor = model.NewSentencePiece(&vocabulary)
processor = tokenizer.NewSentencePiece(&vocabulary)
default:
return nil, model.ErrUnsupportedTokenizer
}
m := Model{
TextProcessor: processor,
Layers: make([]Layer, c.Uint("block_count")),
Tokenizer: processor,
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),

View File

@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
ImageProcessor
*VisionModel `gguf:"v"`
@@ -33,8 +34,8 @@ func (p *Projector) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor {
func New(c fs.Config) (model.Model, error) {
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
*TextModel
*VisionModel `gguf:"v"`
@@ -28,12 +29,12 @@ type Model struct {
var _ model.MultimodalProcessor = (*Model)(nil)
// Implement TextProcessor interface
var _ model.TextProcessor = (*Model)(nil)
var _ tokenizer.Tokenizer = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
m := &Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
*VisionModel `gguf:"v"`
*TextModel
@@ -32,8 +33,8 @@ const (
func New(c fs.Config) (model.Model, error) {
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TypeEmbedding *nn.Embedding `gguf:"token_types"`
@@ -178,29 +179,6 @@ func New(c fs.Config) (model.Model, error) {
numHeads := int(c.Uint("attention.head_count"))
headDim := hiddenSize / numHeads
processor := model.NewWordPiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.cls_token_id"),
c.Uint("tokenizer.ggml.bos_token_id"),
)),
},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
EOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.separator_token_id"),
c.Uint("tokenizer.ggml.eos_token_id"),
)),
},
},
false,
)
blockCount := int(c.Uint("block_count"))
moeEveryNLayers := int(c.Uint("moe_every_n_layers", 0))
layers := make([]EncoderLayer, blockCount)
@@ -219,8 +197,29 @@ func New(c fs.Config) (model.Model, error) {
}
return &Model{
TextProcessor: processor,
Layers: layers,
Tokenizer: tokenizer.NewWordPiece(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.cls_token_id"),
c.Uint("tokenizer.ggml.bos_token_id"),
)),
},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
EOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.separator_token_id"),
c.Uint("tokenizer.ggml.eos_token_id"),
)),
},
},
false,
),
Layers: layers,
Options: Options{
hiddenSize: hiddenSize,
numHeads: numHeads,

View File

@@ -11,6 +11,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
const (
@@ -33,7 +34,7 @@ type Options struct {
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -44,28 +45,24 @@ type Model struct {
}
func New(c fs.Config) (model.Model, error) {
vocabulary := model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
}
processor := model.NewBytePairEncoding(
&vocabulary,
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
)
m := Model{
TextProcessor: processor,
Layers: make([]Layer, c.Uint("block_count")),
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
),
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),

View File

@@ -13,6 +13,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -92,7 +93,7 @@ func (d DecoderLayer) Forward(ctx ml.Context, hiddenStates, positions, outputs m
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []DecoderLayer `gguf:"blk"`
@@ -139,8 +140,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
Layers: make([]DecoderLayer, c.Uint("block_count")),
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,11 +10,12 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
*TextModel
*VisionModel `gguf:"v"`
@@ -27,8 +28,8 @@ var _ model.MultimodalProcessor = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
m := &Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -7,11 +7,12 @@ import (
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type embedModel struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
*Model
poolingType pooling.Type
@@ -34,8 +35,8 @@ func newEmbed(c fs.Config) (model.Model, error) {
layers[i].MLP = &dense{}
}
m := embedModel{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -12,6 +12,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -159,7 +160,7 @@ func (d *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
@@ -218,8 +219,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -11,6 +11,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
// Options contains model configuration
@@ -207,7 +208,7 @@ func (l *Layer) Forward(ctx ml.Context, layer int, hiddenStates, positions, outp
// Model is the main Qwen3-Next model
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
@@ -353,8 +354,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,11 +10,12 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
*TextModel
*VisionModel `gguf:"v"`
@@ -172,8 +173,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func New(c fs.Config) (model.Model, error) {
m := Model{
TextProcessor: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -1,53 +0,0 @@
package model
import (
"slices"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestWordPiece(t *testing.T) {
wpm := NewWordPiece(
&Vocabulary{
Values: []string{"[UNK]", "[CLS]", "[SEP]", "▁hello", "▁world", "s", "▁!", "▁@", "▁#"},
AddBOS: true,
AddEOS: true,
BOS: []int32{1},
EOS: []int32{2},
},
true, // lowercase
)
ids, err := wpm.Encode("Hello world!", true)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff([]int32{1, 3, 4, 6, 2}, ids); diff != "" {
t.Errorf("unexpected ids (-want +got):\n%s", diff)
}
words, err := wpm.Decode(ids)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff("[CLS] hello world! [SEP]", words); diff != "" {
t.Errorf("unexpected words (-want +got):\n%s", diff)
}
}
func TestWordPieceWords(t *testing.T) {
var wpm WordPiece
basic := slices.Collect(wpm.words("Hey friend! How are you?!?"))
if diff := cmp.Diff([]string{"Hey", "friend", "!", "How", "are", "you", "?", "!", "?"}, basic); diff != "" {
t.Errorf("unexpected words (-want +got):\n%s", diff)
}
chinese := slices.Collect(wpm.words("野口里佳 Noguchi Rika"))
if diff := cmp.Diff([]string{"野", "口", "里", "佳", "Noguchi", "Rika"}, chinese); diff != "" {
t.Errorf("unexpected words (-want +got):\n%s", diff)
}
}

View File

@@ -37,6 +37,7 @@ import (
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/runner/common"
"github.com/ollama/ollama/sample"
"github.com/ollama/ollama/tokenizer"
_ "github.com/ollama/ollama/model/models"
)
@@ -210,9 +211,9 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
}
// calculateLogprobs converts raw logits to log probabilities and finds top K tokens
func calculateLogprobs(logits []float32, selectedToken int32, topK int, textProcessor model.TextProcessor) []llm.Logprob {
func calculateLogprobs(logits []float32, selectedToken int32, topK int, tok tokenizer.Tokenizer) []llm.Logprob {
decoder := func(tokenID int) string {
text, _ := textProcessor.Decode([]int32{int32(tokenID)})
text, _ := tok.Decode([]int32{int32(tokenID)})
return text
}
return common.CalculateLogprobs(logits, int(selectedToken), topK, decoder)
@@ -242,7 +243,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input,
for i, part := range parts {
// text - tokenize
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
tokens, err := s.model.(tokenizer.Tokenizer).Encode(part, i == 0)
if err != nil {
return nil, nil, nil, err
}
@@ -764,7 +765,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
nextBatchTokens[i].Token = token
// if it's an end of sequence token, break
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
if s.model.(tokenizer.Tokenizer).Is(token, tokenizer.SpecialEOS) {
// TODO (jmorganca): we should send this back
// as it's important for the /api/generate context
// seq.responses <- piece
@@ -773,14 +774,14 @@ func (s *Server) computeBatch(activeBatch batchState) {
continue
}
piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
piece, err := s.model.(tokenizer.Tokenizer).Decode([]int32{token})
if err != nil {
panic("failed to decode token")
}
// Calculate logprobs if requested (after EOS check to avoid logprobs for EOS tokens)
if seq.logprobs {
logprobs := calculateLogprobs(logits, token, seq.topLogprobs, s.model.(model.TextProcessor))
logprobs := calculateLogprobs(logits, token, seq.topLogprobs, s.model.(tokenizer.Tokenizer))
seq.pendingLogprobs = append(seq.pendingLogprobs, logprobs...)
}
@@ -878,7 +879,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
var grammar *sample.GrammarSampler
var err error
if req.Grammar != "" {
grammar, err = sample.NewGrammarSampler(s.model.(model.TextProcessor), req.Grammar)
grammar, err = sample.NewGrammarSampler(s.model.(tokenizer.Tokenizer), req.Grammar)
if err != nil {
http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
return

View File

@@ -3,7 +3,7 @@ package runner
import (
"github.com/ollama/ollama/runner/llamarunner"
"github.com/ollama/ollama/runner/ollamarunner"
"github.com/ollama/ollama/x/mlxrunner"
"github.com/ollama/ollama/x/imagegen"
)
func Execute(args []string) error {
@@ -11,22 +11,13 @@ func Execute(args []string) error {
args = args[1:]
}
var newRunner bool
var mlxRunner bool
if len(args) > 0 && args[0] == "--ollama-engine" {
args = args[1:]
newRunner = true
}
if len(args) > 0 && args[0] == "--mlx-engine" {
args = args[1:]
mlxRunner = true
}
if mlxRunner {
return mlxrunner.Execute(args)
} else if newRunner {
return ollamarunner.Execute(args)
} else {
return llamarunner.Execute(args)
if len(args) > 0 {
switch args[0] {
case "--ollama-engine":
return ollamarunner.Execute(args[1:])
case "--imagegen-engine":
return imagegen.Execute(args[1:])
}
}
return llamarunner.Execute(args)
}

View File

@@ -7,7 +7,7 @@ import (
"slices"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/tokenizer"
)
// token represents information about a single token during sampling
@@ -168,15 +168,15 @@ type GrammarSampler struct {
grammar *llama.Grammar
}
func NewGrammarSampler(model model.TextProcessor, grammarStr string) (*GrammarSampler, error) {
vocabIds := make([]uint32, len(model.Vocabulary().Values))
pieces := make([]string, len(model.Vocabulary().Values))
for i := range model.Vocabulary().Values {
pieces[i], _ = model.Decode([]int32{int32(i)})
func NewGrammarSampler(tok tokenizer.Tokenizer, grammarStr string) (*GrammarSampler, error) {
vocabIds := make([]uint32, len(tok.Vocabulary().Values))
pieces := make([]string, len(tok.Vocabulary().Values))
for i := range tok.Vocabulary().Values {
pieces[i], _ = tok.Decode([]int32{int32(i)})
vocabIds[i] = uint32(i)
}
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, model.Vocabulary().EOS)
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, tok.Vocabulary().EOS)
if grammar == nil {
return nil, errors.New("sample: failed to initialize grammar")
}

View File

@@ -8,7 +8,7 @@ import (
"path/filepath"
"testing"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/tokenizer"
)
func TestWeighted(t *testing.T) {
@@ -60,10 +60,10 @@ func TestWeighted(t *testing.T) {
}
}
func modelHelper(t testing.TB) model.BytePairEncoding {
func modelHelper(t testing.TB) tokenizer.Tokenizer {
t.Helper()
f, err := os.Open(filepath.Join("..", "model", "testdata", "llama3.2", "encoder.json"))
f, err := os.Open(filepath.FromSlash("../tokenizer/testdata/llama3.2/encoder.json"))
if err != nil {
t.Fatal(err)
}
@@ -81,8 +81,8 @@ func modelHelper(t testing.TB) model.BytePairEncoding {
merges := make([]string, 0, 1)
// Only need vocab for Grammar Test
return model.NewBytePairEncoding(
&model.Vocabulary{
return tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: tokens,
Types: make([]int32, len(vocab)),
Merges: merges,

View File

@@ -52,7 +52,7 @@ import (
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
"github.com/ollama/ollama/x/imagegen"
imagegenmanifest "github.com/ollama/ollama/x/imagegen/manifest"
xserver "github.com/ollama/ollama/x/server"
)
@@ -1106,7 +1106,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
// For image generation models, populate details from imagegen package
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
if info, err := imagegen.GetModelInfo(name.String()); err == nil {
if info, err := imagegenmanifest.GetModelInfo(name.String()); err == nil {
modelDetails.Family = info.Architecture
modelDetails.ParameterSize = format.HumanNumber(uint64(info.ParameterCount))
modelDetails.QuantizationLevel = info.Quantization

View File

@@ -21,7 +21,7 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/mlxrunner"
"github.com/ollama/ollama/x/imagegen"
)
type LlmRequest struct {
@@ -567,16 +567,16 @@ iGPUScan:
// This supports both LLM (completion) and image generation models.
func (s *Scheduler) loadMLX(req *LlmRequest) bool {
// Determine mode based on capabilities
var mode mlxrunner.ModelMode
var mode imagegen.ModelMode
if slices.Contains(req.model.Config.Capabilities, "image") {
mode = mlxrunner.ModeImageGen
mode = imagegen.ModeImageGen
} else {
mode = mlxrunner.ModeLLM
mode = imagegen.ModeLLM
}
// Use model name for MLX (it resolves manifests by name, not file path)
modelName := req.model.ShortName
server, err := mlxrunner.NewServer(modelName, mode)
server, err := imagegen.NewServer(modelName, mode)
if err != nil {
req.errCh <- err
return true

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"cmp"
@@ -18,19 +18,19 @@ type BytePairEncoding struct {
regexps []*regexp2.Regexp
}
var _ TextProcessor = (*BytePairEncoding)(nil)
var _ Tokenizer = (*BytePairEncoding)(nil)
func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding {
if len(pretokenizers) == 0 {
func NewBytePairEncoding(vocab *Vocabulary, pretokenizer ...string) BytePairEncoding {
if len(pretokenizer) == 0 {
// set default byte-level pretokenizer if none provided, e.g.
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44
pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
// https://github.com/huggingface/tokenizer/blob/main/tokenizer/src/pre_tokenizer/byte_level.rs#L44
pretokenizer = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
}
return BytePairEncoding{
vocab: vocab,
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
for _, p := range pretokenizers {
for _, p := range pretokenizer {
if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
return
}

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"bufio"
@@ -17,7 +17,7 @@ import (
func llama(t testing.TB) BytePairEncoding {
t.Helper()
f, err := os.Open(filepath.Join("testdata", "llama3.2", "encoder.json"))
f, err := os.Open(filepath.FromSlash("testdata/llama3.2/encoder.json"))
if err != nil {
t.Fatal(err)
}
@@ -43,7 +43,7 @@ func llama(t testing.TB) BytePairEncoding {
}
}
f, err = os.Open(filepath.Join("testdata", "llama3.2", "vocab.bpe"))
f, err = os.Open(filepath.FromSlash("testdata/llama3.2/vocab.bpe"))
if err != nil {
t.Fatal(err)
}

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"container/heap"
@@ -17,7 +17,7 @@ type SentencePiece struct {
vocab *Vocabulary
}
var _ TextProcessor = (*SentencePiece)(nil)
var _ Tokenizer = (*SentencePiece)(nil)
func (spm SentencePiece) Vocabulary() *Vocabulary {
return spm.vocab
@@ -224,7 +224,7 @@ func (spm SentencePiece) Decode(ids []int32) (string, error) {
data := spm.vocab.Decode(id)
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
// For tokenizers that use byte tokens like "<0xEA>"
// For tokenizer that use byte tokens like "<0xEA>"
// convert them to the partial unicode character
// so they are buffered correctly by the runner instead
// of being sent back to the api as "<0xEA>"

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"log/slog"
@@ -15,7 +15,7 @@ import (
func loadSentencePieceVocab(t *testing.T) SentencePiece {
t.Helper()
bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
bts, err := os.ReadFile(filepath.FromSlash("testdata/gemma2/tokenizer.model"))
if err != nil {
t.Fatal(err)
}

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
const (
TOKEN_TYPE_NORMAL = iota + 1
@@ -9,7 +9,7 @@ const (
TOKEN_TYPE_BYTE
)
type TextProcessor interface {
type Tokenizer interface {
Encode(s string, addSpecial bool) ([]int32, error)
Decode([]int32) (string, error)
Is(int32, Special) bool

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"log/slog"

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"testing"

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"fmt"
@@ -32,7 +32,7 @@ var wordPieceReplacer = strings.NewReplacer(
" 're", "'re",
)
// Decode implements TextProcessor.
// Decode implements Tokenizer.
func (wpm WordPiece) Decode(ids []int32) (string, error) {
var sb strings.Builder
for i, id := range ids {
@@ -96,7 +96,7 @@ func (wpm WordPiece) words(s string) iter.Seq[string] {
}
}
// Encode implements TextProcessor.
// Encode implements Tokenizer.
func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
var ids []int32
@@ -151,17 +151,17 @@ func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
return ids, nil
}
// Is implements TextProcessor.
// Is implements Tokenizer.
func (wpm WordPiece) Is(id int32, special Special) bool {
return wpm.vocab.Is(id, special)
}
// Vocabulary implements TextProcessor.
// Vocabulary implements Tokenizer.
func (wpm WordPiece) Vocabulary() *Vocabulary {
return wpm.vocab
}
var _ TextProcessor = (*WordPiece)(nil)
var _ Tokenizer = (*WordPiece)(nil)
func NewWordPiece(vocab *Vocabulary, lowercase bool) WordPiece {
return WordPiece{

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"slices"

View File

@@ -1,6 +1,6 @@
//go:build mlx
package mlxrunner
package imagegen
import (
"context"
@@ -11,7 +11,7 @@ import (
"sync"
"time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/flux2"
"github.com/ollama/ollama/x/imagegen/models/zimage"
@@ -28,8 +28,8 @@ var imageGenMu sync.Mutex
func (s *server) loadImageModel() error {
// Check memory requirements before loading
var requiredMemory uint64
if manifest, err := imagegen.LoadManifest(s.modelName); err == nil {
requiredMemory = uint64(manifest.TotalTensorSize())
if modelManifest, err := manifest.LoadManifest(s.modelName); err == nil {
requiredMemory = uint64(modelManifest.TotalTensorSize())
}
availableMemory := mlx.GetMemoryLimit()
if availableMemory > 0 && requiredMemory > 0 && availableMemory < requiredMemory {
@@ -38,7 +38,7 @@ func (s *server) loadImageModel() error {
}
// Detect model type and load appropriate model
modelType := imagegen.DetectModelType(s.modelName)
modelType := DetectModelType(s.modelName)
slog.Info("detected image model type", "type", modelType)
var model ImageModel
@@ -108,7 +108,7 @@ func (s *server) handleImageCompletion(w http.ResponseWriter, r *http.Request, r
}
// Encode image as base64 PNG
imageData, err := imagegen.EncodeImageBase64(img)
imageData, err := EncodeImageBase64(img)
if err != nil {
resp := Response{Content: fmt.Sprintf("error encoding: %v", err), Done: true}
data, _ := json.Marshal(resp)

View File

@@ -1,6 +1,6 @@
//go:build mlx
package mlxrunner
package imagegen
import (
"encoding/json"
@@ -12,8 +12,8 @@ import (
"sync"
"time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/glm4_moe_lite"
"github.com/ollama/ollama/x/imagegen/tokenizer"
@@ -197,13 +197,13 @@ func sample(logits *mlx.Array, temp float32, vocabSize int32) *mlx.Array {
// loadLLMModel loads a safetensors LLM model and its tokenizer from manifest storage.
func (s *server) loadLLMModel() error {
// Load the manifest to get model information
manifest, err := imagegen.LoadManifest(s.modelName)
modelManifest, err := manifest.LoadManifest(s.modelName)
if err != nil {
return fmt.Errorf("failed to load manifest: %w", err)
}
// Detect model architecture from config.json
configData, err := manifest.ReadConfig("config.json")
configData, err := modelManifest.ReadConfig("config.json")
if err != nil {
return fmt.Errorf("failed to read config.json: %w", err)
}
@@ -232,7 +232,7 @@ func (s *server) loadLLMModel() error {
switch {
case strings.Contains(archLower, "glm4moelite"):
m, err := glm4_moe_lite.LoadFromManifest(manifest)
m, err := glm4_moe_lite.LoadFromManifest(modelManifest)
if err != nil {
return fmt.Errorf("failed to load glm4-moe-lite model: %w", err)
}

View File

@@ -1,4 +1,4 @@
package imagegen
package manifest
import (
"encoding/json"

View File

@@ -1,4 +1,4 @@
package imagegen
package manifest
import (
"path/filepath"

View File

@@ -1,6 +1,6 @@
//go:build mlx
package imagegen
package manifest
import (
"fmt"
@@ -15,9 +15,9 @@ import (
type ManifestWeights struct {
manifest *ModelManifest
component string
tensors map[string]ManifestLayer // name -> layer
cache map[string]*mlx.Array // name -> loaded array
nativeCache []*mlx.SafetensorsFile // keep native handles alive
tensors map[string]ManifestLayer // name -> layer
cache map[string]*mlx.Array // name -> loaded array
nativeCache []*mlx.SafetensorsFile // keep native handles alive
}
// LoadWeightsFromManifest creates a weight loader from manifest storage.

View File

@@ -14,6 +14,8 @@ import (
"encoding/json"
"fmt"
"runtime"
"github.com/ollama/ollama/x/imagegen/manifest"
)
// SupportedBackends lists the backends that support image generation.
@@ -41,8 +43,8 @@ func CheckPlatformSupport() error {
// ResolveModelName checks if a model name is a known image generation model.
// Returns the normalized model name if found, empty string otherwise.
func ResolveModelName(modelName string) string {
manifest, err := LoadManifest(modelName)
if err == nil && manifest.HasTensorLayers() {
modelManifest, err := manifest.LoadManifest(modelName)
if err == nil && modelManifest.HasTensorLayers() {
return modelName
}
return ""
@@ -52,12 +54,12 @@ func ResolveModelName(modelName string) string {
// Checks both "architecture" (Ollama format) and "_class_name" (diffusers format).
// Returns empty string if detection fails.
func DetectModelType(modelName string) string {
manifest, err := LoadManifest(modelName)
modelManifest, err := manifest.LoadManifest(modelName)
if err != nil {
return ""
}
data, err := manifest.ReadConfig("model_index.json")
data, err := modelManifest.ReadConfig("model_index.json")
if err != nil {
return ""
}

View File

@@ -12,7 +12,7 @@ import (
"math"
"time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/qwen3"
"github.com/ollama/ollama/x/imagegen/tokenizer"
@@ -61,7 +61,7 @@ func (m *Model) Load(modelName string) error {
m.ModelName = modelName
// Load manifest
manifest, err := imagegen.LoadManifest(modelName)
manifest, err := manifest.LoadManifest(modelName)
if err != nil {
return fmt.Errorf("load manifest: %w", err)
}

View File

@@ -6,7 +6,7 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
@@ -14,19 +14,19 @@ import (
// TransformerConfig holds Flux2 transformer configuration
type TransformerConfig struct {
AttentionHeadDim int32 `json:"attention_head_dim"` // 128
AxesDimsRoPE []int32 `json:"axes_dims_rope"` // [32, 32, 32, 32]
Eps float32 `json:"eps"` // 1e-6
GuidanceEmbeds bool `json:"guidance_embeds"` // false for Klein
InChannels int32 `json:"in_channels"` // 128
JointAttentionDim int32 `json:"joint_attention_dim"` // 7680
MLPRatio float32 `json:"mlp_ratio"` // 3.0
NumAttentionHeads int32 `json:"num_attention_heads"` // 24
NumLayers int32 `json:"num_layers"` // 5
NumSingleLayers int32 `json:"num_single_layers"` // 20
PatchSize int32 `json:"patch_size"` // 1
RopeTheta int32 `json:"rope_theta"` // 2000
TimestepGuidanceChannels int32 `json:"timestep_guidance_channels"` // 256
AttentionHeadDim int32 `json:"attention_head_dim"` // 128
AxesDimsRoPE []int32 `json:"axes_dims_rope"` // [32, 32, 32, 32]
Eps float32 `json:"eps"` // 1e-6
GuidanceEmbeds bool `json:"guidance_embeds"` // false for Klein
InChannels int32 `json:"in_channels"` // 128
JointAttentionDim int32 `json:"joint_attention_dim"` // 7680
MLPRatio float32 `json:"mlp_ratio"` // 3.0
NumAttentionHeads int32 `json:"num_attention_heads"` // 24
NumLayers int32 `json:"num_layers"` // 5
NumSingleLayers int32 `json:"num_single_layers"` // 20
PatchSize int32 `json:"patch_size"` // 1
RopeTheta int32 `json:"rope_theta"` // 2000
TimestepGuidanceChannels int32 `json:"timestep_guidance_channels"` // 256
}
// Computed dimensions
@@ -392,12 +392,12 @@ type Flux2Transformer2DModel struct {
}
// Load loads the Flux2 transformer from ollama blob storage.
func (m *Flux2Transformer2DModel) Load(manifest *imagegen.ModelManifest) error {
func (m *Flux2Transformer2DModel) Load(modelManifest *manifest.ModelManifest) error {
fmt.Print(" Loading transformer... ")
// Load config from blob
var cfg TransformerConfig
if err := manifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
if err := modelManifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.TransformerConfig = &cfg
@@ -412,7 +412,7 @@ func (m *Flux2Transformer2DModel) Load(manifest *imagegen.ModelManifest) error {
}
// Load weights from tensor blobs
weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer")
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "transformer")
if err != nil {
return fmt.Errorf("weights: %w", err)
}

View File

@@ -6,7 +6,7 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
@@ -15,21 +15,21 @@ import (
// VAEConfig holds AutoencoderKLFlux2 configuration
type VAEConfig struct {
ActFn string `json:"act_fn"` // "silu"
BatchNormEps float32 `json:"batch_norm_eps"` // 0.0001
BatchNormMomentum float32 `json:"batch_norm_momentum"` // 0.1
BlockOutChannels []int32 `json:"block_out_channels"` // [128, 256, 512, 512]
ForceUpcast bool `json:"force_upcast"` // true
InChannels int32 `json:"in_channels"` // 3
LatentChannels int32 `json:"latent_channels"` // 32
LayersPerBlock int32 `json:"layers_per_block"` // 2
ActFn string `json:"act_fn"` // "silu"
BatchNormEps float32 `json:"batch_norm_eps"` // 0.0001
BatchNormMomentum float32 `json:"batch_norm_momentum"` // 0.1
BlockOutChannels []int32 `json:"block_out_channels"` // [128, 256, 512, 512]
ForceUpcast bool `json:"force_upcast"` // true
InChannels int32 `json:"in_channels"` // 3
LatentChannels int32 `json:"latent_channels"` // 32
LayersPerBlock int32 `json:"layers_per_block"` // 2
MidBlockAddAttn bool `json:"mid_block_add_attention"` // true
NormNumGroups int32 `json:"norm_num_groups"` // 32
OutChannels int32 `json:"out_channels"` // 3
PatchSize []int32 `json:"patch_size"` // [2, 2]
SampleSize int32 `json:"sample_size"` // 1024
UsePostQuantConv bool `json:"use_post_quant_conv"` // true
UseQuantConv bool `json:"use_quant_conv"` // true
NormNumGroups int32 `json:"norm_num_groups"` // 32
OutChannels int32 `json:"out_channels"` // 3
PatchSize []int32 `json:"patch_size"` // [2, 2]
SampleSize int32 `json:"sample_size"` // 1024
UsePostQuantConv bool `json:"use_post_quant_conv"` // true
UseQuantConv bool `json:"use_quant_conv"` // true
}
// BatchNorm2D implements 2D batch normalization with running statistics
@@ -356,18 +356,18 @@ func (db *DownEncoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
}
// Load loads the Flux2 VAE from ollama blob storage.
func (m *AutoencoderKLFlux2) Load(manifest *imagegen.ModelManifest) error {
func (m *AutoencoderKLFlux2) Load(modelManifest *manifest.ModelManifest) error {
fmt.Print(" Loading VAE... ")
// Load config from blob
var cfg VAEConfig
if err := manifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
if err := modelManifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = &cfg
// Load weights from tensor blobs
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vae")
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "vae")
if err != nil {
return fmt.Errorf("weights: %w", err)
}

View File

@@ -9,8 +9,8 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
@@ -38,11 +38,11 @@ type Config struct {
AttentionBias bool `json:"attention_bias"`
// MLA (Multi-head Latent Attention) parameters
QLoraRank int32 `json:"q_lora_rank"`
KVLoraRank int32 `json:"kv_lora_rank"`
QKRopeHeadDim int32 `json:"qk_rope_head_dim"`
QKNopeHeadDim int32 `json:"qk_nope_head_dim"`
VHeadDim int32 `json:"v_head_dim"`
QLoraRank int32 `json:"q_lora_rank"`
KVLoraRank int32 `json:"kv_lora_rank"`
QKRopeHeadDim int32 `json:"qk_rope_head_dim"`
QKNopeHeadDim int32 `json:"qk_nope_head_dim"`
VHeadDim int32 `json:"v_head_dim"`
// MoE parameters
NRoutedExperts int32 `json:"n_routed_experts"`
@@ -82,7 +82,7 @@ type MLAAttention struct {
// Absorbed MLA projections (derived from kv_b_proj)
// EmbedQ: projects q_nope to latent space [num_heads, kv_lora_rank, qk_nope_head_dim]
// UnembedOut: projects attention output from latent space [num_heads, v_head_dim, kv_lora_rank]
EmbedQ *nn.MultiLinear `weight:"-"`
EmbedQ *nn.MultiLinear `weight:"-"`
UnembedOut *nn.MultiLinear `weight:"-"`
// Output projection
@@ -194,8 +194,8 @@ func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array {
// MoEGate implements the expert gating mechanism
type MoEGate struct {
Gate nn.LinearLayer `weight:"mlp.gate"`
EScoreCorrectionBias *mlx.Array `weight:"mlp.gate.e_score_correction_bias,optional"`
Gate nn.LinearLayer `weight:"mlp.gate"`
EScoreCorrectionBias *mlx.Array `weight:"mlp.gate.e_score_correction_bias,optional"`
}
// Forward computes expert selection indices and scores
@@ -617,9 +617,9 @@ func sanitizeExpertWeights(weights safetensors.WeightSource, prefix string, numE
}
// LoadFromManifest loads a GLM4-MoE-Lite model from a manifest (Ollama blob storage).
func LoadFromManifest(manifest *imagegen.ModelManifest) (*Model, error) {
func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) {
// Read config from manifest
configData, err := manifest.ReadConfig("config.json")
configData, err := modelManifest.ReadConfig("config.json")
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
@@ -634,7 +634,7 @@ func LoadFromManifest(manifest *imagegen.ModelManifest) (*Model, error) {
cfg.Scale = computeScale(&cfg)
// Load weights from manifest blobs
weights, err := imagegen.LoadWeightsFromManifest(manifest, "")
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "")
if err != nil {
return nil, fmt.Errorf("load weights: %w", err)
}
@@ -653,7 +653,7 @@ func LoadFromManifest(manifest *imagegen.ModelManifest) (*Model, error) {
}
// Load tokenizer from manifest with config files for EOS token detection
tokData, err := manifest.ReadConfig("tokenizer.json")
tokData, err := modelManifest.ReadConfig("tokenizer.json")
if err != nil {
return nil, fmt.Errorf("load tokenizer config: %w", err)
}
@@ -664,12 +664,12 @@ func LoadFromManifest(manifest *imagegen.ModelManifest) (*Model, error) {
}
// Try to load generation_config.json if available (preferred source for EOS)
if genConfigData, err := manifest.ReadConfig("generation_config.json"); err == nil {
if genConfigData, err := modelManifest.ReadConfig("generation_config.json"); err == nil {
tokConfig.GenerationConfigJSON = genConfigData
}
// Try to load tokenizer_config.json if available
if tokConfigData, err := manifest.ReadConfig("tokenizer_config.json"); err == nil {
if tokConfigData, err := modelManifest.ReadConfig("tokenizer_config.json"); err == nil {
tokConfig.TokenizerConfigJSON = tokConfigData
}

View File

@@ -7,7 +7,7 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
@@ -181,19 +181,19 @@ type TextEncoder struct {
}
// Load loads the Qwen3 text encoder from ollama blob storage.
func (m *TextEncoder) Load(manifest *imagegen.ModelManifest, configPath string) error {
func (m *TextEncoder) Load(modelManifest *manifest.ModelManifest, configPath string) error {
fmt.Print(" Loading text encoder... ")
// Load config from blob
var cfg Config
if err := manifest.ReadConfigJSON(configPath, &cfg); err != nil {
if err := modelManifest.ReadConfigJSON(configPath, &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = &cfg
m.Layers = make([]*Block, cfg.NumHiddenLayers)
// Load weights from tensor blobs
weights, err := imagegen.LoadWeightsFromManifest(manifest, "text_encoder")
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "text_encoder")
if err != nil {
return fmt.Errorf("weights: %w", err)
}

View File

@@ -7,8 +7,8 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
@@ -38,7 +38,7 @@ type TransformerConfig struct {
type TimestepEmbedder struct {
Linear1 nn.LinearLayer `weight:"mlp.0"`
Linear2 nn.LinearLayer `weight:"mlp.2"`
FreqEmbedSize int32 // 256 (computed)
FreqEmbedSize int32 // 256 (computed)
}
// Forward computes timestep embeddings -> [B, 256]
@@ -85,9 +85,9 @@ func (xe *XEmbedder) Forward(x *mlx.Array) *mlx.Array {
// CapEmbedder projects caption features to model dimension
type CapEmbedder struct {
Norm *nn.RMSNorm `weight:"0"`
Linear nn.LinearLayer `weight:"1"`
PadToken *mlx.Array // loaded separately at root level
Norm *nn.RMSNorm `weight:"0"`
Linear nn.LinearLayer `weight:"1"`
PadToken *mlx.Array // loaded separately at root level
}
// Forward projects caption embeddings: [B, L, cap_feat_dim] -> [B, L, dim]
@@ -103,10 +103,9 @@ type FeedForward struct {
W1 nn.LinearLayer `weight:"w1"` // gate projection
W2 nn.LinearLayer `weight:"w2"` // down projection
W3 nn.LinearLayer `weight:"w3"` // up projection
OutDim int32 // computed from W2
OutDim int32 // computed from W2
}
// Forward applies SwiGLU: silu(W1(x)) * W3(x), then W2
func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
@@ -132,11 +131,11 @@ type Attention struct {
ToK nn.LinearLayer `weight:"to_k"`
ToV nn.LinearLayer `weight:"to_v"`
ToOut nn.LinearLayer `weight:"to_out.0"`
NormQ *mlx.Array `weight:"norm_q.weight"` // [head_dim] for per-head RMSNorm
NormK *mlx.Array `weight:"norm_k.weight"`
NormQ *mlx.Array `weight:"norm_q.weight"` // [head_dim] for per-head RMSNorm
NormK *mlx.Array `weight:"norm_k.weight"`
// Fused QKV (computed at init time for efficiency, not loaded from weights)
ToQKV nn.LinearLayer `weight:"-"` // Fused Q+K+V projection (created by FuseQKV)
Fused bool `weight:"-"` // Whether to use fused QKV path
Fused bool `weight:"-"` // Whether to use fused QKV path
// Computed fields (not loaded from weights)
NHeads int32 `weight:"-"`
HeadDim int32 `weight:"-"`
@@ -288,13 +287,13 @@ func applyRoPE3D(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
// TransformerBlock is a single transformer block with optional AdaLN modulation
type TransformerBlock struct {
Attention *Attention `weight:"attention"`
FeedForward *FeedForward `weight:"feed_forward"`
AttentionNorm1 *nn.RMSNorm `weight:"attention_norm1"`
AttentionNorm2 *nn.RMSNorm `weight:"attention_norm2"`
FFNNorm1 *nn.RMSNorm `weight:"ffn_norm1"`
FFNNorm2 *nn.RMSNorm `weight:"ffn_norm2"`
AdaLN nn.LinearLayer `weight:"adaLN_modulation.0,optional"` // only if modulation
Attention *Attention `weight:"attention"`
FeedForward *FeedForward `weight:"feed_forward"`
AttentionNorm1 *nn.RMSNorm `weight:"attention_norm1"`
AttentionNorm2 *nn.RMSNorm `weight:"attention_norm2"`
FFNNorm1 *nn.RMSNorm `weight:"ffn_norm1"`
FFNNorm2 *nn.RMSNorm `weight:"ffn_norm2"`
AdaLN nn.LinearLayer `weight:"adaLN_modulation.0,optional"` // only if modulation
// Computed fields
HasModulation bool
Dim int32
@@ -350,7 +349,7 @@ func (tb *TransformerBlock) Forward(x *mlx.Array, adaln *mlx.Array, cos, sin *ml
type FinalLayer struct {
AdaLN nn.LinearLayer `weight:"adaLN_modulation.1"` // [256] -> [dim]
Output nn.LinearLayer `weight:"linear"` // [dim] -> [out_channels]
OutDim int32 // computed from Output
OutDim int32 // computed from Output
}
// Forward computes final output
@@ -401,12 +400,12 @@ type Transformer struct {
}
// Load loads the Z-Image transformer from ollama blob storage.
func (m *Transformer) Load(manifest *imagegen.ModelManifest) error {
func (m *Transformer) Load(modelManifest *manifest.ModelManifest) error {
fmt.Print(" Loading transformer... ")
// Load config from blob
var cfg TransformerConfig
if err := manifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
if err := modelManifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
if len(cfg.AllPatchSize) > 0 {
@@ -417,7 +416,7 @@ func (m *Transformer) Load(manifest *imagegen.ModelManifest) error {
m.ContextRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
m.Layers = make([]*TransformerBlock, cfg.NLayers)
weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer")
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "transformer")
if err != nil {
return fmt.Errorf("weights: %w", err)
}

View File

@@ -6,7 +6,7 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
"github.com/ollama/ollama/x/imagegen/vae"
@@ -562,7 +562,7 @@ func (ub *UpDecoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
if ub.Upsample != nil {
// Stage 1: Upsample2x (nearest neighbor)
{
prev := x
prev := x
x = Upsample2x(x)
prev.Free()
mlx.Eval(x)
@@ -570,7 +570,7 @@ func (ub *UpDecoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
// Stage 2: Upsample conv
{
prev := x
prev := x
x = ub.Upsample.Forward(x)
prev.Free()
mlx.Eval(x)
@@ -643,16 +643,16 @@ type VAEDecoder struct {
}
// Load loads the VAE decoder from ollama blob storage.
func (m *VAEDecoder) Load(manifest *imagegen.ModelManifest) error {
func (m *VAEDecoder) Load(modelManifest *manifest.ModelManifest) error {
// Load config from blob
var cfg VAEConfig
if err := manifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
if err := modelManifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = &cfg
// Load weights from tensor blobs
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vae")
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "vae")
if err != nil {
return fmt.Errorf("weights: %w", err)
}

View File

@@ -8,8 +8,8 @@ import (
"fmt"
"time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/imagegen/vae"
@@ -18,14 +18,14 @@ import (
// GenerateConfig holds all options for image generation.
type GenerateConfig struct {
Prompt string
NegativePrompt string // Empty = no CFG
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
Width int32 // Image width (default: 1024)
Height int32 // Image height (default: 1024)
Steps int // Denoising steps (default: 9 for turbo)
Seed int64 // Random seed
NegativePrompt string // Empty = no CFG
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
Width int32 // Image width (default: 1024)
Height int32 // Image height (default: 1024)
Steps int // Denoising steps (default: 9 for turbo)
Seed int64 // Random seed
Progress func(step, totalSteps int) // Optional progress callback
CapturePath string // GPU capture path (debug)
CapturePath string // GPU capture path (debug)
// TeaCache options (timestep embedding aware caching)
TeaCache bool // TeaCache is always enabled for faster inference
@@ -58,7 +58,7 @@ func (m *Model) Load(modelName string) error {
m.ModelName = modelName
// Load manifest
manifest, err := imagegen.LoadManifest(modelName)
manifest, err := manifest.LoadManifest(modelName)
if err != nil {
return fmt.Errorf("load manifest: %w", err)
}

View File

@@ -1,7 +1,7 @@
//go:build mlx
// Package mlxrunner provides a unified MLX runner for both LLM and image generation models.
package mlxrunner
// Package imagegen provides a unified MLX runner for both LLM and image generation models.
package imagegen
import (
"context"
@@ -16,7 +16,6 @@ import (
"time"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
)
@@ -98,7 +97,7 @@ func Execute(args []string) error {
// detectModelMode determines whether a model is an LLM or image generation model.
func detectModelMode(modelName string) ModelMode {
// Check for image generation model by looking at model_index.json
modelType := imagegen.DetectModelType(modelName)
modelType := DetectModelType(modelName)
if modelType != "" {
// Known image generation model types
switch modelType {

View File

@@ -1,6 +1,6 @@
//go:build !mlx
package mlxrunner
package imagegen
import "errors"

View File

@@ -1,4 +1,4 @@
package mlxrunner
package imagegen
import (
"bufio"
@@ -23,7 +23,7 @@ import (
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/manifest"
)
// Server wraps an MLX runner subprocess to implement llm.LlamaServer.
@@ -46,7 +46,7 @@ type Server struct {
// NewServer spawns a new MLX runner subprocess and waits until it's ready.
func NewServer(modelName string, mode ModelMode) (*Server, error) {
// Validate platform support before attempting to start
if err := imagegen.CheckPlatformSupport(); err != nil {
if err := CheckPlatformSupport(); err != nil {
return nil, err
}
@@ -71,8 +71,8 @@ func NewServer(modelName string, mode ModelMode) (*Server, error) {
exe = eval
}
// Spawn subprocess: ollama runner --mlx-engine --model <path> --port <port>
cmd := exec.Command(exe, "runner", "--mlx-engine", "--model", modelName, "--port", strconv.Itoa(port))
// Spawn subprocess: ollama runner --imagegen-engine --model <path> --port <port>
cmd := exec.Command(exe, "runner", "--imagegen-engine", "--model", modelName, "--port", strconv.Itoa(port))
cmd.Env = os.Environ()
// On Linux, set LD_LIBRARY_PATH to include MLX library directories
@@ -107,8 +107,8 @@ func NewServer(modelName string, mode ModelMode) (*Server, error) {
// Estimate VRAM based on tensor size from manifest
var vramSize uint64
if manifest, err := imagegen.LoadManifest(modelName); err == nil {
vramSize = uint64(manifest.TotalTensorSize())
if modelManifest, err := manifest.LoadManifest(modelName); err == nil {
vramSize = uint64(modelManifest.TotalTensorSize())
} else {
// Fallback: default to 8GB if manifest can't be loaded
vramSize = 8 * 1024 * 1024 * 1024

View File

@@ -1,9 +1,9 @@
// Package mlxrunner provides a unified MLX runner for both LLM and image generation models.
// Package imagegen provides a unified MLX runner for both LLM and image generation models.
//
// This package handles safetensors models created with `ollama create --experimental`,
// supporting both text generation (LLM) and image generation (diffusion) models
// through a single unified interface.
package mlxrunner
package imagegen
// Request is the request format for completion requests.
type Request struct {

View File

@@ -1,77 +0,0 @@
package kvcache
import (
"errors"
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/model/input"
)
var (
ErrKvCacheFull = errors.New("could not find a kv cache slot")
ErrNotSupported = errors.New("model does not support operation")
)
type Cache interface {
// ** used by model implementations **
// SetLayer sets the active layer of the cache
SetLayer(layer int)
// Get returns the history of key and value tensors plus a mask
//
// The shape of the tensors is documented in the specific
// cache implementation used.
Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)
// Put stores a batch of key and value in the cache
//
// The shape of the tensors is documented in the specific
// cache implementation used.
Put(ctx ml.Context, key, value ml.Tensor)
// SetConfig controls optimizations (mostly backend-specific) that may transform
// the output of the cache to work better with specific kernels. If not called,
// the backend settings will be used. This works well when calling Attention.
//
// The config can be overridden by models, especially if they require vanilla
// output when implementing their own version of attention. To do this, pass
// an empty ml.CacheConfig.
//
// Most models will not need to use this.
SetConfig(ml.CacheConfig)
// ** cache management **
// Init sets up runtime parameters.
// backend: Used to allocate cache data storage and execute management operations (such as defrag)
// dtype: The data type for storing cache entries
// maxSequences: The maximum number of sequences stored in the cache - across all batches
// capacity: The number of cache entries to store, per sequence
// maxBatch: The maximum number of tokens that can occur in a single batch
Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
// Close closes the cache and frees resources associated with it
Close()
// StartForward is called before the start of the model's forward pass.
// For each token in the coming batch, there must be a corresponding
// entry in positions and seqs. reserve is to preallocate memory
// without actually storing data in the cache.
StartForward(ctx ml.Context, batch input.Batch, reserve bool) error
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
CopyPrefix(srcSeq, dstSeq int, len int32)
// CanResume returns true if the cache can continue with the next token at
// the given position and sequence. Assumes that the caller has already
// verified the contents of the cache.
CanResume(seq int, pos int32) bool
// Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set
// endIndex to math.MaxInt32 to remove everything starting at beginIndex.
//
// If an error occurs, the entire context for the sequence should be
// removed by calling Remove(seq, 0, math.MaxInt32)
Remove(seq int, beginIndex, endIndex int32) error
}

View File

@@ -1,144 +0,0 @@
//go:build mlx
package kvcache
import (
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/model/input"
)
// Causal cache stores K and V tensors according to their position in the
// sequence. Returns the history and a mask for attending to past tokens
type Causal struct {
DType ml.DType
// locations for data storage for this batch
curLocPut ml.Tensor
// locations for data storage for this batch
curLocGet ml.Tensor
// the active layer for Get and Put
curLayer int
capacity int
offset int
backend ml.Backend
ctxs map[int]ml.Context
keys, values map[int]ml.Tensor
// TODO is this needed per layer, or will it always be consistent?
kHeadDims, vHeadDims, numKVHeads map[int]int
}
func NewCausalCache() *Causal {
return &Causal{
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
kHeadDims: make(map[int]int),
vHeadDims: make(map[int]int),
numKVHeads: make(map[int]int),
}
}
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
c.DType = dtype
c.capacity = capacity
c.backend = backend
}
func (c *Causal) SetConfig(config ml.CacheConfig) {}
func (c *Causal) SetLayer(layer int) {
c.curLayer = layer
}
func (c *Causal) Close() {
// slog.Info("XXX Causal.Close called", "number of contexts", len(c.ctxs))
for _, ctx := range c.ctxs {
ctx.Close()
}
}
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
locsPut := make([]int32, len(batch.Positions))
for i := c.offset; i < len(batch.Positions); i++ {
locsPut[i-c.offset] = int32(i)
}
c.offset += len(batch.Positions)
locsGet := make([]int32, c.offset)
for i := range c.offset {
locsGet[i] = int32(i)
}
c.curLocGet = ctx.Input().FromInts(locsGet, len(locsGet))
c.curLocPut = ctx.Input().FromInts(locsPut, len(locsPut))
// slog.Info("XXX Causal.StartForward", "offset", c.offset, "put", locsPut, "get", locsGet)
return nil
}
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
kHeadDim := key.Dim(3)
vHeadDim := value.Dim(3)
numKVHeads := key.Dim(1)
batchSize := key.Dim(2)
kCellSize := kHeadDim * numKVHeads
vCellSize := vHeadDim * numKVHeads
// slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize, "kCellSize", kCellSize, "vCellSize", vCellSize)
if _, ok := c.ctxs[c.curLayer]; !ok {
// slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer)
c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
}
if _, ok := c.keys[c.curLayer]; !ok {
// slog.Info("XXX Causal.Put allocating keys and values", "c.curLayer", c.curLayer, "shape", []int{c.capacity, kCellSize})
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, kCellSize)
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, vCellSize)
c.kHeadDims[c.curLayer] = kHeadDim
c.vHeadDims[c.curLayer] = vHeadDim
c.numKVHeads[c.curLayer] = numKVHeads
}
key = key.Reshape(ctx, batchSize, 1, kCellSize)
// slog.Info("XXX Causal.Put ", "c.keys[c.curLayer]", c.keys[c.curLayer])
// slog.Info("XXX Causal.Put ", "c.curLocPut", c.curLocPut)
// slog.Info("XXX Causal.Put ", "key", key)
ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, key, []int{0}))
value = value.Reshape(ctx, batchSize, 1, vCellSize)
ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, value, []int{0}))
}
func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
key := c.keys[c.curLayer]
value := c.values[c.curLayer]
kHeadDim := c.kHeadDims[c.curLayer]
vHeadDim := c.vHeadDims[c.curLayer]
numKVHeads := c.numKVHeads[c.curLayer]
// rowSize := numKVHeads * c.curBatchSize
// cachedSize := c.curMask.Dim(1)
cachedSize := c.curLocGet.Dim(0)
// kCellSize := kHeadDim * numKVHeads
// vCellSize := vHeadDim * numKVHeads
// slog.Info("XXX Causal.Get", "shape", []int{1, numKVHeads, cachedSize, kHeadDim})
key = key.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim)
value = value.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim)
return key, value, nil
}
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
panic("not implemented")
}
func (c *Causal) CanResume(seq int, pos int32) bool {
panic("not implemented")
}
func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
panic("not implemented")
}

View File

@@ -1,156 +0,0 @@
package kvcache
// import (
// "fmt"
// "github.com/ollama/ollama/ml"
// "github.com/ollama/ollama/model/input"
// )
// // Encoder cache stores K and V tensors that are position independent
// //
// // The tensors can be of any shape and will be returned as they were stored
// // The mask is currently always nil
// //
// // Not currently safe for multiple sequences
// type EncoderCache struct {
// // config controls mostly backend-specific optimizations
// config *ml.CacheConfig
// // ** current forward pass **
// // the active layer for Get and Put
// curLayer int
// // if something is stored during this pass, this
// // will be the position (but there is no guarantee
// // anything will be stored)
// curPos int32
// // curReserve indicates that this forward pass is only for
// // memory reservation and we should not update our metadata
// // based on it.
// curReserve bool
// // ** cache metadata **
// // was something stored in the cache?
// encoderCached bool
// // position of the cached data
// encoderPos int32
// // ** cache data storage **
// backend ml.Backend
// ctxs map[int]ml.Context
// keys, values map[int]ml.Tensor
// }
// func NewEncoderCache() *EncoderCache {
// return &EncoderCache{
// ctxs: make(map[int]ml.Context),
// keys: make(map[int]ml.Tensor),
// values: make(map[int]ml.Tensor),
// }
// }
// func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
// if c.config == nil {
// var config ml.CacheConfig
// if cc, ok := backend.(ml.BackendCacheConfig); ok {
// config = cc.CacheConfig()
// }
// c.config = &config
// }
// if maxSequences > 1 {
// panic(fmt.Errorf("encoder cache does not support multiple sequences; requested: %v", maxSequences))
// }
// if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
// panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
// }
// c.backend = backend
// }
// func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
// if c.config != nil {
// panic("config cannot be changed after being previously set, either by the model or backend")
// }
// c.config = &config
// }
// func (c *EncoderCache) Close() {
// for _, ctx := range c.ctxs {
// ctx.Close()
// }
// }
// func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
// // We work with the most recent image
// if len(batch.Multimodal) > 0 {
// c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
// }
// c.curReserve = reserve
// return nil
// }
// func (c *EncoderCache) SetLayer(layer int) {
// c.curLayer = layer
// }
// func (c *EncoderCache) EncoderCached() bool {
// return c.encoderCached
// }
// func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
// return c.keys[c.curLayer], c.values[c.curLayer], nil
// }
// func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
// if !c.curReserve {
// c.encoderPos = c.curPos
// c.encoderCached = true
// }
// if c.config.PermutedV {
// value = value.Transpose(ctx, 1, 2, 0, 3)
// }
// if _, ok := c.ctxs[c.curLayer]; !ok {
// c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
// }
// if _, ok := c.keys[c.curLayer]; !ok {
// c.keys[c.curLayer] = c.ctxs[c.curLayer].Empty(key.DType(), key.Shape()...)
// }
// if _, ok := c.values[c.curLayer]; !ok {
// c.values[c.curLayer] = c.ctxs[c.curLayer].Empty(value.DType(), value.Shape()...)
// }
// ctx.Forward(
// key.Copy(ctx, c.keys[c.curLayer]),
// value.Copy(ctx, c.values[c.curLayer]),
// )
// }
// func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
// panic("encoder cache does not support multiple sequences")
// }
// func (c *EncoderCache) CanResume(seq int, pos int32) bool {
// return true
// }
// func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error {
// if c.encoderPos >= beginIndex && c.encoderPos < endIndex {
// c.encoderCached = false
// }
// return nil
// }

View File

@@ -1,110 +0,0 @@
package kvcache
// import (
// "math"
// "github.com/ollama/ollama/ml"
// "github.com/ollama/ollama/model/input"
// )
// // Wrapper cache is a container for multiple types of caches,
// // such as for the encoding and decoding portions of a model.
// type WrapperCache struct {
// // caches we are wrapping
// caches []Cache
// // cache to be used for this layer
// curType int
// }
// func NewWrapperCache(caches ...Cache) *WrapperCache {
// return &WrapperCache{
// caches: caches,
// }
// }
// func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
// for _, cache := range c.caches {
// cache.Init(backend, dtype, maxSequences, capacity, maxBatch)
// }
// }
// func (c *WrapperCache) SetConfig(config ml.CacheConfig) {
// for _, cache := range c.caches {
// cache.SetConfig(config)
// }
// }
// func (c *WrapperCache) Close() {
// for _, cache := range c.caches {
// cache.Close()
// }
// }
// func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
// for i, cache := range c.caches {
// err := cache.StartForward(ctx, batch, reserve)
// if err != nil {
// // unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
// for j := i - 1; j >= 0; j-- {
// for k := range batch.Positions {
// _ = c.caches[j].Remove(batch.Sequences[k], batch.Positions[k], math.MaxInt32)
// }
// }
// return err
// }
// }
// c.curType = 0
// return nil
// }
// func (c *WrapperCache) SetLayer(layer int) {
// for _, cache := range c.caches {
// cache.SetLayer(layer)
// }
// }
// func (c *WrapperCache) SetLayerType(layerType int) {
// c.curType = layerType
// }
// func (c *WrapperCache) UnderlyingCache() Cache {
// return c.caches[c.curType]
// }
// func (c *WrapperCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
// return c.caches[c.curType].Get(ctx)
// }
// func (c *WrapperCache) Put(ctx ml.Context, key, value ml.Tensor) {
// c.caches[c.curType].Put(ctx, key, value)
// }
// func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
// for _, cache := range c.caches {
// cache.CopyPrefix(srcSeq, dstSeq, len)
// }
// }
// func (c *WrapperCache) CanResume(seq int, pos int32) bool {
// for _, cache := range c.caches {
// if !cache.CanResume(seq, pos) {
// return false
// }
// }
// return true
// }
// func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error {
// // If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail
// for _, cache := range c.caches {
// err := cache.Remove(seq, beginIndex, endIndex)
// if err != nil {
// return err
// }
// }
// return nil
// }

View File

@@ -1,433 +0,0 @@
package ml
import (
"fmt"
"log/slog"
"os"
"github.com/ollama/ollama/fs"
)
type Backend interface {
// Close frees all memory associated with this backend
// Close()
// Load(ctx context.Context, progress func(float32)) error
// BackendMemory returns the memory allocations that were made for this model
// BackendMemory() BackendMemory
Config() fs.Config
Get(name string) Tensor
NewContext() Context
// NewContextSize(size int) Context
// Enumerate the devices available for inference via this backend
// BackendDevices() []DeviceInfo
}
// BackendCacheConfig should be implemented by backends that need special output
// from the cache to meet specific requirements. It is frequently implemented in
// conjunction with ScaledDotProductAttention.
type BackendCacheConfig interface {
CacheConfig() CacheConfig
}
// CacheConfig controls optimizations (mostly backend-specific) that may transform
// the output the cache to work better with specific kernels.
type CacheConfig struct {
// CachePadding specifies the multiple for the number of tokens of cache history
// that will be returned from cache Get for k, v and mask. The capacity of the
// cache itself will also be increased to a multiple of this size if needed.
CachePadding int
// PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put
// and return the permuted version via Get. This uses the cache copy operation
// to avoid a Contiguous call on the permuted tensor.
PermutedV bool
// MaskDType specifies the data type for generating the mask. If unset it will
// default to DTypeF32.
MaskDType DType
// MaskBatchPadding specifies the multiple for the batch size dimension in the mask.
// Any position that does not correspond to an actual token will be filled with -Inf.
MaskBatchPadding int
}
// BackendParams controls how the backend loads and executes models
type BackendParams struct {
// AllocMemory causes the backend to allocate memory for the model. If
// false, this is only being used for discovering the required amount of
// memory and cannot load the model for running.
AllocMemory bool
// NumThreads sets the number of threads to use if running on the CPU
NumThreads int
// GPULayers is the set of layers to offload to GPUs
GPULayers GPULayersList
// FlashAttention indicates that we should use a fused flash attention kernel
FlashAttention bool
}
var backends = make(map[string]func(string, BackendParams) (Backend, error))
func RegisterBackend(name string, f func(string, BackendParams) (Backend, error)) {
if _, ok := backends[name]; ok {
panic("backend: backend already registered")
}
backends[name] = f
}
func NewBackend(modelPath string, params BackendParams) (Backend, error) {
be := os.Getenv("OLLAMA_BACKEND")
if be == "" {
be = "mlx"
slog.Info("Defaulting to " + be + ". Set OLLAMA_BACKEND to override")
}
slog.Info("Loading new engine", "backend", be)
if backend, ok := backends[be]; ok {
return backend(modelPath, params)
}
return nil, fmt.Errorf("unsupported backend")
}
type Context interface {
Empty(dtype DType, shape ...int) Tensor
Zeros(dtype DType, shape ...int) Tensor
// FromBytes(dtype DType, s []byte, shape ...int) Tensor
FromFloats(s []float32, shape ...int) Tensor
FromInts(s []int32, shape ...int) Tensor
RandomNormal(shape []int, dtype DType, loc, scale float32, key Tensor) Tensor
// Arange creates a 1D tensor with values within an interval (start, stop] increased by step.
Arange(start, stop, step float32, dtype DType) Tensor
Forward(...Tensor) Context
// SetBatchSize provides a hint on the batch size to optimize processing
// Uses heuristics if not set
// SetBatchSize(int)
Compute(...Tensor)
// ComputeWithNotify(func(), ...Tensor) // notify callback once compute has begun
// Reserve is analogous to Compute but rather than executing a
// graph, simply preallocates memory. Typically called with a
// worst case graph to ensure all resources are available for
// for future inference.
// Reserve()
// MaxGraphNodes() int
Close()
// Input returns a context appropriate for creating tensors that are
// inputs to the model (which includes things like output locations)
Input() Context
// Layer returns a context appropriate for creating intermediate tensors
Layer(int) Context
// Load a tensor from "filename" safetensors file, and compare with the input tensor
// Returns error if the shape is inconsistent, or similarity measures are below 99%
CompareWith(filename string, tensors map[string]Tensor, abortOnError bool) error
}
type RoPEOptions struct {
Base *float32
Freqs Tensor
}
func WithRoPEBase(base float32) func(*RoPEOptions) {
return func(opts *RoPEOptions) {
opts.Base = &base
}
}
func WithRoPEFreqs(freqs Tensor) func(*RoPEOptions) {
return func(opts *RoPEOptions) {
opts.Freqs = freqs
}
}
type Tensor interface {
ToString() string
RoPE(ctx Context, dims int, traditional bool, scale float32, offset int, options ...func(*RoPEOptions)) Tensor
ScaledDotProductAttention(ctx Context, keys, values Tensor, scale float64, maskMode string, mask Tensor, sinks Tensor) Tensor
TakeAxes(ctx Context, indicies Tensor, axes int) Tensor
// TakeAxes(ctx Context, axes int, indicies ...int) Tensor
Dim(n int) int
Stride(n int) int
Shape() []int
DType() DType
// Cast(ctx Context, dtype DType) Tensor
// Bytes() []byte
Floats() []float32
Ints() []int32
// FromBytes([]byte)
// FromFloats([]float32)
// FromInts([]int32)
Add(ctx Context, t2 Tensor) Tensor
Sub(ctx Context, t2 Tensor) Tensor
// Mul(ctx Context, t2 Tensor) Tensor
// Div(ctx Context, t2 Tensor) Tensor
Max(ctx Context, axes []int, keepDims bool) Tensor
Min(ctx Context, axes []int, keepDims bool) Tensor
Matmul(ctx Context, a2 Tensor) Tensor
// Mulmat(ctx Context, t2 Tensor) Tensor
// MulmatFullPrec(ctx Context, t2 Tensor) Tensor
// MulmatID(ctx Context, t2, ids Tensor) Tensor
// AddID(ctx Context, t2, ids Tensor) Tensor
Softmax(ctx Context) Tensor
L2Norm(ctx Context, eps float32) Tensor
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
Scale(ctx Context, s float64) Tensor
// SumRows(ctx Context) Tensor
AvgPool2D(ctx Context, k, s int, p float32) Tensor
Conv2D(ctx Context, weight Tensor, stride0, stride1, padding0, padding1, dilation0, dilation1, groups int) Tensor
Conv3D(ctx Context, weight Tensor, stride0, stride1, stride2, padding0, padding1, padding2, dilation0, dilation1, dilation2, groups int) Tensor
// IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
// Sin(ctx Context) Tensor
// Cos(ctx Context) Tensor
// Tanh(ctx Context) Tensor
GELU(ctx Context, up ...Tensor) Tensor
// QuickGELU(ctx Context, up ...Tensor) Tensor
// SILU(ctx Context, up ...Tensor) Tensor
// RELU(ctx Context, up ...Tensor) Tensor
// Sigmoid(ctx Context) Tensor
// AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit]
// SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor
Reshape(ctx Context, shape ...int) Tensor
AsStrided(ctx Context, shape, strides []int, offset int) Tensor
Transpose(ctx Context, shape ...int) Tensor
Contiguous(ctx Context, allowColMajor bool) Tensor
// Pad(ctx Context, shape ...int) Tensor
// Stack(ctx Context, dim int, s ...Tensor) Tensor
// Repeat repeats the tensor n times along dimension dim
// Repeat(ctx Context, dim, n int) Tensor
// Concat(ctx Context, t2 Tensor, dim int) Tensor
// Rows(ctx Context, t2 Tensor) Tensor
// TODO these probably aren't actually needed - false starts on trying to wire up cache
// SliceUpdate(ctx Context, update Tensor, start, stop, strides []int) Tensor
// SliceUpdateDynamic(ctx Context, update, start Tensor, axes []int) Tensor
// PutAlongAxis(ctx Context, indicies, values Tensor, axis int) Tensor
Scatter(ctx Context, indicies []Tensor, updates Tensor, axes []int) Tensor
Copy(ctx Context, t2 Tensor) Tensor
// Duplicate(ctx Context) Tensor
// Slice(ctx Context, dim, low, high, step int) Tensor
// Chunk(ctx Context, dim int, size int) []Tensor
// ChunkSections(ctx Context, dim int, sections ...int) []Tensor
// TopK(ctx Context, k int) Tensor
// Argsort(ctx Context) Tensor
// Mean(ctx Context) Tensor
// Variance(ctx Context) Tensor
// Stddev(ctx Context) Tensor
// Sqr(ctx Context) Tensor
// Sqrt(ctx Context) Tensor
// Interpolate(ctx Context, dims [4]int, samplingMode SamplingMode) Tensor
}
// ScaledDotProductAttention implements a fused attention
// operation equivalent to following code on a tensor named
// query:
//
// query = query.Permute(ctx, 0, 2, 1, 3)
// key = key.Permute(ctx, 0, 2, 1, 3)
// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
//
// kq := key.MulmatFullPrec(ctx, query)
//
// kq = kq.Scale(ctx, scale)
//
// if mask != nil {
// kq = kq.Add(ctx, mask)
// }
//
// kq = kq.Softmax(ctx)
//
// kqv := value.Mulmat(ctx, kq)
// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
// type ScaledDotProductAttention interface {
// ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64) Tensor
// }
// type number interface {
// ~int | ~int8 | ~int16 | ~int32 | ~int64 |
// ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
// ~float32 | ~float64 |
// ~complex64 | ~complex128
// }
// func mul[T number](s ...T) T {
// p := T(1)
// for _, v := range s {
// p *= v
// }
// return p
// }
// type DumpOptions func(*dumpOptions)
// // DumpWithPrecision sets the number of decimal places to print. Applies to float32 and float64.
// func DumpWithPrecision(n int) DumpOptions {
// return func(opts *dumpOptions) {
// opts.Precision = n
// }
// }
// // DumpWithThreshold sets the threshold for printing the entire tensor. If the number of elements
// // is less than or equal to this value, the entire tensor will be printed. Otherwise, only the
// // beginning and end of each dimension will be printed.
// func DumpWithThreshold(n int) DumpOptions {
// return func(opts *dumpOptions) {
// opts.Threshold = n
// }
// }
// // DumpWithEdgeItems sets the number of elements to print at the beginning and end of each dimension.
// func DumpWithEdgeItems(n int) DumpOptions {
// return func(opts *dumpOptions) {
// opts.EdgeItems = n
// }
// }
// type dumpOptions struct {
// Precision, Threshold, EdgeItems int
// }
// func Dump(ctx Context, t Tensor, optsFuncs ...DumpOptions) string {
// opts := dumpOptions{Precision: 4, Threshold: 1000, EdgeItems: 3}
// for _, optsFunc := range optsFuncs {
// optsFunc(&opts)
// }
// if mul(t.Shape()...) <= opts.Threshold {
// opts.EdgeItems = math.MaxInt
// }
// switch t.DType() {
// case DTypeFloat32:
// return dump[[]float32](ctx, t, opts.EdgeItems, func(f float32) string {
// return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
// })
// case DTypeFloat16: // TODO other types...
// f32 := ctx.Input().Empty(DTypeFloat32, t.Shape()...)
// f32 = t.Copy(ctx, f32)
// return dump[[]float32](ctx, f32, opts.EdgeItems, func(f float32) string {
// return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
// })
// case DTypeInt32:
// return dump[[]int32](ctx, t, opts.EdgeItems, func(i int32) string {
// return strconv.FormatInt(int64(i), 10)
// })
// default:
// return "<unsupported>"
// }
// }
// func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
// if t.Bytes() == nil {
// ctx.Compute(t)
// }
// s := make(S, mul(t.Shape()...))
// if err := binary.Read(bytes.NewBuffer(t.Bytes()), binary.LittleEndian, &s); err != nil {
// panic(err)
// }
// shape := t.Shape()
// slices.Reverse(shape)
// var sb strings.Builder
// var f func([]int, int)
// f = func(dims []int, stride int) {
// prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
// sb.WriteString("[")
// defer func() { sb.WriteString("]") }()
// for i := 0; i < dims[0]; i++ {
// if i >= items && i < dims[0]-items {
// sb.WriteString("..., ")
// // skip to next printable element
// skip := dims[0] - 2*items
// if len(dims) > 1 {
// stride += mul(append(dims[1:], skip)...)
// fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
// }
// i += skip - 1
// } else if len(dims) > 1 {
// f(dims[1:], stride)
// stride += mul(dims[1:]...)
// if i < dims[0]-1 {
// fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
// }
// } else {
// text := fn(s[stride+i])
// if len(text) > 0 && text[0] != '-' {
// sb.WriteString(" ")
// }
// sb.WriteString(text)
// if i < dims[0]-1 {
// sb.WriteString(", ")
// }
// }
// }
// }
// f(shape, 0)
// return sb.String()
// }
type DType int
const (
DTypeBool DType = iota
DTypeUint8
DTypeUint16
DTypeUint32
DTypeUint64
DTypeInt8
DTypeInt16
DTypeInt32
DTypeInt64
DTypeFloat16
DTypeFloat32
DTypeFloat64
DTypeBfloat16
DTypeComplex64
)
type SamplingMode int
const (
SamplingModeNearest SamplingMode = iota
SamplingModeBilinear
)

View File

@@ -1,3 +0,0 @@
package backend
// _ "github.com/ollama/ollama/x/ml/backend/mlx"

View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,92 +0,0 @@
// mlx_dynamic.c - Dynamic loading wrapper for MLX-C library
// This file provides runtime dynamic loading of libmlxc instead of link-time binding
#include "mlx_dynamic.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#ifdef _WIN32
#include <windows.h>
typedef HMODULE lib_handle_t;
#define LOAD_LIB(path) LoadLibraryA(path)
#define GET_SYMBOL(handle, name) GetProcAddress(handle, name)
#define CLOSE_LIB(handle) FreeLibrary(handle)
#define LIB_ERROR() "LoadLibrary failed"
static const char* LIB_NAMES[] = {"libmlxc.dll", NULL};
#else
#include <dlfcn.h>
typedef void* lib_handle_t;
#define LOAD_LIB(path) dlopen(path, RTLD_LAZY | RTLD_GLOBAL)
#define GET_SYMBOL(handle, name) dlsym(handle, name)
#define CLOSE_LIB(handle) dlclose(handle)
#define LIB_ERROR() dlerror()
#ifdef __APPLE__
static const char* LIB_NAMES[] = {
"libmlxc.dylib",
"@loader_path/../build/lib/ollama/libmlxc.dylib",
"@executable_path/../build/lib/ollama/libmlxc.dylib",
"build/lib/ollama/libmlxc.dylib",
"../build/lib/ollama/libmlxc.dylib",
NULL
};
#else
static const char* LIB_NAMES[] = {
"libmlxc.so",
"$ORIGIN/../build/lib/ollama/libmlxc.so",
"build/lib/ollama/libmlxc.so",
"../build/lib/ollama/libmlxc.so",
NULL
};
#endif
#endif
static lib_handle_t mlx_handle = NULL;
static int mlx_initialized = 0;
static char mlx_error_buffer[512] = {0};
// Initialize MLX dynamic library
// Returns 0 on success, -1 on failure
// On failure, call mlx_dynamic_error() to get error message
int mlx_dynamic_init(void) {
if (mlx_initialized) {
return 0; // Already initialized
}
// Try each possible library path
for (int i = 0; LIB_NAMES[i] != NULL; i++) {
mlx_handle = LOAD_LIB(LIB_NAMES[i]);
if (mlx_handle != NULL) {
mlx_initialized = 1;
snprintf(mlx_error_buffer, sizeof(mlx_error_buffer),
"MLX: Successfully loaded %s", LIB_NAMES[i]);
return 0;
}
}
// Failed to load library
const char* err = LIB_ERROR();
snprintf(mlx_error_buffer, sizeof(mlx_error_buffer),
"MLX: Failed to load libmlxc library. %s",
err ? err : "Unknown error");
return -1;
}
// Get the last error message
const char* mlx_dynamic_error(void) {
return mlx_error_buffer;
}
// Check if MLX is initialized
int mlx_dynamic_is_initialized(void) {
return mlx_initialized;
}
// Cleanup (optional, called at program exit)
void mlx_dynamic_cleanup(void) {
if (mlx_handle != NULL) {
CLOSE_LIB(mlx_handle);
mlx_handle = NULL;
mlx_initialized = 0;
}
}

View File

@@ -1,26 +0,0 @@
// mlx_dynamic.h - Dynamic loading interface for MLX-C library
#ifndef MLX_DYNAMIC_H
#define MLX_DYNAMIC_H
#ifdef __cplusplus
extern "C" {
#endif
// Initialize the MLX dynamic library
// Returns 0 on success, -1 on failure
int mlx_dynamic_init(void);
// Get the last error message from dynamic loading
const char* mlx_dynamic_error(void);
// Check if MLX is initialized
int mlx_dynamic_is_initialized(void);
// Cleanup resources (optional, for clean shutdown)
void mlx_dynamic_cleanup(void);
#ifdef __cplusplus
}
#endif
#endif // MLX_DYNAMIC_H

View File

@@ -1,314 +0,0 @@
//go:build mlx
package mlx
import (
"log/slog"
"os"
"reflect"
"strings"
"testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/runner/common"
"github.com/ollama/ollama/sample"
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/model"
"github.com/ollama/ollama/x/model/input"
_ "github.com/ollama/ollama/x/model/models/gemma3"
)
func init() {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
slog.SetDefault(logger)
}
func TestLoadModel(t *testing.T) {
dir := "/Users/daniel/Models/gemma-3-4b-it/"
b := &Backend{}
err := b.LoadSafeTensors(dir)
if err != nil {
t.Fatalf("load failed: %s", err)
}
}
func TestFromInts(t *testing.T) {
b := &Backend{}
c := b.NewContext()
defer c.Close()
data := []int32{1, 2, 3, 4, 5, 6}
a := c.FromInts(data, 2, 3)
slog.Info("", "array", a)
t.Log(a.ToString())
if !reflect.DeepEqual(a.Shape(), []int{2, 3}) {
t.Fatalf("incorrect shape: %v", a.Shape())
}
}
func TestFromFloats(t *testing.T) {
b := &Backend{}
c := b.NewContext()
defer c.Close()
data := []float32{1, 2, 3, 4, 5, 6}
a := c.FromFloats(data, 2, 3)
slog.Info("", "array", a)
t.Log(a.ToString())
if !reflect.DeepEqual(a.Shape(), []int{2, 3}) {
t.Fatalf("incorrect shape: %v", a.Shape())
}
res := a.Floats()
if !reflect.DeepEqual(res, data) {
t.Fatalf("incorrect results: %v", res)
}
}
func TestAdd(t *testing.T) {
b := &Backend{}
c := b.NewContext()
defer c.Close()
t1 := c.Arange(0, 24, 1, ml.DTypeFloat16)
t2 := c.Arange(0, 24, 1, ml.DTypeFloat16)
exp := c.Arange(0, 48, 2, ml.DTypeFloat16)
t3 := t1.Add(c, t2)
c.Compute(t3, exp)
t3f := t3.Floats()
if !reflect.DeepEqual(t3f, exp.Floats()) {
t.Fatalf("incorrect result: %v", t3f)
}
}
func TestReshapeTranspose(t *testing.T) {
b := &Backend{}
c := b.NewContext()
defer c.Close()
t1 := c.Arange(0, 24, 1, ml.DTypeFloat16).Reshape(c, 2, 3, 4).Transpose(c, 0, 2, 1).Contiguous(c, false)
c.Compute(t1)
t1f := t1.Floats()
exp := []float32{
0, 4, 8,
1, 5, 9,
2, 6, 10,
3, 7, 11,
12, 16, 20,
13, 17, 21,
14, 18, 22,
15, 19, 23,
}
if !reflect.DeepEqual(t1f, exp) {
t.Fatalf("incorrect results: %v", t1f)
}
}
func prod(vals ...int) int {
r := 1
for _, v := range vals {
r *= v
}
return r
}
func TestMatmul(t *testing.T) {
// TODO create scenarios...
b := &Backend{}
c := b.NewContext()
defer c.Close()
s1 := []int{1, 3, 2, 4}
t1 := c.Arange(0, float32(prod(s1...)), 1, ml.DTypeFloat16).Reshape(c, s1...)
s2 := []int{4, 2}
t2 := c.Arange(0, float32(prod(s2...)), 1, ml.DTypeFloat16).Reshape(c, s2...)
t3 := t1.Matmul(c, t2)
exp := []float32{
28, 34,
76, 98,
124, 162,
172, 226,
220, 290,
268, 354,
}
c.Compute(t3)
t3f := t3.Floats()
if !reflect.DeepEqual(t3f, exp) {
t.Fatalf("incorrect result: %v", t3f)
}
}
func TestRows(t *testing.T) {
b := &Backend{}
c := b.NewContext()
defer c.Close()
t1 := c.Arange(0, 12, 1, ml.DTypeFloat32).Reshape(c, 1, 4, 3)
outputs := c.Zeros(ml.DTypeInt32, 1)
t2 := t1.TakeAxes(c, outputs, 1)
c.Forward(t1, t2).Compute(t1, t2)
t.Log(t1.ToString())
t.Log(t2.ToString())
f := t2.Floats()
t.Logf("Result: %v", f)
}
func TestCaching(t *testing.T) {
// Validate the caching algorithm
b := &Backend{}
c := b.NewContext()
defer c.Close()
batchSize := 3
headDim := 4
numKVHeads := 2
// Make cache twice the size of one test batch
cells := batchSize * 2
cellSize := numKVHeads * headDim
shape := []int{1, numKVHeads, batchSize, headDim}
stop := float32(1)
for _, x := range shape {
stop *= float32(x)
}
// Create the cache
cache := c.Zeros(ml.DTypeFloat16, cells, cellSize)
t.Logf("Empty Cache shape%v\n"+cache.ToString(), []int{cells, cellSize})
// Input tensor
t1 := c.Arange(0, stop, 1, ml.DTypeFloat16).Reshape(c, shape...)
t.Logf("Initial Data shape%v\n"+t1.ToString(), shape)
// Reshape to copy into the cache
/*
From MLX python/src/indexing.cpp mlx_scatter_args_array
// The update shape must broadcast with indices.shape + [1] + src.shape[1:]
auto up_shape = indices.shape();
up_shape.insert(up_shape.end(), src.shape().begin() + 1, src.shape().end());
up = broadcast_to(up, up_shape);
up_shape.insert(up_shape.begin() + indices.ndim(), 1);
up = reshape(up, up_shape);
*/
numRows := 3
up := t1.Reshape(c, numRows, 1, cellSize) // The shape has to look like this for scatter to work properly
t.Logf("Data reshaped for cache input shape%v\n"+up.ToString(), []int{batchSize, numKVHeads * headDim})
// Simulate cells 1,3,5 are available
indicies := []ml.Tensor{c.FromInts([]int32{1, 3, 5}, numRows)}
t.Logf("Indicies shape%v\n"+indicies[0].ToString(), []int{numRows})
axis := []int{0} // The 1,3,5 of the indicies are in reference to axis 0 in the cache shape
cache.Scatter(c, indicies, up, axis)
c.Forward(cache)
// Cache should contain the data now
t.Log("Cache after put\n" + cache.ToString())
// Retrieve cache content and verify it matches
out := cache.TakeAxes(c, indicies[0], 0).Reshape(c, shape...)
t.Logf("Output shape%v\n"+out.ToString(), out.Shape())
t1f := t1.Floats()
outf := out.Floats()
if !reflect.DeepEqual(t1f, outf) {
t.Fatalf("mismatched in->out\n%v\n ->\n%v", t1f, outf)
}
}
func TestGemma3(t *testing.T) {
// Why is the sky blue
inputs := []int32{2, 105, 2364, 107, 36425, 563, 506, 7217, 3730, 106, 107, 105, 4368}
limit := 50
// TODO generalize this
dir := "/Users/daniel/Models/gemma-3-4b-it/"
m, err := model.New(dir, ml.BackendParams{})
if err != nil {
t.Fatalf("unable to load model: %s", err)
}
b := m.Backend()
ctx := b.NewContext()
defer ctx.Close()
batch := input.Batch{
Inputs: ctx.FromInts(inputs[:], 1, len(inputs)),
Positions: make([]int32, len(inputs)),
Sequences: make([]int, len(inputs)),
Outputs: ctx.FromInts([]int32{int32(len(inputs) - 1)}, 1),
Offset: 0,
}
for i := range len(inputs) {
batch.Positions[i] = int32(i)
}
offset := len(inputs)
cache := m.Config().Cache
if cache != nil {
numSlots := 1
batchSize := 512
numCtx := 4096
// Note: this is inconsistent with mlx-py, but trying to be consistent with the GGML cache impl to get things working
// cache.SetConfig(ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeBfloat16, MaskBatchPadding: 64})
cache.SetConfig(ml.CacheConfig{CachePadding: 0, MaskDType: ml.DTypeBfloat16, MaskBatchPadding: 0})
cache.Init(b, ml.DTypeBfloat16, numSlots, int(numCtx), batchSize)
err := cache.StartForward(ctx, batch, false)
if err != nil {
t.Fatalf("failed cache.StartForward: %s", err)
}
}
opts := api.DefaultOptions()
var grammar *sample.GrammarSampler
sampler := sample.NewSampler(
opts.Temperature,
opts.TopK,
opts.TopP,
opts.MinP,
opts.Seed,
grammar,
)
t.Log("Starting Forward pass loop")
pendingResponses := []string{}
for {
out, err := m.Forward(ctx, batch)
if err != nil {
t.Fatalf("failed forward pass: %s", err)
}
ctx.Forward(out)
outputs := out.Floats()
t.Logf("finished forward pass! length:%d", len(outputs))
// sample a token
logits := outputs
token, err := sampler.Sample(logits)
if err != nil {
t.Fatalf("unable to sample token: %s", err)
}
t.Logf("Sampled token: %v", token)
if m.(model.TextProcessor).Is(token, model.SpecialEOS) {
t.Log("hit EOS")
break
}
piece, err := m.(model.TextProcessor).Decode([]int32{token})
if err != nil {
t.Fatalf("unable to decode token: %s", err)
}
pendingResponses = append(pendingResponses, piece)
sequence := strings.Join(pendingResponses, "")
if ok, stop := common.FindStop(sequence, opts.Stop); ok {
t.Logf("hit stop token: %v", stop)
break
}
t.Logf("RESULTS: %s", sequence)
batch = input.Batch{
Inputs: ctx.FromInts([]int32{token}, 1, 1),
Positions: make([]int32, 1),
Sequences: make([]int, 1),
Outputs: ctx.FromInts([]int32{0}, 1),
Offset: offset,
}
offset++
batch.Positions[0] = 0
err = cache.StartForward(ctx, batch, false)
if err != nil {
t.Fatalf("failed cache.StartForward: %s", err)
}
if offset > limit {
break
}
}
}

View File

@@ -1,335 +0,0 @@
//go:build mlx
package mlx
/*
#include <stdio.h>
#include <string.h>
#include "mlx/c/array.h"
#include "mlx/c/ops.h"
// Derived from https://github.com/ml-explore/mlx/blob/main/mlx/io/gguf_quants.cpp
void unpack_32_4(uint8_t* data, int8_t* dst) {
memset(dst, 0, 16);
for (int j = 0; j < 16; ++j) {
uint8_t x = (data[j + 2] & 0x0F); // j+2 to skip scale bytes.
if (j % 2 != 0) {
x <<= 4;
}
dst[j / 2] += x;
}
// Last 16 weights are in the higher bits
for (int j = 0; j < 16; ++j) {
uint8_t x = (data[j + 2] >> 4);
if (j % 2 != 0) {
x <<= 4;
}
dst[8 + j / 2] += x;
}
}
// Extracts (weight, scales, biases) from Q4_0 tensors.
// Data layout is: |16 bit scale|32 x 4bit weights|.
void extract_q4_0_data(
uint8_t* data,
mlx_array* weights_arr,
mlx_array* scales_arr,
mlx_array* biases_arr) {
const uint64_t bytes_per_block = 18; // 2 bytes scale, 32x0.5 byte weights
uint8_t* weights = mlx_array_data_uint8(*weights_arr);
float16_t* scales = mlx_array_data_float16(*scales_arr);
float16_t* biases = mlx_array_data_float16(*biases_arr);
for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
scales[i] = *((float16_t*)data);
biases[i] = -8 * scales[i];
unpack_32_4(data, weights);
weights += 16;
data += bytes_per_block;
}
}
// Extracts (weight, scales, biases) from Q4_1 tensors.
// Data layout is: |16 bit scale|16 bit bias|32 x 4bit weights|.
void extract_q4_1_data(
uint8_t* data,
mlx_array* weights_arr,
mlx_array* scales_arr,
mlx_array* biases_arr) {
const uint64_t bytes_per_block = 20; // 2 bytes scale, 2 bytes bias, 32x0.5 byte weights
uint8_t* weights = mlx_array_data_uint8(*weights_arr);
float16_t* scales = mlx_array_data_float16(*scales_arr);
float16_t* biases = mlx_array_data_float16(*biases_arr);
for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
scales[i] = *((float16_t*)data);
biases[i] = *((float16_t*)(data) + 1);
unpack_32_4(data, weights);
weights += 16;
data += bytes_per_block;
}
}
// Extracts (weight, scales, biases) from Q8_0 tensors.
// Data layout is: |16 bit scale|32 x 8bit weights|.
void extract_q8_0_data(
uint8_t* data,
mlx_array* weights_arr,
mlx_array* scales_arr,
mlx_array* biases_arr) {
const uint64_t weights_per_block = 32;
const uint64_t bytes_per_block = 34; // 2 bytes scale, 32x1 byte weights
uint8_t* weights = mlx_array_data_uint8(*weights_arr);
float16_t* scales = mlx_array_data_float16(*scales_arr);
float16_t* biases = mlx_array_data_float16(*biases_arr);
for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
uint8_t* block_data = data + i * bytes_per_block;
scales[i] = *((float16_t*)block_data);
biases[i] = -128 * scales[i];
for (int64_t j = 0; j < weights_per_block; ++j) {
uint8_t x = block_data[j + 2]; // j+2 to skip the scale bytes.
// Original data is in int8_t, so we add a bias of -128 and invert the
// first bit.
x ^= 1 << 7;
weights[i * weights_per_block + j] = x;
}
}
}
// Drived from ggml-quants.c
#define QK_K 256
// 6-bit quantization
// weight is represented as x = a * q
// 16 blocks of 16 elements each
// Effectively 6.5625 bits per weight
typedef struct {
uint8_t ql[QK_K/2]; // quants, lower 4 bits
uint8_t qh[QK_K/4]; // quants, upper 2 bits
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
uint16_t d; // super-block scale
} block_q6_K;
void dequant_row_q6_K(const void * restrict vx, void * restrict vy, int k) {
const int64_t nb = k / QK_K;
block_q6_K *x = (block_q6_K *)vx;
float16_t* y = (float16_t *)vy;
for (int i = 0; i < nb; i++) {
float16_t d = 0.0;
memcpy(&d, &x[i].d, sizeof(d));
const uint8_t * restrict ql = x[i].ql;
const uint8_t * restrict qh = x[i].qh;
const int8_t * restrict sc = x[i].scales;
for (int n = 0; n < QK_K; n += 128) {
for (int l = 0; l < 32; ++l) {
int is = l/16;
const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
y[l + 0] = d * sc[is + 0] * q1;
y[l + 32] = d * sc[is + 2] * q2;
y[l + 64] = d * sc[is + 4] * q3;
y[l + 96] = d * sc[is + 6] * q4;
}
y += 128;
ql += 64;
qh += 32;
sc += 8;
}
}
}
#define K_SCALE_SIZE 12
#define GGML_COMMON_AGGR_U
#define GGML_COMMON_AGGR_S
// 4-bit quantization
// 8 blocks of 32 elements each
// weight is represented as x = a * q + b
// Effectively 4.5 bits per weight
typedef struct {
union {
struct {
uint16_t d; // super-block scale for quantized scales
uint16_t dmin; // super-block scale for quantized mins
} GGML_COMMON_AGGR_S;
uint16_t dm;
} GGML_COMMON_AGGR_U;
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
uint8_t qs[QK_K/2]; // 4--bit quants
} block_q4_K;
static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
if (j < 4) {
*d = q[j] & 63; *m = q[j + 4] & 63;
} else {
*d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
*m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
}
}
void dequant_row_q4_K(const void * restrict vx, void * restrict vy, int k) {
block_q4_K *x = (block_q4_K *)vx;
float16_t* y = (float16_t *)vy;
const int nb = k / QK_K;
for (int i = 0; i < nb; i++) {
const uint8_t * q = x[i].qs;
float16_t d = 0.0;
memcpy(&d, &x[i].d, sizeof(d));
float16_t min = 0.0;
memcpy(&min, &x[i].dmin, sizeof(d));
int is = 0;
uint8_t sc, m;
for (int j = 0; j < QK_K; j += 64) {
get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
const float16_t d1 = d * sc; const float16_t m1 = min * m;
get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
const float16_t d2 = d * sc; const float16_t m2 = min * m;
for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
q += 32; is += 2;
}
}
}
*/
import "C"
import (
"fmt"
"unsafe"
"github.com/x448/float16"
)
func gguf_load_quantized(data unsafe.Pointer, name string, final_shape []C.int, dtype uint32, stream C.mlx_stream) (r C.mlx_array, err error) {
shape := append([]C.int{}, final_shape...)
var weights_per_byte C.int
if dtype == 2 || dtype == 3 {
weights_per_byte = 2
} else if dtype == 8 {
weights_per_byte = 1
} else {
return r, fmt.Errorf("unsupported tensor type %d", dtype)
}
weights_per_block := C.int(32)
if shape[len(shape)-1]%weights_per_block != 0 {
return r, fmt.Errorf("[load_gguf] tensor has incompatible last dim shape: %d", shape[len(shape)-1])
}
weights_shape := append([]C.int{}, shape...)
weights_shape[len(weights_shape)-1] /= (weights_per_byte * 4)
w_nbytes := C.int(unsafe.Sizeof(uint32(0)))
for i := range weights_shape {
w_nbytes *= weights_shape[i]
}
w_data := make([]byte, w_nbytes)
cbytes := C.CBytes(w_data)
defer C.free(cbytes)
weights := C.mlx_array_new_data(
cbytes,
&weights_shape[0],
C.int(len(weights_shape)),
C.MLX_UINT32,
)
// For scales and bias
shape[len(shape)-1] = shape[len(shape)-1] / weights_per_block
sb_nbytes := C.int(unsafe.Sizeof(float16.Float16(0)))
for i := range shape {
sb_nbytes *= shape[i]
}
s_data := make([]byte, sb_nbytes)
cbytes = C.CBytes(s_data)
defer C.free(cbytes)
scales := C.mlx_array_new_data(
cbytes,
&shape[0],
C.int(len(shape)),
C.MLX_FLOAT16,
)
b_data := make([]byte, sb_nbytes)
cbytes = C.CBytes(b_data)
defer C.free(cbytes)
biases := C.mlx_array_new_data(
cbytes,
&shape[0],
C.int(len(shape)),
C.MLX_FLOAT16,
)
var bits C.int
switch dtype {
case 2:
C.extract_q4_0_data((*C.uint8_t)(data), &weights, &scales, &biases)
bits = 4
case 3:
C.extract_q4_1_data((*C.uint8_t)(data), &weights, &scales, &biases)
bits = 4
case 8:
C.extract_q8_0_data((*C.uint8_t)(data), &weights, &scales, &biases)
bits = 8
}
groupSize := C.mlx_optional_int{value: 32, has_value: true}
bitsOpt := C.mlx_optional_int{value: bits, has_value: true}
var dtypeOpt C.mlx_optional_dtype // has_value defaults to false
C.mlx_dequantize(
&r,
weights,
scales,
biases,
groupSize,
bitsOpt,
nil, // TODO mode
dtypeOpt,
stream,
)
C.mlx_array_free(weights)
C.mlx_array_free(scales)
C.mlx_array_free(biases)
return r, nil
}
func load_k_quantized(data unsafe.Pointer, name string, shape []C.int, dtype uint32, stream C.mlx_stream) (r C.mlx_array, err error) {
size := 1
for _, d := range shape {
size *= int(d)
}
fdata := make([]float16.Float16, size)
switch dtype {
case 14:
C.dequant_row_q6_K(
data,
unsafe.Pointer(&fdata[0]),
C.int(size),
)
case 12:
C.dequant_row_q4_K(
data,
unsafe.Pointer(&fdata[0]),
C.int(size),
)
default:
return r, fmt.Errorf("unsupported K quant")
}
r = C.mlx_array_new_data(
unsafe.Pointer(&fdata[0]),
&shape[0],
C.int(len(shape)),
C.MLX_FLOAT16,
)
return r, nil
}

View File

@@ -1,643 +0,0 @@
package ml
import (
"context"
"encoding/binary"
"encoding/json"
"fmt"
"hash/maphash"
"io"
"log/slog"
"math"
"net/http"
"runtime"
"slices"
"sort"
"strconv"
"strings"
"time"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/logutil"
)
// GPULayers is a set of layers to be allocated on a single GPU
type GPULayers struct {
DeviceID
// Layers is a set of layer indicies to load
Layers []int
}
// FirstLayer returns the smallest layer index scheduled on this GPU, or MaxInt when empty.
func (g GPULayers) FirstLayer() int {
if len(g.Layers) == 0 {
return math.MaxInt
}
first := g.Layers[0]
for i := 1; i < len(g.Layers); i++ {
if g.Layers[i] < first {
first = g.Layers[i]
}
}
return first
}
func (g GPULayers) String() string {
if len(g.Layers) == 0 {
return ""
}
slices.Sort(g.Layers)
contiguous := true
base := g.Layers[0]
for i := range g.Layers {
if g.Layers[i] != base+i {
contiguous = false
break
}
}
if contiguous {
return fmt.Sprintf("ID:%v Layers:%v(%v..%v)", g.ID, len(g.Layers), g.Layers[0], g.Layers[len(g.Layers)-1])
} else {
return fmt.Sprintf("ID:%v Layers:%v%v", g.ID, len(g.Layers), g.Layers)
}
}
// GPULayersList is a set of layer allocations across multiple GPUs
type GPULayersList []GPULayers
func (l GPULayersList) Len() int { return len(l) }
func (l GPULayersList) Swap(i, j int) { l[i], l[j] = l[j], l[i] }
// Sort by the ordering of the layers offloaded
func (l GPULayersList) Less(i, j int) bool {
li := l[i].FirstLayer()
lj := l[j].FirstLayer()
return li < lj
}
func (l GPULayersList) String() string {
if l.Sum() > 0 {
return fmt.Sprintf("%v%v", l.Sum(), []GPULayers(l))
} else {
return fmt.Sprintf("%v", []GPULayers(l))
}
}
// Sum is the total number of layers assigned across all GPUs
func (l GPULayersList) Sum() int {
var sum int
for _, g := range l {
sum += len(g.Layers)
}
return sum
}
var h maphash.Hash
// Hash is an identifier of this layer assignment
func (l GPULayersList) Hash() uint64 {
h.Reset()
for _, g := range l {
if len(g.Layers) > 0 {
h.WriteString(g.ID + g.Library)
for _, l := range g.Layers {
binary.Write(&h, binary.NativeEndian, int64(l))
}
}
}
return h.Sum64()
}
// ErrNoMem is returned when panicing due to insufficient memory. It includes
// the attempted memory allocation.
type ErrNoMem struct {
BackendMemory
}
func (e ErrNoMem) Error() string {
return fmt.Sprintf("insufficient memory - required allocations: %+v", e.BackendMemory)
}
// Minimal unique device identification
type DeviceID struct {
// ID is an identifier for the device for matching with system
// management libraries. The ID is only unique for other devices
// using the same Library.
// This ID represents a "post filtered" view of the enumerated devices
// if the ID is numeric
ID string `json:"id"`
// Library identifies which library is used for the device (e.g. CUDA, ROCm, etc.)
Library string `json:"backend,omitempty"`
}
// DeviceMemory provides a breakdown of the memory needed
// per device, such as a CPU or GPU.
type DeviceMemory struct {
DeviceID
// Name is the name of the device as labeled by the backend. It
// may not be persistent across instances of the runner.
Name string
// Weights is the per-layer memory needed for the model weights.
Weights []uint64
// Cache is the per-layer memory needed for the KV cache.
Cache []uint64
// Graph is the size of the compute graph. It is not per-layer.
Graph uint64
}
func sumMemory(mem []uint64) uint64 {
var sum uint64
for _, m := range mem {
sum += m
}
return sum
}
// Size returns the total size of the memory required by this device
func (m DeviceMemory) Size() uint64 {
return sumMemory(m.Weights) + sumMemory(m.Cache) + m.Graph
}
func memoryPresent(mem []uint64) bool {
return slices.ContainsFunc(mem, func(m uint64) bool { return m != 0 })
}
func (m DeviceMemory) LogValue() slog.Value {
var attrs []slog.Attr
if memoryPresent(m.Weights) {
attrs = append(attrs, slog.Any("Weights", m.Weights))
}
if memoryPresent(m.Cache) {
attrs = append(attrs, slog.Any("Cache", m.Cache))
}
if m.Graph != 0 {
attrs = append(attrs, slog.Any("Graph", m.Graph))
}
if len(attrs) > 0 && m.ID != "" {
attrs = append([]slog.Attr{slog.String("ID", m.ID)}, attrs...)
}
return slog.GroupValue(attrs...)
}
// BackendMemory provides the amount of memory required to load the model
// per device based on the BackendParams. In some cases, not all required
// allocations will be known at this point. However, the size of the most recent
// allocation is guaranteed to be provided so that if it failed, the caller can
// accommodate that to make forward progress.
type BackendMemory struct {
// InputWeights are always located on the CPU and cannot be moved
InputWeights uint64
// CPU model components are located in system memory. This does not
// include unified memory allocated through the GPU.
CPU DeviceMemory
// GPU model components are located on one or more GPUs.
GPUs []DeviceMemory
}
func (m BackendMemory) LogValue() slog.Value {
var attrs []slog.Attr
if m.InputWeights != 0 {
attrs = append(attrs, slog.Any("InputWeights", m.InputWeights))
}
attrs = append(attrs, slog.Any(m.CPU.Name, m.CPU))
for _, g := range m.GPUs {
attrs = append(attrs, slog.Any(g.Name, g))
}
return slog.GroupValue(attrs...)
}
// Log prints a high level summary of the memory
func (m BackendMemory) Log(level slog.Level) {
var total uint64
for _, gpu := range m.GPUs {
if sum := sumMemory(gpu.Weights); sum > 0 {
slog.Log(context.TODO(), level, "model weights", "device", gpu.Name, "size", format.HumanBytes2(sum))
total += sum
}
}
if sum := m.InputWeights + sumMemory(m.CPU.Weights); sum > 0 {
slog.Log(context.TODO(), level, "model weights", "device", m.CPU.Name, "size", format.HumanBytes2(sum))
total += sum
}
for _, gpu := range m.GPUs {
if sum := sumMemory(gpu.Cache); sum > 0 {
slog.Log(context.TODO(), level, "kv cache", "device", gpu.Name, "size", format.HumanBytes2(sum))
total += sum
}
}
if sum := sumMemory(m.CPU.Cache); sum > 0 {
slog.Log(context.TODO(), level, "kv cache", "device", m.CPU.Name, "size", format.HumanBytes2(sum))
total += sum
}
for _, gpu := range m.GPUs {
if sum := gpu.Graph; sum > 0 {
slog.Log(context.TODO(), level, "compute graph", "device", gpu.Name, "size", format.HumanBytes2(sum))
total += sum
}
}
if sum := m.CPU.Graph; sum > 0 {
slog.Log(context.TODO(), level, "compute graph", "device", m.CPU.Name, "size", format.HumanBytes2(sum))
total += sum
}
if total > 0 {
slog.Log(context.TODO(), level, "total memory", "size", format.HumanBytes2(total))
}
}
type DeviceInfo struct {
DeviceID
// Name is the name of the device as labeled by the backend. It
// may not be persistent across instances of the runner.
Name string `json:"name"`
// Description is the longer user-friendly identification of the device
Description string `json:"description"`
// FilterID is populated with the unfiltered device ID if a numeric ID is used
// so the device can be included.
FilterID string `json:"filter_id,omitempty"`
// Integrated is set true for integrated GPUs, false for Discrete GPUs
Integrated bool `json:"integration,omitempty"`
// PCIID is the bus, device and domain ID of the device for deduplication
// when discovered by multiple backends
PCIID string `json:"pci_id,omitempty"`
// TotalMemory is the total amount of memory the device can use for loading models
TotalMemory uint64 `json:"total_memory"`
// FreeMemory is the amount of memory currently available on the device for loading models
FreeMemory uint64 `json:"free_memory,omitempty"`
// ComputeMajor is the major version of capabilities of the device
// if unsupported by the backend, -1 will be returned
ComputeMajor int
// ComputeMinor is the minor version of capabilities of the device
// if unsupported by the backend, -1 will be returned
ComputeMinor int
// Driver Information
DriverMajor int `json:"driver_major,omitempty"`
DriverMinor int `json:"driver_minor,omitempty"`
// Where backends were loaded from
LibraryPath []string
}
type SystemInfo struct {
// ThreadCount is the optimal number of threads to use for inference
ThreadCount int `json:"threads,omitempty"`
// TotalMemory is the total amount of system memory
TotalMemory uint64 `json:"total_memory,omitempty"`
// FreeMemory is the amount of memory currently available on the system for loading models
FreeMemory uint64 `json:"free_memory,omitempty"`
// FreeSwap is the amount of system swap space reported as available
FreeSwap uint64 `json:"free_swap,omitempty"`
}
func (d DeviceInfo) Compute() string {
// AMD gfx is encoded into the major minor in hex form
if strings.EqualFold(d.Library, "ROCm") {
return fmt.Sprintf("gfx%x%02x", d.ComputeMajor, d.ComputeMinor)
}
return strconv.Itoa(d.ComputeMajor) + "." + strconv.Itoa(d.ComputeMinor)
}
func (d DeviceInfo) Driver() string {
return strconv.Itoa(d.DriverMajor) + "." + strconv.Itoa(d.DriverMinor)
}
// MinimumMemory reports the amount of memory that should be set aside
// on the device for overhead (e.g. VRAM consumed by context structures independent
// of model allocations)
func (d DeviceInfo) MinimumMemory() uint64 {
if d.Library == "Metal" {
return 512 * format.MebiByte
}
return 457 * format.MebiByte
}
// Sort by Free Space.
// iGPUs are reported first, thus Reverse() yields the largest discrete GPU first
type ByFreeMemory []DeviceInfo
func (a ByFreeMemory) Len() int { return len(a) }
func (a ByFreeMemory) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a ByFreeMemory) Less(i, j int) bool {
if a[i].Integrated && !a[j].Integrated {
return true
} else if !a[i].Integrated && a[j].Integrated {
return false
}
return a[i].FreeMemory < a[j].FreeMemory
}
// ByPerformance groups devices by similar speed
func ByPerformance(l []DeviceInfo) [][]DeviceInfo {
resp := [][]DeviceInfo{}
scores := []bool{}
for _, info := range l {
found := false
requested := info.Integrated
for i, score := range scores {
if score == requested {
resp[i] = append(resp[i], info)
found = true
break
}
}
if !found {
scores = append(scores, requested)
resp = append(resp, []DeviceInfo{info})
}
}
return resp
}
func ByLibrary(l []DeviceInfo) [][]DeviceInfo {
resp := [][]DeviceInfo{}
libs := []string{}
for _, info := range l {
found := false
requested := info.Library
for i, lib := range libs {
if lib == requested {
resp[i] = append(resp[i], info)
found = true
break
}
}
if !found {
libs = append(libs, requested)
resp = append(resp, []DeviceInfo{info})
}
}
return resp
}
func LibraryPaths(l []DeviceInfo) []string {
gpuLibs := []string{LibOllamaPath}
for _, gpu := range l {
for _, dir := range gpu.LibraryPath {
needed := true
for _, existing := range gpuLibs {
if dir == existing {
needed = false
break
}
}
if needed {
gpuLibs = append(gpuLibs, dir)
}
}
}
return gpuLibs
}
type DeviceComparison int
const (
UniqueDevice DeviceComparison = iota
SameBackendDevice // The device is the same, and the library/backend is the same
DuplicateDevice // The same physical device but different library/backend (overlapping device)
)
func (a DeviceInfo) Compare(b DeviceInfo) DeviceComparison {
if a.PCIID != b.PCIID {
return UniqueDevice
}
// If PCIID is empty, we have to use ID + library for uniqueness
if a.PCIID == "" && a.DeviceID != b.DeviceID {
return UniqueDevice
}
if a.Library == b.Library {
return SameBackendDevice
}
return DuplicateDevice
}
// For a SameBackendDevice, return true if b is better than a
// e.g. newer GPU library version
func (a DeviceInfo) IsBetter(b DeviceInfo) bool {
aLib := a.LibraryPath[len(a.LibraryPath)-1]
bLib := b.LibraryPath[len(b.LibraryPath)-1]
if aLib == bLib {
return false
}
aLibSplit := strings.SplitN(aLib, "_", 2)
bLibSplit := strings.SplitN(bLib, "_", 2)
if len(aLibSplit) < 2 || len(bLibSplit) < 2 {
return false
}
if aLibSplit[0] != bLibSplit[0] {
slog.Debug("unexpected libraries", "a", aLib, "b", bLib)
return false
}
if aLibSplit[1] == bLibSplit[1] {
return false
}
cmp := []string{aLibSplit[1], bLibSplit[1]}
sort.Sort(sort.Reverse(sort.StringSlice(cmp)))
return cmp[0] == bLibSplit[1]
}
// For each GPU, check if it does NOT support flash attention
func FlashAttentionSupported(l []DeviceInfo) bool {
for _, gpu := range l {
supportsFA := gpu.Library == "cpu" ||
gpu.Name == "Metal" || gpu.Library == "Metal" ||
(gpu.Library == "CUDA" && gpu.DriverMajor >= 7 && !(gpu.ComputeMajor == 7 && gpu.ComputeMinor == 2)) ||
gpu.Library == "ROCm" ||
gpu.Library == "Vulkan"
if !supportsFA {
return false
}
}
return true
}
// Given the list of GPUs this instantiation is targeted for,
// figure out the visible devices environment variables
// Set mustFilter true to enable filtering of CUDA devices
func GetVisibleDevicesEnv(l []DeviceInfo, mustFilter bool) map[string]string {
if len(l) == 0 {
return nil
}
env := map[string]string{}
for _, d := range l {
d.updateVisibleDevicesEnv(env, mustFilter)
}
return env
}
// NeedsInitValidation returns true if the device in question has the potential
// to crash at inference time and requires deeper validation before we include
// it in the supported devices list.
func (d DeviceInfo) NeedsInitValidation() bool {
// ROCm: rocblas will crash on unsupported devices.
// CUDA: verify CC is supported by the version of the library
return d.Library == "ROCm" || d.Library == "CUDA"
}
// Set the init validation environment variable
func (d DeviceInfo) AddInitValidation(env map[string]string) {
env["GGML_CUDA_INIT"] = "1" // force deep initialization to trigger crash on unsupported GPUs
}
// PreferredLibrary returns true if this library is preferred over the other input
// library
// Used to filter out Vulkan in favor of CUDA or ROCm
func (d DeviceInfo) PreferredLibrary(other DeviceInfo) bool {
// TODO in the future if we find Vulkan is better than ROCm on some devices
// that implementation can live here.
if d.Library == "CUDA" || d.Library == "ROCm" {
return true
}
return false
}
func (d DeviceInfo) updateVisibleDevicesEnv(env map[string]string, mustFilter bool) {
var envVar string
switch d.Library {
case "ROCm":
// ROCm must be filtered as it can crash the runner on unsupported devices
envVar = "ROCR_VISIBLE_DEVICES"
if runtime.GOOS != "linux" {
envVar = "HIP_VISIBLE_DEVICES"
}
case "CUDA":
if !mustFilter {
// By default we try to avoid filtering CUDA devices because ROCm also
// looks at the CUDA env var, and gets confused in mixed vendor environments.
return
}
envVar = "CUDA_VISIBLE_DEVICES"
default:
// Vulkan is not filtered via env var, but via scheduling decisions
return
}
v, existing := env[envVar]
if existing {
v = v + ","
}
if d.FilterID != "" {
v = v + d.FilterID
} else {
v = v + d.ID
}
env[envVar] = v
}
type BaseRunner interface {
// GetPort returns the localhost port number the runner is running on
GetPort() int
// HasExited indicates if the runner is no longer running. This can be used during
// bootstrap to detect if a given filtered device is incompatible and triggered an assert
HasExited() bool
}
type RunnerDiscovery interface {
BaseRunner
// GetDeviceInfos will perform a query of the underlying device libraries
// for device identification and free VRAM information
// During bootstrap scenarios, this routine may take seconds to complete
GetDeviceInfos(ctx context.Context) []DeviceInfo
}
type FilteredRunnerDiscovery interface {
RunnerDiscovery
// GetActiveDeviceIDs returns the filtered set of devices actively in
// use by this runner for running models. If the runner is a bootstrap runner, no devices
// will be active yet so no device IDs are returned.
// This routine will not query the underlying device and will return immediately
GetActiveDeviceIDs() []DeviceID
}
func GetDevicesFromRunner(ctx context.Context, runner BaseRunner) ([]DeviceInfo, error) {
var moreDevices []DeviceInfo
port := runner.GetPort()
tick := time.Tick(10 * time.Millisecond)
for {
select {
case <-ctx.Done():
return nil, fmt.Errorf("failed to finish discovery before timeout")
case <-tick:
r, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/info", port), nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
r.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(r)
if err != nil {
// slog.Warn("failed to send request", "error", err)
if runner.HasExited() {
return nil, fmt.Errorf("runner crashed")
}
continue
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
// old runner, fall back to bootstrapping model
return nil, fmt.Errorf("llamarunner free vram reporting not supported")
}
body, err := io.ReadAll(resp.Body)
if err != nil {
slog.Warn("failed to read response", "error", err)
continue
}
if resp.StatusCode != 200 {
logutil.Trace("runner failed to discover free VRAM", "status", resp.StatusCode, "response", body)
return nil, fmt.Errorf("runner error: %s", string(body))
}
if err := json.Unmarshal(body, &moreDevices); err != nil {
slog.Warn("unmarshal encode response", "error", err)
continue
}
return moreDevices, nil
}
}
}

View File

@@ -1,103 +0,0 @@
package nn
import (
"fmt"
"github.com/ollama/ollama/x/kvcache"
"github.com/ollama/ollama/x/ml"
)
// Attention implements scaled dot-product attention for transformer models:
// Attention(Q, K, V) = softmax(QK^T/√d_k)V
//
// Parameters:
// - ctx: Context for tensor operations
// - query: Query tensor (Q) with shape [d_k, heads, seq_len_q]
// - key: Key tensor (K) with shape [d_k, kv_heads, seq_len_k], can be nil to read from cache only
// - value: Value tensor (V) with shape [d_v, kv_heads, seq_len_k], can be nil to read from cache only
// - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
// - cache: KV cache to store key/value and get past history, can be nil to only use provided key/value
//
// Returns:
//
// Attention output with shape [d_v, heads, seq_len_q]
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
return AttentionWithVMLA(ctx, query, key, value, nil, nil, scale, cache)
}
func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
return AttentionWithVMLA(ctx, query, key, value, sinks, nil, scale, cache)
}
func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
ctx.Forward(query)
if key != nil && value != nil {
if query.Dim(0) != key.Dim(0) {
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
}
if key.Dim(1) != value.Dim(1) {
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1)))
}
if key.Dim(2) != value.Dim(2) {
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
}
ctx.Forward(key, value)
if cache != nil {
cache.Put(ctx, key, value)
}
} else if cache == nil {
panic("key & value tensors must be provided if cache is nil")
}
// ctx.CompareWith("/tmp/test", map[string]ml.Tensor{"q": query, "k": key, "v": value}, true)
// panic("after cache get") //
// 2025/12/10 16:02:33 INFO XXX tensors are similar q=0.9999869465827942 shape="[1 8 13 256]" min_difference=[-0.07926178] max_difference=[0.07012844]
// 2025/12/10 16:02:33 INFO XXX tensors are similar k=0.9999891519546509 shape="[1 4 13 256]" min_difference=[-0.21365738] max_difference=[0.19916534]
// 2025/12/10 16:02:33 INFO XXX tensors are similar v=0.9999960660934448 shape="[1 4 13 256]" min_difference=[-0.32923126] max_difference=[0.32646942]
// var mask ml.Tensor
if cache != nil {
key, value, _ = cache.Get(ctx)
}
// ctx.CompareWith("/tmp/test", map[string]ml.Tensor{"q": query.Contiguous(ctx, false), "k": key.Contiguous(ctx, false), "v": value.Contiguous(ctx, false)}, true)
// panic("after cache get") //
// 2025/12/10 15:34:03 INFO XXX tensors are similar q=0.9999869465827942 shape="[1 8 13 256]" min_difference=[-0.07926178] max_difference=[0.07012844]
// 2025/12/10 15:34:03 INFO XXX tensors are similar k=0.9999881982803345 shape="[1 4 13 256]" min_difference=[-0.25] max_difference=[0.25]
// 2025/12/10 15:34:03 INFO XXX tensors are similar v=0.9999913573265076 shape="[1 4 13 256]" min_difference=[-0.5] max_difference=[0.5]
// Only use the fast SDPA implementation if we have a cache, since that's what
// will do any expected backend-specific transformations for us
if cache != nil {
// TODO what to do with vmla?
// return query.Transpose(ctx, 0, 2, 1, 3).ScaledDotProductAttention(ctx, key.Transpose(ctx, 0, 2, 1, 3), value.Transpose(ctx, 0, 2, 1, 3), scale, "array", mask, sinks)
return query.ScaledDotProductAttention(ctx, key, value, scale, "causal", nil, sinks)
// TODO these two produce identical output, but not similar enough - 92.9% - should be 99.999%
} else {
panic("else case not supported")
// TODO transpose shapes are wrong
// key = key.Transpose(ctx, 0, 2, 1, 3)
// value = value.Transpose(ctx, 1, 2, 0, 3).Contiguous(ctx, false)
// kq := query.Matmul(ctx, key)
// kq = kq.Scale(ctx, scale)
// if mask != nil {
// kq = kq.Add(ctx, mask)
// }
// kq = kq.Softmax(ctx)
// kqv := kq.Matmul(ctx, value)
// if vmla != nil {
// kqv = kqv.Matmul(ctx, vmla)
// }
// return kqv.Transpose(ctx, 0, 2, 1, 3).Contiguous(ctx, false)
}
}

View File

@@ -1,30 +0,0 @@
package nn
import "github.com/ollama/ollama/x/ml"
type Conv2D struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
}
func (m *Conv2D) Forward(ctx ml.Context, t ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
t = m.Weight.Conv2D(ctx, t, s0, s1, p0, p1, d0, d1, 1)
if m.Bias != nil {
// Bias shape is (out_channels,) while t shape is (width, height, out_channels, batch)
t = t.Add(ctx, m.Bias.Reshape(ctx, 1, 1, -1))
}
return t
}
type Conv3D struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
}
func (m *Conv3D) Forward(ctx ml.Context, t ml.Tensor, s0, s1, s2, p0, p1, p2, d0, d1, d2, g int) ml.Tensor {
t = m.Weight.Conv3D(ctx, t, s0, s1, s2, p0, p1, p2, d0, d1, d2, g)
if m.Bias != nil {
t = t.Add(ctx, m.Bias)
}
return t
}

View File

@@ -1,11 +0,0 @@
package nn
import "github.com/ollama/ollama/x/ml"
type Embedding struct {
Weight ml.Tensor `gguf:"weight"`
}
func (m *Embedding) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor {
return m.Weight.TakeAxes(ctx, hiddenState, 0)
}

View File

@@ -1,32 +0,0 @@
package nn
import "github.com/ollama/ollama/x/ml"
type Linear struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
}
func (m *Linear) Forward(ctx ml.Context, t ml.Tensor) ml.Tensor {
t = t.Matmul(ctx, m.Weight.Transpose(ctx))
if m.Bias != nil {
t = t.Add(ctx, m.Bias)
}
return t
}
type LinearBatch struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
}
func (m *LinearBatch) Forward(ctx ml.Context, t, indices ml.Tensor) ml.Tensor {
panic("not yet ported")
// t = m.Weight.MulmatID(ctx, t, indices)
// if m.Bias != nil {
// t = t.AddID(ctx, m.Bias, indices)
// }
// return t
}

View File

@@ -1,29 +0,0 @@
package nn
import (
"github.com/ollama/ollama/x/ml"
)
type LayerNorm struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
}
func (m *LayerNorm) Forward(ctx ml.Context, t ml.Tensor, eps float32) ml.Tensor {
return t.LayerNorm(ctx, m.Weight, m.Bias, eps)
}
type RMSNorm struct {
Weight ml.Tensor `gguf:"weight"`
}
func (m *RMSNorm) Forward(ctx ml.Context, t ml.Tensor, eps float32) ml.Tensor {
// slog.Info("RMSNorm", "eps", eps)
// fmt.Fprintln(os.Stderr, t.ToString())
// fmt.Fprintln(os.Stderr, m.Weight.ToString())
// TODO this is probably model specific, not generalized...
w := m.Weight.Add(ctx, ctx.FromFloats([]float32{1.0}, 1))
return t.RMSNorm(ctx, w, eps)
}

View File

@@ -1,41 +0,0 @@
package pooling
import (
"github.com/ollama/ollama/x/ml"
)
type Type uint32
const (
TypeNone Type = iota
TypeMean
TypeCLS
TypeLast
)
func (t Type) String() string {
switch t {
case TypeMean:
return "Mean"
case TypeCLS:
return "CLS"
case TypeLast:
return "Last"
default:
return "Unknown"
}
}
func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
switch t {
// case TypeMean:
// hiddenStates = hiddenStates.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false).Mean(ctx)
// return hiddenStates.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false)
// case TypeCLS:
// return hiddenStates.Slice(ctx, 1, 0, 1, 1)
// case TypeLast:
// return hiddenStates.Slice(ctx, 1, hiddenStates.Dim(1)-1, hiddenStates.Dim(1), 1)
default:
panic("unknown pooling type")
}
}

View File

@@ -1,72 +0,0 @@
package rope
import "github.com/ollama/ollama/x/ml"
// Options contains optional parameters for RoPE function
type Options struct {
Type int
Factors ml.Tensor
// YaRN options
YaRN struct {
OriginalContextLength int
ExtrapolationFactor,
AttentionFactor,
BetaFast,
BetaSlow float32
}
// MRoPE options
MRoPE struct {
Sections []int
}
}
// WithTypeNeoX sets RoPE type to NeoX
func WithTypeNeoX() func(*Options) {
return func(opts *Options) {
opts.Type = 2
}
}
// WithFactors sets custom rope factors
func WithFactors(factors ml.Tensor) func(*Options) {
return func(opts *Options) {
if factors != nil {
opts.Factors = factors
}
}
}
// WithOriginalContextLength sets a custom context length
func WithOriginalContextLength(n int) func(*Options) {
return func(opts *Options) {
opts.YaRN.OriginalContextLength = n
}
}
func WithExtrapolationFactor(extrapolationFactor float32) func(*Options) {
return func(opts *Options) {
opts.YaRN.ExtrapolationFactor = extrapolationFactor
}
}
func WithAttentionFactor(attentionFactor float32) func(*Options) {
return func(opts *Options) {
opts.YaRN.AttentionFactor = attentionFactor
}
}
func WithMRoPE(sections []int) func(*Options) {
return func(opts *Options) {
opts.Type |= 1 << 3
opts.MRoPE.Sections = sections
}
}
func WithInterleaveMRoPE(sections []int) func(*Options) {
return func(opts *Options) {
opts.Type |= 1<<3 | 1<<5
opts.MRoPE.Sections = sections
}
}

View File

@@ -1,56 +0,0 @@
package ml
import (
"os"
"path/filepath"
"runtime"
)
// LibPath is a path to lookup dynamic libraries
// in development it's usually 'build/lib/ollama'
// in distribution builds it's 'lib/ollama' on Windows
// '../lib/ollama' on Linux and the executable's directory on macOS
// note: distribution builds, additional GPU-specific libraries are
// found in subdirectories of the returned path, such as
// 'cuda_v12', 'rocm', etc.
var LibOllamaPath string = func() string {
exe, err := os.Executable()
if err != nil {
return ""
}
if eval, err := filepath.EvalSymlinks(exe); err == nil {
exe = eval
}
var libPath string
switch runtime.GOOS {
case "windows":
libPath = filepath.Join(filepath.Dir(exe), "lib", "ollama")
case "linux":
libPath = filepath.Join(filepath.Dir(exe), "..", "lib", "ollama")
case "darwin":
libPath = filepath.Dir(exe)
}
cwd, err := os.Getwd()
if err != nil {
return ""
}
paths := []string{
libPath,
// build paths for development
filepath.Join(filepath.Dir(exe), "build", "lib", "ollama"),
filepath.Join(cwd, "build", "lib", "ollama"),
}
for _, p := range paths {
if _, err := os.Stat(p); err == nil {
return p
}
}
return filepath.Dir(exe)
}()

Some files were not shown because too many files have changed in this diff Show More