Compare commits

..

2 Commits

Author SHA1 Message Date
Michael Yang
f1373193dc move tokenizers to separate package (#13825) 2026-02-05 17:44:11 -08:00
Parth Sareen
8a4b77f9da cmd: set context limits for cloud models in opencode (#14107) 2026-02-05 16:36:46 -08:00
83 changed files with 445 additions and 6721 deletions

View File

@@ -482,6 +482,8 @@ Examples:
}
}
}
} else if saved, err := loadIntegration(name); err == nil && len(saved.Models) > 0 && !configFlag {
return runIntegration(name, saved.Models[0], passArgs)
} else {
var err error
models, err = selectModels(cmd.Context(), name, "")

View File

@@ -502,6 +502,28 @@ func TestBuildModelList_ReturnsExistingAndCloudMaps(t *testing.T) {
}
}
func TestEditorIntegration_SavedConfigSkipsSelection(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Save a config for opencode so it looks like a previous launch
if err := saveIntegration("opencode", []string{"llama3.2"}); err != nil {
t.Fatal(err)
}
// Verify loadIntegration returns the saved models
saved, err := loadIntegration("opencode")
if err != nil {
t.Fatal(err)
}
if len(saved.Models) == 0 {
t.Fatal("expected saved models")
}
if saved.Models[0] != "llama3.2" {
t.Errorf("expected llama3.2, got %s", saved.Models[0])
}
}
func TestAliasConfigurerInterface(t *testing.T) {
t.Run("claude implements AliasConfigurer", func(t *testing.T) {
claude := &Claude{}

View File

@@ -1,6 +1,7 @@
package config
import (
"context"
"encoding/json"
"fmt"
"maps"
@@ -10,12 +11,52 @@ import (
"slices"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
)
// OpenCode implements Runner and Editor for OpenCode integration
type OpenCode struct{}
// cloudModelLimit holds context and output token limits for a cloud model.
type cloudModelLimit struct {
Context int
Output int
}
// cloudModelLimits maps cloud model base names to their token limits.
// TODO(parthsareen): grab context/output limits from model info instead of hardcoding
var cloudModelLimits = map[string]cloudModelLimit{
"cogito-2.1:671b": {Context: 163_840, Output: 65_536},
"deepseek-v3.1:671b": {Context: 163_840, Output: 163_840},
"deepseek-v3.2": {Context: 163_840, Output: 65_536},
"glm-4.6": {Context: 202_752, Output: 131_072},
"glm-4.7": {Context: 202_752, Output: 131_072},
"gpt-oss:120b": {Context: 131_072, Output: 131_072},
"gpt-oss:20b": {Context: 131_072, Output: 131_072},
"kimi-k2:1t": {Context: 262_144, Output: 262_144},
"kimi-k2.5": {Context: 262_144, Output: 262_144},
"kimi-k2-thinking": {Context: 262_144, Output: 262_144},
"nemotron-3-nano:30b": {Context: 1_048_576, Output: 131_072},
"qwen3-coder:480b": {Context: 262_144, Output: 65_536},
"qwen3-next:80b": {Context: 262_144, Output: 32_768},
}
// lookupCloudModelLimit returns the token limits for a cloud model.
// It tries the exact name first, then strips the ":cloud" suffix.
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
if l, ok := cloudModelLimits[name]; ok {
return l, true
}
base := strings.TrimSuffix(name, ":cloud")
if base != name {
if l, ok := cloudModelLimits[base]; ok {
return l, true
}
}
return cloudModelLimit{}, false
}
func (o *OpenCode) String() string { return "OpenCode" }
func (o *OpenCode) Run(model string, args []string) error {
@@ -113,6 +154,8 @@ func (o *OpenCode) Edit(modelList []string) error {
}
}
client, _ := api.ClientFromEnvironment()
for _, model := range modelList {
if existing, ok := models[model].(map[string]any); ok {
// migrate existing models without _launch marker
@@ -122,12 +165,29 @@ func (o *OpenCode) Edit(modelList []string) error {
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
}
}
if isCloudModel(context.Background(), client, model) {
if l, ok := lookupCloudModelLimit(model); ok {
existing["limit"] = map[string]any{
"context": l.Context,
"output": l.Output,
}
}
}
continue
}
models[model] = map[string]any{
entry := map[string]any{
"name": model,
"_launch": true,
}
if isCloudModel(context.Background(), client, model) {
if l, ok := lookupCloudModelLimit(model); ok {
entry["limit"] = map[string]any{
"context": l.Context,
"output": l.Output,
}
}
}
models[model] = entry
}
ollama["models"] = models

View File

@@ -2,6 +2,7 @@ package config
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"testing"
@@ -495,6 +496,165 @@ func TestOpenCodeEdit_SpecialCharsInModelName(t *testing.T) {
}
}
func readOpenCodeModel(t *testing.T, configPath, model string) map[string]any {
t.Helper()
data, err := os.ReadFile(configPath)
if err != nil {
t.Fatal(err)
}
var cfg map[string]any
json.Unmarshal(data, &cfg)
provider := cfg["provider"].(map[string]any)
ollama := provider["ollama"].(map[string]any)
models := ollama["models"].(map[string]any)
entry, ok := models[model].(map[string]any)
if !ok {
t.Fatalf("model %s not found in config", model)
}
return entry
}
func TestOpenCodeEdit_LocalModelNoLimit(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configPath := filepath.Join(tmpDir, ".config", "opencode", "opencode.json")
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
entry := readOpenCodeModel(t, configPath, "llama3.2")
if entry["limit"] != nil {
t.Errorf("local model should not have limit set, got %v", entry["limit"])
}
}
func TestOpenCodeEdit_PreservesUserLimit(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
// Set up a model with a user-configured limit
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{
"provider": {
"ollama": {
"models": {
"llama3.2": {
"name": "llama3.2",
"_launch": true,
"limit": {"context": 8192, "output": 4096}
}
}
}
}
}`), 0o644)
// Re-edit should preserve the user's limit (not delete it)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
entry := readOpenCodeModel(t, configPath, "llama3.2")
limit, ok := entry["limit"].(map[string]any)
if !ok {
t.Fatal("user-configured limit was removed")
}
if limit["context"] != float64(8192) {
t.Errorf("context limit changed: got %v, want 8192", limit["context"])
}
if limit["output"] != float64(4096) {
t.Errorf("output limit changed: got %v, want 4096", limit["output"])
}
}
func TestOpenCodeEdit_CloudModelLimitStructure(t *testing.T) {
// Verify that when a cloud model entry has limits set (as Edit would do),
// the structure matches what opencode expects and re-edit preserves them.
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
expected := cloudModelLimits["glm-4.7"]
// Simulate a cloud model that already has the limit set by a previous Edit
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(fmt.Sprintf(`{
"provider": {
"ollama": {
"models": {
"glm-4.7:cloud": {
"name": "glm-4.7:cloud",
"_launch": true,
"limit": {"context": %d, "output": %d}
}
}
}
}
}`, expected.Context, expected.Output)), 0o644)
// Re-edit should preserve the cloud model limit
if err := o.Edit([]string{"glm-4.7:cloud"}); err != nil {
t.Fatal(err)
}
entry := readOpenCodeModel(t, configPath, "glm-4.7:cloud")
limit, ok := entry["limit"].(map[string]any)
if !ok {
t.Fatal("cloud model limit was removed on re-edit")
}
if limit["context"] != float64(expected.Context) {
t.Errorf("context = %v, want %d", limit["context"], expected.Context)
}
if limit["output"] != float64(expected.Output) {
t.Errorf("output = %v, want %d", limit["output"], expected.Output)
}
}
func TestLookupCloudModelLimit(t *testing.T) {
tests := []struct {
name string
wantOK bool
wantContext int
wantOutput int
}{
{"glm-4.7", true, 202_752, 131_072},
{"glm-4.7:cloud", true, 202_752, 131_072},
{"kimi-k2.5", true, 262_144, 262_144},
{"kimi-k2.5:cloud", true, 262_144, 262_144},
{"deepseek-v3.2", true, 163_840, 65_536},
{"deepseek-v3.2:cloud", true, 163_840, 65_536},
{"qwen3-coder:480b", true, 262_144, 65_536},
{"llama3.2", false, 0, 0},
{"unknown-model:cloud", false, 0, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
l, ok := lookupCloudModelLimit(tt.name)
if ok != tt.wantOK {
t.Errorf("lookupCloudModelLimit(%q) ok = %v, want %v", tt.name, ok, tt.wantOK)
}
if ok {
if l.Context != tt.wantContext {
t.Errorf("context = %d, want %d", l.Context, tt.wantContext)
}
if l.Output != tt.wantOutput {
t.Errorf("output = %d, want %d", l.Output, tt.wantOutput)
}
}
})
}
}
func TestOpenCodeModels_NoConfig(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()

View File

@@ -34,6 +34,7 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/tokenizer"
)
type filteredEnv []string
@@ -116,7 +117,7 @@ type llamaServer struct {
type ollamaServer struct {
llmServer
textProcessor model.TextProcessor // textProcessor handles text encoding/decoding
tokenizer tokenizer.Tokenizer // tokenizer handles text encoding/decoding
}
// LoadModel will load a model from disk. The model must be in the GGML format.
@@ -142,11 +143,11 @@ func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) {
// NewLlamaServer will run a server for the given GPUs
func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
var llamaModel *llama.Model
var textProcessor model.TextProcessor
var tok tokenizer.Tokenizer
var err error
if envconfig.NewEngine() || f.KV().OllamaEngineRequired() {
if len(projectors) == 0 {
textProcessor, err = model.NewTextProcessor(modelPath)
tok, err = model.NewTextProcessor(modelPath)
} else {
err = errors.New("split vision models aren't supported")
}
@@ -155,7 +156,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err)
}
}
if textProcessor == nil {
if tok == nil {
llamaModel, err = llama.LoadModelFromFile(modelPath, llama.ModelParams{VocabOnly: true})
if err != nil {
return nil, err
@@ -211,7 +212,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
kvct := strings.ToLower(envconfig.KvCacheType())
if textProcessor == nil {
if tok == nil {
flashAttention := ml.FlashAttentionAuto
if faUserSet {
if fa {
@@ -261,7 +262,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
gpuLibs := ml.LibraryPaths(gpus)
status := NewStatusWriter(os.Stderr)
cmd, port, err := StartRunner(
textProcessor != nil,
tok != nil,
modelPath,
gpuLibs,
status,
@@ -310,8 +311,8 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
}
}()
if textProcessor != nil {
return &ollamaServer{llmServer: s, textProcessor: textProcessor}, nil
if tok != nil {
return &ollamaServer{llmServer: s, tokenizer: tok}, nil
} else {
return &llamaServer{llmServer: s, ggml: f}, nil
}
@@ -1774,7 +1775,7 @@ func (s *llamaServer) Tokenize(ctx context.Context, content string) ([]int, erro
}
func (s *ollamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
tokens, err := s.textProcessor.Encode(content, false)
tokens, err := s.tokenizer.Encode(content, false)
if err != nil {
return nil, err
}
@@ -1809,7 +1810,7 @@ func (s *ollamaServer) Detokenize(ctx context.Context, tokens []int) (string, er
toks[i] = int32(t)
}
content, err := s.textProcessor.Decode(toks)
content, err := s.tokenizer.Decode(toks)
if err != nil {
return "", err
}

View File

@@ -1,272 +0,0 @@
package model
import (
"cmp"
"iter"
"slices"
"strings"
"github.com/dlclark/regexp2"
heap "github.com/emirpasic/gods/v2/trees/binaryheap"
"github.com/ollama/ollama/logutil"
)
type BytePairEncoding struct {
vocab *Vocabulary
regexps []*regexp2.Regexp
}
var _ TextProcessor = (*BytePairEncoding)(nil)
func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding {
if len(pretokenizers) == 0 {
// set default byte-level pretokenizer if none provided, e.g.
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44
pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
}
return BytePairEncoding{
vocab: vocab,
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
for _, p := range pretokenizers {
if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
return
}
}
}),
}
}
func (bpe BytePairEncoding) Vocabulary() *Vocabulary {
return bpe.vocab
}
func (bpe BytePairEncoding) Is(id int32, special Special) bool {
return bpe.vocab.Is(id, special)
}
func (bpe *BytePairEncoding) split(s string) iter.Seq[string] {
parts := []string{s}
for _, re := range bpe.regexps {
parts = slices.Collect(func(yield func(string) bool) {
for _, part := range parts {
r := []rune(part)
var offset int
for m, _ := re.FindRunesMatch(r); m != nil; m, _ = re.FindNextMatch(m) {
if offset-m.Index != 0 {
if !yield(string(r[:m.Index])) {
return
}
}
if !yield(m.String()) {
return
}
offset = m.Index + m.Length
}
if offset < len(r) {
if !yield(string(r[offset:])) {
return
}
}
}
})
}
return slices.Values(parts)
}
// fragment is a string fragment and their corresponding token IDs
type fragment struct {
value string
ids []int32
}
// pair is a pair of runes and its rank
type pair struct {
a, b int
rank int
value string
}
type merge struct {
p, n int
runes []rune
}
func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
fragments := []fragment{{value: s}}
for _, special := range bpe.vocab.SpecialVocabulary() {
// TODO: process special tokens concurrently
id := bpe.vocab.Encode(special)
for i := 0; i < len(fragments); i++ {
frag := fragments[i]
if len(frag.ids) > 0 {
continue
}
var middle []fragment
switch i := strings.Index(frag.value, special); {
case i < 0:
middle = append(middle, frag)
case i > 0:
middle = append(middle, fragment{value: frag.value[:i]})
fallthrough
default:
middle = append(middle, fragment{value: special, ids: []int32{id}})
if rest := frag.value[i+len(special):]; rest != "" {
middle = append(middle, fragment{value: rest})
}
}
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
}
}
var ids []int32
for _, frag := range fragments {
if len(frag.ids) > 0 {
ids = append(ids, frag.ids...)
continue
}
for split := range bpe.split(frag.value) {
// TODO: process splits concurrently
var sb strings.Builder
for _, b := range []byte(split) {
r := rune(b)
switch {
case r == 0x00ad:
r = 0x0143
case r <= 0x0020:
r = r + 0x0100
case r >= 0x007f && r <= 0x00a0:
r = r + 0x00a2
}
sb.WriteRune(r)
}
// short circuit if the fragment is in the vocabulary
if id := bpe.vocab.Encode(sb.String()); id >= 0 {
ids = append(ids, id)
continue
}
runes := []rune(sb.String())
merges := make([]merge, len(runes))
for r := range runes {
merges[r] = merge{
p: r - 1,
n: r + 1,
runes: []rune{runes[r]},
}
}
pairwise := func(a, b int) *pair {
if a < 0 || b >= len(runes) {
return nil
}
left, right := string(merges[a].runes), string(merges[b].runes)
rank := bpe.vocab.Merge(left, right)
if rank < 0 {
return nil
}
return &pair{
a: a,
b: b,
rank: rank,
value: left + right,
}
}
pairs := heap.NewWith(func(i, j *pair) int {
return cmp.Compare(i.rank, j.rank)
})
for i := range len(runes) - 1 {
if pair := pairwise(i, i+1); pair != nil {
pairs.Push(pair)
}
}
for !pairs.Empty() {
pair, _ := pairs.Pop()
left, right := merges[pair.a], merges[pair.b]
if len(left.runes) == 0 || len(right.runes) == 0 ||
string(left.runes)+string(right.runes) != pair.value {
continue
}
if id := bpe.vocab.Encode(pair.value); id < 0 {
continue
}
merges[pair.a].runes = append(left.runes, right.runes...)
merges[pair.b].runes = nil
merges[pair.a].n = right.n
if right.n < len(merges) {
merges[right.n].p = pair.a
}
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
pairs.Push(pair)
}
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
pairs.Push(pair)
}
}
for _, merge := range merges {
if len(merge.runes) > 0 {
// TODO: handle the edge case where the rune isn't in the vocabulary
if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 {
ids = append(ids, id)
}
}
}
}
}
if addSpecial {
ids = bpe.vocab.addSpecials(ids)
}
logutil.Trace("encoded", "string", s, "ids", ids)
return ids, nil
}
func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
var sb strings.Builder
for _, id := range ids {
for _, r := range bpe.vocab.Decode(id) {
switch {
case r == 0x0100:
// this produces 0x00 aka NULL
continue
case r == 0x0143:
r = 0x00ad
case r > 0x0100 && r <= 0x0120:
r = r - 0x0100
case r > 0x0120 && r <= 0x0142:
r = r - 0x00a2
}
// NOTE: not using WriteRune here because it writes the UTF-8
// encoding of the rune which is _not_ what we want
if err := sb.WriteByte(byte(r)); err != nil {
return "", err
}
}
}
logutil.Trace("decoded", "string", sb.String(), "from", ids)
return sb.String(), nil
}

View File

@@ -23,6 +23,7 @@ import (
_ "github.com/ollama/ollama/ml/backend"
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
var (
@@ -133,7 +134,7 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
return m, nil
}
func NewTextProcessor(s string) (TextProcessor, error) {
func NewTextProcessor(s string) (tokenizer.Tokenizer, error) {
r, err := os.Open(s)
if err != nil {
return nil, err
@@ -150,7 +151,7 @@ func NewTextProcessor(s string) (TextProcessor, error) {
return nil, err
}
tp, ok := m.(TextProcessor)
tp, ok := m.(tokenizer.Tokenizer)
if !ok {
return nil, ErrUnsupportedTokenizer
}

View File

@@ -10,11 +10,12 @@ import (
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TypeEmbedding *nn.Embedding `gguf:"token_types"`
@@ -129,7 +130,7 @@ func (o Options) headDim() int {
}
func New(c fs.Config) (model.Model, error) {
vocab := &model.Vocabulary{
vocab := &tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -153,17 +154,17 @@ func New(c fs.Config) (model.Model, error) {
},
}
var processor model.TextProcessor
var t tokenizer.Tokenizer
switch c.String("tokenizer.ggml.model", "bert") {
case "bert":
processor = model.NewWordPiece(vocab, true)
t = tokenizer.NewWordPiece(vocab, true)
default:
return nil, model.ErrUnsupportedTokenizer
}
return &Model{
TextProcessor: processor,
Layers: make([]EncoderLayer, c.Uint("block_count")),
Tokenizer: t,
Layers: make([]EncoderLayer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),

View File

@@ -13,6 +13,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -222,7 +223,7 @@ func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -277,8 +278,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,11 +10,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
Sam *samModel `gguf:"s"`
Vision *visionModel `gguf:"v"`
@@ -134,8 +135,8 @@ func init() {
}
m := Model{
TextProcessor: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,6 +10,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -27,7 +28,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
type Model struct {
model.Base
model.SentencePiece
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -43,8 +44,8 @@ const (
func New(c fs.Config) (model.Model, error) {
m := Model{
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Tokenizer: tokenizer.NewSentencePiece(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),

View File

@@ -7,11 +7,12 @@ import (
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type embedModel struct {
model.Base
model.SentencePiece
tokenizer.Tokenizer
*TextModel
poolingType pooling.Type
@@ -31,8 +32,8 @@ func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, erro
func newEmbedModel(c fs.Config) (model.Model, error) {
m := &embedModel{
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Tokenizer: tokenizer.NewSentencePiece(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),

View File

@@ -12,11 +12,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
*VisionModel `gguf:"v"`
*TextModel
@@ -54,7 +55,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
}
func New(c fs.Config) (model.Model, error) {
vocabulary := model.Vocabulary{
vocabulary := tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -70,19 +71,19 @@ func New(c fs.Config) (model.Model, error) {
),
}
var processor model.TextProcessor
var t tokenizer.Tokenizer
switch c.String("tokenizer.ggml.model") {
case "gpt2":
processor = model.NewBytePairEncoding(&vocabulary)
t = tokenizer.NewBytePairEncoding(&vocabulary)
default:
// Previous uploads of Gemma 3 on Ollama did not have token 106
// (i.e. "<end_of_turn>") so we need to add in case it's not already present
vocabulary.EOS = append(vocabulary.EOS, int32(c.Uint("tokenizer.ggml.eot_token_id", 106)))
processor = model.NewSentencePiece(&vocabulary)
t = tokenizer.NewSentencePiece(&vocabulary)
}
m := Model{
TextProcessor: processor,
Tokenizer: t,
ImageProcessor: newImageProcessor(c),
VisionModel: newVisionModel(c),
TextModel: newTextModel(c),

View File

@@ -6,11 +6,12 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.SentencePiece
tokenizer.Tokenizer
*TextModel
}
@@ -23,8 +24,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func New(c fs.Config) (model.Model, error) {
m := Model{
TextModel: newTextModel(c),
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Tokenizer: tokenizer.NewSentencePiece(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),

View File

@@ -10,6 +10,7 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
var ErrOldModelFormat = errors.New("this model uses a weight format that is no longer supported; please re-download it")
@@ -198,7 +199,7 @@ func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -236,8 +237,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
*TextModel
*VisionModel `gguf:"v"`
@@ -37,8 +38,8 @@ func New(c fs.Config) (model.Model, error) {
allEOS := append([]int32{eosTokenID}, eosTokenIDs...)
m := &Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -12,11 +12,12 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Transformer struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TransformerBlocks []TransformerBlock `gguf:"blk"`
@@ -196,8 +197,8 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Optio
func New(c fs.Config) (model.Model, error) {
m := Transformer{
TransformerBlocks: make([]TransformerBlock, c.Uint("block_count")),
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,6 +10,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -59,7 +60,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -78,7 +79,7 @@ func New(c fs.Config) (model.Model, error) {
return nil, model.ErrUnsupportedTokenizer
}
vocabulary := model.Vocabulary{
vocabulary := tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -104,8 +105,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
TextProcessor: model.NewBytePairEncoding(&vocabulary, pretokenizers...),
Layers: make([]Layer, c.Uint("block_count")),
Tokenizer: tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...),
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
headDim: int(c.Uint("attention.key_length")),

View File

@@ -11,6 +11,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -25,7 +26,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -41,8 +42,8 @@ func New(c fs.Config) (model.Model, error) {
return nil, model.ErrUnsupportedModel
}
var processor model.TextProcessor
vocabulary := model.Vocabulary{
var processor tokenizer.Tokenizer
vocabulary := tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -80,16 +81,16 @@ func New(c fs.Config) (model.Model, error) {
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}
}
processor = model.NewBytePairEncoding(&vocabulary, pretokenizers...)
processor = tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...)
case "llama":
processor = model.NewSentencePiece(&vocabulary)
processor = tokenizer.NewSentencePiece(&vocabulary)
default:
return nil, model.ErrUnsupportedTokenizer
}
m := Model{
TextProcessor: processor,
Layers: make([]Layer, c.Uint("block_count")),
Tokenizer: processor,
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),

View File

@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
ImageProcessor
*VisionModel `gguf:"v"`
@@ -33,8 +34,8 @@ func (p *Projector) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor {
func New(c fs.Config) (model.Model, error) {
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
*TextModel
*VisionModel `gguf:"v"`
@@ -28,12 +29,12 @@ type Model struct {
var _ model.MultimodalProcessor = (*Model)(nil)
// Implement TextProcessor interface
var _ model.TextProcessor = (*Model)(nil)
var _ tokenizer.Tokenizer = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
m := &Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
*VisionModel `gguf:"v"`
*TextModel
@@ -32,8 +33,8 @@ const (
func New(c fs.Config) (model.Model, error) {
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TypeEmbedding *nn.Embedding `gguf:"token_types"`
@@ -178,29 +179,6 @@ func New(c fs.Config) (model.Model, error) {
numHeads := int(c.Uint("attention.head_count"))
headDim := hiddenSize / numHeads
processor := model.NewWordPiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.cls_token_id"),
c.Uint("tokenizer.ggml.bos_token_id"),
)),
},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
EOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.separator_token_id"),
c.Uint("tokenizer.ggml.eos_token_id"),
)),
},
},
false,
)
blockCount := int(c.Uint("block_count"))
moeEveryNLayers := int(c.Uint("moe_every_n_layers", 0))
layers := make([]EncoderLayer, blockCount)
@@ -219,8 +197,29 @@ func New(c fs.Config) (model.Model, error) {
}
return &Model{
TextProcessor: processor,
Layers: layers,
Tokenizer: tokenizer.NewWordPiece(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.cls_token_id"),
c.Uint("tokenizer.ggml.bos_token_id"),
)),
},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
EOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.separator_token_id"),
c.Uint("tokenizer.ggml.eos_token_id"),
)),
},
},
false,
),
Layers: layers,
Options: Options{
hiddenSize: hiddenSize,
numHeads: numHeads,

View File

@@ -11,6 +11,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
const (
@@ -33,7 +34,7 @@ type Options struct {
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -44,28 +45,24 @@ type Model struct {
}
func New(c fs.Config) (model.Model, error) {
vocabulary := model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
}
processor := model.NewBytePairEncoding(
&vocabulary,
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
)
m := Model{
TextProcessor: processor,
Layers: make([]Layer, c.Uint("block_count")),
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
),
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),

View File

@@ -13,6 +13,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -92,7 +93,7 @@ func (d DecoderLayer) Forward(ctx ml.Context, hiddenStates, positions, outputs m
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []DecoderLayer `gguf:"blk"`
@@ -139,8 +140,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
Layers: make([]DecoderLayer, c.Uint("block_count")),
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,11 +10,12 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
*TextModel
*VisionModel `gguf:"v"`
@@ -27,8 +28,8 @@ var _ model.MultimodalProcessor = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
m := &Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -7,11 +7,12 @@ import (
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type embedModel struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
*Model
poolingType pooling.Type
@@ -34,8 +35,8 @@ func newEmbed(c fs.Config) (model.Model, error) {
layers[i].MLP = &dense{}
}
m := embedModel{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -12,6 +12,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -159,7 +160,7 @@ func (d *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
@@ -218,8 +219,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -11,6 +11,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
// Options contains model configuration
@@ -207,7 +208,7 @@ func (l *Layer) Forward(ctx ml.Context, layer int, hiddenStates, positions, outp
// Model is the main Qwen3-Next model
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
@@ -353,8 +354,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,11 +10,12 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
*TextModel
*VisionModel `gguf:"v"`
@@ -172,8 +173,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func New(c fs.Config) (model.Model, error) {
m := Model{
TextProcessor: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -1,53 +0,0 @@
package model
import (
"slices"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestWordPiece(t *testing.T) {
wpm := NewWordPiece(
&Vocabulary{
Values: []string{"[UNK]", "[CLS]", "[SEP]", "▁hello", "▁world", "s", "▁!", "▁@", "▁#"},
AddBOS: true,
AddEOS: true,
BOS: []int32{1},
EOS: []int32{2},
},
true, // lowercase
)
ids, err := wpm.Encode("Hello world!", true)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff([]int32{1, 3, 4, 6, 2}, ids); diff != "" {
t.Errorf("unexpected ids (-want +got):\n%s", diff)
}
words, err := wpm.Decode(ids)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff("[CLS] hello world! [SEP]", words); diff != "" {
t.Errorf("unexpected words (-want +got):\n%s", diff)
}
}
func TestWordPieceWords(t *testing.T) {
var wpm WordPiece
basic := slices.Collect(wpm.words("Hey friend! How are you?!?"))
if diff := cmp.Diff([]string{"Hey", "friend", "!", "How", "are", "you", "?", "!", "?"}, basic); diff != "" {
t.Errorf("unexpected words (-want +got):\n%s", diff)
}
chinese := slices.Collect(wpm.words("野口里佳 Noguchi Rika"))
if diff := cmp.Diff([]string{"野", "口", "里", "佳", "Noguchi", "Rika"}, chinese); diff != "" {
t.Errorf("unexpected words (-want +got):\n%s", diff)
}
}

View File

@@ -37,6 +37,7 @@ import (
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/runner/common"
"github.com/ollama/ollama/sample"
"github.com/ollama/ollama/tokenizer"
_ "github.com/ollama/ollama/model/models"
)
@@ -210,9 +211,9 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
}
// calculateLogprobs converts raw logits to log probabilities and finds top K tokens
func calculateLogprobs(logits []float32, selectedToken int32, topK int, textProcessor model.TextProcessor) []llm.Logprob {
func calculateLogprobs(logits []float32, selectedToken int32, topK int, tok tokenizer.Tokenizer) []llm.Logprob {
decoder := func(tokenID int) string {
text, _ := textProcessor.Decode([]int32{int32(tokenID)})
text, _ := tok.Decode([]int32{int32(tokenID)})
return text
}
return common.CalculateLogprobs(logits, int(selectedToken), topK, decoder)
@@ -242,7 +243,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input,
for i, part := range parts {
// text - tokenize
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
tokens, err := s.model.(tokenizer.Tokenizer).Encode(part, i == 0)
if err != nil {
return nil, nil, nil, err
}
@@ -764,7 +765,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
nextBatchTokens[i].Token = token
// if it's an end of sequence token, break
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
if s.model.(tokenizer.Tokenizer).Is(token, tokenizer.SpecialEOS) {
// TODO (jmorganca): we should send this back
// as it's important for the /api/generate context
// seq.responses <- piece
@@ -773,14 +774,14 @@ func (s *Server) computeBatch(activeBatch batchState) {
continue
}
piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
piece, err := s.model.(tokenizer.Tokenizer).Decode([]int32{token})
if err != nil {
panic("failed to decode token")
}
// Calculate logprobs if requested (after EOS check to avoid logprobs for EOS tokens)
if seq.logprobs {
logprobs := calculateLogprobs(logits, token, seq.topLogprobs, s.model.(model.TextProcessor))
logprobs := calculateLogprobs(logits, token, seq.topLogprobs, s.model.(tokenizer.Tokenizer))
seq.pendingLogprobs = append(seq.pendingLogprobs, logprobs...)
}
@@ -878,7 +879,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
var grammar *sample.GrammarSampler
var err error
if req.Grammar != "" {
grammar, err = sample.NewGrammarSampler(s.model.(model.TextProcessor), req.Grammar)
grammar, err = sample.NewGrammarSampler(s.model.(tokenizer.Tokenizer), req.Grammar)
if err != nil {
http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
return

View File

@@ -7,7 +7,7 @@ import (
"slices"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/tokenizer"
)
// token represents information about a single token during sampling
@@ -168,15 +168,15 @@ type GrammarSampler struct {
grammar *llama.Grammar
}
func NewGrammarSampler(model model.TextProcessor, grammarStr string) (*GrammarSampler, error) {
vocabIds := make([]uint32, len(model.Vocabulary().Values))
pieces := make([]string, len(model.Vocabulary().Values))
for i := range model.Vocabulary().Values {
pieces[i], _ = model.Decode([]int32{int32(i)})
func NewGrammarSampler(tok tokenizer.Tokenizer, grammarStr string) (*GrammarSampler, error) {
vocabIds := make([]uint32, len(tok.Vocabulary().Values))
pieces := make([]string, len(tok.Vocabulary().Values))
for i := range tok.Vocabulary().Values {
pieces[i], _ = tok.Decode([]int32{int32(i)})
vocabIds[i] = uint32(i)
}
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, model.Vocabulary().EOS)
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, tok.Vocabulary().EOS)
if grammar == nil {
return nil, errors.New("sample: failed to initialize grammar")
}

View File

@@ -8,7 +8,7 @@ import (
"path/filepath"
"testing"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/tokenizer"
)
func TestWeighted(t *testing.T) {
@@ -60,10 +60,10 @@ func TestWeighted(t *testing.T) {
}
}
func modelHelper(t testing.TB) model.BytePairEncoding {
func modelHelper(t testing.TB) tokenizer.Tokenizer {
t.Helper()
f, err := os.Open(filepath.Join("..", "model", "testdata", "llama3.2", "encoder.json"))
f, err := os.Open(filepath.FromSlash("../tokenizer/testdata/llama3.2/encoder.json"))
if err != nil {
t.Fatal(err)
}
@@ -81,8 +81,8 @@ func modelHelper(t testing.TB) model.BytePairEncoding {
merges := make([]string, 0, 1)
// Only need vocab for Grammar Test
return model.NewBytePairEncoding(
&model.Vocabulary{
return tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: tokens,
Types: make([]int32, len(vocab)),
Merges: merges,

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"cmp"
@@ -18,19 +18,19 @@ type BytePairEncoding struct {
regexps []*regexp2.Regexp
}
var _ TextProcessor = (*BytePairEncoding)(nil)
var _ Tokenizer = (*BytePairEncoding)(nil)
func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding {
if len(pretokenizers) == 0 {
func NewBytePairEncoding(vocab *Vocabulary, pretokenizer ...string) BytePairEncoding {
if len(pretokenizer) == 0 {
// set default byte-level pretokenizer if none provided, e.g.
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44
pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
// https://github.com/huggingface/tokenizer/blob/main/tokenizer/src/pre_tokenizer/byte_level.rs#L44
pretokenizer = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
}
return BytePairEncoding{
vocab: vocab,
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
for _, p := range pretokenizers {
for _, p := range pretokenizer {
if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
return
}

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"bufio"
@@ -17,7 +17,7 @@ import (
func llama(t testing.TB) BytePairEncoding {
t.Helper()
f, err := os.Open(filepath.Join("testdata", "llama3.2", "encoder.json"))
f, err := os.Open(filepath.FromSlash("testdata/llama3.2/encoder.json"))
if err != nil {
t.Fatal(err)
}
@@ -43,7 +43,7 @@ func llama(t testing.TB) BytePairEncoding {
}
}
f, err = os.Open(filepath.Join("testdata", "llama3.2", "vocab.bpe"))
f, err = os.Open(filepath.FromSlash("testdata/llama3.2/vocab.bpe"))
if err != nil {
t.Fatal(err)
}

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"container/heap"
@@ -17,7 +17,7 @@ type SentencePiece struct {
vocab *Vocabulary
}
var _ TextProcessor = (*SentencePiece)(nil)
var _ Tokenizer = (*SentencePiece)(nil)
func (spm SentencePiece) Vocabulary() *Vocabulary {
return spm.vocab
@@ -224,7 +224,7 @@ func (spm SentencePiece) Decode(ids []int32) (string, error) {
data := spm.vocab.Decode(id)
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
// For tokenizers that use byte tokens like "<0xEA>"
// For tokenizer that use byte tokens like "<0xEA>"
// convert them to the partial unicode character
// so they are buffered correctly by the runner instead
// of being sent back to the api as "<0xEA>"

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"log/slog"
@@ -15,7 +15,7 @@ import (
func loadSentencePieceVocab(t *testing.T) SentencePiece {
t.Helper()
bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
bts, err := os.ReadFile(filepath.FromSlash("testdata/gemma2/tokenizer.model"))
if err != nil {
t.Fatal(err)
}

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
const (
TOKEN_TYPE_NORMAL = iota + 1
@@ -9,7 +9,7 @@ const (
TOKEN_TYPE_BYTE
)
type TextProcessor interface {
type Tokenizer interface {
Encode(s string, addSpecial bool) ([]int32, error)
Decode([]int32) (string, error)
Is(int32, Special) bool

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"log/slog"

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"testing"

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"fmt"
@@ -32,7 +32,7 @@ var wordPieceReplacer = strings.NewReplacer(
" 're", "'re",
)
// Decode implements TextProcessor.
// Decode implements Tokenizer.
func (wpm WordPiece) Decode(ids []int32) (string, error) {
var sb strings.Builder
for i, id := range ids {
@@ -96,7 +96,7 @@ func (wpm WordPiece) words(s string) iter.Seq[string] {
}
}
// Encode implements TextProcessor.
// Encode implements Tokenizer.
func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
var ids []int32
@@ -151,17 +151,17 @@ func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
return ids, nil
}
// Is implements TextProcessor.
// Is implements Tokenizer.
func (wpm WordPiece) Is(id int32, special Special) bool {
return wpm.vocab.Is(id, special)
}
// Vocabulary implements TextProcessor.
// Vocabulary implements Tokenizer.
func (wpm WordPiece) Vocabulary() *Vocabulary {
return wpm.vocab
}
var _ TextProcessor = (*WordPiece)(nil)
var _ Tokenizer = (*WordPiece)(nil)
func NewWordPiece(vocab *Vocabulary, lowercase bool) WordPiece {
return WordPiece{

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"slices"

View File

@@ -1,77 +0,0 @@
package kvcache
import (
"errors"
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/model/input"
)
var (
ErrKvCacheFull = errors.New("could not find a kv cache slot")
ErrNotSupported = errors.New("model does not support operation")
)
type Cache interface {
// ** used by model implementations **
// SetLayer sets the active layer of the cache
SetLayer(layer int)
// Get returns the history of key and value tensors plus a mask
//
// The shape of the tensors is documented in the specific
// cache implementation used.
Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)
// Put stores a batch of key and value in the cache
//
// The shape of the tensors is documented in the specific
// cache implementation used.
Put(ctx ml.Context, key, value ml.Tensor)
// SetConfig controls optimizations (mostly backend-specific) that may transform
// the output of the cache to work better with specific kernels. If not called,
// the backend settings will be used. This works well when calling Attention.
//
// The config can be overridden by models, especially if they require vanilla
// output when implementing their own version of attention. To do this, pass
// an empty ml.CacheConfig.
//
// Most models will not need to use this.
SetConfig(ml.CacheConfig)
// ** cache management **
// Init sets up runtime parameters.
// backend: Used to allocate cache data storage and execute management operations (such as defrag)
// dtype: The data type for storing cache entries
// maxSequences: The maximum number of sequences stored in the cache - across all batches
// capacity: The number of cache entries to store, per sequence
// maxBatch: The maximum number of tokens that can occur in a single batch
Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
// Close closes the cache and frees resources associated with it
Close()
// StartForward is called before the start of the model's forward pass.
// For each token in the coming batch, there must be a corresponding
// entry in positions and seqs. reserve is to preallocate memory
// without actually storing data in the cache.
StartForward(ctx ml.Context, batch input.Batch, reserve bool) error
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
CopyPrefix(srcSeq, dstSeq int, len int32)
// CanResume returns true if the cache can continue with the next token at
// the given position and sequence. Assumes that the caller has already
// verified the contents of the cache.
CanResume(seq int, pos int32) bool
// Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set
// endIndex to math.MaxInt32 to remove everything starting at beginIndex.
//
// If an error occurs, the entire context for the sequence should be
// removed by calling Remove(seq, 0, math.MaxInt32)
Remove(seq int, beginIndex, endIndex int32) error
}

View File

@@ -1,144 +0,0 @@
//go:build mlx
package kvcache
import (
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/model/input"
)
// Causal cache stores K and V tensors according to their position in the
// sequence. Returns the history and a mask for attending to past tokens
type Causal struct {
DType ml.DType
// locations for data storage for this batch
curLocPut ml.Tensor
// locations for data storage for this batch
curLocGet ml.Tensor
// the active layer for Get and Put
curLayer int
capacity int
offset int
backend ml.Backend
ctxs map[int]ml.Context
keys, values map[int]ml.Tensor
// TODO is this needed per layer, or will it always be consistent?
kHeadDims, vHeadDims, numKVHeads map[int]int
}
func NewCausalCache() *Causal {
return &Causal{
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
kHeadDims: make(map[int]int),
vHeadDims: make(map[int]int),
numKVHeads: make(map[int]int),
}
}
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
c.DType = dtype
c.capacity = capacity
c.backend = backend
}
func (c *Causal) SetConfig(config ml.CacheConfig) {}
func (c *Causal) SetLayer(layer int) {
c.curLayer = layer
}
func (c *Causal) Close() {
// slog.Info("XXX Causal.Close called", "number of contexts", len(c.ctxs))
for _, ctx := range c.ctxs {
ctx.Close()
}
}
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
locsPut := make([]int32, len(batch.Positions))
for i := c.offset; i < len(batch.Positions); i++ {
locsPut[i-c.offset] = int32(i)
}
c.offset += len(batch.Positions)
locsGet := make([]int32, c.offset)
for i := range c.offset {
locsGet[i] = int32(i)
}
c.curLocGet = ctx.Input().FromInts(locsGet, len(locsGet))
c.curLocPut = ctx.Input().FromInts(locsPut, len(locsPut))
// slog.Info("XXX Causal.StartForward", "offset", c.offset, "put", locsPut, "get", locsGet)
return nil
}
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
kHeadDim := key.Dim(3)
vHeadDim := value.Dim(3)
numKVHeads := key.Dim(1)
batchSize := key.Dim(2)
kCellSize := kHeadDim * numKVHeads
vCellSize := vHeadDim * numKVHeads
// slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize, "kCellSize", kCellSize, "vCellSize", vCellSize)
if _, ok := c.ctxs[c.curLayer]; !ok {
// slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer)
c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
}
if _, ok := c.keys[c.curLayer]; !ok {
// slog.Info("XXX Causal.Put allocating keys and values", "c.curLayer", c.curLayer, "shape", []int{c.capacity, kCellSize})
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, kCellSize)
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, vCellSize)
c.kHeadDims[c.curLayer] = kHeadDim
c.vHeadDims[c.curLayer] = vHeadDim
c.numKVHeads[c.curLayer] = numKVHeads
}
key = key.Reshape(ctx, batchSize, 1, kCellSize)
// slog.Info("XXX Causal.Put ", "c.keys[c.curLayer]", c.keys[c.curLayer])
// slog.Info("XXX Causal.Put ", "c.curLocPut", c.curLocPut)
// slog.Info("XXX Causal.Put ", "key", key)
ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, key, []int{0}))
value = value.Reshape(ctx, batchSize, 1, vCellSize)
ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, value, []int{0}))
}
func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
key := c.keys[c.curLayer]
value := c.values[c.curLayer]
kHeadDim := c.kHeadDims[c.curLayer]
vHeadDim := c.vHeadDims[c.curLayer]
numKVHeads := c.numKVHeads[c.curLayer]
// rowSize := numKVHeads * c.curBatchSize
// cachedSize := c.curMask.Dim(1)
cachedSize := c.curLocGet.Dim(0)
// kCellSize := kHeadDim * numKVHeads
// vCellSize := vHeadDim * numKVHeads
// slog.Info("XXX Causal.Get", "shape", []int{1, numKVHeads, cachedSize, kHeadDim})
key = key.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim)
value = value.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim)
return key, value, nil
}
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
panic("not implemented")
}
func (c *Causal) CanResume(seq int, pos int32) bool {
panic("not implemented")
}
func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
panic("not implemented")
}

View File

@@ -1,156 +0,0 @@
package kvcache
// import (
// "fmt"
// "github.com/ollama/ollama/ml"
// "github.com/ollama/ollama/model/input"
// )
// // Encoder cache stores K and V tensors that are position independent
// //
// // The tensors can be of any shape and will be returned as they were stored
// // The mask is currently always nil
// //
// // Not currently safe for multiple sequences
// type EncoderCache struct {
// // config controls mostly backend-specific optimizations
// config *ml.CacheConfig
// // ** current forward pass **
// // the active layer for Get and Put
// curLayer int
// // if something is stored during this pass, this
// // will be the position (but there is no guarantee
// // anything will be stored)
// curPos int32
// // curReserve indicates that this forward pass is only for
// // memory reservation and we should not update our metadata
// // based on it.
// curReserve bool
// // ** cache metadata **
// // was something stored in the cache?
// encoderCached bool
// // position of the cached data
// encoderPos int32
// // ** cache data storage **
// backend ml.Backend
// ctxs map[int]ml.Context
// keys, values map[int]ml.Tensor
// }
// func NewEncoderCache() *EncoderCache {
// return &EncoderCache{
// ctxs: make(map[int]ml.Context),
// keys: make(map[int]ml.Tensor),
// values: make(map[int]ml.Tensor),
// }
// }
// func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
// if c.config == nil {
// var config ml.CacheConfig
// if cc, ok := backend.(ml.BackendCacheConfig); ok {
// config = cc.CacheConfig()
// }
// c.config = &config
// }
// if maxSequences > 1 {
// panic(fmt.Errorf("encoder cache does not support multiple sequences; requested: %v", maxSequences))
// }
// if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
// panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
// }
// c.backend = backend
// }
// func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
// if c.config != nil {
// panic("config cannot be changed after being previously set, either by the model or backend")
// }
// c.config = &config
// }
// func (c *EncoderCache) Close() {
// for _, ctx := range c.ctxs {
// ctx.Close()
// }
// }
// func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
// // We work with the most recent image
// if len(batch.Multimodal) > 0 {
// c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
// }
// c.curReserve = reserve
// return nil
// }
// func (c *EncoderCache) SetLayer(layer int) {
// c.curLayer = layer
// }
// func (c *EncoderCache) EncoderCached() bool {
// return c.encoderCached
// }
// func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
// return c.keys[c.curLayer], c.values[c.curLayer], nil
// }
// func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
// if !c.curReserve {
// c.encoderPos = c.curPos
// c.encoderCached = true
// }
// if c.config.PermutedV {
// value = value.Transpose(ctx, 1, 2, 0, 3)
// }
// if _, ok := c.ctxs[c.curLayer]; !ok {
// c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
// }
// if _, ok := c.keys[c.curLayer]; !ok {
// c.keys[c.curLayer] = c.ctxs[c.curLayer].Empty(key.DType(), key.Shape()...)
// }
// if _, ok := c.values[c.curLayer]; !ok {
// c.values[c.curLayer] = c.ctxs[c.curLayer].Empty(value.DType(), value.Shape()...)
// }
// ctx.Forward(
// key.Copy(ctx, c.keys[c.curLayer]),
// value.Copy(ctx, c.values[c.curLayer]),
// )
// }
// func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
// panic("encoder cache does not support multiple sequences")
// }
// func (c *EncoderCache) CanResume(seq int, pos int32) bool {
// return true
// }
// func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error {
// if c.encoderPos >= beginIndex && c.encoderPos < endIndex {
// c.encoderCached = false
// }
// return nil
// }

View File

@@ -1,110 +0,0 @@
package kvcache
// import (
// "math"
// "github.com/ollama/ollama/ml"
// "github.com/ollama/ollama/model/input"
// )
// // Wrapper cache is a container for multiple types of caches,
// // such as for the encoding and decoding portions of a model.
// type WrapperCache struct {
// // caches we are wrapping
// caches []Cache
// // cache to be used for this layer
// curType int
// }
// func NewWrapperCache(caches ...Cache) *WrapperCache {
// return &WrapperCache{
// caches: caches,
// }
// }
// func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
// for _, cache := range c.caches {
// cache.Init(backend, dtype, maxSequences, capacity, maxBatch)
// }
// }
// func (c *WrapperCache) SetConfig(config ml.CacheConfig) {
// for _, cache := range c.caches {
// cache.SetConfig(config)
// }
// }
// func (c *WrapperCache) Close() {
// for _, cache := range c.caches {
// cache.Close()
// }
// }
// func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
// for i, cache := range c.caches {
// err := cache.StartForward(ctx, batch, reserve)
// if err != nil {
// // unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
// for j := i - 1; j >= 0; j-- {
// for k := range batch.Positions {
// _ = c.caches[j].Remove(batch.Sequences[k], batch.Positions[k], math.MaxInt32)
// }
// }
// return err
// }
// }
// c.curType = 0
// return nil
// }
// func (c *WrapperCache) SetLayer(layer int) {
// for _, cache := range c.caches {
// cache.SetLayer(layer)
// }
// }
// func (c *WrapperCache) SetLayerType(layerType int) {
// c.curType = layerType
// }
// func (c *WrapperCache) UnderlyingCache() Cache {
// return c.caches[c.curType]
// }
// func (c *WrapperCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
// return c.caches[c.curType].Get(ctx)
// }
// func (c *WrapperCache) Put(ctx ml.Context, key, value ml.Tensor) {
// c.caches[c.curType].Put(ctx, key, value)
// }
// func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
// for _, cache := range c.caches {
// cache.CopyPrefix(srcSeq, dstSeq, len)
// }
// }
// func (c *WrapperCache) CanResume(seq int, pos int32) bool {
// for _, cache := range c.caches {
// if !cache.CanResume(seq, pos) {
// return false
// }
// }
// return true
// }
// func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error {
// // If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail
// for _, cache := range c.caches {
// err := cache.Remove(seq, beginIndex, endIndex)
// if err != nil {
// return err
// }
// }
// return nil
// }

View File

@@ -1,433 +0,0 @@
package ml
import (
"fmt"
"log/slog"
"os"
"github.com/ollama/ollama/fs"
)
type Backend interface {
// Close frees all memory associated with this backend
// Close()
// Load(ctx context.Context, progress func(float32)) error
// BackendMemory returns the memory allocations that were made for this model
// BackendMemory() BackendMemory
Config() fs.Config
Get(name string) Tensor
NewContext() Context
// NewContextSize(size int) Context
// Enumerate the devices available for inference via this backend
// BackendDevices() []DeviceInfo
}
// BackendCacheConfig should be implemented by backends that need special output
// from the cache to meet specific requirements. It is frequently implemented in
// conjunction with ScaledDotProductAttention.
type BackendCacheConfig interface {
CacheConfig() CacheConfig
}
// CacheConfig controls optimizations (mostly backend-specific) that may transform
// the output the cache to work better with specific kernels.
type CacheConfig struct {
// CachePadding specifies the multiple for the number of tokens of cache history
// that will be returned from cache Get for k, v and mask. The capacity of the
// cache itself will also be increased to a multiple of this size if needed.
CachePadding int
// PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put
// and return the permuted version via Get. This uses the cache copy operation
// to avoid a Contiguous call on the permuted tensor.
PermutedV bool
// MaskDType specifies the data type for generating the mask. If unset it will
// default to DTypeF32.
MaskDType DType
// MaskBatchPadding specifies the multiple for the batch size dimension in the mask.
// Any position that does not correspond to an actual token will be filled with -Inf.
MaskBatchPadding int
}
// BackendParams controls how the backend loads and executes models
type BackendParams struct {
// AllocMemory causes the backend to allocate memory for the model. If
// false, this is only being used for discovering the required amount of
// memory and cannot load the model for running.
AllocMemory bool
// NumThreads sets the number of threads to use if running on the CPU
NumThreads int
// GPULayers is the set of layers to offload to GPUs
GPULayers GPULayersList
// FlashAttention indicates that we should use a fused flash attention kernel
FlashAttention bool
}
var backends = make(map[string]func(string, BackendParams) (Backend, error))
func RegisterBackend(name string, f func(string, BackendParams) (Backend, error)) {
if _, ok := backends[name]; ok {
panic("backend: backend already registered")
}
backends[name] = f
}
func NewBackend(modelPath string, params BackendParams) (Backend, error) {
be := os.Getenv("OLLAMA_BACKEND")
if be == "" {
be = "mlx"
slog.Info("Defaulting to " + be + ". Set OLLAMA_BACKEND to override")
}
slog.Info("Loading new engine", "backend", be)
if backend, ok := backends[be]; ok {
return backend(modelPath, params)
}
return nil, fmt.Errorf("unsupported backend")
}
type Context interface {
Empty(dtype DType, shape ...int) Tensor
Zeros(dtype DType, shape ...int) Tensor
// FromBytes(dtype DType, s []byte, shape ...int) Tensor
FromFloats(s []float32, shape ...int) Tensor
FromInts(s []int32, shape ...int) Tensor
RandomNormal(shape []int, dtype DType, loc, scale float32, key Tensor) Tensor
// Arange creates a 1D tensor with values within an interval (start, stop] increased by step.
Arange(start, stop, step float32, dtype DType) Tensor
Forward(...Tensor) Context
// SetBatchSize provides a hint on the batch size to optimize processing
// Uses heuristics if not set
// SetBatchSize(int)
Compute(...Tensor)
// ComputeWithNotify(func(), ...Tensor) // notify callback once compute has begun
// Reserve is analogous to Compute but rather than executing a
// graph, simply preallocates memory. Typically called with a
// worst case graph to ensure all resources are available for
// for future inference.
// Reserve()
// MaxGraphNodes() int
Close()
// Input returns a context appropriate for creating tensors that are
// inputs to the model (which includes things like output locations)
Input() Context
// Layer returns a context appropriate for creating intermediate tensors
Layer(int) Context
// Load a tensor from "filename" safetensors file, and compare with the input tensor
// Returns error if the shape is inconsistent, or similarity measures are below 99%
CompareWith(filename string, tensors map[string]Tensor, abortOnError bool) error
}
type RoPEOptions struct {
Base *float32
Freqs Tensor
}
func WithRoPEBase(base float32) func(*RoPEOptions) {
return func(opts *RoPEOptions) {
opts.Base = &base
}
}
func WithRoPEFreqs(freqs Tensor) func(*RoPEOptions) {
return func(opts *RoPEOptions) {
opts.Freqs = freqs
}
}
type Tensor interface {
ToString() string
RoPE(ctx Context, dims int, traditional bool, scale float32, offset int, options ...func(*RoPEOptions)) Tensor
ScaledDotProductAttention(ctx Context, keys, values Tensor, scale float64, maskMode string, mask Tensor, sinks Tensor) Tensor
TakeAxes(ctx Context, indicies Tensor, axes int) Tensor
// TakeAxes(ctx Context, axes int, indicies ...int) Tensor
Dim(n int) int
Stride(n int) int
Shape() []int
DType() DType
// Cast(ctx Context, dtype DType) Tensor
// Bytes() []byte
Floats() []float32
Ints() []int32
// FromBytes([]byte)
// FromFloats([]float32)
// FromInts([]int32)
Add(ctx Context, t2 Tensor) Tensor
Sub(ctx Context, t2 Tensor) Tensor
// Mul(ctx Context, t2 Tensor) Tensor
// Div(ctx Context, t2 Tensor) Tensor
Max(ctx Context, axes []int, keepDims bool) Tensor
Min(ctx Context, axes []int, keepDims bool) Tensor
Matmul(ctx Context, a2 Tensor) Tensor
// Mulmat(ctx Context, t2 Tensor) Tensor
// MulmatFullPrec(ctx Context, t2 Tensor) Tensor
// MulmatID(ctx Context, t2, ids Tensor) Tensor
// AddID(ctx Context, t2, ids Tensor) Tensor
Softmax(ctx Context) Tensor
L2Norm(ctx Context, eps float32) Tensor
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
Scale(ctx Context, s float64) Tensor
// SumRows(ctx Context) Tensor
AvgPool2D(ctx Context, k, s int, p float32) Tensor
Conv2D(ctx Context, weight Tensor, stride0, stride1, padding0, padding1, dilation0, dilation1, groups int) Tensor
Conv3D(ctx Context, weight Tensor, stride0, stride1, stride2, padding0, padding1, padding2, dilation0, dilation1, dilation2, groups int) Tensor
// IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
// Sin(ctx Context) Tensor
// Cos(ctx Context) Tensor
// Tanh(ctx Context) Tensor
GELU(ctx Context, up ...Tensor) Tensor
// QuickGELU(ctx Context, up ...Tensor) Tensor
// SILU(ctx Context, up ...Tensor) Tensor
// RELU(ctx Context, up ...Tensor) Tensor
// Sigmoid(ctx Context) Tensor
// AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit]
// SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor
Reshape(ctx Context, shape ...int) Tensor
AsStrided(ctx Context, shape, strides []int, offset int) Tensor
Transpose(ctx Context, shape ...int) Tensor
Contiguous(ctx Context, allowColMajor bool) Tensor
// Pad(ctx Context, shape ...int) Tensor
// Stack(ctx Context, dim int, s ...Tensor) Tensor
// Repeat repeats the tensor n times along dimension dim
// Repeat(ctx Context, dim, n int) Tensor
// Concat(ctx Context, t2 Tensor, dim int) Tensor
// Rows(ctx Context, t2 Tensor) Tensor
// TODO these probably aren't actually needed - false starts on trying to wire up cache
// SliceUpdate(ctx Context, update Tensor, start, stop, strides []int) Tensor
// SliceUpdateDynamic(ctx Context, update, start Tensor, axes []int) Tensor
// PutAlongAxis(ctx Context, indicies, values Tensor, axis int) Tensor
Scatter(ctx Context, indicies []Tensor, updates Tensor, axes []int) Tensor
Copy(ctx Context, t2 Tensor) Tensor
// Duplicate(ctx Context) Tensor
// Slice(ctx Context, dim, low, high, step int) Tensor
// Chunk(ctx Context, dim int, size int) []Tensor
// ChunkSections(ctx Context, dim int, sections ...int) []Tensor
// TopK(ctx Context, k int) Tensor
// Argsort(ctx Context) Tensor
// Mean(ctx Context) Tensor
// Variance(ctx Context) Tensor
// Stddev(ctx Context) Tensor
// Sqr(ctx Context) Tensor
// Sqrt(ctx Context) Tensor
// Interpolate(ctx Context, dims [4]int, samplingMode SamplingMode) Tensor
}
// ScaledDotProductAttention implements a fused attention
// operation equivalent to following code on a tensor named
// query:
//
// query = query.Permute(ctx, 0, 2, 1, 3)
// key = key.Permute(ctx, 0, 2, 1, 3)
// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
//
// kq := key.MulmatFullPrec(ctx, query)
//
// kq = kq.Scale(ctx, scale)
//
// if mask != nil {
// kq = kq.Add(ctx, mask)
// }
//
// kq = kq.Softmax(ctx)
//
// kqv := value.Mulmat(ctx, kq)
// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
// type ScaledDotProductAttention interface {
// ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64) Tensor
// }
// type number interface {
// ~int | ~int8 | ~int16 | ~int32 | ~int64 |
// ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
// ~float32 | ~float64 |
// ~complex64 | ~complex128
// }
// func mul[T number](s ...T) T {
// p := T(1)
// for _, v := range s {
// p *= v
// }
// return p
// }
// type DumpOptions func(*dumpOptions)
// // DumpWithPrecision sets the number of decimal places to print. Applies to float32 and float64.
// func DumpWithPrecision(n int) DumpOptions {
// return func(opts *dumpOptions) {
// opts.Precision = n
// }
// }
// // DumpWithThreshold sets the threshold for printing the entire tensor. If the number of elements
// // is less than or equal to this value, the entire tensor will be printed. Otherwise, only the
// // beginning and end of each dimension will be printed.
// func DumpWithThreshold(n int) DumpOptions {
// return func(opts *dumpOptions) {
// opts.Threshold = n
// }
// }
// // DumpWithEdgeItems sets the number of elements to print at the beginning and end of each dimension.
// func DumpWithEdgeItems(n int) DumpOptions {
// return func(opts *dumpOptions) {
// opts.EdgeItems = n
// }
// }
// type dumpOptions struct {
// Precision, Threshold, EdgeItems int
// }
// func Dump(ctx Context, t Tensor, optsFuncs ...DumpOptions) string {
// opts := dumpOptions{Precision: 4, Threshold: 1000, EdgeItems: 3}
// for _, optsFunc := range optsFuncs {
// optsFunc(&opts)
// }
// if mul(t.Shape()...) <= opts.Threshold {
// opts.EdgeItems = math.MaxInt
// }
// switch t.DType() {
// case DTypeFloat32:
// return dump[[]float32](ctx, t, opts.EdgeItems, func(f float32) string {
// return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
// })
// case DTypeFloat16: // TODO other types...
// f32 := ctx.Input().Empty(DTypeFloat32, t.Shape()...)
// f32 = t.Copy(ctx, f32)
// return dump[[]float32](ctx, f32, opts.EdgeItems, func(f float32) string {
// return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
// })
// case DTypeInt32:
// return dump[[]int32](ctx, t, opts.EdgeItems, func(i int32) string {
// return strconv.FormatInt(int64(i), 10)
// })
// default:
// return "<unsupported>"
// }
// }
// func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
// if t.Bytes() == nil {
// ctx.Compute(t)
// }
// s := make(S, mul(t.Shape()...))
// if err := binary.Read(bytes.NewBuffer(t.Bytes()), binary.LittleEndian, &s); err != nil {
// panic(err)
// }
// shape := t.Shape()
// slices.Reverse(shape)
// var sb strings.Builder
// var f func([]int, int)
// f = func(dims []int, stride int) {
// prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
// sb.WriteString("[")
// defer func() { sb.WriteString("]") }()
// for i := 0; i < dims[0]; i++ {
// if i >= items && i < dims[0]-items {
// sb.WriteString("..., ")
// // skip to next printable element
// skip := dims[0] - 2*items
// if len(dims) > 1 {
// stride += mul(append(dims[1:], skip)...)
// fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
// }
// i += skip - 1
// } else if len(dims) > 1 {
// f(dims[1:], stride)
// stride += mul(dims[1:]...)
// if i < dims[0]-1 {
// fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
// }
// } else {
// text := fn(s[stride+i])
// if len(text) > 0 && text[0] != '-' {
// sb.WriteString(" ")
// }
// sb.WriteString(text)
// if i < dims[0]-1 {
// sb.WriteString(", ")
// }
// }
// }
// }
// f(shape, 0)
// return sb.String()
// }
type DType int
const (
DTypeBool DType = iota
DTypeUint8
DTypeUint16
DTypeUint32
DTypeUint64
DTypeInt8
DTypeInt16
DTypeInt32
DTypeInt64
DTypeFloat16
DTypeFloat32
DTypeFloat64
DTypeBfloat16
DTypeComplex64
)
type SamplingMode int
const (
SamplingModeNearest SamplingMode = iota
SamplingModeBilinear
)

View File

@@ -1,3 +0,0 @@
package backend
// _ "github.com/ollama/ollama/x/ml/backend/mlx"

View File

@@ -1,61 +0,0 @@
include(FetchContent)
# Read MLX version from top-level file (shared with Dockerfile)
file(READ "${CMAKE_SOURCE_DIR}/MLX_VERSION" MLX_C_GIT_TAG)
string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG)
set(MLX_C_BUILD_EXAMPLES OFF)
set(MLX_BUILD_GGUF OFF)
set(MLX_BUILD_SAFETENSORS ON)
function(set_target_output_directory _target)
if(TARGET ${_target})
set_target_properties(${_target} PROPERTIES
RUNTIME_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR}
LIBRARY_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR}
ARCHIVE_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR}
)
endif()
endfunction()
# Check for Metal support (macOS only)
if(CMAKE_SYSTEM_NAME MATCHES "Darwin")
execute_process(
COMMAND
zsh "-c"
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
if(NOT MLX_METAL_VERSION)
message(STATUS "`xcrun metal` error. Setting MLX_BUILD_METAL=OFF")
set(MLX_BUILD_METAL OFF)
endif()
else()
# On Linux, disable Metal backend
message(STATUS "Non-macOS platform detected. Setting MLX_BUILD_METAL=OFF")
set(MLX_BUILD_METAL OFF)
endif()
# Map CMAKE_CUDA_ARCHITECTURES to MLX_CUDA_ARCHITECTURES if not explicitly set
if(NOT MLX_CUDA_ARCHITECTURES AND CMAKE_CUDA_ARCHITECTURES)
set(MLX_CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES})
message(STATUS "Using CMAKE_CUDA_ARCHITECTURES for MLX: ${MLX_CUDA_ARCHITECTURES}")
endif()
# Enable CUDA backend if CUDA architectures are specified and CUDA compiler is available
if(MLX_CUDA_ARCHITECTURES AND CMAKE_CUDA_COMPILER)
set(MLX_BUILD_CUDA ON CACHE BOOL "Build CUDA backend for MLX" FORCE)
message(STATUS "Enabling MLX CUDA backend with architectures: ${MLX_CUDA_ARCHITECTURES}")
elseif(MLX_CUDA_ARCHITECTURES)
message(WARNING "MLX_CUDA_ARCHITECTURES specified but CUDA compiler not found, CUDA backend will be disabled")
endif()
FetchContent_Declare(
mlx-c
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
GIT_TAG ${MLX_C_GIT_TAG})
FetchContent_MakeAvailable(mlx-c)
set_target_output_directory(mlx)
set_target_output_directory(mlxc)

View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,92 +0,0 @@
// mlx_dynamic.c - Dynamic loading wrapper for MLX-C library
// This file provides runtime dynamic loading of libmlxc instead of link-time binding
#include "mlx_dynamic.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#ifdef _WIN32
#include <windows.h>
typedef HMODULE lib_handle_t;
#define LOAD_LIB(path) LoadLibraryA(path)
#define GET_SYMBOL(handle, name) GetProcAddress(handle, name)
#define CLOSE_LIB(handle) FreeLibrary(handle)
#define LIB_ERROR() "LoadLibrary failed"
static const char* LIB_NAMES[] = {"libmlxc.dll", NULL};
#else
#include <dlfcn.h>
typedef void* lib_handle_t;
#define LOAD_LIB(path) dlopen(path, RTLD_LAZY | RTLD_GLOBAL)
#define GET_SYMBOL(handle, name) dlsym(handle, name)
#define CLOSE_LIB(handle) dlclose(handle)
#define LIB_ERROR() dlerror()
#ifdef __APPLE__
static const char* LIB_NAMES[] = {
"libmlxc.dylib",
"@loader_path/../build/lib/ollama/libmlxc.dylib",
"@executable_path/../build/lib/ollama/libmlxc.dylib",
"build/lib/ollama/libmlxc.dylib",
"../build/lib/ollama/libmlxc.dylib",
NULL
};
#else
static const char* LIB_NAMES[] = {
"libmlxc.so",
"$ORIGIN/../build/lib/ollama/libmlxc.so",
"build/lib/ollama/libmlxc.so",
"../build/lib/ollama/libmlxc.so",
NULL
};
#endif
#endif
static lib_handle_t mlx_handle = NULL;
static int mlx_initialized = 0;
static char mlx_error_buffer[512] = {0};
// Initialize MLX dynamic library
// Returns 0 on success, -1 on failure
// On failure, call mlx_dynamic_error() to get error message
int mlx_dynamic_init(void) {
if (mlx_initialized) {
return 0; // Already initialized
}
// Try each possible library path
for (int i = 0; LIB_NAMES[i] != NULL; i++) {
mlx_handle = LOAD_LIB(LIB_NAMES[i]);
if (mlx_handle != NULL) {
mlx_initialized = 1;
snprintf(mlx_error_buffer, sizeof(mlx_error_buffer),
"MLX: Successfully loaded %s", LIB_NAMES[i]);
return 0;
}
}
// Failed to load library
const char* err = LIB_ERROR();
snprintf(mlx_error_buffer, sizeof(mlx_error_buffer),
"MLX: Failed to load libmlxc library. %s",
err ? err : "Unknown error");
return -1;
}
// Get the last error message
const char* mlx_dynamic_error(void) {
return mlx_error_buffer;
}
// Check if MLX is initialized
int mlx_dynamic_is_initialized(void) {
return mlx_initialized;
}
// Cleanup (optional, called at program exit)
void mlx_dynamic_cleanup(void) {
if (mlx_handle != NULL) {
CLOSE_LIB(mlx_handle);
mlx_handle = NULL;
mlx_initialized = 0;
}
}

View File

@@ -1,26 +0,0 @@
// mlx_dynamic.h - Dynamic loading interface for MLX-C library
#ifndef MLX_DYNAMIC_H
#define MLX_DYNAMIC_H
#ifdef __cplusplus
extern "C" {
#endif
// Initialize the MLX dynamic library
// Returns 0 on success, -1 on failure
int mlx_dynamic_init(void);
// Get the last error message from dynamic loading
const char* mlx_dynamic_error(void);
// Check if MLX is initialized
int mlx_dynamic_is_initialized(void);
// Cleanup resources (optional, for clean shutdown)
void mlx_dynamic_cleanup(void);
#ifdef __cplusplus
}
#endif
#endif // MLX_DYNAMIC_H

View File

@@ -1,314 +0,0 @@
//go:build mlx
package mlx
import (
"log/slog"
"os"
"reflect"
"strings"
"testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/runner/common"
"github.com/ollama/ollama/sample"
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/model"
"github.com/ollama/ollama/x/model/input"
_ "github.com/ollama/ollama/x/model/models/gemma3"
)
func init() {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
slog.SetDefault(logger)
}
func TestLoadModel(t *testing.T) {
dir := "/Users/daniel/Models/gemma-3-4b-it/"
b := &Backend{}
err := b.LoadSafeTensors(dir)
if err != nil {
t.Fatalf("load failed: %s", err)
}
}
func TestFromInts(t *testing.T) {
b := &Backend{}
c := b.NewContext()
defer c.Close()
data := []int32{1, 2, 3, 4, 5, 6}
a := c.FromInts(data, 2, 3)
slog.Info("", "array", a)
t.Log(a.ToString())
if !reflect.DeepEqual(a.Shape(), []int{2, 3}) {
t.Fatalf("incorrect shape: %v", a.Shape())
}
}
func TestFromFloats(t *testing.T) {
b := &Backend{}
c := b.NewContext()
defer c.Close()
data := []float32{1, 2, 3, 4, 5, 6}
a := c.FromFloats(data, 2, 3)
slog.Info("", "array", a)
t.Log(a.ToString())
if !reflect.DeepEqual(a.Shape(), []int{2, 3}) {
t.Fatalf("incorrect shape: %v", a.Shape())
}
res := a.Floats()
if !reflect.DeepEqual(res, data) {
t.Fatalf("incorrect results: %v", res)
}
}
func TestAdd(t *testing.T) {
b := &Backend{}
c := b.NewContext()
defer c.Close()
t1 := c.Arange(0, 24, 1, ml.DTypeFloat16)
t2 := c.Arange(0, 24, 1, ml.DTypeFloat16)
exp := c.Arange(0, 48, 2, ml.DTypeFloat16)
t3 := t1.Add(c, t2)
c.Compute(t3, exp)
t3f := t3.Floats()
if !reflect.DeepEqual(t3f, exp.Floats()) {
t.Fatalf("incorrect result: %v", t3f)
}
}
func TestReshapeTranspose(t *testing.T) {
b := &Backend{}
c := b.NewContext()
defer c.Close()
t1 := c.Arange(0, 24, 1, ml.DTypeFloat16).Reshape(c, 2, 3, 4).Transpose(c, 0, 2, 1).Contiguous(c, false)
c.Compute(t1)
t1f := t1.Floats()
exp := []float32{
0, 4, 8,
1, 5, 9,
2, 6, 10,
3, 7, 11,
12, 16, 20,
13, 17, 21,
14, 18, 22,
15, 19, 23,
}
if !reflect.DeepEqual(t1f, exp) {
t.Fatalf("incorrect results: %v", t1f)
}
}
func prod(vals ...int) int {
r := 1
for _, v := range vals {
r *= v
}
return r
}
func TestMatmul(t *testing.T) {
// TODO create scenarios...
b := &Backend{}
c := b.NewContext()
defer c.Close()
s1 := []int{1, 3, 2, 4}
t1 := c.Arange(0, float32(prod(s1...)), 1, ml.DTypeFloat16).Reshape(c, s1...)
s2 := []int{4, 2}
t2 := c.Arange(0, float32(prod(s2...)), 1, ml.DTypeFloat16).Reshape(c, s2...)
t3 := t1.Matmul(c, t2)
exp := []float32{
28, 34,
76, 98,
124, 162,
172, 226,
220, 290,
268, 354,
}
c.Compute(t3)
t3f := t3.Floats()
if !reflect.DeepEqual(t3f, exp) {
t.Fatalf("incorrect result: %v", t3f)
}
}
func TestRows(t *testing.T) {
b := &Backend{}
c := b.NewContext()
defer c.Close()
t1 := c.Arange(0, 12, 1, ml.DTypeFloat32).Reshape(c, 1, 4, 3)
outputs := c.Zeros(ml.DTypeInt32, 1)
t2 := t1.TakeAxes(c, outputs, 1)
c.Forward(t1, t2).Compute(t1, t2)
t.Log(t1.ToString())
t.Log(t2.ToString())
f := t2.Floats()
t.Logf("Result: %v", f)
}
func TestCaching(t *testing.T) {
// Validate the caching algorithm
b := &Backend{}
c := b.NewContext()
defer c.Close()
batchSize := 3
headDim := 4
numKVHeads := 2
// Make cache twice the size of one test batch
cells := batchSize * 2
cellSize := numKVHeads * headDim
shape := []int{1, numKVHeads, batchSize, headDim}
stop := float32(1)
for _, x := range shape {
stop *= float32(x)
}
// Create the cache
cache := c.Zeros(ml.DTypeFloat16, cells, cellSize)
t.Logf("Empty Cache shape%v\n"+cache.ToString(), []int{cells, cellSize})
// Input tensor
t1 := c.Arange(0, stop, 1, ml.DTypeFloat16).Reshape(c, shape...)
t.Logf("Initial Data shape%v\n"+t1.ToString(), shape)
// Reshape to copy into the cache
/*
From MLX python/src/indexing.cpp mlx_scatter_args_array
// The update shape must broadcast with indices.shape + [1] + src.shape[1:]
auto up_shape = indices.shape();
up_shape.insert(up_shape.end(), src.shape().begin() + 1, src.shape().end());
up = broadcast_to(up, up_shape);
up_shape.insert(up_shape.begin() + indices.ndim(), 1);
up = reshape(up, up_shape);
*/
numRows := 3
up := t1.Reshape(c, numRows, 1, cellSize) // The shape has to look like this for scatter to work properly
t.Logf("Data reshaped for cache input shape%v\n"+up.ToString(), []int{batchSize, numKVHeads * headDim})
// Simulate cells 1,3,5 are available
indicies := []ml.Tensor{c.FromInts([]int32{1, 3, 5}, numRows)}
t.Logf("Indicies shape%v\n"+indicies[0].ToString(), []int{numRows})
axis := []int{0} // The 1,3,5 of the indicies are in reference to axis 0 in the cache shape
cache.Scatter(c, indicies, up, axis)
c.Forward(cache)
// Cache should contain the data now
t.Log("Cache after put\n" + cache.ToString())
// Retrieve cache content and verify it matches
out := cache.TakeAxes(c, indicies[0], 0).Reshape(c, shape...)
t.Logf("Output shape%v\n"+out.ToString(), out.Shape())
t1f := t1.Floats()
outf := out.Floats()
if !reflect.DeepEqual(t1f, outf) {
t.Fatalf("mismatched in->out\n%v\n ->\n%v", t1f, outf)
}
}
func TestGemma3(t *testing.T) {
// Why is the sky blue
inputs := []int32{2, 105, 2364, 107, 36425, 563, 506, 7217, 3730, 106, 107, 105, 4368}
limit := 50
// TODO generalize this
dir := "/Users/daniel/Models/gemma-3-4b-it/"
m, err := model.New(dir, ml.BackendParams{})
if err != nil {
t.Fatalf("unable to load model: %s", err)
}
b := m.Backend()
ctx := b.NewContext()
defer ctx.Close()
batch := input.Batch{
Inputs: ctx.FromInts(inputs[:], 1, len(inputs)),
Positions: make([]int32, len(inputs)),
Sequences: make([]int, len(inputs)),
Outputs: ctx.FromInts([]int32{int32(len(inputs) - 1)}, 1),
Offset: 0,
}
for i := range len(inputs) {
batch.Positions[i] = int32(i)
}
offset := len(inputs)
cache := m.Config().Cache
if cache != nil {
numSlots := 1
batchSize := 512
numCtx := 4096
// Note: this is inconsistent with mlx-py, but trying to be consistent with the GGML cache impl to get things working
// cache.SetConfig(ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeBfloat16, MaskBatchPadding: 64})
cache.SetConfig(ml.CacheConfig{CachePadding: 0, MaskDType: ml.DTypeBfloat16, MaskBatchPadding: 0})
cache.Init(b, ml.DTypeBfloat16, numSlots, int(numCtx), batchSize)
err := cache.StartForward(ctx, batch, false)
if err != nil {
t.Fatalf("failed cache.StartForward: %s", err)
}
}
opts := api.DefaultOptions()
var grammar *sample.GrammarSampler
sampler := sample.NewSampler(
opts.Temperature,
opts.TopK,
opts.TopP,
opts.MinP,
opts.Seed,
grammar,
)
t.Log("Starting Forward pass loop")
pendingResponses := []string{}
for {
out, err := m.Forward(ctx, batch)
if err != nil {
t.Fatalf("failed forward pass: %s", err)
}
ctx.Forward(out)
outputs := out.Floats()
t.Logf("finished forward pass! length:%d", len(outputs))
// sample a token
logits := outputs
token, err := sampler.Sample(logits)
if err != nil {
t.Fatalf("unable to sample token: %s", err)
}
t.Logf("Sampled token: %v", token)
if m.(model.TextProcessor).Is(token, model.SpecialEOS) {
t.Log("hit EOS")
break
}
piece, err := m.(model.TextProcessor).Decode([]int32{token})
if err != nil {
t.Fatalf("unable to decode token: %s", err)
}
pendingResponses = append(pendingResponses, piece)
sequence := strings.Join(pendingResponses, "")
if ok, stop := common.FindStop(sequence, opts.Stop); ok {
t.Logf("hit stop token: %v", stop)
break
}
t.Logf("RESULTS: %s", sequence)
batch = input.Batch{
Inputs: ctx.FromInts([]int32{token}, 1, 1),
Positions: make([]int32, 1),
Sequences: make([]int, 1),
Outputs: ctx.FromInts([]int32{0}, 1),
Offset: offset,
}
offset++
batch.Positions[0] = 0
err = cache.StartForward(ctx, batch, false)
if err != nil {
t.Fatalf("failed cache.StartForward: %s", err)
}
if offset > limit {
break
}
}
}

View File

@@ -1,335 +0,0 @@
//go:build mlx
package mlx
/*
#include <stdio.h>
#include <string.h>
#include "mlx/c/array.h"
#include "mlx/c/ops.h"
// Derived from https://github.com/ml-explore/mlx/blob/main/mlx/io/gguf_quants.cpp
void unpack_32_4(uint8_t* data, int8_t* dst) {
memset(dst, 0, 16);
for (int j = 0; j < 16; ++j) {
uint8_t x = (data[j + 2] & 0x0F); // j+2 to skip scale bytes.
if (j % 2 != 0) {
x <<= 4;
}
dst[j / 2] += x;
}
// Last 16 weights are in the higher bits
for (int j = 0; j < 16; ++j) {
uint8_t x = (data[j + 2] >> 4);
if (j % 2 != 0) {
x <<= 4;
}
dst[8 + j / 2] += x;
}
}
// Extracts (weight, scales, biases) from Q4_0 tensors.
// Data layout is: |16 bit scale|32 x 4bit weights|.
void extract_q4_0_data(
uint8_t* data,
mlx_array* weights_arr,
mlx_array* scales_arr,
mlx_array* biases_arr) {
const uint64_t bytes_per_block = 18; // 2 bytes scale, 32x0.5 byte weights
uint8_t* weights = mlx_array_data_uint8(*weights_arr);
float16_t* scales = mlx_array_data_float16(*scales_arr);
float16_t* biases = mlx_array_data_float16(*biases_arr);
for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
scales[i] = *((float16_t*)data);
biases[i] = -8 * scales[i];
unpack_32_4(data, weights);
weights += 16;
data += bytes_per_block;
}
}
// Extracts (weight, scales, biases) from Q4_1 tensors.
// Data layout is: |16 bit scale|16 bit bias|32 x 4bit weights|.
void extract_q4_1_data(
uint8_t* data,
mlx_array* weights_arr,
mlx_array* scales_arr,
mlx_array* biases_arr) {
const uint64_t bytes_per_block = 20; // 2 bytes scale, 2 bytes bias, 32x0.5 byte weights
uint8_t* weights = mlx_array_data_uint8(*weights_arr);
float16_t* scales = mlx_array_data_float16(*scales_arr);
float16_t* biases = mlx_array_data_float16(*biases_arr);
for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
scales[i] = *((float16_t*)data);
biases[i] = *((float16_t*)(data) + 1);
unpack_32_4(data, weights);
weights += 16;
data += bytes_per_block;
}
}
// Extracts (weight, scales, biases) from Q8_0 tensors.
// Data layout is: |16 bit scale|32 x 8bit weights|.
void extract_q8_0_data(
uint8_t* data,
mlx_array* weights_arr,
mlx_array* scales_arr,
mlx_array* biases_arr) {
const uint64_t weights_per_block = 32;
const uint64_t bytes_per_block = 34; // 2 bytes scale, 32x1 byte weights
uint8_t* weights = mlx_array_data_uint8(*weights_arr);
float16_t* scales = mlx_array_data_float16(*scales_arr);
float16_t* biases = mlx_array_data_float16(*biases_arr);
for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
uint8_t* block_data = data + i * bytes_per_block;
scales[i] = *((float16_t*)block_data);
biases[i] = -128 * scales[i];
for (int64_t j = 0; j < weights_per_block; ++j) {
uint8_t x = block_data[j + 2]; // j+2 to skip the scale bytes.
// Original data is in int8_t, so we add a bias of -128 and invert the
// first bit.
x ^= 1 << 7;
weights[i * weights_per_block + j] = x;
}
}
}
// Drived from ggml-quants.c
#define QK_K 256
// 6-bit quantization
// weight is represented as x = a * q
// 16 blocks of 16 elements each
// Effectively 6.5625 bits per weight
typedef struct {
uint8_t ql[QK_K/2]; // quants, lower 4 bits
uint8_t qh[QK_K/4]; // quants, upper 2 bits
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
uint16_t d; // super-block scale
} block_q6_K;
void dequant_row_q6_K(const void * restrict vx, void * restrict vy, int k) {
const int64_t nb = k / QK_K;
block_q6_K *x = (block_q6_K *)vx;
float16_t* y = (float16_t *)vy;
for (int i = 0; i < nb; i++) {
float16_t d = 0.0;
memcpy(&d, &x[i].d, sizeof(d));
const uint8_t * restrict ql = x[i].ql;
const uint8_t * restrict qh = x[i].qh;
const int8_t * restrict sc = x[i].scales;
for (int n = 0; n < QK_K; n += 128) {
for (int l = 0; l < 32; ++l) {
int is = l/16;
const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
y[l + 0] = d * sc[is + 0] * q1;
y[l + 32] = d * sc[is + 2] * q2;
y[l + 64] = d * sc[is + 4] * q3;
y[l + 96] = d * sc[is + 6] * q4;
}
y += 128;
ql += 64;
qh += 32;
sc += 8;
}
}
}
#define K_SCALE_SIZE 12
#define GGML_COMMON_AGGR_U
#define GGML_COMMON_AGGR_S
// 4-bit quantization
// 8 blocks of 32 elements each
// weight is represented as x = a * q + b
// Effectively 4.5 bits per weight
typedef struct {
union {
struct {
uint16_t d; // super-block scale for quantized scales
uint16_t dmin; // super-block scale for quantized mins
} GGML_COMMON_AGGR_S;
uint16_t dm;
} GGML_COMMON_AGGR_U;
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
uint8_t qs[QK_K/2]; // 4--bit quants
} block_q4_K;
static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
if (j < 4) {
*d = q[j] & 63; *m = q[j + 4] & 63;
} else {
*d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
*m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
}
}
void dequant_row_q4_K(const void * restrict vx, void * restrict vy, int k) {
block_q4_K *x = (block_q4_K *)vx;
float16_t* y = (float16_t *)vy;
const int nb = k / QK_K;
for (int i = 0; i < nb; i++) {
const uint8_t * q = x[i].qs;
float16_t d = 0.0;
memcpy(&d, &x[i].d, sizeof(d));
float16_t min = 0.0;
memcpy(&min, &x[i].dmin, sizeof(d));
int is = 0;
uint8_t sc, m;
for (int j = 0; j < QK_K; j += 64) {
get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
const float16_t d1 = d * sc; const float16_t m1 = min * m;
get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
const float16_t d2 = d * sc; const float16_t m2 = min * m;
for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
q += 32; is += 2;
}
}
}
*/
import "C"
import (
"fmt"
"unsafe"
"github.com/x448/float16"
)
func gguf_load_quantized(data unsafe.Pointer, name string, final_shape []C.int, dtype uint32, stream C.mlx_stream) (r C.mlx_array, err error) {
shape := append([]C.int{}, final_shape...)
var weights_per_byte C.int
if dtype == 2 || dtype == 3 {
weights_per_byte = 2
} else if dtype == 8 {
weights_per_byte = 1
} else {
return r, fmt.Errorf("unsupported tensor type %d", dtype)
}
weights_per_block := C.int(32)
if shape[len(shape)-1]%weights_per_block != 0 {
return r, fmt.Errorf("[load_gguf] tensor has incompatible last dim shape: %d", shape[len(shape)-1])
}
weights_shape := append([]C.int{}, shape...)
weights_shape[len(weights_shape)-1] /= (weights_per_byte * 4)
w_nbytes := C.int(unsafe.Sizeof(uint32(0)))
for i := range weights_shape {
w_nbytes *= weights_shape[i]
}
w_data := make([]byte, w_nbytes)
cbytes := C.CBytes(w_data)
defer C.free(cbytes)
weights := C.mlx_array_new_data(
cbytes,
&weights_shape[0],
C.int(len(weights_shape)),
C.MLX_UINT32,
)
// For scales and bias
shape[len(shape)-1] = shape[len(shape)-1] / weights_per_block
sb_nbytes := C.int(unsafe.Sizeof(float16.Float16(0)))
for i := range shape {
sb_nbytes *= shape[i]
}
s_data := make([]byte, sb_nbytes)
cbytes = C.CBytes(s_data)
defer C.free(cbytes)
scales := C.mlx_array_new_data(
cbytes,
&shape[0],
C.int(len(shape)),
C.MLX_FLOAT16,
)
b_data := make([]byte, sb_nbytes)
cbytes = C.CBytes(b_data)
defer C.free(cbytes)
biases := C.mlx_array_new_data(
cbytes,
&shape[0],
C.int(len(shape)),
C.MLX_FLOAT16,
)
var bits C.int
switch dtype {
case 2:
C.extract_q4_0_data((*C.uint8_t)(data), &weights, &scales, &biases)
bits = 4
case 3:
C.extract_q4_1_data((*C.uint8_t)(data), &weights, &scales, &biases)
bits = 4
case 8:
C.extract_q8_0_data((*C.uint8_t)(data), &weights, &scales, &biases)
bits = 8
}
groupSize := C.mlx_optional_int{value: 32, has_value: true}
bitsOpt := C.mlx_optional_int{value: bits, has_value: true}
var dtypeOpt C.mlx_optional_dtype // has_value defaults to false
C.mlx_dequantize(
&r,
weights,
scales,
biases,
groupSize,
bitsOpt,
nil, // TODO mode
dtypeOpt,
stream,
)
C.mlx_array_free(weights)
C.mlx_array_free(scales)
C.mlx_array_free(biases)
return r, nil
}
func load_k_quantized(data unsafe.Pointer, name string, shape []C.int, dtype uint32, stream C.mlx_stream) (r C.mlx_array, err error) {
size := 1
for _, d := range shape {
size *= int(d)
}
fdata := make([]float16.Float16, size)
switch dtype {
case 14:
C.dequant_row_q6_K(
data,
unsafe.Pointer(&fdata[0]),
C.int(size),
)
case 12:
C.dequant_row_q4_K(
data,
unsafe.Pointer(&fdata[0]),
C.int(size),
)
default:
return r, fmt.Errorf("unsupported K quant")
}
r = C.mlx_array_new_data(
unsafe.Pointer(&fdata[0]),
&shape[0],
C.int(len(shape)),
C.MLX_FLOAT16,
)
return r, nil
}

View File

@@ -1,643 +0,0 @@
package ml
import (
"context"
"encoding/binary"
"encoding/json"
"fmt"
"hash/maphash"
"io"
"log/slog"
"math"
"net/http"
"runtime"
"slices"
"sort"
"strconv"
"strings"
"time"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/logutil"
)
// GPULayers is a set of layers to be allocated on a single GPU
type GPULayers struct {
DeviceID
// Layers is a set of layer indicies to load
Layers []int
}
// FirstLayer returns the smallest layer index scheduled on this GPU, or MaxInt when empty.
func (g GPULayers) FirstLayer() int {
if len(g.Layers) == 0 {
return math.MaxInt
}
first := g.Layers[0]
for i := 1; i < len(g.Layers); i++ {
if g.Layers[i] < first {
first = g.Layers[i]
}
}
return first
}
func (g GPULayers) String() string {
if len(g.Layers) == 0 {
return ""
}
slices.Sort(g.Layers)
contiguous := true
base := g.Layers[0]
for i := range g.Layers {
if g.Layers[i] != base+i {
contiguous = false
break
}
}
if contiguous {
return fmt.Sprintf("ID:%v Layers:%v(%v..%v)", g.ID, len(g.Layers), g.Layers[0], g.Layers[len(g.Layers)-1])
} else {
return fmt.Sprintf("ID:%v Layers:%v%v", g.ID, len(g.Layers), g.Layers)
}
}
// GPULayersList is a set of layer allocations across multiple GPUs
type GPULayersList []GPULayers
func (l GPULayersList) Len() int { return len(l) }
func (l GPULayersList) Swap(i, j int) { l[i], l[j] = l[j], l[i] }
// Sort by the ordering of the layers offloaded
func (l GPULayersList) Less(i, j int) bool {
li := l[i].FirstLayer()
lj := l[j].FirstLayer()
return li < lj
}
func (l GPULayersList) String() string {
if l.Sum() > 0 {
return fmt.Sprintf("%v%v", l.Sum(), []GPULayers(l))
} else {
return fmt.Sprintf("%v", []GPULayers(l))
}
}
// Sum is the total number of layers assigned across all GPUs
func (l GPULayersList) Sum() int {
var sum int
for _, g := range l {
sum += len(g.Layers)
}
return sum
}
var h maphash.Hash
// Hash is an identifier of this layer assignment
func (l GPULayersList) Hash() uint64 {
h.Reset()
for _, g := range l {
if len(g.Layers) > 0 {
h.WriteString(g.ID + g.Library)
for _, l := range g.Layers {
binary.Write(&h, binary.NativeEndian, int64(l))
}
}
}
return h.Sum64()
}
// ErrNoMem is returned when panicing due to insufficient memory. It includes
// the attempted memory allocation.
type ErrNoMem struct {
BackendMemory
}
func (e ErrNoMem) Error() string {
return fmt.Sprintf("insufficient memory - required allocations: %+v", e.BackendMemory)
}
// Minimal unique device identification
type DeviceID struct {
// ID is an identifier for the device for matching with system
// management libraries. The ID is only unique for other devices
// using the same Library.
// This ID represents a "post filtered" view of the enumerated devices
// if the ID is numeric
ID string `json:"id"`
// Library identifies which library is used for the device (e.g. CUDA, ROCm, etc.)
Library string `json:"backend,omitempty"`
}
// DeviceMemory provides a breakdown of the memory needed
// per device, such as a CPU or GPU.
type DeviceMemory struct {
DeviceID
// Name is the name of the device as labeled by the backend. It
// may not be persistent across instances of the runner.
Name string
// Weights is the per-layer memory needed for the model weights.
Weights []uint64
// Cache is the per-layer memory needed for the KV cache.
Cache []uint64
// Graph is the size of the compute graph. It is not per-layer.
Graph uint64
}
func sumMemory(mem []uint64) uint64 {
var sum uint64
for _, m := range mem {
sum += m
}
return sum
}
// Size returns the total size of the memory required by this device
func (m DeviceMemory) Size() uint64 {
return sumMemory(m.Weights) + sumMemory(m.Cache) + m.Graph
}
func memoryPresent(mem []uint64) bool {
return slices.ContainsFunc(mem, func(m uint64) bool { return m != 0 })
}
func (m DeviceMemory) LogValue() slog.Value {
var attrs []slog.Attr
if memoryPresent(m.Weights) {
attrs = append(attrs, slog.Any("Weights", m.Weights))
}
if memoryPresent(m.Cache) {
attrs = append(attrs, slog.Any("Cache", m.Cache))
}
if m.Graph != 0 {
attrs = append(attrs, slog.Any("Graph", m.Graph))
}
if len(attrs) > 0 && m.ID != "" {
attrs = append([]slog.Attr{slog.String("ID", m.ID)}, attrs...)
}
return slog.GroupValue(attrs...)
}
// BackendMemory provides the amount of memory required to load the model
// per device based on the BackendParams. In some cases, not all required
// allocations will be known at this point. However, the size of the most recent
// allocation is guaranteed to be provided so that if it failed, the caller can
// accommodate that to make forward progress.
type BackendMemory struct {
// InputWeights are always located on the CPU and cannot be moved
InputWeights uint64
// CPU model components are located in system memory. This does not
// include unified memory allocated through the GPU.
CPU DeviceMemory
// GPU model components are located on one or more GPUs.
GPUs []DeviceMemory
}
func (m BackendMemory) LogValue() slog.Value {
var attrs []slog.Attr
if m.InputWeights != 0 {
attrs = append(attrs, slog.Any("InputWeights", m.InputWeights))
}
attrs = append(attrs, slog.Any(m.CPU.Name, m.CPU))
for _, g := range m.GPUs {
attrs = append(attrs, slog.Any(g.Name, g))
}
return slog.GroupValue(attrs...)
}
// Log prints a high level summary of the memory
func (m BackendMemory) Log(level slog.Level) {
var total uint64
for _, gpu := range m.GPUs {
if sum := sumMemory(gpu.Weights); sum > 0 {
slog.Log(context.TODO(), level, "model weights", "device", gpu.Name, "size", format.HumanBytes2(sum))
total += sum
}
}
if sum := m.InputWeights + sumMemory(m.CPU.Weights); sum > 0 {
slog.Log(context.TODO(), level, "model weights", "device", m.CPU.Name, "size", format.HumanBytes2(sum))
total += sum
}
for _, gpu := range m.GPUs {
if sum := sumMemory(gpu.Cache); sum > 0 {
slog.Log(context.TODO(), level, "kv cache", "device", gpu.Name, "size", format.HumanBytes2(sum))
total += sum
}
}
if sum := sumMemory(m.CPU.Cache); sum > 0 {
slog.Log(context.TODO(), level, "kv cache", "device", m.CPU.Name, "size", format.HumanBytes2(sum))
total += sum
}
for _, gpu := range m.GPUs {
if sum := gpu.Graph; sum > 0 {
slog.Log(context.TODO(), level, "compute graph", "device", gpu.Name, "size", format.HumanBytes2(sum))
total += sum
}
}
if sum := m.CPU.Graph; sum > 0 {
slog.Log(context.TODO(), level, "compute graph", "device", m.CPU.Name, "size", format.HumanBytes2(sum))
total += sum
}
if total > 0 {
slog.Log(context.TODO(), level, "total memory", "size", format.HumanBytes2(total))
}
}
type DeviceInfo struct {
DeviceID
// Name is the name of the device as labeled by the backend. It
// may not be persistent across instances of the runner.
Name string `json:"name"`
// Description is the longer user-friendly identification of the device
Description string `json:"description"`
// FilterID is populated with the unfiltered device ID if a numeric ID is used
// so the device can be included.
FilterID string `json:"filter_id,omitempty"`
// Integrated is set true for integrated GPUs, false for Discrete GPUs
Integrated bool `json:"integration,omitempty"`
// PCIID is the bus, device and domain ID of the device for deduplication
// when discovered by multiple backends
PCIID string `json:"pci_id,omitempty"`
// TotalMemory is the total amount of memory the device can use for loading models
TotalMemory uint64 `json:"total_memory"`
// FreeMemory is the amount of memory currently available on the device for loading models
FreeMemory uint64 `json:"free_memory,omitempty"`
// ComputeMajor is the major version of capabilities of the device
// if unsupported by the backend, -1 will be returned
ComputeMajor int
// ComputeMinor is the minor version of capabilities of the device
// if unsupported by the backend, -1 will be returned
ComputeMinor int
// Driver Information
DriverMajor int `json:"driver_major,omitempty"`
DriverMinor int `json:"driver_minor,omitempty"`
// Where backends were loaded from
LibraryPath []string
}
type SystemInfo struct {
// ThreadCount is the optimal number of threads to use for inference
ThreadCount int `json:"threads,omitempty"`
// TotalMemory is the total amount of system memory
TotalMemory uint64 `json:"total_memory,omitempty"`
// FreeMemory is the amount of memory currently available on the system for loading models
FreeMemory uint64 `json:"free_memory,omitempty"`
// FreeSwap is the amount of system swap space reported as available
FreeSwap uint64 `json:"free_swap,omitempty"`
}
func (d DeviceInfo) Compute() string {
// AMD gfx is encoded into the major minor in hex form
if strings.EqualFold(d.Library, "ROCm") {
return fmt.Sprintf("gfx%x%02x", d.ComputeMajor, d.ComputeMinor)
}
return strconv.Itoa(d.ComputeMajor) + "." + strconv.Itoa(d.ComputeMinor)
}
func (d DeviceInfo) Driver() string {
return strconv.Itoa(d.DriverMajor) + "." + strconv.Itoa(d.DriverMinor)
}
// MinimumMemory reports the amount of memory that should be set aside
// on the device for overhead (e.g. VRAM consumed by context structures independent
// of model allocations)
func (d DeviceInfo) MinimumMemory() uint64 {
if d.Library == "Metal" {
return 512 * format.MebiByte
}
return 457 * format.MebiByte
}
// Sort by Free Space.
// iGPUs are reported first, thus Reverse() yields the largest discrete GPU first
type ByFreeMemory []DeviceInfo
func (a ByFreeMemory) Len() int { return len(a) }
func (a ByFreeMemory) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a ByFreeMemory) Less(i, j int) bool {
if a[i].Integrated && !a[j].Integrated {
return true
} else if !a[i].Integrated && a[j].Integrated {
return false
}
return a[i].FreeMemory < a[j].FreeMemory
}
// ByPerformance groups devices by similar speed
func ByPerformance(l []DeviceInfo) [][]DeviceInfo {
resp := [][]DeviceInfo{}
scores := []bool{}
for _, info := range l {
found := false
requested := info.Integrated
for i, score := range scores {
if score == requested {
resp[i] = append(resp[i], info)
found = true
break
}
}
if !found {
scores = append(scores, requested)
resp = append(resp, []DeviceInfo{info})
}
}
return resp
}
func ByLibrary(l []DeviceInfo) [][]DeviceInfo {
resp := [][]DeviceInfo{}
libs := []string{}
for _, info := range l {
found := false
requested := info.Library
for i, lib := range libs {
if lib == requested {
resp[i] = append(resp[i], info)
found = true
break
}
}
if !found {
libs = append(libs, requested)
resp = append(resp, []DeviceInfo{info})
}
}
return resp
}
func LibraryPaths(l []DeviceInfo) []string {
gpuLibs := []string{LibOllamaPath}
for _, gpu := range l {
for _, dir := range gpu.LibraryPath {
needed := true
for _, existing := range gpuLibs {
if dir == existing {
needed = false
break
}
}
if needed {
gpuLibs = append(gpuLibs, dir)
}
}
}
return gpuLibs
}
type DeviceComparison int
const (
UniqueDevice DeviceComparison = iota
SameBackendDevice // The device is the same, and the library/backend is the same
DuplicateDevice // The same physical device but different library/backend (overlapping device)
)
func (a DeviceInfo) Compare(b DeviceInfo) DeviceComparison {
if a.PCIID != b.PCIID {
return UniqueDevice
}
// If PCIID is empty, we have to use ID + library for uniqueness
if a.PCIID == "" && a.DeviceID != b.DeviceID {
return UniqueDevice
}
if a.Library == b.Library {
return SameBackendDevice
}
return DuplicateDevice
}
// For a SameBackendDevice, return true if b is better than a
// e.g. newer GPU library version
func (a DeviceInfo) IsBetter(b DeviceInfo) bool {
aLib := a.LibraryPath[len(a.LibraryPath)-1]
bLib := b.LibraryPath[len(b.LibraryPath)-1]
if aLib == bLib {
return false
}
aLibSplit := strings.SplitN(aLib, "_", 2)
bLibSplit := strings.SplitN(bLib, "_", 2)
if len(aLibSplit) < 2 || len(bLibSplit) < 2 {
return false
}
if aLibSplit[0] != bLibSplit[0] {
slog.Debug("unexpected libraries", "a", aLib, "b", bLib)
return false
}
if aLibSplit[1] == bLibSplit[1] {
return false
}
cmp := []string{aLibSplit[1], bLibSplit[1]}
sort.Sort(sort.Reverse(sort.StringSlice(cmp)))
return cmp[0] == bLibSplit[1]
}
// For each GPU, check if it does NOT support flash attention
func FlashAttentionSupported(l []DeviceInfo) bool {
for _, gpu := range l {
supportsFA := gpu.Library == "cpu" ||
gpu.Name == "Metal" || gpu.Library == "Metal" ||
(gpu.Library == "CUDA" && gpu.DriverMajor >= 7 && !(gpu.ComputeMajor == 7 && gpu.ComputeMinor == 2)) ||
gpu.Library == "ROCm" ||
gpu.Library == "Vulkan"
if !supportsFA {
return false
}
}
return true
}
// Given the list of GPUs this instantiation is targeted for,
// figure out the visible devices environment variables
// Set mustFilter true to enable filtering of CUDA devices
func GetVisibleDevicesEnv(l []DeviceInfo, mustFilter bool) map[string]string {
if len(l) == 0 {
return nil
}
env := map[string]string{}
for _, d := range l {
d.updateVisibleDevicesEnv(env, mustFilter)
}
return env
}
// NeedsInitValidation returns true if the device in question has the potential
// to crash at inference time and requires deeper validation before we include
// it in the supported devices list.
func (d DeviceInfo) NeedsInitValidation() bool {
// ROCm: rocblas will crash on unsupported devices.
// CUDA: verify CC is supported by the version of the library
return d.Library == "ROCm" || d.Library == "CUDA"
}
// Set the init validation environment variable
func (d DeviceInfo) AddInitValidation(env map[string]string) {
env["GGML_CUDA_INIT"] = "1" // force deep initialization to trigger crash on unsupported GPUs
}
// PreferredLibrary returns true if this library is preferred over the other input
// library
// Used to filter out Vulkan in favor of CUDA or ROCm
func (d DeviceInfo) PreferredLibrary(other DeviceInfo) bool {
// TODO in the future if we find Vulkan is better than ROCm on some devices
// that implementation can live here.
if d.Library == "CUDA" || d.Library == "ROCm" {
return true
}
return false
}
func (d DeviceInfo) updateVisibleDevicesEnv(env map[string]string, mustFilter bool) {
var envVar string
switch d.Library {
case "ROCm":
// ROCm must be filtered as it can crash the runner on unsupported devices
envVar = "ROCR_VISIBLE_DEVICES"
if runtime.GOOS != "linux" {
envVar = "HIP_VISIBLE_DEVICES"
}
case "CUDA":
if !mustFilter {
// By default we try to avoid filtering CUDA devices because ROCm also
// looks at the CUDA env var, and gets confused in mixed vendor environments.
return
}
envVar = "CUDA_VISIBLE_DEVICES"
default:
// Vulkan is not filtered via env var, but via scheduling decisions
return
}
v, existing := env[envVar]
if existing {
v = v + ","
}
if d.FilterID != "" {
v = v + d.FilterID
} else {
v = v + d.ID
}
env[envVar] = v
}
type BaseRunner interface {
// GetPort returns the localhost port number the runner is running on
GetPort() int
// HasExited indicates if the runner is no longer running. This can be used during
// bootstrap to detect if a given filtered device is incompatible and triggered an assert
HasExited() bool
}
type RunnerDiscovery interface {
BaseRunner
// GetDeviceInfos will perform a query of the underlying device libraries
// for device identification and free VRAM information
// During bootstrap scenarios, this routine may take seconds to complete
GetDeviceInfos(ctx context.Context) []DeviceInfo
}
type FilteredRunnerDiscovery interface {
RunnerDiscovery
// GetActiveDeviceIDs returns the filtered set of devices actively in
// use by this runner for running models. If the runner is a bootstrap runner, no devices
// will be active yet so no device IDs are returned.
// This routine will not query the underlying device and will return immediately
GetActiveDeviceIDs() []DeviceID
}
func GetDevicesFromRunner(ctx context.Context, runner BaseRunner) ([]DeviceInfo, error) {
var moreDevices []DeviceInfo
port := runner.GetPort()
tick := time.Tick(10 * time.Millisecond)
for {
select {
case <-ctx.Done():
return nil, fmt.Errorf("failed to finish discovery before timeout")
case <-tick:
r, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/info", port), nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
r.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(r)
if err != nil {
// slog.Warn("failed to send request", "error", err)
if runner.HasExited() {
return nil, fmt.Errorf("runner crashed")
}
continue
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
// old runner, fall back to bootstrapping model
return nil, fmt.Errorf("llamarunner free vram reporting not supported")
}
body, err := io.ReadAll(resp.Body)
if err != nil {
slog.Warn("failed to read response", "error", err)
continue
}
if resp.StatusCode != 200 {
logutil.Trace("runner failed to discover free VRAM", "status", resp.StatusCode, "response", body)
return nil, fmt.Errorf("runner error: %s", string(body))
}
if err := json.Unmarshal(body, &moreDevices); err != nil {
slog.Warn("unmarshal encode response", "error", err)
continue
}
return moreDevices, nil
}
}
}

View File

@@ -1,103 +0,0 @@
package nn
import (
"fmt"
"github.com/ollama/ollama/x/kvcache"
"github.com/ollama/ollama/x/ml"
)
// Attention implements scaled dot-product attention for transformer models:
// Attention(Q, K, V) = softmax(QK^T/√d_k)V
//
// Parameters:
// - ctx: Context for tensor operations
// - query: Query tensor (Q) with shape [d_k, heads, seq_len_q]
// - key: Key tensor (K) with shape [d_k, kv_heads, seq_len_k], can be nil to read from cache only
// - value: Value tensor (V) with shape [d_v, kv_heads, seq_len_k], can be nil to read from cache only
// - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
// - cache: KV cache to store key/value and get past history, can be nil to only use provided key/value
//
// Returns:
//
// Attention output with shape [d_v, heads, seq_len_q]
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
return AttentionWithVMLA(ctx, query, key, value, nil, nil, scale, cache)
}
func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
return AttentionWithVMLA(ctx, query, key, value, sinks, nil, scale, cache)
}
func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
ctx.Forward(query)
if key != nil && value != nil {
if query.Dim(0) != key.Dim(0) {
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
}
if key.Dim(1) != value.Dim(1) {
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1)))
}
if key.Dim(2) != value.Dim(2) {
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
}
ctx.Forward(key, value)
if cache != nil {
cache.Put(ctx, key, value)
}
} else if cache == nil {
panic("key & value tensors must be provided if cache is nil")
}
// ctx.CompareWith("/tmp/test", map[string]ml.Tensor{"q": query, "k": key, "v": value}, true)
// panic("after cache get") //
// 2025/12/10 16:02:33 INFO XXX tensors are similar q=0.9999869465827942 shape="[1 8 13 256]" min_difference=[-0.07926178] max_difference=[0.07012844]
// 2025/12/10 16:02:33 INFO XXX tensors are similar k=0.9999891519546509 shape="[1 4 13 256]" min_difference=[-0.21365738] max_difference=[0.19916534]
// 2025/12/10 16:02:33 INFO XXX tensors are similar v=0.9999960660934448 shape="[1 4 13 256]" min_difference=[-0.32923126] max_difference=[0.32646942]
// var mask ml.Tensor
if cache != nil {
key, value, _ = cache.Get(ctx)
}
// ctx.CompareWith("/tmp/test", map[string]ml.Tensor{"q": query.Contiguous(ctx, false), "k": key.Contiguous(ctx, false), "v": value.Contiguous(ctx, false)}, true)
// panic("after cache get") //
// 2025/12/10 15:34:03 INFO XXX tensors are similar q=0.9999869465827942 shape="[1 8 13 256]" min_difference=[-0.07926178] max_difference=[0.07012844]
// 2025/12/10 15:34:03 INFO XXX tensors are similar k=0.9999881982803345 shape="[1 4 13 256]" min_difference=[-0.25] max_difference=[0.25]
// 2025/12/10 15:34:03 INFO XXX tensors are similar v=0.9999913573265076 shape="[1 4 13 256]" min_difference=[-0.5] max_difference=[0.5]
// Only use the fast SDPA implementation if we have a cache, since that's what
// will do any expected backend-specific transformations for us
if cache != nil {
// TODO what to do with vmla?
// return query.Transpose(ctx, 0, 2, 1, 3).ScaledDotProductAttention(ctx, key.Transpose(ctx, 0, 2, 1, 3), value.Transpose(ctx, 0, 2, 1, 3), scale, "array", mask, sinks)
return query.ScaledDotProductAttention(ctx, key, value, scale, "causal", nil, sinks)
// TODO these two produce identical output, but not similar enough - 92.9% - should be 99.999%
} else {
panic("else case not supported")
// TODO transpose shapes are wrong
// key = key.Transpose(ctx, 0, 2, 1, 3)
// value = value.Transpose(ctx, 1, 2, 0, 3).Contiguous(ctx, false)
// kq := query.Matmul(ctx, key)
// kq = kq.Scale(ctx, scale)
// if mask != nil {
// kq = kq.Add(ctx, mask)
// }
// kq = kq.Softmax(ctx)
// kqv := kq.Matmul(ctx, value)
// if vmla != nil {
// kqv = kqv.Matmul(ctx, vmla)
// }
// return kqv.Transpose(ctx, 0, 2, 1, 3).Contiguous(ctx, false)
}
}

View File

@@ -1,30 +0,0 @@
package nn
import "github.com/ollama/ollama/x/ml"
type Conv2D struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
}
func (m *Conv2D) Forward(ctx ml.Context, t ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
t = m.Weight.Conv2D(ctx, t, s0, s1, p0, p1, d0, d1, 1)
if m.Bias != nil {
// Bias shape is (out_channels,) while t shape is (width, height, out_channels, batch)
t = t.Add(ctx, m.Bias.Reshape(ctx, 1, 1, -1))
}
return t
}
type Conv3D struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
}
func (m *Conv3D) Forward(ctx ml.Context, t ml.Tensor, s0, s1, s2, p0, p1, p2, d0, d1, d2, g int) ml.Tensor {
t = m.Weight.Conv3D(ctx, t, s0, s1, s2, p0, p1, p2, d0, d1, d2, g)
if m.Bias != nil {
t = t.Add(ctx, m.Bias)
}
return t
}

View File

@@ -1,11 +0,0 @@
package nn
import "github.com/ollama/ollama/x/ml"
type Embedding struct {
Weight ml.Tensor `gguf:"weight"`
}
func (m *Embedding) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor {
return m.Weight.TakeAxes(ctx, hiddenState, 0)
}

View File

@@ -1,32 +0,0 @@
package nn
import "github.com/ollama/ollama/x/ml"
type Linear struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
}
func (m *Linear) Forward(ctx ml.Context, t ml.Tensor) ml.Tensor {
t = t.Matmul(ctx, m.Weight.Transpose(ctx))
if m.Bias != nil {
t = t.Add(ctx, m.Bias)
}
return t
}
type LinearBatch struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
}
func (m *LinearBatch) Forward(ctx ml.Context, t, indices ml.Tensor) ml.Tensor {
panic("not yet ported")
// t = m.Weight.MulmatID(ctx, t, indices)
// if m.Bias != nil {
// t = t.AddID(ctx, m.Bias, indices)
// }
// return t
}

View File

@@ -1,29 +0,0 @@
package nn
import (
"github.com/ollama/ollama/x/ml"
)
type LayerNorm struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
}
func (m *LayerNorm) Forward(ctx ml.Context, t ml.Tensor, eps float32) ml.Tensor {
return t.LayerNorm(ctx, m.Weight, m.Bias, eps)
}
type RMSNorm struct {
Weight ml.Tensor `gguf:"weight"`
}
func (m *RMSNorm) Forward(ctx ml.Context, t ml.Tensor, eps float32) ml.Tensor {
// slog.Info("RMSNorm", "eps", eps)
// fmt.Fprintln(os.Stderr, t.ToString())
// fmt.Fprintln(os.Stderr, m.Weight.ToString())
// TODO this is probably model specific, not generalized...
w := m.Weight.Add(ctx, ctx.FromFloats([]float32{1.0}, 1))
return t.RMSNorm(ctx, w, eps)
}

View File

@@ -1,41 +0,0 @@
package pooling
import (
"github.com/ollama/ollama/x/ml"
)
type Type uint32
const (
TypeNone Type = iota
TypeMean
TypeCLS
TypeLast
)
func (t Type) String() string {
switch t {
case TypeMean:
return "Mean"
case TypeCLS:
return "CLS"
case TypeLast:
return "Last"
default:
return "Unknown"
}
}
func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
switch t {
// case TypeMean:
// hiddenStates = hiddenStates.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false).Mean(ctx)
// return hiddenStates.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false)
// case TypeCLS:
// return hiddenStates.Slice(ctx, 1, 0, 1, 1)
// case TypeLast:
// return hiddenStates.Slice(ctx, 1, hiddenStates.Dim(1)-1, hiddenStates.Dim(1), 1)
default:
panic("unknown pooling type")
}
}

View File

@@ -1,72 +0,0 @@
package rope
import "github.com/ollama/ollama/x/ml"
// Options contains optional parameters for RoPE function
type Options struct {
Type int
Factors ml.Tensor
// YaRN options
YaRN struct {
OriginalContextLength int
ExtrapolationFactor,
AttentionFactor,
BetaFast,
BetaSlow float32
}
// MRoPE options
MRoPE struct {
Sections []int
}
}
// WithTypeNeoX sets RoPE type to NeoX
func WithTypeNeoX() func(*Options) {
return func(opts *Options) {
opts.Type = 2
}
}
// WithFactors sets custom rope factors
func WithFactors(factors ml.Tensor) func(*Options) {
return func(opts *Options) {
if factors != nil {
opts.Factors = factors
}
}
}
// WithOriginalContextLength sets a custom context length
func WithOriginalContextLength(n int) func(*Options) {
return func(opts *Options) {
opts.YaRN.OriginalContextLength = n
}
}
func WithExtrapolationFactor(extrapolationFactor float32) func(*Options) {
return func(opts *Options) {
opts.YaRN.ExtrapolationFactor = extrapolationFactor
}
}
func WithAttentionFactor(attentionFactor float32) func(*Options) {
return func(opts *Options) {
opts.YaRN.AttentionFactor = attentionFactor
}
}
func WithMRoPE(sections []int) func(*Options) {
return func(opts *Options) {
opts.Type |= 1 << 3
opts.MRoPE.Sections = sections
}
}
func WithInterleaveMRoPE(sections []int) func(*Options) {
return func(opts *Options) {
opts.Type |= 1<<3 | 1<<5
opts.MRoPE.Sections = sections
}
}

View File

@@ -1,56 +0,0 @@
package ml
import (
"os"
"path/filepath"
"runtime"
)
// LibPath is a path to lookup dynamic libraries
// in development it's usually 'build/lib/ollama'
// in distribution builds it's 'lib/ollama' on Windows
// '../lib/ollama' on Linux and the executable's directory on macOS
// note: distribution builds, additional GPU-specific libraries are
// found in subdirectories of the returned path, such as
// 'cuda_v12', 'rocm', etc.
var LibOllamaPath string = func() string {
exe, err := os.Executable()
if err != nil {
return ""
}
if eval, err := filepath.EvalSymlinks(exe); err == nil {
exe = eval
}
var libPath string
switch runtime.GOOS {
case "windows":
libPath = filepath.Join(filepath.Dir(exe), "lib", "ollama")
case "linux":
libPath = filepath.Join(filepath.Dir(exe), "..", "lib", "ollama")
case "darwin":
libPath = filepath.Dir(exe)
}
cwd, err := os.Getwd()
if err != nil {
return ""
}
paths := []string{
libPath,
// build paths for development
filepath.Join(filepath.Dir(exe), "build", "lib", "ollama"),
filepath.Join(cwd, "build", "lib", "ollama"),
}
for _, p := range paths {
if _, err := os.Stat(p); err == nil {
return p
}
}
return filepath.Dir(exe)
}()

View File

@@ -1,322 +0,0 @@
package model
import (
"bufio"
"encoding/json"
"math"
"os"
"path/filepath"
"slices"
"strconv"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
)
func llama(t testing.TB) BytePairEncoding {
t.Helper()
f, err := os.Open(filepath.Join("..", "..", "model", "testdata", "llama3.2", "encoder.json"))
if err != nil {
t.Fatal(err)
}
defer f.Close()
vocab := make(map[string]int32)
if err := json.NewDecoder(f).Decode(&vocab); err != nil {
t.Fatal(err)
}
types := make([]int32, len(vocab))
tokens := make([]string, len(vocab))
for token, id := range vocab {
tokens[id] = token
types[id] = 1
}
for _, token := range []string{"<|begin_of_text|>", "<|end_of_text|>"} {
if _, ok := vocab[token]; !ok {
tokens = append(tokens, token) //nolint:makezero
types = append(types, 3) //nolint:makezero
vocab[token] = int32(len(vocab))
}
}
f, err = os.Open(filepath.Join("..", "..", "model", "testdata", "llama3.2", "vocab.bpe"))
if err != nil {
t.Fatal(err)
}
defer f.Close()
merges := make([]string, 0, 50000)
scanner := bufio.NewScanner(f)
for scanner.Scan() {
if !strings.HasPrefix(scanner.Text(), "#") {
merges = append(merges, scanner.Text())
}
}
return NewBytePairEncoding(
&Vocabulary{
Values: tokens,
Types: types,
Merges: merges,
},
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
)
}
func TestLlama(t *testing.T) {
tokenizer := llama(t)
t.Run("simple", func(t *testing.T) {
t.Parallel()
ids, err := tokenizer.Encode("hello world", true)
if err != nil {
t.Error(err)
}
if diff := cmp.Diff([]int32{15339, 1917}, ids); diff != "" {
t.Errorf("no match (-theirs +ours):\n%s", diff)
}
s, err := tokenizer.Decode([]int32{15339, 1917})
if err != nil {
t.Fatal(err)
}
if s != "hello world" {
t.Errorf("got %q, want hello world", s)
}
ids, err = tokenizer.Encode("hello <|end_of_text|>", true)
if err != nil {
t.Error(err)
}
if diff := cmp.Diff([]int32{15339, 220, 128001}, ids); diff != "" {
t.Errorf("no match (-theirs +ours):\n%s", diff)
}
})
t.Run("simple repeated", func(t *testing.T) {
t.Parallel()
cases := map[string][]int32{
strings.Repeat("0", 1): {15},
strings.Repeat("0", 2): {410},
strings.Repeat("0", 3): {931},
strings.Repeat("0", 4): {931, 15},
strings.Repeat("0", 5): {931, 410},
strings.Repeat("0", 6): {931, 931},
strings.Repeat("0", 7): {931, 931, 15},
strings.Repeat("0", 8): {931, 931, 410},
strings.Repeat("0", 9): {931, 931, 931},
strings.Repeat("0", 10): {931, 931, 931, 15},
strings.Repeat("0", 11): {931, 931, 931, 410},
strings.Repeat("0", 12): {931, 931, 931, 931},
strings.Repeat("0", 13): {931, 931, 931, 931, 15},
strings.Repeat("0", 14): {931, 931, 931, 931, 410},
strings.Repeat("0", 15): {931, 931, 931, 931, 931},
strings.Repeat("0", 16): {931, 931, 931, 931, 931, 15},
strings.Repeat("0", 17): {931, 931, 931, 931, 931, 410},
}
for s, want := range cases {
ids, err := tokenizer.Encode(s, true)
if err != nil {
t.Error(err)
}
if diff := cmp.Diff(want, ids); diff != "" {
t.Errorf("%q no match (-theirs +ours):\n%s", s, diff)
}
}
})
t.Run("basic roundtrip", func(t *testing.T) {
t.Parallel()
cases := []string{
"hello",
"hello ",
"hello ",
" hello",
" hello ",
" hello ",
"hello world",
"请考试我的软件12345",
}
for _, want := range cases {
ids, err := tokenizer.Encode(want, true)
if err != nil {
t.Error(err)
}
if got, err := tokenizer.Decode(ids); err != nil {
t.Fatal(err)
} else if got != want {
t.Errorf("got %q, want %q", got, want)
}
}
})
t.Run("special", func(t *testing.T) {
t.Parallel()
cases := map[string][]int32{
"<|begin_of_text|>A B!": {128000, 32, 426, 0},
"<|begin_of_text|>A<|end_of_text|>B!": {128000, 32, 128001, 33, 0},
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!": {128000, 32, 128001, 33, 128000, 0},
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!<|end_of_text|>": {128000, 32, 128001, 33, 128000, 0, 128001},
}
for s, want := range cases {
ids, err := tokenizer.Encode(s, true)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(want, ids); diff != "" {
t.Errorf("no match (-theirs +ours):\n%s", diff)
}
}
})
t.Run("split", func(t *testing.T) {
t.Parallel()
cases := map[string][]string{
"Hello World!": {"Hello", " World", "!"},
"I'm don't won't": {"I", "'m", " don", "'t", " won", "'t"},
"In 2024 there are 366 days": {"In", " ", "202", "4", " there", " are", " ", "366", " days"},
"Hello!! ...world": {"Hello", "!!", " ...", "world"},
"Hello World": {"Hello", " ", " World"},
"Hello\nWorld": {"Hello", "\n", "World"},
"Hello, WORLD!! How's it going?": {"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?"},
}
for s, want := range cases {
got := slices.Collect(tokenizer.split(s))
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("no match (-theirs +ours):\n%s", diff)
}
}
})
t.Run("roundtriping 0x00-0xFF", func(t *testing.T) {
t.Parallel()
for b := 0x00; b <= 0xFF; b++ {
input := string(rune(b))
ids, err := tokenizer.Encode(input, false)
if err != nil {
t.Errorf("failed to encode rune 0x%02X: %v", b, err)
continue
}
decoded, err := tokenizer.Decode(ids)
if err != nil {
t.Errorf("failed to decode rune 0x%02X: %v", b, err)
continue
}
if b == 0x00 {
if len(decoded) != 0 {
t.Errorf("Decode(Encode(0x00)) should be empty, got %v", ids)
}
continue
}
if decoded != input {
t.Errorf("rune 0x%02X failed roundtrip: got %q, want %q", b, decoded, input)
}
}
})
}
func BenchmarkBytePairEncoding(b *testing.B) {
tokenizer := llama(b)
bts, err := os.ReadFile(filepath.Join("testdata", "war-and-peace.txt"))
if err != nil {
b.Fatal(err)
}
for i := range 8 {
n := min(int(math.Pow10(i)), len(bts))
bts := bts[:n]
b.Run("encode"+strconv.Itoa(n), func(b *testing.B) {
b.ResetTimer()
for b.Loop() {
_, err := tokenizer.Encode(string(bts), true)
if err != nil {
b.Fatal(err)
}
}
})
b.Run("decode"+strconv.Itoa(n), func(b *testing.B) {
ids, err := tokenizer.Encode(string(bts), true)
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
for b.Loop() {
_, err := tokenizer.Decode(ids)
if err != nil {
b.Fatal(err)
}
}
})
b.Run("split"+strconv.Itoa(n), func(b *testing.B) {
b.ResetTimer()
for b.Loop() {
slices.Collect(tokenizer.split(string(bts)))
}
})
}
}
func TestSplit(t *testing.T) {
cases := []struct {
name string
patterns,
want []string
}{
{
name: "default",
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " 123", " 一二三"},
},
{
name: "unicode",
patterns: []string{
"\\p{N}{1,3}",
`[一-龥぀-ゟ゠-ヿ]+`,
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
},
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "123", " ", "一二三"},
},
{
name: "individual digits",
patterns: []string{
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
},
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "1", "2", "3", " 一二三"},
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
tokenizer := NewBytePairEncoding(nil, tt.patterns...)
if diff := cmp.Diff(tt.want, slices.Collect(tokenizer.split("Hello, WORLD!! How's it going? 123 一二三"))); diff != "" {
t.Errorf("no match (-theirs +ours):\n%s", diff)
}
})
}
}

View File

@@ -1,76 +0,0 @@
package input
import "github.com/ollama/ollama/x/ml"
// Multimodal is a multimodal embedding or a component of one.
// For example, it could be a row of an image that can be processed
// independently.
type Multimodal struct {
// Tensor is the embedding data. Implementations may chose what to
// store here or it may be nil if not needed. However, any ml.Tensor
// objects must be stored here and not in Data.
Tensor ml.Tensor
// Data is implementation-specific opaque data, such as metadata on how
// to layout Tensor. It may be nil if not needed. It may also store larger
// objects such as complete images if they are to be processed later.
Data any
}
// Input represents one token in the input stream
type Input struct {
// Token is a single element of text.
Token int32
// Multimodal is represents a non-text element such as an
// image (or part of one if the image can be processed in pieces).
// It may be used either together with Token or on its own.
Multimodal []Multimodal
// MultimodalHash is a unique representation of the data
// stored in Multimodal, used for caching and comparing
// equality.
MultimodalHash uint64
// SameBatch forces the following number of tokens to be processed
// in a single batch, breaking and extending batches as needed.
// Useful for things like images that must be processed in one
// shot.
SameBatch int
}
// MultimodalIndex is a multimodal element (such as an image)
// together with an index into the slice of Inputs with the
// corresponding token. Note that the index is not the same
// as the position - to find that use the index with the
// Positions slice.
type MultimodalIndex struct {
Index int
Multimodal []Multimodal
}
// Batch contains the inputs for a model forward pass
type Batch struct {
// Inputs is the input tokens, including placeholders for multimodal inputs.
Inputs ml.Tensor
// Outputs are the set of indicies into Inputs for which output data should
// be returned.
Outputs ml.Tensor
// TODO maybe not the optimal way to handle this
// Offset of final tensor in the final batch
Offset int
// Positions is the position for each Input, relative to its sequence. Equal
// in length to Inputs.
Positions []int32
// Sequences is the sequence for each Input. Equal in length to Inputs.
Sequences []int
// Multimodal is a set of multimodal embeddings previously created by
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
// models or for batches without multimodal elements.
Multimodal []MultimodalIndex
}

View File

@@ -1,333 +0,0 @@
package model
import (
"errors"
"fmt"
_ "image/jpeg"
_ "image/png"
"log/slog"
"os"
"reflect"
"strconv"
"strings"
_ "golang.org/x/image/bmp"
_ "golang.org/x/image/tiff"
_ "golang.org/x/image/webp"
"github.com/ollama/ollama/fs"
fsggml "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/kvcache"
"github.com/ollama/ollama/x/ml"
_ "github.com/ollama/ollama/x/ml/backend"
"github.com/ollama/ollama/x/ml/nn/pooling"
"github.com/ollama/ollama/x/model/input"
)
var (
ErrNoVisionModel = errors.New("this model is missing data required for image input")
ErrUnsupportedModel = errors.New("model not supported")
ErrUnsupportedTokenizer = errors.New("tokenizer not supported")
)
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
type Model interface {
Forward(ml.Context, input.Batch) (ml.Tensor, error)
Backend() ml.Backend
Config() config
}
// MultimodalProcessor must be implemented by multimodal models.
type MultimodalProcessor interface {
// EncodeMultimodal processes a single input (such as an image) and
// generates an output (typically an embedding) that can be used by the model.
//
// The return value is one or more tensors, each with optional model-specific
// opaque metadata. Typically, the tensors might be views into an embedding
// with each view representing a chunk of data that can be processed independently
// in different batches.
//
// The result may be cached by the runner.
EncodeMultimodal(ml.Context, []byte) ([]input.Multimodal, error)
// PostTokenize is called after tokenization to allow the model to edit the
// input stream to correctly arrange multimodal elements.
//
// The input is a slice of tokens with the results of EncodeMultimodal interleaved
// in the order that the user provided them. Each element of the slice will be
// either a single token or single multimodal object.
//
// The model must ensure that inputs are stored according to how they will be
// processed and stored in the cache. For example, Llava-style models should insert
// placeholder tokens equal to the feature size of the corresponding image with
// the image itself attached to and split across these tokens. When Forward is called
// a partial subset of these tokens may be submitted according to the batch size.
//
// This function is also responsible for updating MultimodalHash for any Multimodal
// that is modified to ensure that there is a unique hash value that accurately
// represents the contents.
PostTokenize([]*input.Input) ([]*input.Input, error)
}
// Base implements the common fields and methods for all models
type Base struct {
b ml.Backend
config
}
type config struct {
Cache kvcache.Cache
}
// Backend returns the underlying backend that will run the model
func (m *Base) Backend() ml.Backend {
return m.b
}
func (m *Base) Config() config {
return m.config
}
var models = make(map[string]func(fs.Config) (Model, error))
// Register registers a model constructor for the given architecture
func Register(name string, f func(fs.Config) (Model, error)) {
if _, ok := models[name]; ok {
panic("model: model already registered")
}
models[name] = f
}
// New initializes a new model instance with the provided configuration based on the metadata in the model file
func New(modelPath string, params ml.BackendParams) (Model, error) {
b, err := ml.NewBackend(modelPath, params)
if err != nil {
return nil, err
}
m, err := modelForArch(b.Config())
if err != nil {
return nil, err
}
base := Base{b: b, config: m.Config()}
v := reflect.ValueOf(m)
v.Elem().Set(populateFields(base, v.Elem()))
return m, nil
}
func NewTextProcessor(s string) (TextProcessor, error) {
r, err := os.Open(s)
if err != nil {
return nil, err
}
defer r.Close()
meta, err := fsggml.Decode(r, -1)
if err != nil {
return nil, err
}
m, err := modelForArch(meta.KV())
if err != nil {
return nil, err
}
tp, ok := m.(TextProcessor)
if !ok {
return nil, ErrUnsupportedTokenizer
}
return tp, nil
}
func modelForArch(c fs.Config) (Model, error) {
arch := c.Architecture()
if pooling.Type(c.Uint("pooling_type")) != pooling.TypeNone {
arch = arch + "_embed"
}
f, ok := models[arch]
if !ok {
return nil, ErrUnsupportedModel
}
return f(c)
}
func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
t := v.Type()
if t.Kind() == reflect.Struct {
allNil := true
for i := range t.NumField() {
tt := t.Field(i).Type
vv := v.Field(i)
if !vv.CanSet() {
continue
}
// make a copy
tagsCopy := tags
if tag := t.Field(i).Tag.Get("gguf"); tag != "" {
tagsCopy = append(tagsCopy, parseTag(tag))
}
if tt == reflect.TypeOf((*Base)(nil)).Elem() {
vv.Set(reflect.ValueOf(base))
} else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() {
var fn func([]Tag, string, string) [][]string
fn = func(tags []Tag, prefix, suffix string) (fullNames [][]string) {
if len(tags) > 0 {
var names []string
if tags[0].name != "" {
for _, n := range append([]string{tags[0].name}, tags[0].alternatives...) {
names = append(names, prefix+n+suffix)
}
}
childNames := fn(tags[1:], tags[0].prefix, tags[0].suffix)
if len(names) == 0 {
// current tag has no name, use child names only
fullNames = append(fullNames, childNames...)
} else if len(childNames) == 0 {
// current tag has names but no children, create branches for each name
for _, name := range names {
fullNames = append(fullNames, []string{name})
}
} else {
// merge each name with each child
for _, name := range names {
for _, childName := range childNames {
fullNames = append(fullNames, append([]string{name}, childName...))
}
}
}
}
return fullNames
}
names := fn(tagsCopy, "", "")
for _, name := range names {
if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil {
logutil.Trace("found tensor", "", tensor)
vv.Set(reflect.ValueOf(tensor))
break
}
}
} else if tt.Kind() == reflect.Pointer || tt.Kind() == reflect.Interface {
setPointer(base, vv, tagsCopy)
} else if tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array {
for i := range vv.Len() {
vvv := vv.Index(i)
if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface {
setPointer(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)}))
} else {
vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)})...))
}
}
}
if !canNil(tt) || !vv.IsNil() {
allNil = false
}
}
if allNil {
return reflect.Zero(t)
}
}
return v
}
func setPointer(base Base, v reflect.Value, tags []Tag) {
vv := v
if v.Kind() == reflect.Interface {
if v.IsNil() {
return
}
vv = vv.Elem()
}
vv = reflect.Indirect(vv)
if v.IsNil() {
vv = reflect.New(v.Type().Elem()).Elem()
}
if f := populateFields(base, vv, tags...); f.CanAddr() {
v.Set(f.Addr())
}
}
type Tag struct {
name,
// prefix and suffix are applied to child tags
prefix,
suffix string
alternatives []string
}
func parseTag(s string) (tag Tag) {
parts := strings.Split(s, ",")
if len(parts) > 0 {
tag.name = parts[0]
for _, part := range parts[1:] {
if value, ok := strings.CutPrefix(part, "alt:"); ok && tag.name == "" {
// elevate alternative to primary if no primary given
tag.name = value
slog.Warn("gguf tag has alt: but no primary name", "tag", s)
} else if ok {
tag.alternatives = append(tag.alternatives, value)
}
if value, ok := strings.CutPrefix(part, "pre:"); ok {
tag.prefix = value
}
if value, ok := strings.CutPrefix(part, "suf:"); ok {
tag.suffix = value
}
}
}
return
}
func canNil(t reflect.Type) bool {
return t.Kind() == reflect.Chan ||
t.Kind() == reflect.Func ||
t.Kind() == reflect.Interface ||
t.Kind() == reflect.Map ||
t.Kind() == reflect.Pointer ||
t.Kind() == reflect.Slice
}
func Forward(ctx ml.Context, m Model, batch input.Batch) (ml.Tensor, error) {
if len(batch.Positions) != len(batch.Sequences) {
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
}
if len(batch.Positions) < 1 {
return nil, errors.New("batch size cannot be less than 1")
}
cache := m.Config().Cache
if cache != nil {
err := cache.StartForward(ctx, batch, false)
if err != nil {
return nil, err
}
}
t, err := m.Forward(ctx, batch)
if err != nil {
return nil, err
}
ctx.Forward(t)
return t, nil
}

View File

@@ -1,58 +0,0 @@
//go:build mlx
package gemma3
import (
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/ml/nn"
"github.com/ollama/ollama/x/ml/nn/pooling"
"github.com/ollama/ollama/x/model"
"github.com/ollama/ollama/x/model/input"
)
type embedModel struct {
model.Base
model.SentencePiece
*TextModel
poolingType pooling.Type
Dense [2]*nn.Linear `gguf:"dense"`
}
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
for _, dense := range m.Dense {
hiddenStates = dense.Forward(ctx, hiddenStates)
}
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
return hiddenStates, nil
}
func newEmbedModel(c fs.Config) (model.Model, error) {
m := &embedModel{
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{
int32(c.Uint("tokenizer.ggml.eos_token_id")),
int32(c.Uint("tokenizer.ggml.eot_token_id", 106)),
},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
),
TextModel: newTextModel(c),
poolingType: pooling.Type(c.Uint("pooling_type", 0)),
}
return m, nil
}

View File

@@ -1,157 +0,0 @@
//go:build mlx
package gemma3
import (
"bytes"
"image"
"math"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/x/kvcache"
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/ml/nn"
"github.com/ollama/ollama/x/model"
"github.com/ollama/ollama/x/model/input"
)
type Model struct {
model.Base
model.SentencePiece
*VisionModel `gguf:"vision_tower.vision_model"`
*TextModel `gguf:"language_model.model"`
*MultiModalProjector `gguf:"multi_modal_projector"`
ImageProcessor
}
var _ model.MultimodalProcessor = (*Model)(nil)
type MultiModalProjector struct {
SoftEmbNorm *nn.RMSNorm `gguf:"mm_soft_emb_norm"`
InputProjection *nn.Linear `gguf:"mm_input_projection_weight"` // TODO .weight vs _weight
tokensPerImage int
}
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, imageSize, patchSize int, eps float32) ml.Tensor {
l := visionOutputs.Dim(0)
visionOutputs = visionOutputs.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false)
patchesPerImage := imageSize / patchSize
visionOutputs = visionOutputs.Reshape(ctx, patchesPerImage, patchesPerImage, l)
kernelSize := patchesPerImage / int(math.Sqrt(float64(p.tokensPerImage)))
visionOutputs = visionOutputs.AvgPool2D(ctx, kernelSize, kernelSize, 0)
visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0)*visionOutputs.Dim(1), l)
visionOutputs = visionOutputs.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false)
visionOutputs = p.SoftEmbNorm.Forward(ctx, visionOutputs, eps)
// TODO: inputProjection must be transposed since they're incompatible with visionOutputs
visionOutputs = visionOutputs.Matmul(ctx, p.InputProjection.Weight.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false))
return visionOutputs
}
func New(c fs.Config) (model.Model, error) {
// slog.Info("XXX Config", "c", c)
m := Model{
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{
int32(c.Uint("tokenizer.ggml.eos_token_id")),
int32(c.Uint("tokenizer.ggml.eot_token_id", 106)),
},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
),
ImageProcessor: newImageProcessor(c),
VisionModel: newVisionModel(c),
TextModel: newTextModel(c),
MultiModalProjector: &MultiModalProjector{
tokensPerImage: int(c.Uint("mm_tokens_per_image", 256)),
},
}
// slidingWindowLen := int32(c.Uint("attention.sliding_window"))
// m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
// TODO need to implement sliding window...
m.Cache = kvcache.NewCausalCache()
return &m, nil
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
if len(m.VisionModel.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
image, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
}
f32s, err := m.ImageProcessor.ProcessImage(image)
if err != nil {
return nil, err
}
pixelValues := ctx.Input().FromFloats(f32s,
m.ImageProcessor.imageSize,
m.ImageProcessor.imageSize,
m.ImageProcessor.numChannels,
)
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.imageSize, m.patchSize, m.VisionModel.eps)
return []input.Multimodal{{Tensor: visionOutputs}}, nil
}
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
var result []*input.Input
for _, inp := range inputs {
if len(inp.Multimodal) == 0 {
result = append(result, inp)
} else {
inputMultimodal := inp.Multimodal[0].Tensor
result = append(result,
&input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
&input.Input{Token: 255999}, // "<start_of_image>""
&input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
)
// add image token placeholders
result = append(result, slices.Repeat([]*input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
result = append(result,
&input.Input{Token: 256000}, // <end_of_image>
&input.Input{Token: 108}, // "\n\n"
)
}
}
return result, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
return m.Output.Forward(ctx, hiddenStates), nil
}
func init() {
model.Register("gemma3", New)
model.Register("gemma3_embed", newEmbedModel)
}

View File

@@ -1,211 +0,0 @@
//go:build mlx
package gemma3
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/x/kvcache"
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/ml/nn"
"github.com/ollama/ollama/x/model/input"
)
type TextConfig struct {
hiddenSize, numHeads, numKVHeads int
attnKeyLen int
eps, ropeScale float32
ropeLocalBase, ropeGlobalBase float32
largeModelScaling bool
}
type TextModel struct {
TokenEmbedding *nn.Embedding `gguf:"embed_tokens"`
Layers []TextLayer `gguf:"layers"`
OutputNorm *nn.RMSNorm `gguf:"norm"`
Output *nn.Linear `gguf:"embed_tokens"`
*TextConfig
}
const (
gemmaGlobalCacheCount = 6
gemma27BLayerCount = 62
)
// const (
// cacheTypeSWA = iota
// cacheTypeCausal
// )
func newTextModel(c fs.Config) *TextModel {
numBlocks := int(c.Uint("block_count"))
m := TextModel{
Layers: make([]TextLayer, numBlocks),
TextConfig: &TextConfig{
hiddenSize: int(c.Uint("embedding_length")), // 2560 -- config.json: text_config.hidden_size
numHeads: int(c.Uint("attention.head_count")), // 8 -- hard coded in python implementation for the model, 4 in some places, then overridden as 8
numKVHeads: int(c.Uint("attention.head_count_kv")), // 4 -- same as above
attnKeyLen: int(c.Uint("attention.key_length", 256)), //256 -- rope settings, hardcoded in model definition python
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), // 1e-06 - hardcoded in model definition python
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0), // 10000 - hardcoded in python
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0), // 1e+06 - hardcoded in python
ropeScale: 1, // 1 - default is 1, implied in python code
// vocabSize: vocabSize, // 262144
// attnValLen: int(c.Uint("attention.value_length", 256)), //256
// NOTE: the rope.scaling.factor is set incorrectly in the official QAT weights
// (8 instead of 1)
// ropeScale: c.Float("rope.scaling.factor", 1.0),
},
}
if numBlocks == gemma27BLayerCount {
m.largeModelScaling = true
}
return &m
}
type TextSelfAttention struct {
Query *nn.Linear `gguf:"q_proj"`
QueryNorm *nn.RMSNorm `gguf:"q_norm"`
Key *nn.Linear `gguf:"k_proj"`
KeyNorm *nn.RMSNorm `gguf:"k_norm"`
Value *nn.Linear `gguf:"v_proj"`
Output *nn.Linear `gguf:"o_proj"`
}
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState ml.Tensor, offset int, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
B := hiddenState.Dim(0)
L := hiddenState.Dim(1)
ropeBase := opts.ropeLocalBase
if (layer+1)%gemmaGlobalCacheCount == 0 {
ropeBase = opts.ropeGlobalBase
}
q := sa.Query.Forward(ctx, hiddenState)
k := sa.Key.Forward(ctx, hiddenState)
v := sa.Value.Forward(ctx, hiddenState)
q = q.Reshape(ctx, B, L, opts.numHeads, -1).Transpose(ctx, 0, 2, 1, 3)
k = k.Reshape(ctx, B, L, opts.numKVHeads, -1).Transpose(ctx, 0, 2, 1, 3)
v = v.Reshape(ctx, B, L, opts.numKVHeads, -1).Transpose(ctx, 0, 2, 1, 3).Contiguous(ctx, false)
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
traditional := false
q = q.RoPE(ctx, opts.attnKeyLen, traditional, opts.ropeScale, offset, ml.WithRoPEBase(ropeBase))
k = k.RoPE(ctx, opts.attnKeyLen, traditional, opts.ropeScale, offset, ml.WithRoPEBase(ropeBase))
// TODO - this is wrong somehow so commenting out
// if opts.largeModelScaling {
// q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
// } else {
// q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
// }
scaleFactor := math.Pow(256, -0.5)
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
kqv = kqv.Transpose(ctx, 0, 2, 1, 3).Reshape(ctx, B, L, -1)
return sa.Output.Forward(ctx, kqv)
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// ropeBase := m.TextConfig.ropeLocalBase
// if (layer+1)%gemmaGlobalCacheCount == 0 {
// ropeBase = m.TextConfig.ropeGlobalBase
// }
// q = q.RoPE(ctx, opts.attnKeyLen, traditional, opts.ropeScale, offset, ml.WithRoPEBase(ropeBase))
panic("not yet implemented")
// return key.RoPE(ctx, shift, m.TextConfig.attnKeyLen, ropeBase, 1/m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
}
type TextMLP struct {
Up *nn.Linear `gguf:"up_proj"`
Down *nn.Linear `gguf:"down_proj"`
Gate *nn.Linear `gguf:"gate_proj"`
}
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
type TextLayer struct {
AttentionNorm *nn.RMSNorm `gguf:"input_layernorm"`
SelfAttention *TextSelfAttention `gguf:"self_attn"`
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_layernorm"`
MLPNorm *nn.RMSNorm `gguf:"pre_feedforward_layernorm"`
MLP *TextMLP `gguf:"mlp"`
PostMLPNorm *nn.RMSNorm `gguf:"post_feedforward_layernorm"`
}
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, outputs ml.Tensor, offset int, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, offset, cache, opts)
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
if outputs != nil {
hiddenState = hiddenState.TakeAxes(ctx, outputs, 1)
residual = residual.TakeAxes(ctx, outputs, 1)
}
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState, opts) // TODO this is where it goes bad most likely...
hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
return hiddenState.Add(ctx, residual)
}
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
// set image embeddings
// var except []int
// for _, image := range batch.Multimodal {
// visionOutputs := image.Multimodal[0].Tensor
// ctx.Forward(visionOutputs.Copy(ctx, hiddenState.AsStrided(ctx,
// []int{visionOutputs.Dim(0) * visionOutputs.Dim(1)},
// []int{image.Index * hiddenState.Stride(1)}, 0)))
// for i := range visionOutputs.Dim(1) {
// except = append(except, image.Index+i)
// }
// }
for i, layer := range m.Layers {
// gemma alternates between the sliding window (local) and causal (global)
// kv cache every 6 layers
if cache != nil {
// cacheType := cacheTypeSWA
// if (i+1)%gemmaGlobalCacheCount == 0 {
// cacheType = cacheTypeCausal
// }
cache.SetLayer(i)
// TODO this needs to come back
// wc := cache.(*kvcache.WrapperCache)
// wc.SetLayerType(cacheType)
// if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
// causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
// }
}
var offset int
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
offset = batch.Offset
lastLayerOutputs = batch.Outputs
}
hiddenState = layer.Forward(ctx, i, hiddenState, lastLayerOutputs, offset, cache, m.TextConfig)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return hiddenState
}

View File

@@ -1,121 +0,0 @@
//go:build mlx
package gemma3
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/ml/nn"
)
var batchSize int = 1
type VisionSelfAttention struct {
Query *nn.Linear `gguf:"self_attn.q_proj"`
Key *nn.Linear `gguf:"self_attn.k_proj"`
Value *nn.Linear `gguf:"self_attn.v_proj"`
Output *nn.Linear `gguf:"self_attn.out_proj"`
}
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
headDim := opts.hiddenSize / opts.numHeads
query := sa.Query.Forward(ctx, hiddenState)
key := sa.Key.Forward(ctx, hiddenState)
value := sa.Value.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize)
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
hiddenState = sa.Output.Forward(ctx, attention)
return hiddenState
}
type VisionMLP struct {
FC1 *nn.Linear `gguf:"fc1"`
FC2 *nn.Linear `gguf:"fc2"`
}
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
hiddenState = mlp.FC1.Forward(ctx, hiddenState).GELU(ctx)
hiddenState = mlp.FC2.Forward(ctx, hiddenState)
return hiddenState
}
type VisionEncoderLayer struct {
LayerNorm1 *nn.LayerNorm `gguf:"layer_norm1"`
SelfAttention *VisionSelfAttention
LayerNorm2 *nn.LayerNorm `gguf:"layer_norm2"`
MLP *VisionMLP `gguf:"mlp"`
}
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
residual := hiddenState
// self attention
hiddenState = e.LayerNorm1.Forward(ctx, hiddenState, opts.eps)
hiddenState = e.SelfAttention.Forward(ctx, hiddenState, opts)
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
// feed forward
hiddenState = e.LayerNorm2.Forward(ctx, hiddenState, opts.eps)
hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
return hiddenState.Add(ctx, residual)
}
type VisionModelOptions struct {
hiddenSize, numHeads int
imageSize, patchSize int
eps float32
}
type VisionModel struct {
PatchEmbedding *nn.Conv2D `gguf:"embeddings.patch_embedding"`
PositionEmbedding *nn.Embedding `gguf:"embeddings.position_embedding"`
PostLayerNorm *nn.LayerNorm `gguf:"post_layernorm"`
Layers []VisionEncoderLayer `gguf:"encoder.layers"`
*VisionModelOptions
}
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
numPatches := (m.imageSize / m.patchSize) * (m.imageSize / m.patchSize)
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
hiddenState = hiddenState.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false)
positionIDs := ctx.Arange(0, float32(numPatches), 1, ml.DTypeInt32)
hiddenState = hiddenState.Add(ctx, m.PositionEmbedding.Forward(ctx, positionIDs))
for _, layer := range m.Layers {
hiddenState = layer.Forward(ctx, hiddenState, m.VisionModelOptions)
}
hiddenState = m.PostLayerNorm.Forward(ctx, hiddenState, m.eps)
return hiddenState
}
func newVisionModel(c fs.Config) *VisionModel {
return &VisionModel{
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")),
VisionModelOptions: &VisionModelOptions{
hiddenSize: int(c.Uint("vision.embedding_length")),
numHeads: int(c.Uint("vision.attention.head_count")),
imageSize: int(c.Uint("vision.image_size")),
patchSize: int(c.Uint("vision.patch_size")),
eps: c.Float("vision.attention.layer_norm_epsilon"),
},
}
}

View File

@@ -1,60 +0,0 @@
//go:build mlx
package gemma3
import (
"image"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/model/imageproc"
)
type ImageProcessor struct {
imageSize, patchSize, numChannels int
}
func newImageProcessor(c fs.Config) ImageProcessor {
return ImageProcessor{
imageSize: int(c.Uint("vision.image_size")),
patchSize: int(c.Uint("vision.patch_size")),
numChannels: int(c.Uint("vision.num_channels")),
}
}
func (p *ImageProcessor) pack(img image.Image, mean, std [3]float32) []float32 {
var pixelVals, rVals, gVals, bVals []float32
bounds := img.Bounds()
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
for x := bounds.Min.X; x < bounds.Max.X; x++ {
c := img.At(x, y)
r, g, b, _ := c.RGBA()
rVal := float32(r>>8) / 255.0
gVal := float32(g>>8) / 255.0
bVal := float32(b>>8) / 255.0
rVal = (rVal - mean[0]) / std[0]
gVal = (gVal - mean[1]) / std[1]
bVal = (bVal - mean[2]) / std[2]
rVals = append(rVals, rVal)
gVals = append(gVals, gVal)
bVals = append(bVals, bVal)
}
}
pixelVals = append(pixelVals, rVals...)
pixelVals = append(pixelVals, gVals...)
pixelVals = append(pixelVals, bVals...)
return pixelVals
}
func (p ImageProcessor) ProcessImage(img image.Image) ([]float32, error) {
outputSize := image.Point{p.imageSize, p.imageSize}
newImage := imageproc.Composite(img)
newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear)
data := p.pack(newImage, imageproc.ImageNetStandardMean, imageproc.ImageNetStandardSTD)
return data, nil
}

View File

@@ -1,3 +0,0 @@
package models
// _ "github.com/ollama/ollama/x/model/models/gemma3"

View File

@@ -1,249 +0,0 @@
package model
import (
"container/heap"
"fmt"
"log/slog"
"strconv"
"strings"
"github.com/ollama/ollama/logutil"
)
const spmWhitespaceSep = "▁"
type SentencePiece struct {
maxTokenLen int
vocab *Vocabulary
}
var _ TextProcessor = (*SentencePiece)(nil)
func (spm SentencePiece) Vocabulary() *Vocabulary {
return spm.vocab
}
func NewSentencePiece(vocab *Vocabulary) SentencePiece {
logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
counter := map[int]int{}
var maxTokenLen int
for cnt := range vocab.Types {
switch vocab.Types[cnt] {
case TOKEN_TYPE_NORMAL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_UNUSED:
maxTokenLen = max(maxTokenLen, len(vocab.Values[cnt]))
fallthrough
default:
counter[int(vocab.Types[cnt])] += 1
}
}
logutil.Trace("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
"max token len", maxTokenLen)
return SentencePiece{
maxTokenLen: maxTokenLen,
vocab: vocab,
}
}
func (spm SentencePiece) Is(id int32, special Special) bool {
return spm.vocab.Is(id, special)
}
func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
fragments := []fragment{{value: s}}
for _, special := range spm.vocab.SpecialVocabulary() {
id := spm.vocab.Encode(special)
for i := 0; i < len(fragments); i++ {
frag := fragments[i]
if len(frag.ids) > 0 {
continue
}
var middle []fragment
switch i := strings.Index(frag.value, special); {
case i < 0:
middle = append(middle, frag)
case i > 0:
middle = append(middle, fragment{value: frag.value[:i]})
fallthrough
default:
middle = append(middle, fragment{value: special, ids: []int32{id}})
if rest := frag.value[i+len(special):]; rest != "" {
middle = append(middle, fragment{value: rest})
}
}
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
}
}
var ids []int32
for _, frag := range fragments {
if len(frag.ids) > 0 {
ids = append(ids, frag.ids...)
continue
}
text := strings.ReplaceAll(frag.value, " ", spmWhitespaceSep)
if id := spm.vocab.Encode(text); id >= 0 {
ids = append(ids, id)
continue
}
q := &queue{}
heap.Init(q)
runes := []rune(text)
merges := make([]merge, len(runes))
for r := range runes {
merges[r] = merge{
p: r - 1,
n: r + 1,
runes: []rune{runes[r]},
}
}
pairwise := func(a, b int) *candidate {
if a < 0 || b >= len(runes) {
return nil
}
left, right := string(merges[a].runes), string(merges[b].runes)
if id := spm.vocab.Encode(left + right); id >= 0 {
return &candidate{
a: a,
b: b,
score: spm.vocab.Scores[id],
size: len(left) + len(right),
}
}
return nil
}
for i := range len(runes) - 1 {
if pair := pairwise(i, i+1); pair != nil {
heap.Push(q, pair)
}
}
for q.Len() > 0 {
pair := heap.Pop(q).(*candidate)
left, right := merges[pair.a], merges[pair.b]
if string(left.runes) == "" || string(right.runes) == "" || len(string(left.runes))+len(string(right.runes)) != pair.size {
continue
}
merges[pair.a].runes = append(left.runes, right.runes...)
merges[pair.b].runes = nil
merges[pair.a].n = right.n
if right.n < len(merges) {
merges[right.n].p = pair.a
}
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
heap.Push(q, pair)
}
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
heap.Push(q, pair)
}
}
for _, merge := range merges {
if token := string(merge.runes); token != "" {
id := spm.vocab.Encode(token)
if id >= 0 {
ids = append(ids, id)
continue
}
// Fallback to byte tokenization
var result []int32
for _, b := range []byte(token) {
byteToken := fmt.Sprintf("<0x%02X>", b)
unknownID := spm.vocab.Encode(byteToken)
if unknownID >= 0 {
result = append(result, unknownID)
} else {
slog.Debug("unknown byte token", "byte", b, "token", byteToken)
}
}
ids = append(ids, result...)
}
}
}
if addSpecial {
ids = spm.vocab.addSpecials(ids)
}
logutil.Trace("encoded", "string", s, "ids", ids)
return ids, nil
}
type candidate struct {
a, b int
score float32
size int
}
type queue []*candidate
func (q queue) Len() int { return len(q) }
func (q queue) Less(i, j int) bool {
return (q[i].score > q[j].score) || (q[i].score == q[j].score && q[i].a < q[j].a)
}
func (q queue) Swap(i, j int) { q[i], q[j] = q[j], q[i] }
func (q *queue) Push(x interface{}) {
item := x.(*candidate)
*q = append(*q, item)
}
func (q *queue) Pop() interface{} {
old := *q
n := len(old)
item := old[n-1]
*q = old[0 : n-1]
return item
}
func (spm SentencePiece) Decode(ids []int32) (string, error) {
var sb strings.Builder
for _, id := range ids {
data := spm.vocab.Decode(id)
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
// For tokenizers that use byte tokens like "<0xEA>"
// convert them to the partial unicode character
// so they are buffered correctly by the runner instead
// of being sent back to the api as "<0xEA>"
if len(data) == 6 && strings.HasPrefix(data, "<0x") && strings.HasSuffix(data, ">") {
byteVal, err := strconv.ParseUint(data[1:5], 0, 8)
if err != nil {
return "", fmt.Errorf("failed to parse hex byte: %v", err)
}
if err := sb.WriteByte(byte(byteVal)); err != nil {
return "", err
}
} else {
if _, err := sb.WriteString(data); err != nil {
return "", err
}
}
}
logutil.Trace("decoded", "ids", ids, "string", sb.String())
return sb.String(), nil
}

View File

@@ -1,172 +0,0 @@
package model
import (
"log/slog"
"os"
"path/filepath"
"slices"
"testing"
"google.golang.org/protobuf/proto"
"github.com/ollama/ollama/convert/sentencepiece"
)
func loadSentencePieceVocab(t *testing.T) SentencePiece {
t.Helper()
bts, err := os.ReadFile(filepath.Join("..", "..", "model", "testdata", "gemma2", "tokenizer.model"))
if err != nil {
t.Fatal(err)
}
var spm sentencepiece.ModelProto
if err := proto.Unmarshal(bts, &spm); err != nil {
t.Fatal(err)
}
var v Vocabulary
for _, piece := range spm.GetPieces() {
v.Values = append(v.Values, piece.GetPiece())
v.Scores = append(v.Scores, piece.GetScore())
switch t := piece.GetType(); t {
case sentencepiece.ModelProto_SentencePiece_UNKNOWN,
sentencepiece.ModelProto_SentencePiece_CONTROL,
sentencepiece.ModelProto_SentencePiece_UNUSED,
sentencepiece.ModelProto_SentencePiece_BYTE:
v.Types = append(v.Types, int32(t))
default:
tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
// todo parse the special tokens file
// - this will roundtrip correctly but the <start_of_turn> and
// <end_of_turn> tokens aren't processed
v.Types = append(v.Types, tt)
}
}
return NewSentencePiece(&v)
}
func TestSentencePieceEncode(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
slog.SetDefault(logger)
tokenizer := loadSentencePieceVocab(t)
t.Run("basic roundtrip", func(t *testing.T) {
t.Parallel()
cases := []string{
"hello",
"hello ",
"hello ",
" hello",
" hello ",
" hello ",
"hello world",
"请考试我的软件12345",
"你好",
"Hello 你好 world!",
"Special characters: !@#$%^&*()_+-=[]{}|;':\",./<>?",
"Multilingual: 你好 こんにちは Привет Hola مرحبا",
"Numbers and symbols: 123456789 +- */",
"Special tokens: <bos> text <eos>",
"Code snippets: func main() { fmt.Println(\"Hello World\") }",
"Long text: " + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " +
"Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " +
"Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris.",
}
for _, want := range cases {
ids, err := tokenizer.Encode(want, true)
if err != nil {
t.Fatal(err)
}
if got, err := tokenizer.Decode(ids); err != nil {
t.Fatal(err)
} else if got != want {
t.Errorf("got %q, want %q [%#v]", got, want, ids)
}
}
})
t.Run("special tokens", func(t *testing.T) {
type candidate struct {
token string
ids []int32
}
cases := []candidate{
{"<bos>", []int32{2}},
{"<eos>", []int32{1}},
}
for _, want := range cases {
ids, err := tokenizer.Encode(want.token, true)
if err != nil {
t.Fatal(err)
}
if !slices.Equal(ids, want.ids) {
t.Errorf("got %#v, want %#v", ids, want.ids)
}
}
})
}
func TestSentencePieceDecodeByteTokens(t *testing.T) {
vocab := &Vocabulary{
Values: []string{
"normal",
"<0xEA>",
"<0x41>",
"<0xC3>",
"<0xA3>",
},
Types: []int32{
TOKEN_TYPE_NORMAL,
TOKEN_TYPE_BYTE,
TOKEN_TYPE_BYTE,
TOKEN_TYPE_BYTE,
TOKEN_TYPE_BYTE,
},
Scores: []float32{0, 0, 0, 0, 0},
}
spm := NewSentencePiece(vocab)
tests := []struct {
name string
ids []int32
expected string
}{
{
name: "single byte token",
ids: []int32{1},
expected: "\xea",
},
{
name: "ASCII byte token",
ids: []int32{2},
expected: "A",
},
{
name: "multiple byte tokens forming UTF-8 character",
ids: []int32{3, 4},
expected: "ã",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := spm.Decode(tt.ids)
if err != nil {
t.Errorf("failed to decode token IDs %v: %v", tt.ids, err)
}
if result != tt.expected {
t.Errorf("got %q, want %q", result, tt.expected)
}
})
}
}

View File

@@ -1,17 +0,0 @@
package model
const (
TOKEN_TYPE_NORMAL = iota + 1
TOKEN_TYPE_UNKNOWN
TOKEN_TYPE_CONTROL
TOKEN_TYPE_USER_DEFINED
TOKEN_TYPE_UNUSED
TOKEN_TYPE_BYTE
)
type TextProcessor interface {
Encode(s string, addSpecial bool) ([]int32, error)
Decode([]int32) (string, error)
Is(int32, Special) bool
Vocabulary() *Vocabulary
}

View File

@@ -1,112 +0,0 @@
package model
import (
"log/slog"
"slices"
"sync"
)
type Special int32
const (
SpecialBOS Special = iota
SpecialEOS
)
type Vocabulary struct {
Values []string
Types []int32
Scores []float32
Merges []string
BOS, EOS []int32
AddBOS, AddEOS bool
specialOnce sync.Once
special []string
valuesOnce sync.Once
values map[string]int32
mergeOnce sync.Once
merge map[string]int32
}
func (v *Vocabulary) Is(id int32, special Special) bool {
switch special {
case SpecialBOS:
return slices.Contains(v.BOS, id)
case SpecialEOS:
return slices.Contains(v.EOS, id)
default:
return false
}
}
func (v *Vocabulary) addSpecials(ids []int32) []int32 {
if v.AddBOS && len(v.BOS) > 0 {
if len(ids) > 0 && slices.Contains(v.BOS, ids[0]) {
slog.Warn("adding bos token to prompt which already has it", "id", v.BOS)
}
slog.Debug("adding bos token to prompt", "id", v.BOS[0])
ids = append([]int32{v.BOS[0]}, ids...)
}
if v.AddEOS && len(v.EOS) > 0 {
if len(ids) > 0 && slices.Contains(v.BOS, ids[len(ids)-1]) {
slog.Warn("adding eos token to prompt which already has it", "id", v.EOS)
}
slog.Debug("adding eos token to prompt", "id", v.EOS[0])
ids = append(ids, v.EOS[0])
}
return ids
}
func (v *Vocabulary) Encode(s string) int32 {
v.valuesOnce.Do(func() {
v.values = make(map[string]int32, len(v.Values))
for i, value := range v.Values {
v.values[value] = int32(i)
}
})
if id, ok := v.values[s]; ok {
return id
}
return -1
}
func (v *Vocabulary) Decode(id int32) string {
return v.Values[id]
}
func (v *Vocabulary) SpecialVocabulary() []string {
v.specialOnce.Do(func() {
for i := range v.Values {
if v.Types[i] == TOKEN_TYPE_CONTROL || v.Types[i] == TOKEN_TYPE_USER_DEFINED {
v.special = append(v.special, v.Values[i])
}
}
})
return v.special
}
func (v *Vocabulary) Merge(left, right string) int {
v.mergeOnce.Do(func() {
v.merge = make(map[string]int32, len(v.Merges))
for i, merge := range v.Merges {
v.merge[merge] = int32(i)
}
})
if id, ok := v.merge[left+" "+right]; ok {
return int(id)
}
return -1
}

View File

@@ -1,107 +0,0 @@
package model
import (
"testing"
"github.com/google/go-cmp/cmp"
)
func TestSpecialVocabulary(t *testing.T) {
vocab := &Vocabulary{
Values: []string{"<|startoftext|>", "<|endoftext|>", "<|tool_call_start|>", "<|tool_call_end|>", "hi"},
Types: []int32{TOKEN_TYPE_CONTROL, TOKEN_TYPE_CONTROL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_NORMAL},
}
specialVocab := vocab.SpecialVocabulary()
if len(specialVocab) != 4 {
t.Errorf("expected 4 special tokens, got %d", len(specialVocab))
}
}
func TestAddSpecialVocabulary(t *testing.T) {
cases := []struct {
name string
vocab *Vocabulary
input []int32
want []int32
}{
{
name: "add bos",
vocab: &Vocabulary{
BOS: []int32{0},
EOS: []int32{1},
AddBOS: true,
AddEOS: false,
},
input: []int32{2, 3, 4},
want: []int32{0, 2, 3, 4},
},
{
// TODO(mxyng): this is to match previous behaviour
name: "add bos when already present",
vocab: &Vocabulary{
BOS: []int32{0},
EOS: []int32{1},
AddBOS: true,
AddEOS: false,
},
input: []int32{0, 2, 3, 4},
want: []int32{0, 0, 2, 3, 4},
},
{
name: "add eos",
vocab: &Vocabulary{
BOS: []int32{0},
EOS: []int32{1},
AddBOS: false,
AddEOS: true,
},
input: []int32{2, 3, 4},
want: []int32{2, 3, 4, 1},
},
{
// TODO(mxyng): this is to match previous behaviour
name: "add eos when already present",
vocab: &Vocabulary{
BOS: []int32{0},
EOS: []int32{1},
AddBOS: false,
AddEOS: true,
},
input: []int32{2, 3, 4, 1},
want: []int32{2, 3, 4, 1, 1},
},
{
name: "add both",
vocab: &Vocabulary{
BOS: []int32{0},
EOS: []int32{1},
AddBOS: true,
AddEOS: true,
},
input: []int32{2, 3, 4},
want: []int32{0, 2, 3, 4, 1},
},
{
name: "add bos to empty inputs",
vocab: &Vocabulary{
BOS: []int32{0},
EOS: []int32{1},
AddBOS: true,
AddEOS: false,
},
input: []int32{},
want: []int32{0},
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
got := tt.vocab.addSpecials(tt.input)
if diff := cmp.Diff(tt.want, got); diff != "" {
t.Errorf("no match (-want +got):\n%s", diff)
}
})
}
}

View File

@@ -1,171 +0,0 @@
package model
import (
"fmt"
"iter"
"strings"
"unicode"
"github.com/ollama/ollama/logutil"
)
type WordPiece struct {
vocab *Vocabulary
lowercase bool
}
// ggmlPrefix is the prefix used by GGML vocabularies to indicate word boundaries.
// this differs from original word piece which uses "##" to indicate subwords.
const ggmlPrefix = "▁"
var wordPieceReplacer = strings.NewReplacer(
" .", ".",
" ?", "?",
" !", "!",
" ,", ",",
" ' ", "'",
" n't", "n't",
" 'm", "'m",
" do not", " don't",
" 's", "'s",
" 've", "'ve",
" 're", "'re",
)
// Decode implements TextProcessor.
func (wpm WordPiece) Decode(ids []int32) (string, error) {
var sb strings.Builder
for i, id := range ids {
if id < 0 || int(id) >= len(wpm.vocab.Values) {
return "", fmt.Errorf("invalid token id: %d", id)
}
var separator string
piece := wpm.vocab.Values[id]
if i > 0 &&
(strings.HasPrefix(piece, ggmlPrefix) ||
(strings.HasPrefix(piece, "[") && strings.HasSuffix(piece, "]"))) {
separator = " "
}
sb.WriteString(wordPieceReplacer.Replace(separator + strings.TrimPrefix(piece, ggmlPrefix)))
}
return sb.String(), nil
}
// words splits a string into words, treating CJK characters as separate words.
// TODO: this is specifically for BERT and may need to be adjusted or refactored for other models.
func (wpm WordPiece) words(s string) iter.Seq[string] {
return func(yield func(string) bool) {
runes := make([]rune, 0, len(s)*3)
for _, r := range s {
switch {
case r >= 0x4E00 && r <= 0x9FFF,
r >= 0x3400 && r <= 0x4DBF,
r >= 0x20000 && r <= 0x2A6DF,
r >= 0x2A700 && r <= 0x2B73F,
r >= 0x2B740 && r <= 0x2B81F,
r >= 0x2B820 && r <= 0x2CEAF,
r >= 0xF900 && r <= 0xFAFF,
r >= 0x2F800 && r <= 0x2FA1F:
runes = append(runes, ' ', r, ' ')
default:
runes = append(runes, r)
}
}
for w := range strings.FieldsFuncSeq(string(runes), unicode.IsSpace) {
// split on but keep punctuation
var start int
for start < len(w) {
end := strings.IndexFunc(w[start:], unicode.IsPunct)
if end < 0 {
end = len(w) - start
} else if end == 0 {
end = 1
}
if !yield(w[start : start+end]) {
return
}
start += end
}
}
}
}
// Encode implements TextProcessor.
func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
var ids []int32
// TODO: use [UNK] from config
unk := wpm.vocab.Encode("[UNK]")
for word := range wpm.words(s) {
var start int
var pieces []int32
for start < len(word) {
end := len(word)
var piece int32
for start < end {
subword := word[start:end]
if start == 0 {
subword = ggmlPrefix + subword
}
if wpm.lowercase {
subword = strings.ToLower(subword)
}
piece = wpm.vocab.Encode(subword)
if piece >= 0 {
break
}
end--
}
if piece < 0 {
// Unknown token
pieces = pieces[:0]
break
}
pieces = append(pieces, piece)
start = end
}
if len(pieces) > 0 {
ids = append(ids, pieces...)
} else {
ids = append(ids, unk)
}
}
if addSpecial {
ids = wpm.vocab.addSpecials(ids)
}
logutil.Trace("encoded", "string", s, "ids", ids)
return ids, nil
}
// Is implements TextProcessor.
func (wpm WordPiece) Is(id int32, special Special) bool {
return wpm.vocab.Is(id, special)
}
// Vocabulary implements TextProcessor.
func (wpm WordPiece) Vocabulary() *Vocabulary {
return wpm.vocab
}
var _ TextProcessor = (*WordPiece)(nil)
func NewWordPiece(vocab *Vocabulary, lowercase bool) WordPiece {
return WordPiece{
vocab: vocab,
lowercase: lowercase,
}
}