mirror of
https://github.com/ollama/ollama.git
synced 2026-02-05 21:23:43 -05:00
Compare commits
2 Commits
v0.15.5-rc
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f1373193dc | ||
|
|
8a4b77f9da |
@@ -482,6 +482,8 @@ Examples:
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if saved, err := loadIntegration(name); err == nil && len(saved.Models) > 0 && !configFlag {
|
||||
return runIntegration(name, saved.Models[0], passArgs)
|
||||
} else {
|
||||
var err error
|
||||
models, err = selectModels(cmd.Context(), name, "")
|
||||
|
||||
@@ -502,6 +502,28 @@ func TestBuildModelList_ReturnsExistingAndCloudMaps(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEditorIntegration_SavedConfigSkipsSelection(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save a config for opencode so it looks like a previous launch
|
||||
if err := saveIntegration("opencode", []string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Verify loadIntegration returns the saved models
|
||||
saved, err := loadIntegration("opencode")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(saved.Models) == 0 {
|
||||
t.Fatal("expected saved models")
|
||||
}
|
||||
if saved.Models[0] != "llama3.2" {
|
||||
t.Errorf("expected llama3.2, got %s", saved.Models[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAliasConfigurerInterface(t *testing.T) {
|
||||
t.Run("claude implements AliasConfigurer", func(t *testing.T) {
|
||||
claude := &Claude{}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"maps"
|
||||
@@ -10,12 +11,52 @@ import (
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// OpenCode implements Runner and Editor for OpenCode integration
|
||||
type OpenCode struct{}
|
||||
|
||||
// cloudModelLimit holds context and output token limits for a cloud model.
|
||||
type cloudModelLimit struct {
|
||||
Context int
|
||||
Output int
|
||||
}
|
||||
|
||||
// cloudModelLimits maps cloud model base names to their token limits.
|
||||
// TODO(parthsareen): grab context/output limits from model info instead of hardcoding
|
||||
var cloudModelLimits = map[string]cloudModelLimit{
|
||||
"cogito-2.1:671b": {Context: 163_840, Output: 65_536},
|
||||
"deepseek-v3.1:671b": {Context: 163_840, Output: 163_840},
|
||||
"deepseek-v3.2": {Context: 163_840, Output: 65_536},
|
||||
"glm-4.6": {Context: 202_752, Output: 131_072},
|
||||
"glm-4.7": {Context: 202_752, Output: 131_072},
|
||||
"gpt-oss:120b": {Context: 131_072, Output: 131_072},
|
||||
"gpt-oss:20b": {Context: 131_072, Output: 131_072},
|
||||
"kimi-k2:1t": {Context: 262_144, Output: 262_144},
|
||||
"kimi-k2.5": {Context: 262_144, Output: 262_144},
|
||||
"kimi-k2-thinking": {Context: 262_144, Output: 262_144},
|
||||
"nemotron-3-nano:30b": {Context: 1_048_576, Output: 131_072},
|
||||
"qwen3-coder:480b": {Context: 262_144, Output: 65_536},
|
||||
"qwen3-next:80b": {Context: 262_144, Output: 32_768},
|
||||
}
|
||||
|
||||
// lookupCloudModelLimit returns the token limits for a cloud model.
|
||||
// It tries the exact name first, then strips the ":cloud" suffix.
|
||||
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
|
||||
if l, ok := cloudModelLimits[name]; ok {
|
||||
return l, true
|
||||
}
|
||||
base := strings.TrimSuffix(name, ":cloud")
|
||||
if base != name {
|
||||
if l, ok := cloudModelLimits[base]; ok {
|
||||
return l, true
|
||||
}
|
||||
}
|
||||
return cloudModelLimit{}, false
|
||||
}
|
||||
|
||||
func (o *OpenCode) String() string { return "OpenCode" }
|
||||
|
||||
func (o *OpenCode) Run(model string, args []string) error {
|
||||
@@ -113,6 +154,8 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
client, _ := api.ClientFromEnvironment()
|
||||
|
||||
for _, model := range modelList {
|
||||
if existing, ok := models[model].(map[string]any); ok {
|
||||
// migrate existing models without _launch marker
|
||||
@@ -122,12 +165,29 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
|
||||
}
|
||||
}
|
||||
if isCloudModel(context.Background(), client, model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
existing["limit"] = map[string]any{
|
||||
"context": l.Context,
|
||||
"output": l.Output,
|
||||
}
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
models[model] = map[string]any{
|
||||
entry := map[string]any{
|
||||
"name": model,
|
||||
"_launch": true,
|
||||
}
|
||||
if isCloudModel(context.Background(), client, model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
entry["limit"] = map[string]any{
|
||||
"context": l.Context,
|
||||
"output": l.Output,
|
||||
}
|
||||
}
|
||||
}
|
||||
models[model] = entry
|
||||
}
|
||||
|
||||
ollama["models"] = models
|
||||
|
||||
@@ -2,6 +2,7 @@ package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -495,6 +496,165 @@ func TestOpenCodeEdit_SpecialCharsInModelName(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func readOpenCodeModel(t *testing.T, configPath, model string) map[string]any {
|
||||
t.Helper()
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
provider := cfg["provider"].(map[string]any)
|
||||
ollama := provider["ollama"].(map[string]any)
|
||||
models := ollama["models"].(map[string]any)
|
||||
entry, ok := models[model].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("model %s not found in config", model)
|
||||
}
|
||||
return entry
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_LocalModelNoLimit(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configPath := filepath.Join(tmpDir, ".config", "opencode", "opencode.json")
|
||||
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entry := readOpenCodeModel(t, configPath, "llama3.2")
|
||||
if entry["limit"] != nil {
|
||||
t.Errorf("local model should not have limit set, got %v", entry["limit"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_PreservesUserLimit(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
|
||||
// Set up a model with a user-configured limit
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{
|
||||
"provider": {
|
||||
"ollama": {
|
||||
"models": {
|
||||
"llama3.2": {
|
||||
"name": "llama3.2",
|
||||
"_launch": true,
|
||||
"limit": {"context": 8192, "output": 4096}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`), 0o644)
|
||||
|
||||
// Re-edit should preserve the user's limit (not delete it)
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entry := readOpenCodeModel(t, configPath, "llama3.2")
|
||||
limit, ok := entry["limit"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("user-configured limit was removed")
|
||||
}
|
||||
if limit["context"] != float64(8192) {
|
||||
t.Errorf("context limit changed: got %v, want 8192", limit["context"])
|
||||
}
|
||||
if limit["output"] != float64(4096) {
|
||||
t.Errorf("output limit changed: got %v, want 4096", limit["output"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_CloudModelLimitStructure(t *testing.T) {
|
||||
// Verify that when a cloud model entry has limits set (as Edit would do),
|
||||
// the structure matches what opencode expects and re-edit preserves them.
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
|
||||
expected := cloudModelLimits["glm-4.7"]
|
||||
|
||||
// Simulate a cloud model that already has the limit set by a previous Edit
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(fmt.Sprintf(`{
|
||||
"provider": {
|
||||
"ollama": {
|
||||
"models": {
|
||||
"glm-4.7:cloud": {
|
||||
"name": "glm-4.7:cloud",
|
||||
"_launch": true,
|
||||
"limit": {"context": %d, "output": %d}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`, expected.Context, expected.Output)), 0o644)
|
||||
|
||||
// Re-edit should preserve the cloud model limit
|
||||
if err := o.Edit([]string{"glm-4.7:cloud"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entry := readOpenCodeModel(t, configPath, "glm-4.7:cloud")
|
||||
limit, ok := entry["limit"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("cloud model limit was removed on re-edit")
|
||||
}
|
||||
if limit["context"] != float64(expected.Context) {
|
||||
t.Errorf("context = %v, want %d", limit["context"], expected.Context)
|
||||
}
|
||||
if limit["output"] != float64(expected.Output) {
|
||||
t.Errorf("output = %v, want %d", limit["output"], expected.Output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookupCloudModelLimit(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
wantOK bool
|
||||
wantContext int
|
||||
wantOutput int
|
||||
}{
|
||||
{"glm-4.7", true, 202_752, 131_072},
|
||||
{"glm-4.7:cloud", true, 202_752, 131_072},
|
||||
{"kimi-k2.5", true, 262_144, 262_144},
|
||||
{"kimi-k2.5:cloud", true, 262_144, 262_144},
|
||||
{"deepseek-v3.2", true, 163_840, 65_536},
|
||||
{"deepseek-v3.2:cloud", true, 163_840, 65_536},
|
||||
{"qwen3-coder:480b", true, 262_144, 65_536},
|
||||
{"llama3.2", false, 0, 0},
|
||||
{"unknown-model:cloud", false, 0, 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
l, ok := lookupCloudModelLimit(tt.name)
|
||||
if ok != tt.wantOK {
|
||||
t.Errorf("lookupCloudModelLimit(%q) ok = %v, want %v", tt.name, ok, tt.wantOK)
|
||||
}
|
||||
if ok {
|
||||
if l.Context != tt.wantContext {
|
||||
t.Errorf("context = %d, want %d", l.Context, tt.wantContext)
|
||||
}
|
||||
if l.Output != tt.wantOutput {
|
||||
t.Errorf("output = %d, want %d", l.Output, tt.wantOutput)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeModels_NoConfig(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
@@ -34,6 +34,7 @@ import (
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type filteredEnv []string
|
||||
@@ -116,7 +117,7 @@ type llamaServer struct {
|
||||
type ollamaServer struct {
|
||||
llmServer
|
||||
|
||||
textProcessor model.TextProcessor // textProcessor handles text encoding/decoding
|
||||
tokenizer tokenizer.Tokenizer // tokenizer handles text encoding/decoding
|
||||
}
|
||||
|
||||
// LoadModel will load a model from disk. The model must be in the GGML format.
|
||||
@@ -142,11 +143,11 @@ func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) {
|
||||
// NewLlamaServer will run a server for the given GPUs
|
||||
func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
|
||||
var llamaModel *llama.Model
|
||||
var textProcessor model.TextProcessor
|
||||
var tok tokenizer.Tokenizer
|
||||
var err error
|
||||
if envconfig.NewEngine() || f.KV().OllamaEngineRequired() {
|
||||
if len(projectors) == 0 {
|
||||
textProcessor, err = model.NewTextProcessor(modelPath)
|
||||
tok, err = model.NewTextProcessor(modelPath)
|
||||
} else {
|
||||
err = errors.New("split vision models aren't supported")
|
||||
}
|
||||
@@ -155,7 +156,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err)
|
||||
}
|
||||
}
|
||||
if textProcessor == nil {
|
||||
if tok == nil {
|
||||
llamaModel, err = llama.LoadModelFromFile(modelPath, llama.ModelParams{VocabOnly: true})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -211,7 +212,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
|
||||
kvct := strings.ToLower(envconfig.KvCacheType())
|
||||
|
||||
if textProcessor == nil {
|
||||
if tok == nil {
|
||||
flashAttention := ml.FlashAttentionAuto
|
||||
if faUserSet {
|
||||
if fa {
|
||||
@@ -261,7 +262,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
gpuLibs := ml.LibraryPaths(gpus)
|
||||
status := NewStatusWriter(os.Stderr)
|
||||
cmd, port, err := StartRunner(
|
||||
textProcessor != nil,
|
||||
tok != nil,
|
||||
modelPath,
|
||||
gpuLibs,
|
||||
status,
|
||||
@@ -310,8 +311,8 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
}
|
||||
}()
|
||||
|
||||
if textProcessor != nil {
|
||||
return &ollamaServer{llmServer: s, textProcessor: textProcessor}, nil
|
||||
if tok != nil {
|
||||
return &ollamaServer{llmServer: s, tokenizer: tok}, nil
|
||||
} else {
|
||||
return &llamaServer{llmServer: s, ggml: f}, nil
|
||||
}
|
||||
@@ -1774,7 +1775,7 @@ func (s *llamaServer) Tokenize(ctx context.Context, content string) ([]int, erro
|
||||
}
|
||||
|
||||
func (s *ollamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
tokens, err := s.textProcessor.Encode(content, false)
|
||||
tokens, err := s.tokenizer.Encode(content, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1809,7 +1810,7 @@ func (s *ollamaServer) Detokenize(ctx context.Context, tokens []int) (string, er
|
||||
toks[i] = int32(t)
|
||||
}
|
||||
|
||||
content, err := s.textProcessor.Decode(toks)
|
||||
content, err := s.tokenizer.Decode(toks)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -1,272 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"iter"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/dlclark/regexp2"
|
||||
heap "github.com/emirpasic/gods/v2/trees/binaryheap"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type BytePairEncoding struct {
|
||||
vocab *Vocabulary
|
||||
regexps []*regexp2.Regexp
|
||||
}
|
||||
|
||||
var _ TextProcessor = (*BytePairEncoding)(nil)
|
||||
|
||||
func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding {
|
||||
if len(pretokenizers) == 0 {
|
||||
// set default byte-level pretokenizer if none provided, e.g.
|
||||
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44
|
||||
pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
|
||||
}
|
||||
|
||||
return BytePairEncoding{
|
||||
vocab: vocab,
|
||||
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
|
||||
for _, p := range pretokenizers {
|
||||
if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Vocabulary() *Vocabulary {
|
||||
return bpe.vocab
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Is(id int32, special Special) bool {
|
||||
return bpe.vocab.Is(id, special)
|
||||
}
|
||||
|
||||
func (bpe *BytePairEncoding) split(s string) iter.Seq[string] {
|
||||
parts := []string{s}
|
||||
for _, re := range bpe.regexps {
|
||||
parts = slices.Collect(func(yield func(string) bool) {
|
||||
for _, part := range parts {
|
||||
r := []rune(part)
|
||||
var offset int
|
||||
for m, _ := re.FindRunesMatch(r); m != nil; m, _ = re.FindNextMatch(m) {
|
||||
if offset-m.Index != 0 {
|
||||
if !yield(string(r[:m.Index])) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !yield(m.String()) {
|
||||
return
|
||||
}
|
||||
|
||||
offset = m.Index + m.Length
|
||||
}
|
||||
|
||||
if offset < len(r) {
|
||||
if !yield(string(r[offset:])) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return slices.Values(parts)
|
||||
}
|
||||
|
||||
// fragment is a string fragment and their corresponding token IDs
|
||||
type fragment struct {
|
||||
value string
|
||||
ids []int32
|
||||
}
|
||||
|
||||
// pair is a pair of runes and its rank
|
||||
type pair struct {
|
||||
a, b int
|
||||
rank int
|
||||
value string
|
||||
}
|
||||
|
||||
type merge struct {
|
||||
p, n int
|
||||
runes []rune
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
fragments := []fragment{{value: s}}
|
||||
for _, special := range bpe.vocab.SpecialVocabulary() {
|
||||
// TODO: process special tokens concurrently
|
||||
id := bpe.vocab.Encode(special)
|
||||
for i := 0; i < len(fragments); i++ {
|
||||
frag := fragments[i]
|
||||
if len(frag.ids) > 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var middle []fragment
|
||||
switch i := strings.Index(frag.value, special); {
|
||||
case i < 0:
|
||||
middle = append(middle, frag)
|
||||
case i > 0:
|
||||
middle = append(middle, fragment{value: frag.value[:i]})
|
||||
fallthrough
|
||||
default:
|
||||
middle = append(middle, fragment{value: special, ids: []int32{id}})
|
||||
if rest := frag.value[i+len(special):]; rest != "" {
|
||||
middle = append(middle, fragment{value: rest})
|
||||
}
|
||||
}
|
||||
|
||||
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
|
||||
}
|
||||
}
|
||||
|
||||
var ids []int32
|
||||
for _, frag := range fragments {
|
||||
if len(frag.ids) > 0 {
|
||||
ids = append(ids, frag.ids...)
|
||||
continue
|
||||
}
|
||||
|
||||
for split := range bpe.split(frag.value) {
|
||||
// TODO: process splits concurrently
|
||||
var sb strings.Builder
|
||||
for _, b := range []byte(split) {
|
||||
r := rune(b)
|
||||
switch {
|
||||
case r == 0x00ad:
|
||||
r = 0x0143
|
||||
case r <= 0x0020:
|
||||
r = r + 0x0100
|
||||
case r >= 0x007f && r <= 0x00a0:
|
||||
r = r + 0x00a2
|
||||
}
|
||||
|
||||
sb.WriteRune(r)
|
||||
}
|
||||
|
||||
// short circuit if the fragment is in the vocabulary
|
||||
if id := bpe.vocab.Encode(sb.String()); id >= 0 {
|
||||
ids = append(ids, id)
|
||||
continue
|
||||
}
|
||||
|
||||
runes := []rune(sb.String())
|
||||
merges := make([]merge, len(runes))
|
||||
for r := range runes {
|
||||
merges[r] = merge{
|
||||
p: r - 1,
|
||||
n: r + 1,
|
||||
runes: []rune{runes[r]},
|
||||
}
|
||||
}
|
||||
|
||||
pairwise := func(a, b int) *pair {
|
||||
if a < 0 || b >= len(runes) {
|
||||
return nil
|
||||
}
|
||||
|
||||
left, right := string(merges[a].runes), string(merges[b].runes)
|
||||
rank := bpe.vocab.Merge(left, right)
|
||||
if rank < 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &pair{
|
||||
a: a,
|
||||
b: b,
|
||||
rank: rank,
|
||||
value: left + right,
|
||||
}
|
||||
}
|
||||
|
||||
pairs := heap.NewWith(func(i, j *pair) int {
|
||||
return cmp.Compare(i.rank, j.rank)
|
||||
})
|
||||
|
||||
for i := range len(runes) - 1 {
|
||||
if pair := pairwise(i, i+1); pair != nil {
|
||||
pairs.Push(pair)
|
||||
}
|
||||
}
|
||||
|
||||
for !pairs.Empty() {
|
||||
pair, _ := pairs.Pop()
|
||||
|
||||
left, right := merges[pair.a], merges[pair.b]
|
||||
if len(left.runes) == 0 || len(right.runes) == 0 ||
|
||||
string(left.runes)+string(right.runes) != pair.value {
|
||||
continue
|
||||
}
|
||||
|
||||
if id := bpe.vocab.Encode(pair.value); id < 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
merges[pair.a].runes = append(left.runes, right.runes...)
|
||||
merges[pair.b].runes = nil
|
||||
|
||||
merges[pair.a].n = right.n
|
||||
if right.n < len(merges) {
|
||||
merges[right.n].p = pair.a
|
||||
}
|
||||
|
||||
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
|
||||
pairs.Push(pair)
|
||||
}
|
||||
|
||||
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
|
||||
pairs.Push(pair)
|
||||
}
|
||||
}
|
||||
|
||||
for _, merge := range merges {
|
||||
if len(merge.runes) > 0 {
|
||||
// TODO: handle the edge case where the rune isn't in the vocabulary
|
||||
if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if addSpecial {
|
||||
ids = bpe.vocab.addSpecials(ids)
|
||||
}
|
||||
|
||||
logutil.Trace("encoded", "string", s, "ids", ids)
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
|
||||
var sb strings.Builder
|
||||
for _, id := range ids {
|
||||
for _, r := range bpe.vocab.Decode(id) {
|
||||
switch {
|
||||
case r == 0x0100:
|
||||
// this produces 0x00 aka NULL
|
||||
continue
|
||||
case r == 0x0143:
|
||||
r = 0x00ad
|
||||
case r > 0x0100 && r <= 0x0120:
|
||||
r = r - 0x0100
|
||||
case r > 0x0120 && r <= 0x0142:
|
||||
r = r - 0x00a2
|
||||
}
|
||||
|
||||
// NOTE: not using WriteRune here because it writes the UTF-8
|
||||
// encoding of the rune which is _not_ what we want
|
||||
if err := sb.WriteByte(byte(r)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logutil.Trace("decoded", "string", sb.String(), "from", ids)
|
||||
return sb.String(), nil
|
||||
}
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
_ "github.com/ollama/ollama/ml/backend"
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -133,7 +134,7 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func NewTextProcessor(s string) (TextProcessor, error) {
|
||||
func NewTextProcessor(s string) (tokenizer.Tokenizer, error) {
|
||||
r, err := os.Open(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -150,7 +151,7 @@ func NewTextProcessor(s string) (TextProcessor, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tp, ok := m.(TextProcessor)
|
||||
tp, ok := m.(tokenizer.Tokenizer)
|
||||
if !ok {
|
||||
return nil, ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
@@ -10,11 +10,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
TypeEmbedding *nn.Embedding `gguf:"token_types"`
|
||||
@@ -129,7 +130,7 @@ func (o Options) headDim() int {
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
vocab := &model.Vocabulary{
|
||||
vocab := &tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -153,17 +154,17 @@ func New(c fs.Config) (model.Model, error) {
|
||||
},
|
||||
}
|
||||
|
||||
var processor model.TextProcessor
|
||||
var t tokenizer.Tokenizer
|
||||
switch c.String("tokenizer.ggml.model", "bert") {
|
||||
case "bert":
|
||||
processor = model.NewWordPiece(vocab, true)
|
||||
t = tokenizer.NewWordPiece(vocab, true)
|
||||
default:
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
return &Model{
|
||||
TextProcessor: processor,
|
||||
Layers: make([]EncoderLayer, c.Uint("block_count")),
|
||||
Tokenizer: t,
|
||||
Layers: make([]EncoderLayer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -222,7 +223,7 @@ func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -277,8 +278,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -10,11 +10,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
Sam *samModel `gguf:"s"`
|
||||
Vision *visionModel `gguf:"v"`
|
||||
@@ -134,8 +135,8 @@ func init() {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
TextProcessor: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -27,7 +28,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.SentencePiece
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -43,8 +44,8 @@ const (
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewSentencePiece(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
|
||||
@@ -7,11 +7,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type embedModel struct {
|
||||
model.Base
|
||||
model.SentencePiece
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*TextModel
|
||||
poolingType pooling.Type
|
||||
@@ -31,8 +32,8 @@ func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, erro
|
||||
|
||||
func newEmbedModel(c fs.Config) (model.Model, error) {
|
||||
m := &embedModel{
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewSentencePiece(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
|
||||
@@ -12,11 +12,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
*TextModel
|
||||
@@ -54,7 +55,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
vocabulary := model.Vocabulary{
|
||||
vocabulary := tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -70,19 +71,19 @@ func New(c fs.Config) (model.Model, error) {
|
||||
),
|
||||
}
|
||||
|
||||
var processor model.TextProcessor
|
||||
var t tokenizer.Tokenizer
|
||||
switch c.String("tokenizer.ggml.model") {
|
||||
case "gpt2":
|
||||
processor = model.NewBytePairEncoding(&vocabulary)
|
||||
t = tokenizer.NewBytePairEncoding(&vocabulary)
|
||||
default:
|
||||
// Previous uploads of Gemma 3 on Ollama did not have token 106
|
||||
// (i.e. "<end_of_turn>") so we need to add in case it's not already present
|
||||
vocabulary.EOS = append(vocabulary.EOS, int32(c.Uint("tokenizer.ggml.eot_token_id", 106)))
|
||||
processor = model.NewSentencePiece(&vocabulary)
|
||||
t = tokenizer.NewSentencePiece(&vocabulary)
|
||||
}
|
||||
|
||||
m := Model{
|
||||
TextProcessor: processor,
|
||||
Tokenizer: t,
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
TextModel: newTextModel(c),
|
||||
|
||||
@@ -6,11 +6,12 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.SentencePiece
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*TextModel
|
||||
}
|
||||
@@ -23,8 +24,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
TextModel: newTextModel(c),
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewSentencePiece(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
var ErrOldModelFormat = errors.New("this model uses a weight format that is no longer supported; please re-download it")
|
||||
@@ -198,7 +199,7 @@ func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -236,8 +237,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -11,11 +11,12 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v"`
|
||||
@@ -37,8 +38,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
allEOS := append([]int32{eosTokenID}, eosTokenIDs...)
|
||||
|
||||
m := &Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -12,11 +12,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Transformer struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
TransformerBlocks []TransformerBlock `gguf:"blk"`
|
||||
@@ -196,8 +197,8 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Optio
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Transformer{
|
||||
TransformerBlocks: make([]TransformerBlock, c.Uint("block_count")),
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -59,7 +60,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -78,7 +79,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
vocabulary := model.Vocabulary{
|
||||
vocabulary := tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -104,8 +105,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
TextProcessor: model.NewBytePairEncoding(&vocabulary, pretokenizers...),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
headDim: int(c.Uint("attention.key_length")),
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -25,7 +26,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -41,8 +42,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
return nil, model.ErrUnsupportedModel
|
||||
}
|
||||
|
||||
var processor model.TextProcessor
|
||||
vocabulary := model.Vocabulary{
|
||||
var processor tokenizer.Tokenizer
|
||||
vocabulary := tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -80,16 +81,16 @@ func New(c fs.Config) (model.Model, error) {
|
||||
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
}
|
||||
}
|
||||
processor = model.NewBytePairEncoding(&vocabulary, pretokenizers...)
|
||||
processor = tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...)
|
||||
case "llama":
|
||||
processor = model.NewSentencePiece(&vocabulary)
|
||||
processor = tokenizer.NewSentencePiece(&vocabulary)
|
||||
default:
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
m := Model{
|
||||
TextProcessor: processor,
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Tokenizer: processor,
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
|
||||
@@ -11,11 +11,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
ImageProcessor
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
@@ -33,8 +34,8 @@ func (p *Projector) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor {
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -11,11 +11,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v"`
|
||||
@@ -28,12 +29,12 @@ type Model struct {
|
||||
var _ model.MultimodalProcessor = (*Model)(nil)
|
||||
|
||||
// Implement TextProcessor interface
|
||||
var _ model.TextProcessor = (*Model)(nil)
|
||||
var _ tokenizer.Tokenizer = (*Model)(nil)
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := &Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -11,11 +11,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
*TextModel
|
||||
@@ -32,8 +33,8 @@ const (
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -11,11 +11,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
TypeEmbedding *nn.Embedding `gguf:"token_types"`
|
||||
@@ -178,29 +179,6 @@ func New(c fs.Config) (model.Model, error) {
|
||||
numHeads := int(c.Uint("attention.head_count"))
|
||||
headDim := hiddenSize / numHeads
|
||||
|
||||
processor := model.NewWordPiece(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{
|
||||
int32(cmp.Or(
|
||||
c.Uint("tokenizer.ggml.cls_token_id"),
|
||||
c.Uint("tokenizer.ggml.bos_token_id"),
|
||||
)),
|
||||
},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
|
||||
EOS: []int32{
|
||||
int32(cmp.Or(
|
||||
c.Uint("tokenizer.ggml.separator_token_id"),
|
||||
c.Uint("tokenizer.ggml.eos_token_id"),
|
||||
)),
|
||||
},
|
||||
},
|
||||
false,
|
||||
)
|
||||
|
||||
blockCount := int(c.Uint("block_count"))
|
||||
moeEveryNLayers := int(c.Uint("moe_every_n_layers", 0))
|
||||
layers := make([]EncoderLayer, blockCount)
|
||||
@@ -219,8 +197,29 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
return &Model{
|
||||
TextProcessor: processor,
|
||||
Layers: layers,
|
||||
Tokenizer: tokenizer.NewWordPiece(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{
|
||||
int32(cmp.Or(
|
||||
c.Uint("tokenizer.ggml.cls_token_id"),
|
||||
c.Uint("tokenizer.ggml.bos_token_id"),
|
||||
)),
|
||||
},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
|
||||
EOS: []int32{
|
||||
int32(cmp.Or(
|
||||
c.Uint("tokenizer.ggml.separator_token_id"),
|
||||
c.Uint("tokenizer.ggml.eos_token_id"),
|
||||
)),
|
||||
},
|
||||
},
|
||||
false,
|
||||
),
|
||||
Layers: layers,
|
||||
Options: Options{
|
||||
hiddenSize: hiddenSize,
|
||||
numHeads: numHeads,
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -33,7 +34,7 @@ type Options struct {
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -44,28 +45,24 @@ type Model struct {
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
vocabulary := model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
}
|
||||
|
||||
processor := model.NewBytePairEncoding(
|
||||
&vocabulary,
|
||||
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
)
|
||||
|
||||
m := Model{
|
||||
TextProcessor: processor,
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -92,7 +93,7 @@ func (d DecoderLayer) Forward(ctx ml.Context, hiddenStates, positions, outputs m
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []DecoderLayer `gguf:"blk"`
|
||||
@@ -139,8 +140,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
m := Model{
|
||||
Layers: make([]DecoderLayer, c.Uint("block_count")),
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -10,11 +10,12 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v"`
|
||||
@@ -27,8 +28,8 @@ var _ model.MultimodalProcessor = (*Model)(nil)
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := &Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -7,11 +7,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type embedModel struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*Model
|
||||
poolingType pooling.Type
|
||||
@@ -34,8 +35,8 @@ func newEmbed(c fs.Config) (model.Model, error) {
|
||||
layers[i].MLP = &dense{}
|
||||
}
|
||||
m := embedModel{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -159,7 +160,7 @@ func (d *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
@@ -218,8 +219,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
// Options contains model configuration
|
||||
@@ -207,7 +208,7 @@ func (l *Layer) Forward(ctx ml.Context, layer int, hiddenStates, positions, outp
|
||||
// Model is the main Qwen3-Next model
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
@@ -353,8 +354,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -10,11 +10,12 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v"`
|
||||
@@ -172,8 +173,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
TextProcessor: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestWordPiece(t *testing.T) {
|
||||
wpm := NewWordPiece(
|
||||
&Vocabulary{
|
||||
Values: []string{"[UNK]", "[CLS]", "[SEP]", "▁hello", "▁world", "s", "▁!", "▁@", "▁#"},
|
||||
AddBOS: true,
|
||||
AddEOS: true,
|
||||
BOS: []int32{1},
|
||||
EOS: []int32{2},
|
||||
},
|
||||
true, // lowercase
|
||||
)
|
||||
|
||||
ids, err := wpm.Encode("Hello world!", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff([]int32{1, 3, 4, 6, 2}, ids); diff != "" {
|
||||
t.Errorf("unexpected ids (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
words, err := wpm.Decode(ids)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff("[CLS] hello world! [SEP]", words); diff != "" {
|
||||
t.Errorf("unexpected words (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWordPieceWords(t *testing.T) {
|
||||
var wpm WordPiece
|
||||
|
||||
basic := slices.Collect(wpm.words("Hey friend! How are you?!?"))
|
||||
if diff := cmp.Diff([]string{"Hey", "friend", "!", "How", "are", "you", "?", "!", "?"}, basic); diff != "" {
|
||||
t.Errorf("unexpected words (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
chinese := slices.Collect(wpm.words("野口里佳 Noguchi Rika"))
|
||||
if diff := cmp.Diff([]string{"野", "口", "里", "佳", "Noguchi", "Rika"}, chinese); diff != "" {
|
||||
t.Errorf("unexpected words (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
@@ -37,6 +37,7 @@ import (
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/runner/common"
|
||||
"github.com/ollama/ollama/sample"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
|
||||
_ "github.com/ollama/ollama/model/models"
|
||||
)
|
||||
@@ -210,9 +211,9 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
}
|
||||
|
||||
// calculateLogprobs converts raw logits to log probabilities and finds top K tokens
|
||||
func calculateLogprobs(logits []float32, selectedToken int32, topK int, textProcessor model.TextProcessor) []llm.Logprob {
|
||||
func calculateLogprobs(logits []float32, selectedToken int32, topK int, tok tokenizer.Tokenizer) []llm.Logprob {
|
||||
decoder := func(tokenID int) string {
|
||||
text, _ := textProcessor.Decode([]int32{int32(tokenID)})
|
||||
text, _ := tok.Decode([]int32{int32(tokenID)})
|
||||
return text
|
||||
}
|
||||
return common.CalculateLogprobs(logits, int(selectedToken), topK, decoder)
|
||||
@@ -242,7 +243,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input,
|
||||
|
||||
for i, part := range parts {
|
||||
// text - tokenize
|
||||
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
|
||||
tokens, err := s.model.(tokenizer.Tokenizer).Encode(part, i == 0)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@@ -764,7 +765,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
nextBatchTokens[i].Token = token
|
||||
|
||||
// if it's an end of sequence token, break
|
||||
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
|
||||
if s.model.(tokenizer.Tokenizer).Is(token, tokenizer.SpecialEOS) {
|
||||
// TODO (jmorganca): we should send this back
|
||||
// as it's important for the /api/generate context
|
||||
// seq.responses <- piece
|
||||
@@ -773,14 +774,14 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
continue
|
||||
}
|
||||
|
||||
piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
|
||||
piece, err := s.model.(tokenizer.Tokenizer).Decode([]int32{token})
|
||||
if err != nil {
|
||||
panic("failed to decode token")
|
||||
}
|
||||
|
||||
// Calculate logprobs if requested (after EOS check to avoid logprobs for EOS tokens)
|
||||
if seq.logprobs {
|
||||
logprobs := calculateLogprobs(logits, token, seq.topLogprobs, s.model.(model.TextProcessor))
|
||||
logprobs := calculateLogprobs(logits, token, seq.topLogprobs, s.model.(tokenizer.Tokenizer))
|
||||
seq.pendingLogprobs = append(seq.pendingLogprobs, logprobs...)
|
||||
}
|
||||
|
||||
@@ -878,7 +879,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
var grammar *sample.GrammarSampler
|
||||
var err error
|
||||
if req.Grammar != "" {
|
||||
grammar, err = sample.NewGrammarSampler(s.model.(model.TextProcessor), req.Grammar)
|
||||
grammar, err = sample.NewGrammarSampler(s.model.(tokenizer.Tokenizer), req.Grammar)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
|
||||
return
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
// token represents information about a single token during sampling
|
||||
@@ -168,15 +168,15 @@ type GrammarSampler struct {
|
||||
grammar *llama.Grammar
|
||||
}
|
||||
|
||||
func NewGrammarSampler(model model.TextProcessor, grammarStr string) (*GrammarSampler, error) {
|
||||
vocabIds := make([]uint32, len(model.Vocabulary().Values))
|
||||
pieces := make([]string, len(model.Vocabulary().Values))
|
||||
for i := range model.Vocabulary().Values {
|
||||
pieces[i], _ = model.Decode([]int32{int32(i)})
|
||||
func NewGrammarSampler(tok tokenizer.Tokenizer, grammarStr string) (*GrammarSampler, error) {
|
||||
vocabIds := make([]uint32, len(tok.Vocabulary().Values))
|
||||
pieces := make([]string, len(tok.Vocabulary().Values))
|
||||
for i := range tok.Vocabulary().Values {
|
||||
pieces[i], _ = tok.Decode([]int32{int32(i)})
|
||||
vocabIds[i] = uint32(i)
|
||||
}
|
||||
|
||||
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, model.Vocabulary().EOS)
|
||||
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, tok.Vocabulary().EOS)
|
||||
if grammar == nil {
|
||||
return nil, errors.New("sample: failed to initialize grammar")
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
func TestWeighted(t *testing.T) {
|
||||
@@ -60,10 +60,10 @@ func TestWeighted(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func modelHelper(t testing.TB) model.BytePairEncoding {
|
||||
func modelHelper(t testing.TB) tokenizer.Tokenizer {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.Open(filepath.Join("..", "model", "testdata", "llama3.2", "encoder.json"))
|
||||
f, err := os.Open(filepath.FromSlash("../tokenizer/testdata/llama3.2/encoder.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -81,8 +81,8 @@ func modelHelper(t testing.TB) model.BytePairEncoding {
|
||||
|
||||
merges := make([]string, 0, 1)
|
||||
// Only need vocab for Grammar Test
|
||||
return model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
return tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: tokens,
|
||||
Types: make([]int32, len(vocab)),
|
||||
Merges: merges,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package model
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
@@ -18,19 +18,19 @@ type BytePairEncoding struct {
|
||||
regexps []*regexp2.Regexp
|
||||
}
|
||||
|
||||
var _ TextProcessor = (*BytePairEncoding)(nil)
|
||||
var _ Tokenizer = (*BytePairEncoding)(nil)
|
||||
|
||||
func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding {
|
||||
if len(pretokenizers) == 0 {
|
||||
func NewBytePairEncoding(vocab *Vocabulary, pretokenizer ...string) BytePairEncoding {
|
||||
if len(pretokenizer) == 0 {
|
||||
// set default byte-level pretokenizer if none provided, e.g.
|
||||
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44
|
||||
pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
|
||||
// https://github.com/huggingface/tokenizer/blob/main/tokenizer/src/pre_tokenizer/byte_level.rs#L44
|
||||
pretokenizer = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
|
||||
}
|
||||
|
||||
return BytePairEncoding{
|
||||
vocab: vocab,
|
||||
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
|
||||
for _, p := range pretokenizers {
|
||||
for _, p := range pretokenizer {
|
||||
if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
|
||||
return
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package model
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
func llama(t testing.TB) BytePairEncoding {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.Open(filepath.Join("testdata", "llama3.2", "encoder.json"))
|
||||
f, err := os.Open(filepath.FromSlash("testdata/llama3.2/encoder.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -43,7 +43,7 @@ func llama(t testing.TB) BytePairEncoding {
|
||||
}
|
||||
}
|
||||
|
||||
f, err = os.Open(filepath.Join("testdata", "llama3.2", "vocab.bpe"))
|
||||
f, err = os.Open(filepath.FromSlash("testdata/llama3.2/vocab.bpe"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package model
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
@@ -17,7 +17,7 @@ type SentencePiece struct {
|
||||
vocab *Vocabulary
|
||||
}
|
||||
|
||||
var _ TextProcessor = (*SentencePiece)(nil)
|
||||
var _ Tokenizer = (*SentencePiece)(nil)
|
||||
|
||||
func (spm SentencePiece) Vocabulary() *Vocabulary {
|
||||
return spm.vocab
|
||||
@@ -224,7 +224,7 @@ func (spm SentencePiece) Decode(ids []int32) (string, error) {
|
||||
data := spm.vocab.Decode(id)
|
||||
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
|
||||
|
||||
// For tokenizers that use byte tokens like "<0xEA>"
|
||||
// For tokenizer that use byte tokens like "<0xEA>"
|
||||
// convert them to the partial unicode character
|
||||
// so they are buffered correctly by the runner instead
|
||||
// of being sent back to the api as "<0xEA>"
|
||||
@@ -1,4 +1,4 @@
|
||||
package model
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
func loadSentencePieceVocab(t *testing.T) SentencePiece {
|
||||
t.Helper()
|
||||
|
||||
bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
|
||||
bts, err := os.ReadFile(filepath.FromSlash("testdata/gemma2/tokenizer.model"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package model
|
||||
package tokenizer
|
||||
|
||||
const (
|
||||
TOKEN_TYPE_NORMAL = iota + 1
|
||||
@@ -9,7 +9,7 @@ const (
|
||||
TOKEN_TYPE_BYTE
|
||||
)
|
||||
|
||||
type TextProcessor interface {
|
||||
type Tokenizer interface {
|
||||
Encode(s string, addSpecial bool) ([]int32, error)
|
||||
Decode([]int32) (string, error)
|
||||
Is(int32, Special) bool
|
||||
@@ -1,4 +1,4 @@
|
||||
package model
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
@@ -1,4 +1,4 @@
|
||||
package model
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"testing"
|
||||
@@ -1,4 +1,4 @@
|
||||
package model
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -32,7 +32,7 @@ var wordPieceReplacer = strings.NewReplacer(
|
||||
" 're", "'re",
|
||||
)
|
||||
|
||||
// Decode implements TextProcessor.
|
||||
// Decode implements Tokenizer.
|
||||
func (wpm WordPiece) Decode(ids []int32) (string, error) {
|
||||
var sb strings.Builder
|
||||
for i, id := range ids {
|
||||
@@ -96,7 +96,7 @@ func (wpm WordPiece) words(s string) iter.Seq[string] {
|
||||
}
|
||||
}
|
||||
|
||||
// Encode implements TextProcessor.
|
||||
// Encode implements Tokenizer.
|
||||
func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
var ids []int32
|
||||
|
||||
@@ -151,17 +151,17 @@ func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
// Is implements TextProcessor.
|
||||
// Is implements Tokenizer.
|
||||
func (wpm WordPiece) Is(id int32, special Special) bool {
|
||||
return wpm.vocab.Is(id, special)
|
||||
}
|
||||
|
||||
// Vocabulary implements TextProcessor.
|
||||
// Vocabulary implements Tokenizer.
|
||||
func (wpm WordPiece) Vocabulary() *Vocabulary {
|
||||
return wpm.vocab
|
||||
}
|
||||
|
||||
var _ TextProcessor = (*WordPiece)(nil)
|
||||
var _ Tokenizer = (*WordPiece)(nil)
|
||||
|
||||
func NewWordPiece(vocab *Vocabulary, lowercase bool) WordPiece {
|
||||
return WordPiece{
|
||||
@@ -1,4 +1,4 @@
|
||||
package model
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"slices"
|
||||
@@ -1,77 +0,0 @@
|
||||
package kvcache
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
"github.com/ollama/ollama/x/model/input"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrKvCacheFull = errors.New("could not find a kv cache slot")
|
||||
ErrNotSupported = errors.New("model does not support operation")
|
||||
)
|
||||
|
||||
type Cache interface {
|
||||
// ** used by model implementations **
|
||||
|
||||
// SetLayer sets the active layer of the cache
|
||||
SetLayer(layer int)
|
||||
|
||||
// Get returns the history of key and value tensors plus a mask
|
||||
//
|
||||
// The shape of the tensors is documented in the specific
|
||||
// cache implementation used.
|
||||
Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)
|
||||
|
||||
// Put stores a batch of key and value in the cache
|
||||
//
|
||||
// The shape of the tensors is documented in the specific
|
||||
// cache implementation used.
|
||||
Put(ctx ml.Context, key, value ml.Tensor)
|
||||
|
||||
// SetConfig controls optimizations (mostly backend-specific) that may transform
|
||||
// the output of the cache to work better with specific kernels. If not called,
|
||||
// the backend settings will be used. This works well when calling Attention.
|
||||
//
|
||||
// The config can be overridden by models, especially if they require vanilla
|
||||
// output when implementing their own version of attention. To do this, pass
|
||||
// an empty ml.CacheConfig.
|
||||
//
|
||||
// Most models will not need to use this.
|
||||
SetConfig(ml.CacheConfig)
|
||||
|
||||
// ** cache management **
|
||||
|
||||
// Init sets up runtime parameters.
|
||||
// backend: Used to allocate cache data storage and execute management operations (such as defrag)
|
||||
// dtype: The data type for storing cache entries
|
||||
// maxSequences: The maximum number of sequences stored in the cache - across all batches
|
||||
// capacity: The number of cache entries to store, per sequence
|
||||
// maxBatch: The maximum number of tokens that can occur in a single batch
|
||||
Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
|
||||
|
||||
// Close closes the cache and frees resources associated with it
|
||||
Close()
|
||||
|
||||
// StartForward is called before the start of the model's forward pass.
|
||||
// For each token in the coming batch, there must be a corresponding
|
||||
// entry in positions and seqs. reserve is to preallocate memory
|
||||
// without actually storing data in the cache.
|
||||
StartForward(ctx ml.Context, batch input.Batch, reserve bool) error
|
||||
|
||||
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
|
||||
CopyPrefix(srcSeq, dstSeq int, len int32)
|
||||
|
||||
// CanResume returns true if the cache can continue with the next token at
|
||||
// the given position and sequence. Assumes that the caller has already
|
||||
// verified the contents of the cache.
|
||||
CanResume(seq int, pos int32) bool
|
||||
|
||||
// Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set
|
||||
// endIndex to math.MaxInt32 to remove everything starting at beginIndex.
|
||||
//
|
||||
// If an error occurs, the entire context for the sequence should be
|
||||
// removed by calling Remove(seq, 0, math.MaxInt32)
|
||||
Remove(seq int, beginIndex, endIndex int32) error
|
||||
}
|
||||
@@ -1,144 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package kvcache
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
"github.com/ollama/ollama/x/model/input"
|
||||
)
|
||||
|
||||
// Causal cache stores K and V tensors according to their position in the
|
||||
// sequence. Returns the history and a mask for attending to past tokens
|
||||
type Causal struct {
|
||||
DType ml.DType
|
||||
|
||||
// locations for data storage for this batch
|
||||
curLocPut ml.Tensor
|
||||
|
||||
// locations for data storage for this batch
|
||||
curLocGet ml.Tensor
|
||||
|
||||
// the active layer for Get and Put
|
||||
curLayer int
|
||||
|
||||
capacity int
|
||||
|
||||
offset int
|
||||
|
||||
backend ml.Backend
|
||||
ctxs map[int]ml.Context
|
||||
keys, values map[int]ml.Tensor
|
||||
|
||||
// TODO is this needed per layer, or will it always be consistent?
|
||||
kHeadDims, vHeadDims, numKVHeads map[int]int
|
||||
}
|
||||
|
||||
func NewCausalCache() *Causal {
|
||||
return &Causal{
|
||||
ctxs: make(map[int]ml.Context),
|
||||
keys: make(map[int]ml.Tensor),
|
||||
values: make(map[int]ml.Tensor),
|
||||
kHeadDims: make(map[int]int),
|
||||
vHeadDims: make(map[int]int),
|
||||
numKVHeads: make(map[int]int),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
c.DType = dtype
|
||||
c.capacity = capacity
|
||||
c.backend = backend
|
||||
}
|
||||
|
||||
func (c *Causal) SetConfig(config ml.CacheConfig) {}
|
||||
|
||||
func (c *Causal) SetLayer(layer int) {
|
||||
c.curLayer = layer
|
||||
}
|
||||
|
||||
func (c *Causal) Close() {
|
||||
// slog.Info("XXX Causal.Close called", "number of contexts", len(c.ctxs))
|
||||
for _, ctx := range c.ctxs {
|
||||
ctx.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
locsPut := make([]int32, len(batch.Positions))
|
||||
for i := c.offset; i < len(batch.Positions); i++ {
|
||||
locsPut[i-c.offset] = int32(i)
|
||||
}
|
||||
c.offset += len(batch.Positions)
|
||||
locsGet := make([]int32, c.offset)
|
||||
for i := range c.offset {
|
||||
locsGet[i] = int32(i)
|
||||
}
|
||||
c.curLocGet = ctx.Input().FromInts(locsGet, len(locsGet))
|
||||
c.curLocPut = ctx.Input().FromInts(locsPut, len(locsPut))
|
||||
// slog.Info("XXX Causal.StartForward", "offset", c.offset, "put", locsPut, "get", locsGet)
|
||||
|
||||
return nil
|
||||
}
|
||||
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
kHeadDim := key.Dim(3)
|
||||
vHeadDim := value.Dim(3)
|
||||
numKVHeads := key.Dim(1)
|
||||
batchSize := key.Dim(2)
|
||||
kCellSize := kHeadDim * numKVHeads
|
||||
vCellSize := vHeadDim * numKVHeads
|
||||
// slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize, "kCellSize", kCellSize, "vCellSize", vCellSize)
|
||||
|
||||
if _, ok := c.ctxs[c.curLayer]; !ok {
|
||||
// slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer)
|
||||
c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
|
||||
}
|
||||
|
||||
if _, ok := c.keys[c.curLayer]; !ok {
|
||||
// slog.Info("XXX Causal.Put allocating keys and values", "c.curLayer", c.curLayer, "shape", []int{c.capacity, kCellSize})
|
||||
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, kCellSize)
|
||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, vCellSize)
|
||||
c.kHeadDims[c.curLayer] = kHeadDim
|
||||
c.vHeadDims[c.curLayer] = vHeadDim
|
||||
c.numKVHeads[c.curLayer] = numKVHeads
|
||||
}
|
||||
key = key.Reshape(ctx, batchSize, 1, kCellSize)
|
||||
|
||||
// slog.Info("XXX Causal.Put ", "c.keys[c.curLayer]", c.keys[c.curLayer])
|
||||
// slog.Info("XXX Causal.Put ", "c.curLocPut", c.curLocPut)
|
||||
// slog.Info("XXX Causal.Put ", "key", key)
|
||||
ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, key, []int{0}))
|
||||
value = value.Reshape(ctx, batchSize, 1, vCellSize)
|
||||
ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, value, []int{0}))
|
||||
|
||||
}
|
||||
|
||||
func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
key := c.keys[c.curLayer]
|
||||
value := c.values[c.curLayer]
|
||||
|
||||
kHeadDim := c.kHeadDims[c.curLayer]
|
||||
vHeadDim := c.vHeadDims[c.curLayer]
|
||||
numKVHeads := c.numKVHeads[c.curLayer]
|
||||
// rowSize := numKVHeads * c.curBatchSize
|
||||
// cachedSize := c.curMask.Dim(1)
|
||||
cachedSize := c.curLocGet.Dim(0)
|
||||
// kCellSize := kHeadDim * numKVHeads
|
||||
// vCellSize := vHeadDim * numKVHeads
|
||||
// slog.Info("XXX Causal.Get", "shape", []int{1, numKVHeads, cachedSize, kHeadDim})
|
||||
|
||||
key = key.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim)
|
||||
value = value.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim)
|
||||
return key, value, nil
|
||||
}
|
||||
|
||||
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *Causal) CanResume(seq int, pos int32) bool {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
@@ -1,156 +0,0 @@
|
||||
package kvcache
|
||||
|
||||
// import (
|
||||
// "fmt"
|
||||
|
||||
// "github.com/ollama/ollama/ml"
|
||||
// "github.com/ollama/ollama/model/input"
|
||||
// )
|
||||
|
||||
// // Encoder cache stores K and V tensors that are position independent
|
||||
// //
|
||||
// // The tensors can be of any shape and will be returned as they were stored
|
||||
// // The mask is currently always nil
|
||||
// //
|
||||
// // Not currently safe for multiple sequences
|
||||
// type EncoderCache struct {
|
||||
// // config controls mostly backend-specific optimizations
|
||||
// config *ml.CacheConfig
|
||||
|
||||
// // ** current forward pass **
|
||||
|
||||
// // the active layer for Get and Put
|
||||
// curLayer int
|
||||
|
||||
// // if something is stored during this pass, this
|
||||
// // will be the position (but there is no guarantee
|
||||
// // anything will be stored)
|
||||
// curPos int32
|
||||
|
||||
// // curReserve indicates that this forward pass is only for
|
||||
// // memory reservation and we should not update our metadata
|
||||
// // based on it.
|
||||
// curReserve bool
|
||||
|
||||
// // ** cache metadata **
|
||||
|
||||
// // was something stored in the cache?
|
||||
// encoderCached bool
|
||||
|
||||
// // position of the cached data
|
||||
// encoderPos int32
|
||||
|
||||
// // ** cache data storage **
|
||||
// backend ml.Backend
|
||||
// ctxs map[int]ml.Context
|
||||
// keys, values map[int]ml.Tensor
|
||||
// }
|
||||
|
||||
// func NewEncoderCache() *EncoderCache {
|
||||
// return &EncoderCache{
|
||||
// ctxs: make(map[int]ml.Context),
|
||||
// keys: make(map[int]ml.Tensor),
|
||||
// values: make(map[int]ml.Tensor),
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
// if c.config == nil {
|
||||
// var config ml.CacheConfig
|
||||
// if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
||||
// config = cc.CacheConfig()
|
||||
// }
|
||||
// c.config = &config
|
||||
// }
|
||||
|
||||
// if maxSequences > 1 {
|
||||
// panic(fmt.Errorf("encoder cache does not support multiple sequences; requested: %v", maxSequences))
|
||||
// }
|
||||
|
||||
// if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
|
||||
// panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
|
||||
// }
|
||||
|
||||
// c.backend = backend
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
|
||||
// if c.config != nil {
|
||||
// panic("config cannot be changed after being previously set, either by the model or backend")
|
||||
// }
|
||||
|
||||
// c.config = &config
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) Close() {
|
||||
// for _, ctx := range c.ctxs {
|
||||
// ctx.Close()
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
// // We work with the most recent image
|
||||
// if len(batch.Multimodal) > 0 {
|
||||
// c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
|
||||
// }
|
||||
|
||||
// c.curReserve = reserve
|
||||
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) SetLayer(layer int) {
|
||||
// c.curLayer = layer
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) EncoderCached() bool {
|
||||
// return c.encoderCached
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
// return c.keys[c.curLayer], c.values[c.curLayer], nil
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
// if !c.curReserve {
|
||||
// c.encoderPos = c.curPos
|
||||
// c.encoderCached = true
|
||||
// }
|
||||
|
||||
// if c.config.PermutedV {
|
||||
// value = value.Transpose(ctx, 1, 2, 0, 3)
|
||||
// }
|
||||
|
||||
// if _, ok := c.ctxs[c.curLayer]; !ok {
|
||||
// c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
|
||||
// }
|
||||
|
||||
// if _, ok := c.keys[c.curLayer]; !ok {
|
||||
// c.keys[c.curLayer] = c.ctxs[c.curLayer].Empty(key.DType(), key.Shape()...)
|
||||
// }
|
||||
|
||||
// if _, ok := c.values[c.curLayer]; !ok {
|
||||
// c.values[c.curLayer] = c.ctxs[c.curLayer].Empty(value.DType(), value.Shape()...)
|
||||
// }
|
||||
|
||||
// ctx.Forward(
|
||||
// key.Copy(ctx, c.keys[c.curLayer]),
|
||||
// value.Copy(ctx, c.values[c.curLayer]),
|
||||
// )
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
// panic("encoder cache does not support multiple sequences")
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) CanResume(seq int, pos int32) bool {
|
||||
// return true
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
// if c.encoderPos >= beginIndex && c.encoderPos < endIndex {
|
||||
// c.encoderCached = false
|
||||
// }
|
||||
|
||||
// return nil
|
||||
// }
|
||||
@@ -1,110 +0,0 @@
|
||||
package kvcache
|
||||
|
||||
// import (
|
||||
// "math"
|
||||
|
||||
// "github.com/ollama/ollama/ml"
|
||||
// "github.com/ollama/ollama/model/input"
|
||||
// )
|
||||
|
||||
// // Wrapper cache is a container for multiple types of caches,
|
||||
// // such as for the encoding and decoding portions of a model.
|
||||
// type WrapperCache struct {
|
||||
// // caches we are wrapping
|
||||
// caches []Cache
|
||||
|
||||
// // cache to be used for this layer
|
||||
// curType int
|
||||
// }
|
||||
|
||||
// func NewWrapperCache(caches ...Cache) *WrapperCache {
|
||||
// return &WrapperCache{
|
||||
// caches: caches,
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
// for _, cache := range c.caches {
|
||||
// cache.Init(backend, dtype, maxSequences, capacity, maxBatch)
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) SetConfig(config ml.CacheConfig) {
|
||||
// for _, cache := range c.caches {
|
||||
// cache.SetConfig(config)
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) Close() {
|
||||
// for _, cache := range c.caches {
|
||||
// cache.Close()
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
// for i, cache := range c.caches {
|
||||
// err := cache.StartForward(ctx, batch, reserve)
|
||||
// if err != nil {
|
||||
// // unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
|
||||
// for j := i - 1; j >= 0; j-- {
|
||||
// for k := range batch.Positions {
|
||||
// _ = c.caches[j].Remove(batch.Sequences[k], batch.Positions[k], math.MaxInt32)
|
||||
// }
|
||||
// }
|
||||
// return err
|
||||
// }
|
||||
// }
|
||||
|
||||
// c.curType = 0
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) SetLayer(layer int) {
|
||||
// for _, cache := range c.caches {
|
||||
// cache.SetLayer(layer)
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) SetLayerType(layerType int) {
|
||||
// c.curType = layerType
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) UnderlyingCache() Cache {
|
||||
// return c.caches[c.curType]
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
// return c.caches[c.curType].Get(ctx)
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
// c.caches[c.curType].Put(ctx, key, value)
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
// for _, cache := range c.caches {
|
||||
// cache.CopyPrefix(srcSeq, dstSeq, len)
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) CanResume(seq int, pos int32) bool {
|
||||
// for _, cache := range c.caches {
|
||||
// if !cache.CanResume(seq, pos) {
|
||||
// return false
|
||||
// }
|
||||
// }
|
||||
|
||||
// return true
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
// // If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail
|
||||
// for _, cache := range c.caches {
|
||||
// err := cache.Remove(seq, beginIndex, endIndex)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// }
|
||||
|
||||
// return nil
|
||||
// }
|
||||
433
x/ml/backend.go
433
x/ml/backend.go
@@ -1,433 +0,0 @@
|
||||
package ml
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
)
|
||||
|
||||
type Backend interface {
|
||||
// Close frees all memory associated with this backend
|
||||
// Close()
|
||||
|
||||
// Load(ctx context.Context, progress func(float32)) error
|
||||
|
||||
// BackendMemory returns the memory allocations that were made for this model
|
||||
// BackendMemory() BackendMemory
|
||||
|
||||
Config() fs.Config
|
||||
Get(name string) Tensor
|
||||
NewContext() Context
|
||||
// NewContextSize(size int) Context
|
||||
|
||||
// Enumerate the devices available for inference via this backend
|
||||
// BackendDevices() []DeviceInfo
|
||||
}
|
||||
|
||||
// BackendCacheConfig should be implemented by backends that need special output
|
||||
// from the cache to meet specific requirements. It is frequently implemented in
|
||||
// conjunction with ScaledDotProductAttention.
|
||||
type BackendCacheConfig interface {
|
||||
CacheConfig() CacheConfig
|
||||
}
|
||||
|
||||
// CacheConfig controls optimizations (mostly backend-specific) that may transform
|
||||
// the output the cache to work better with specific kernels.
|
||||
type CacheConfig struct {
|
||||
// CachePadding specifies the multiple for the number of tokens of cache history
|
||||
// that will be returned from cache Get for k, v and mask. The capacity of the
|
||||
// cache itself will also be increased to a multiple of this size if needed.
|
||||
CachePadding int
|
||||
|
||||
// PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put
|
||||
// and return the permuted version via Get. This uses the cache copy operation
|
||||
// to avoid a Contiguous call on the permuted tensor.
|
||||
PermutedV bool
|
||||
|
||||
// MaskDType specifies the data type for generating the mask. If unset it will
|
||||
// default to DTypeF32.
|
||||
MaskDType DType
|
||||
|
||||
// MaskBatchPadding specifies the multiple for the batch size dimension in the mask.
|
||||
// Any position that does not correspond to an actual token will be filled with -Inf.
|
||||
MaskBatchPadding int
|
||||
}
|
||||
|
||||
// BackendParams controls how the backend loads and executes models
|
||||
type BackendParams struct {
|
||||
// AllocMemory causes the backend to allocate memory for the model. If
|
||||
// false, this is only being used for discovering the required amount of
|
||||
// memory and cannot load the model for running.
|
||||
AllocMemory bool
|
||||
|
||||
// NumThreads sets the number of threads to use if running on the CPU
|
||||
NumThreads int
|
||||
|
||||
// GPULayers is the set of layers to offload to GPUs
|
||||
GPULayers GPULayersList
|
||||
|
||||
// FlashAttention indicates that we should use a fused flash attention kernel
|
||||
FlashAttention bool
|
||||
}
|
||||
|
||||
var backends = make(map[string]func(string, BackendParams) (Backend, error))
|
||||
|
||||
func RegisterBackend(name string, f func(string, BackendParams) (Backend, error)) {
|
||||
if _, ok := backends[name]; ok {
|
||||
panic("backend: backend already registered")
|
||||
}
|
||||
|
||||
backends[name] = f
|
||||
}
|
||||
|
||||
func NewBackend(modelPath string, params BackendParams) (Backend, error) {
|
||||
be := os.Getenv("OLLAMA_BACKEND")
|
||||
if be == "" {
|
||||
be = "mlx"
|
||||
slog.Info("Defaulting to " + be + ". Set OLLAMA_BACKEND to override")
|
||||
}
|
||||
slog.Info("Loading new engine", "backend", be)
|
||||
if backend, ok := backends[be]; ok {
|
||||
return backend(modelPath, params)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unsupported backend")
|
||||
}
|
||||
|
||||
type Context interface {
|
||||
Empty(dtype DType, shape ...int) Tensor
|
||||
Zeros(dtype DType, shape ...int) Tensor
|
||||
// FromBytes(dtype DType, s []byte, shape ...int) Tensor
|
||||
FromFloats(s []float32, shape ...int) Tensor
|
||||
FromInts(s []int32, shape ...int) Tensor
|
||||
RandomNormal(shape []int, dtype DType, loc, scale float32, key Tensor) Tensor
|
||||
|
||||
// Arange creates a 1D tensor with values within an interval (start, stop] increased by step.
|
||||
Arange(start, stop, step float32, dtype DType) Tensor
|
||||
|
||||
Forward(...Tensor) Context
|
||||
|
||||
// SetBatchSize provides a hint on the batch size to optimize processing
|
||||
// Uses heuristics if not set
|
||||
// SetBatchSize(int)
|
||||
|
||||
Compute(...Tensor)
|
||||
// ComputeWithNotify(func(), ...Tensor) // notify callback once compute has begun
|
||||
|
||||
// Reserve is analogous to Compute but rather than executing a
|
||||
// graph, simply preallocates memory. Typically called with a
|
||||
// worst case graph to ensure all resources are available for
|
||||
// for future inference.
|
||||
// Reserve()
|
||||
|
||||
// MaxGraphNodes() int
|
||||
Close()
|
||||
|
||||
// Input returns a context appropriate for creating tensors that are
|
||||
// inputs to the model (which includes things like output locations)
|
||||
Input() Context
|
||||
|
||||
// Layer returns a context appropriate for creating intermediate tensors
|
||||
Layer(int) Context
|
||||
|
||||
// Load a tensor from "filename" safetensors file, and compare with the input tensor
|
||||
// Returns error if the shape is inconsistent, or similarity measures are below 99%
|
||||
CompareWith(filename string, tensors map[string]Tensor, abortOnError bool) error
|
||||
}
|
||||
|
||||
type RoPEOptions struct {
|
||||
Base *float32
|
||||
Freqs Tensor
|
||||
}
|
||||
|
||||
func WithRoPEBase(base float32) func(*RoPEOptions) {
|
||||
return func(opts *RoPEOptions) {
|
||||
opts.Base = &base
|
||||
}
|
||||
}
|
||||
|
||||
func WithRoPEFreqs(freqs Tensor) func(*RoPEOptions) {
|
||||
return func(opts *RoPEOptions) {
|
||||
opts.Freqs = freqs
|
||||
}
|
||||
}
|
||||
|
||||
type Tensor interface {
|
||||
ToString() string
|
||||
RoPE(ctx Context, dims int, traditional bool, scale float32, offset int, options ...func(*RoPEOptions)) Tensor
|
||||
ScaledDotProductAttention(ctx Context, keys, values Tensor, scale float64, maskMode string, mask Tensor, sinks Tensor) Tensor
|
||||
TakeAxes(ctx Context, indicies Tensor, axes int) Tensor
|
||||
// TakeAxes(ctx Context, axes int, indicies ...int) Tensor
|
||||
|
||||
Dim(n int) int
|
||||
Stride(n int) int
|
||||
|
||||
Shape() []int
|
||||
DType() DType
|
||||
// Cast(ctx Context, dtype DType) Tensor
|
||||
|
||||
// Bytes() []byte
|
||||
Floats() []float32
|
||||
Ints() []int32
|
||||
|
||||
// FromBytes([]byte)
|
||||
// FromFloats([]float32)
|
||||
// FromInts([]int32)
|
||||
|
||||
Add(ctx Context, t2 Tensor) Tensor
|
||||
Sub(ctx Context, t2 Tensor) Tensor
|
||||
// Mul(ctx Context, t2 Tensor) Tensor
|
||||
// Div(ctx Context, t2 Tensor) Tensor
|
||||
|
||||
Max(ctx Context, axes []int, keepDims bool) Tensor
|
||||
Min(ctx Context, axes []int, keepDims bool) Tensor
|
||||
|
||||
Matmul(ctx Context, a2 Tensor) Tensor
|
||||
// Mulmat(ctx Context, t2 Tensor) Tensor
|
||||
// MulmatFullPrec(ctx Context, t2 Tensor) Tensor
|
||||
// MulmatID(ctx Context, t2, ids Tensor) Tensor
|
||||
// AddID(ctx Context, t2, ids Tensor) Tensor
|
||||
|
||||
Softmax(ctx Context) Tensor
|
||||
L2Norm(ctx Context, eps float32) Tensor
|
||||
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
|
||||
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
|
||||
Scale(ctx Context, s float64) Tensor
|
||||
// SumRows(ctx Context) Tensor
|
||||
|
||||
AvgPool2D(ctx Context, k, s int, p float32) Tensor
|
||||
Conv2D(ctx Context, weight Tensor, stride0, stride1, padding0, padding1, dilation0, dilation1, groups int) Tensor
|
||||
Conv3D(ctx Context, weight Tensor, stride0, stride1, stride2, padding0, padding1, padding2, dilation0, dilation1, dilation2, groups int) Tensor
|
||||
|
||||
// IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||
|
||||
// Sin(ctx Context) Tensor
|
||||
// Cos(ctx Context) Tensor
|
||||
// Tanh(ctx Context) Tensor
|
||||
GELU(ctx Context, up ...Tensor) Tensor
|
||||
// QuickGELU(ctx Context, up ...Tensor) Tensor
|
||||
// SILU(ctx Context, up ...Tensor) Tensor
|
||||
// RELU(ctx Context, up ...Tensor) Tensor
|
||||
// Sigmoid(ctx Context) Tensor
|
||||
|
||||
// AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit]
|
||||
// SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor
|
||||
|
||||
Reshape(ctx Context, shape ...int) Tensor
|
||||
AsStrided(ctx Context, shape, strides []int, offset int) Tensor
|
||||
Transpose(ctx Context, shape ...int) Tensor
|
||||
Contiguous(ctx Context, allowColMajor bool) Tensor
|
||||
|
||||
// Pad(ctx Context, shape ...int) Tensor
|
||||
|
||||
// Stack(ctx Context, dim int, s ...Tensor) Tensor
|
||||
|
||||
// Repeat repeats the tensor n times along dimension dim
|
||||
// Repeat(ctx Context, dim, n int) Tensor
|
||||
// Concat(ctx Context, t2 Tensor, dim int) Tensor
|
||||
// Rows(ctx Context, t2 Tensor) Tensor
|
||||
|
||||
// TODO these probably aren't actually needed - false starts on trying to wire up cache
|
||||
// SliceUpdate(ctx Context, update Tensor, start, stop, strides []int) Tensor
|
||||
// SliceUpdateDynamic(ctx Context, update, start Tensor, axes []int) Tensor
|
||||
// PutAlongAxis(ctx Context, indicies, values Tensor, axis int) Tensor
|
||||
|
||||
Scatter(ctx Context, indicies []Tensor, updates Tensor, axes []int) Tensor
|
||||
|
||||
Copy(ctx Context, t2 Tensor) Tensor
|
||||
// Duplicate(ctx Context) Tensor
|
||||
|
||||
// Slice(ctx Context, dim, low, high, step int) Tensor
|
||||
// Chunk(ctx Context, dim int, size int) []Tensor
|
||||
// ChunkSections(ctx Context, dim int, sections ...int) []Tensor
|
||||
|
||||
// TopK(ctx Context, k int) Tensor
|
||||
// Argsort(ctx Context) Tensor
|
||||
// Mean(ctx Context) Tensor
|
||||
// Variance(ctx Context) Tensor
|
||||
// Stddev(ctx Context) Tensor
|
||||
// Sqr(ctx Context) Tensor
|
||||
// Sqrt(ctx Context) Tensor
|
||||
|
||||
// Interpolate(ctx Context, dims [4]int, samplingMode SamplingMode) Tensor
|
||||
}
|
||||
|
||||
// ScaledDotProductAttention implements a fused attention
|
||||
// operation equivalent to following code on a tensor named
|
||||
// query:
|
||||
//
|
||||
// query = query.Permute(ctx, 0, 2, 1, 3)
|
||||
// key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
//
|
||||
// kq := key.MulmatFullPrec(ctx, query)
|
||||
//
|
||||
// kq = kq.Scale(ctx, scale)
|
||||
//
|
||||
// if mask != nil {
|
||||
// kq = kq.Add(ctx, mask)
|
||||
// }
|
||||
//
|
||||
// kq = kq.Softmax(ctx)
|
||||
//
|
||||
// kqv := value.Mulmat(ctx, kq)
|
||||
// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
// type ScaledDotProductAttention interface {
|
||||
// ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64) Tensor
|
||||
// }
|
||||
|
||||
// type number interface {
|
||||
// ~int | ~int8 | ~int16 | ~int32 | ~int64 |
|
||||
// ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
|
||||
// ~float32 | ~float64 |
|
||||
// ~complex64 | ~complex128
|
||||
// }
|
||||
|
||||
// func mul[T number](s ...T) T {
|
||||
// p := T(1)
|
||||
// for _, v := range s {
|
||||
// p *= v
|
||||
// }
|
||||
|
||||
// return p
|
||||
// }
|
||||
|
||||
// type DumpOptions func(*dumpOptions)
|
||||
|
||||
// // DumpWithPrecision sets the number of decimal places to print. Applies to float32 and float64.
|
||||
// func DumpWithPrecision(n int) DumpOptions {
|
||||
// return func(opts *dumpOptions) {
|
||||
// opts.Precision = n
|
||||
// }
|
||||
// }
|
||||
|
||||
// // DumpWithThreshold sets the threshold for printing the entire tensor. If the number of elements
|
||||
// // is less than or equal to this value, the entire tensor will be printed. Otherwise, only the
|
||||
// // beginning and end of each dimension will be printed.
|
||||
// func DumpWithThreshold(n int) DumpOptions {
|
||||
// return func(opts *dumpOptions) {
|
||||
// opts.Threshold = n
|
||||
// }
|
||||
// }
|
||||
|
||||
// // DumpWithEdgeItems sets the number of elements to print at the beginning and end of each dimension.
|
||||
// func DumpWithEdgeItems(n int) DumpOptions {
|
||||
// return func(opts *dumpOptions) {
|
||||
// opts.EdgeItems = n
|
||||
// }
|
||||
// }
|
||||
|
||||
// type dumpOptions struct {
|
||||
// Precision, Threshold, EdgeItems int
|
||||
// }
|
||||
|
||||
// func Dump(ctx Context, t Tensor, optsFuncs ...DumpOptions) string {
|
||||
// opts := dumpOptions{Precision: 4, Threshold: 1000, EdgeItems: 3}
|
||||
// for _, optsFunc := range optsFuncs {
|
||||
// optsFunc(&opts)
|
||||
// }
|
||||
|
||||
// if mul(t.Shape()...) <= opts.Threshold {
|
||||
// opts.EdgeItems = math.MaxInt
|
||||
// }
|
||||
|
||||
// switch t.DType() {
|
||||
// case DTypeFloat32:
|
||||
// return dump[[]float32](ctx, t, opts.EdgeItems, func(f float32) string {
|
||||
// return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
|
||||
// })
|
||||
// case DTypeFloat16: // TODO other types...
|
||||
// f32 := ctx.Input().Empty(DTypeFloat32, t.Shape()...)
|
||||
// f32 = t.Copy(ctx, f32)
|
||||
// return dump[[]float32](ctx, f32, opts.EdgeItems, func(f float32) string {
|
||||
// return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
|
||||
// })
|
||||
// case DTypeInt32:
|
||||
// return dump[[]int32](ctx, t, opts.EdgeItems, func(i int32) string {
|
||||
// return strconv.FormatInt(int64(i), 10)
|
||||
// })
|
||||
// default:
|
||||
// return "<unsupported>"
|
||||
// }
|
||||
// }
|
||||
|
||||
// func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
|
||||
// if t.Bytes() == nil {
|
||||
// ctx.Compute(t)
|
||||
// }
|
||||
|
||||
// s := make(S, mul(t.Shape()...))
|
||||
// if err := binary.Read(bytes.NewBuffer(t.Bytes()), binary.LittleEndian, &s); err != nil {
|
||||
// panic(err)
|
||||
// }
|
||||
|
||||
// shape := t.Shape()
|
||||
// slices.Reverse(shape)
|
||||
|
||||
// var sb strings.Builder
|
||||
// var f func([]int, int)
|
||||
// f = func(dims []int, stride int) {
|
||||
// prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
|
||||
// sb.WriteString("[")
|
||||
// defer func() { sb.WriteString("]") }()
|
||||
// for i := 0; i < dims[0]; i++ {
|
||||
// if i >= items && i < dims[0]-items {
|
||||
// sb.WriteString("..., ")
|
||||
// // skip to next printable element
|
||||
// skip := dims[0] - 2*items
|
||||
// if len(dims) > 1 {
|
||||
// stride += mul(append(dims[1:], skip)...)
|
||||
// fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
|
||||
// }
|
||||
// i += skip - 1
|
||||
// } else if len(dims) > 1 {
|
||||
// f(dims[1:], stride)
|
||||
// stride += mul(dims[1:]...)
|
||||
// if i < dims[0]-1 {
|
||||
// fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
|
||||
// }
|
||||
// } else {
|
||||
// text := fn(s[stride+i])
|
||||
// if len(text) > 0 && text[0] != '-' {
|
||||
// sb.WriteString(" ")
|
||||
// }
|
||||
|
||||
// sb.WriteString(text)
|
||||
// if i < dims[0]-1 {
|
||||
// sb.WriteString(", ")
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// f(shape, 0)
|
||||
|
||||
// return sb.String()
|
||||
// }
|
||||
|
||||
type DType int
|
||||
|
||||
const (
|
||||
DTypeBool DType = iota
|
||||
DTypeUint8
|
||||
DTypeUint16
|
||||
DTypeUint32
|
||||
DTypeUint64
|
||||
DTypeInt8
|
||||
DTypeInt16
|
||||
DTypeInt32
|
||||
DTypeInt64
|
||||
DTypeFloat16
|
||||
DTypeFloat32
|
||||
DTypeFloat64
|
||||
DTypeBfloat16
|
||||
DTypeComplex64
|
||||
)
|
||||
|
||||
type SamplingMode int
|
||||
|
||||
const (
|
||||
SamplingModeNearest SamplingMode = iota
|
||||
SamplingModeBilinear
|
||||
)
|
||||
@@ -1,3 +0,0 @@
|
||||
package backend
|
||||
|
||||
// _ "github.com/ollama/ollama/x/ml/backend/mlx"
|
||||
@@ -1,61 +0,0 @@
|
||||
include(FetchContent)
|
||||
|
||||
# Read MLX version from top-level file (shared with Dockerfile)
|
||||
file(READ "${CMAKE_SOURCE_DIR}/MLX_VERSION" MLX_C_GIT_TAG)
|
||||
string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG)
|
||||
|
||||
set(MLX_C_BUILD_EXAMPLES OFF)
|
||||
|
||||
set(MLX_BUILD_GGUF OFF)
|
||||
set(MLX_BUILD_SAFETENSORS ON)
|
||||
|
||||
function(set_target_output_directory _target)
|
||||
if(TARGET ${_target})
|
||||
set_target_properties(${_target} PROPERTIES
|
||||
RUNTIME_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR}
|
||||
LIBRARY_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR}
|
||||
ARCHIVE_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR}
|
||||
)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
# Check for Metal support (macOS only)
|
||||
if(CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
||||
execute_process(
|
||||
COMMAND
|
||||
zsh "-c"
|
||||
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
|
||||
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
|
||||
if(NOT MLX_METAL_VERSION)
|
||||
message(STATUS "`xcrun metal` error. Setting MLX_BUILD_METAL=OFF")
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
endif()
|
||||
else()
|
||||
# On Linux, disable Metal backend
|
||||
message(STATUS "Non-macOS platform detected. Setting MLX_BUILD_METAL=OFF")
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
endif()
|
||||
|
||||
# Map CMAKE_CUDA_ARCHITECTURES to MLX_CUDA_ARCHITECTURES if not explicitly set
|
||||
if(NOT MLX_CUDA_ARCHITECTURES AND CMAKE_CUDA_ARCHITECTURES)
|
||||
set(MLX_CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES})
|
||||
message(STATUS "Using CMAKE_CUDA_ARCHITECTURES for MLX: ${MLX_CUDA_ARCHITECTURES}")
|
||||
endif()
|
||||
|
||||
# Enable CUDA backend if CUDA architectures are specified and CUDA compiler is available
|
||||
if(MLX_CUDA_ARCHITECTURES AND CMAKE_CUDA_COMPILER)
|
||||
set(MLX_BUILD_CUDA ON CACHE BOOL "Build CUDA backend for MLX" FORCE)
|
||||
message(STATUS "Enabling MLX CUDA backend with architectures: ${MLX_CUDA_ARCHITECTURES}")
|
||||
elseif(MLX_CUDA_ARCHITECTURES)
|
||||
message(WARNING "MLX_CUDA_ARCHITECTURES specified but CUDA compiler not found, CUDA backend will be disabled")
|
||||
endif()
|
||||
|
||||
FetchContent_Declare(
|
||||
mlx-c
|
||||
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
|
||||
GIT_TAG ${MLX_C_GIT_TAG})
|
||||
FetchContent_MakeAvailable(mlx-c)
|
||||
|
||||
set_target_output_directory(mlx)
|
||||
set_target_output_directory(mlxc)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,92 +0,0 @@
|
||||
// mlx_dynamic.c - Dynamic loading wrapper for MLX-C library
|
||||
// This file provides runtime dynamic loading of libmlxc instead of link-time binding
|
||||
|
||||
#include "mlx_dynamic.h"
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <windows.h>
|
||||
typedef HMODULE lib_handle_t;
|
||||
#define LOAD_LIB(path) LoadLibraryA(path)
|
||||
#define GET_SYMBOL(handle, name) GetProcAddress(handle, name)
|
||||
#define CLOSE_LIB(handle) FreeLibrary(handle)
|
||||
#define LIB_ERROR() "LoadLibrary failed"
|
||||
static const char* LIB_NAMES[] = {"libmlxc.dll", NULL};
|
||||
#else
|
||||
#include <dlfcn.h>
|
||||
typedef void* lib_handle_t;
|
||||
#define LOAD_LIB(path) dlopen(path, RTLD_LAZY | RTLD_GLOBAL)
|
||||
#define GET_SYMBOL(handle, name) dlsym(handle, name)
|
||||
#define CLOSE_LIB(handle) dlclose(handle)
|
||||
#define LIB_ERROR() dlerror()
|
||||
#ifdef __APPLE__
|
||||
static const char* LIB_NAMES[] = {
|
||||
"libmlxc.dylib",
|
||||
"@loader_path/../build/lib/ollama/libmlxc.dylib",
|
||||
"@executable_path/../build/lib/ollama/libmlxc.dylib",
|
||||
"build/lib/ollama/libmlxc.dylib",
|
||||
"../build/lib/ollama/libmlxc.dylib",
|
||||
NULL
|
||||
};
|
||||
#else
|
||||
static const char* LIB_NAMES[] = {
|
||||
"libmlxc.so",
|
||||
"$ORIGIN/../build/lib/ollama/libmlxc.so",
|
||||
"build/lib/ollama/libmlxc.so",
|
||||
"../build/lib/ollama/libmlxc.so",
|
||||
NULL
|
||||
};
|
||||
#endif
|
||||
#endif
|
||||
|
||||
static lib_handle_t mlx_handle = NULL;
|
||||
static int mlx_initialized = 0;
|
||||
static char mlx_error_buffer[512] = {0};
|
||||
|
||||
// Initialize MLX dynamic library
|
||||
// Returns 0 on success, -1 on failure
|
||||
// On failure, call mlx_dynamic_error() to get error message
|
||||
int mlx_dynamic_init(void) {
|
||||
if (mlx_initialized) {
|
||||
return 0; // Already initialized
|
||||
}
|
||||
|
||||
// Try each possible library path
|
||||
for (int i = 0; LIB_NAMES[i] != NULL; i++) {
|
||||
mlx_handle = LOAD_LIB(LIB_NAMES[i]);
|
||||
if (mlx_handle != NULL) {
|
||||
mlx_initialized = 1;
|
||||
snprintf(mlx_error_buffer, sizeof(mlx_error_buffer),
|
||||
"MLX: Successfully loaded %s", LIB_NAMES[i]);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Failed to load library
|
||||
const char* err = LIB_ERROR();
|
||||
snprintf(mlx_error_buffer, sizeof(mlx_error_buffer),
|
||||
"MLX: Failed to load libmlxc library. %s",
|
||||
err ? err : "Unknown error");
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Get the last error message
|
||||
const char* mlx_dynamic_error(void) {
|
||||
return mlx_error_buffer;
|
||||
}
|
||||
|
||||
// Check if MLX is initialized
|
||||
int mlx_dynamic_is_initialized(void) {
|
||||
return mlx_initialized;
|
||||
}
|
||||
|
||||
// Cleanup (optional, called at program exit)
|
||||
void mlx_dynamic_cleanup(void) {
|
||||
if (mlx_handle != NULL) {
|
||||
CLOSE_LIB(mlx_handle);
|
||||
mlx_handle = NULL;
|
||||
mlx_initialized = 0;
|
||||
}
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
// mlx_dynamic.h - Dynamic loading interface for MLX-C library
|
||||
#ifndef MLX_DYNAMIC_H
|
||||
#define MLX_DYNAMIC_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// Initialize the MLX dynamic library
|
||||
// Returns 0 on success, -1 on failure
|
||||
int mlx_dynamic_init(void);
|
||||
|
||||
// Get the last error message from dynamic loading
|
||||
const char* mlx_dynamic_error(void);
|
||||
|
||||
// Check if MLX is initialized
|
||||
int mlx_dynamic_is_initialized(void);
|
||||
|
||||
// Cleanup resources (optional, for clean shutdown)
|
||||
void mlx_dynamic_cleanup(void);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MLX_DYNAMIC_H
|
||||
@@ -1,314 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlx
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/runner/common"
|
||||
"github.com/ollama/ollama/sample"
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
"github.com/ollama/ollama/x/model"
|
||||
"github.com/ollama/ollama/x/model/input"
|
||||
_ "github.com/ollama/ollama/x/model/models/gemma3"
|
||||
)
|
||||
|
||||
func init() {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
slog.SetDefault(logger)
|
||||
}
|
||||
|
||||
func TestLoadModel(t *testing.T) {
|
||||
dir := "/Users/daniel/Models/gemma-3-4b-it/"
|
||||
b := &Backend{}
|
||||
err := b.LoadSafeTensors(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("load failed: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromInts(t *testing.T) {
|
||||
b := &Backend{}
|
||||
c := b.NewContext()
|
||||
defer c.Close()
|
||||
data := []int32{1, 2, 3, 4, 5, 6}
|
||||
a := c.FromInts(data, 2, 3)
|
||||
slog.Info("", "array", a)
|
||||
t.Log(a.ToString())
|
||||
if !reflect.DeepEqual(a.Shape(), []int{2, 3}) {
|
||||
t.Fatalf("incorrect shape: %v", a.Shape())
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromFloats(t *testing.T) {
|
||||
b := &Backend{}
|
||||
c := b.NewContext()
|
||||
defer c.Close()
|
||||
data := []float32{1, 2, 3, 4, 5, 6}
|
||||
a := c.FromFloats(data, 2, 3)
|
||||
slog.Info("", "array", a)
|
||||
t.Log(a.ToString())
|
||||
if !reflect.DeepEqual(a.Shape(), []int{2, 3}) {
|
||||
t.Fatalf("incorrect shape: %v", a.Shape())
|
||||
}
|
||||
res := a.Floats()
|
||||
if !reflect.DeepEqual(res, data) {
|
||||
t.Fatalf("incorrect results: %v", res)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdd(t *testing.T) {
|
||||
b := &Backend{}
|
||||
c := b.NewContext()
|
||||
defer c.Close()
|
||||
t1 := c.Arange(0, 24, 1, ml.DTypeFloat16)
|
||||
t2 := c.Arange(0, 24, 1, ml.DTypeFloat16)
|
||||
exp := c.Arange(0, 48, 2, ml.DTypeFloat16)
|
||||
t3 := t1.Add(c, t2)
|
||||
c.Compute(t3, exp)
|
||||
t3f := t3.Floats()
|
||||
if !reflect.DeepEqual(t3f, exp.Floats()) {
|
||||
t.Fatalf("incorrect result: %v", t3f)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReshapeTranspose(t *testing.T) {
|
||||
b := &Backend{}
|
||||
c := b.NewContext()
|
||||
defer c.Close()
|
||||
t1 := c.Arange(0, 24, 1, ml.DTypeFloat16).Reshape(c, 2, 3, 4).Transpose(c, 0, 2, 1).Contiguous(c, false)
|
||||
c.Compute(t1)
|
||||
t1f := t1.Floats()
|
||||
exp := []float32{
|
||||
0, 4, 8,
|
||||
1, 5, 9,
|
||||
2, 6, 10,
|
||||
3, 7, 11,
|
||||
12, 16, 20,
|
||||
13, 17, 21,
|
||||
14, 18, 22,
|
||||
15, 19, 23,
|
||||
}
|
||||
if !reflect.DeepEqual(t1f, exp) {
|
||||
t.Fatalf("incorrect results: %v", t1f)
|
||||
}
|
||||
}
|
||||
|
||||
func prod(vals ...int) int {
|
||||
r := 1
|
||||
for _, v := range vals {
|
||||
r *= v
|
||||
}
|
||||
return r
|
||||
}
|
||||
func TestMatmul(t *testing.T) {
|
||||
// TODO create scenarios...
|
||||
b := &Backend{}
|
||||
c := b.NewContext()
|
||||
defer c.Close()
|
||||
s1 := []int{1, 3, 2, 4}
|
||||
t1 := c.Arange(0, float32(prod(s1...)), 1, ml.DTypeFloat16).Reshape(c, s1...)
|
||||
s2 := []int{4, 2}
|
||||
t2 := c.Arange(0, float32(prod(s2...)), 1, ml.DTypeFloat16).Reshape(c, s2...)
|
||||
t3 := t1.Matmul(c, t2)
|
||||
exp := []float32{
|
||||
28, 34,
|
||||
76, 98,
|
||||
|
||||
124, 162,
|
||||
172, 226,
|
||||
|
||||
220, 290,
|
||||
268, 354,
|
||||
}
|
||||
c.Compute(t3)
|
||||
t3f := t3.Floats()
|
||||
if !reflect.DeepEqual(t3f, exp) {
|
||||
t.Fatalf("incorrect result: %v", t3f)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRows(t *testing.T) {
|
||||
b := &Backend{}
|
||||
c := b.NewContext()
|
||||
defer c.Close()
|
||||
t1 := c.Arange(0, 12, 1, ml.DTypeFloat32).Reshape(c, 1, 4, 3)
|
||||
outputs := c.Zeros(ml.DTypeInt32, 1)
|
||||
t2 := t1.TakeAxes(c, outputs, 1)
|
||||
c.Forward(t1, t2).Compute(t1, t2)
|
||||
t.Log(t1.ToString())
|
||||
t.Log(t2.ToString())
|
||||
f := t2.Floats()
|
||||
t.Logf("Result: %v", f)
|
||||
}
|
||||
|
||||
func TestCaching(t *testing.T) {
|
||||
// Validate the caching algorithm
|
||||
b := &Backend{}
|
||||
c := b.NewContext()
|
||||
defer c.Close()
|
||||
batchSize := 3
|
||||
headDim := 4
|
||||
numKVHeads := 2
|
||||
// Make cache twice the size of one test batch
|
||||
cells := batchSize * 2
|
||||
cellSize := numKVHeads * headDim
|
||||
shape := []int{1, numKVHeads, batchSize, headDim}
|
||||
stop := float32(1)
|
||||
for _, x := range shape {
|
||||
stop *= float32(x)
|
||||
}
|
||||
// Create the cache
|
||||
cache := c.Zeros(ml.DTypeFloat16, cells, cellSize)
|
||||
t.Logf("Empty Cache shape%v\n"+cache.ToString(), []int{cells, cellSize})
|
||||
|
||||
// Input tensor
|
||||
t1 := c.Arange(0, stop, 1, ml.DTypeFloat16).Reshape(c, shape...)
|
||||
t.Logf("Initial Data shape%v\n"+t1.ToString(), shape)
|
||||
|
||||
// Reshape to copy into the cache
|
||||
/*
|
||||
From MLX python/src/indexing.cpp mlx_scatter_args_array
|
||||
// The update shape must broadcast with indices.shape + [1] + src.shape[1:]
|
||||
auto up_shape = indices.shape();
|
||||
up_shape.insert(up_shape.end(), src.shape().begin() + 1, src.shape().end());
|
||||
up = broadcast_to(up, up_shape);
|
||||
up_shape.insert(up_shape.begin() + indices.ndim(), 1);
|
||||
up = reshape(up, up_shape);
|
||||
*/
|
||||
numRows := 3
|
||||
up := t1.Reshape(c, numRows, 1, cellSize) // The shape has to look like this for scatter to work properly
|
||||
t.Logf("Data reshaped for cache input shape%v\n"+up.ToString(), []int{batchSize, numKVHeads * headDim})
|
||||
|
||||
// Simulate cells 1,3,5 are available
|
||||
indicies := []ml.Tensor{c.FromInts([]int32{1, 3, 5}, numRows)}
|
||||
t.Logf("Indicies shape%v\n"+indicies[0].ToString(), []int{numRows})
|
||||
axis := []int{0} // The 1,3,5 of the indicies are in reference to axis 0 in the cache shape
|
||||
cache.Scatter(c, indicies, up, axis)
|
||||
|
||||
c.Forward(cache)
|
||||
// Cache should contain the data now
|
||||
t.Log("Cache after put\n" + cache.ToString())
|
||||
|
||||
// Retrieve cache content and verify it matches
|
||||
out := cache.TakeAxes(c, indicies[0], 0).Reshape(c, shape...)
|
||||
t.Logf("Output shape%v\n"+out.ToString(), out.Shape())
|
||||
|
||||
t1f := t1.Floats()
|
||||
outf := out.Floats()
|
||||
if !reflect.DeepEqual(t1f, outf) {
|
||||
t.Fatalf("mismatched in->out\n%v\n ->\n%v", t1f, outf)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGemma3(t *testing.T) {
|
||||
// Why is the sky blue
|
||||
inputs := []int32{2, 105, 2364, 107, 36425, 563, 506, 7217, 3730, 106, 107, 105, 4368}
|
||||
limit := 50
|
||||
|
||||
// TODO generalize this
|
||||
dir := "/Users/daniel/Models/gemma-3-4b-it/"
|
||||
|
||||
m, err := model.New(dir, ml.BackendParams{})
|
||||
if err != nil {
|
||||
t.Fatalf("unable to load model: %s", err)
|
||||
}
|
||||
b := m.Backend()
|
||||
ctx := b.NewContext()
|
||||
defer ctx.Close()
|
||||
|
||||
batch := input.Batch{
|
||||
Inputs: ctx.FromInts(inputs[:], 1, len(inputs)),
|
||||
Positions: make([]int32, len(inputs)),
|
||||
Sequences: make([]int, len(inputs)),
|
||||
Outputs: ctx.FromInts([]int32{int32(len(inputs) - 1)}, 1),
|
||||
Offset: 0,
|
||||
}
|
||||
for i := range len(inputs) {
|
||||
batch.Positions[i] = int32(i)
|
||||
}
|
||||
offset := len(inputs)
|
||||
|
||||
cache := m.Config().Cache
|
||||
if cache != nil {
|
||||
numSlots := 1
|
||||
batchSize := 512
|
||||
numCtx := 4096
|
||||
|
||||
// Note: this is inconsistent with mlx-py, but trying to be consistent with the GGML cache impl to get things working
|
||||
// cache.SetConfig(ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeBfloat16, MaskBatchPadding: 64})
|
||||
cache.SetConfig(ml.CacheConfig{CachePadding: 0, MaskDType: ml.DTypeBfloat16, MaskBatchPadding: 0})
|
||||
|
||||
cache.Init(b, ml.DTypeBfloat16, numSlots, int(numCtx), batchSize)
|
||||
err := cache.StartForward(ctx, batch, false)
|
||||
if err != nil {
|
||||
t.Fatalf("failed cache.StartForward: %s", err)
|
||||
}
|
||||
}
|
||||
opts := api.DefaultOptions()
|
||||
var grammar *sample.GrammarSampler
|
||||
sampler := sample.NewSampler(
|
||||
opts.Temperature,
|
||||
opts.TopK,
|
||||
opts.TopP,
|
||||
opts.MinP,
|
||||
opts.Seed,
|
||||
grammar,
|
||||
)
|
||||
|
||||
t.Log("Starting Forward pass loop")
|
||||
pendingResponses := []string{}
|
||||
for {
|
||||
out, err := m.Forward(ctx, batch)
|
||||
if err != nil {
|
||||
t.Fatalf("failed forward pass: %s", err)
|
||||
}
|
||||
ctx.Forward(out)
|
||||
outputs := out.Floats()
|
||||
t.Logf("finished forward pass! length:%d", len(outputs))
|
||||
// sample a token
|
||||
logits := outputs
|
||||
token, err := sampler.Sample(logits)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to sample token: %s", err)
|
||||
}
|
||||
t.Logf("Sampled token: %v", token)
|
||||
if m.(model.TextProcessor).Is(token, model.SpecialEOS) {
|
||||
t.Log("hit EOS")
|
||||
break
|
||||
}
|
||||
piece, err := m.(model.TextProcessor).Decode([]int32{token})
|
||||
if err != nil {
|
||||
t.Fatalf("unable to decode token: %s", err)
|
||||
}
|
||||
|
||||
pendingResponses = append(pendingResponses, piece)
|
||||
sequence := strings.Join(pendingResponses, "")
|
||||
if ok, stop := common.FindStop(sequence, opts.Stop); ok {
|
||||
t.Logf("hit stop token: %v", stop)
|
||||
break
|
||||
}
|
||||
t.Logf("RESULTS: %s", sequence)
|
||||
batch = input.Batch{
|
||||
Inputs: ctx.FromInts([]int32{token}, 1, 1),
|
||||
Positions: make([]int32, 1),
|
||||
Sequences: make([]int, 1),
|
||||
Outputs: ctx.FromInts([]int32{0}, 1),
|
||||
Offset: offset,
|
||||
}
|
||||
offset++
|
||||
batch.Positions[0] = 0
|
||||
err = cache.StartForward(ctx, batch, false)
|
||||
if err != nil {
|
||||
t.Fatalf("failed cache.StartForward: %s", err)
|
||||
}
|
||||
if offset > limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,335 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlx
|
||||
|
||||
/*
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/ops.h"
|
||||
|
||||
// Derived from https://github.com/ml-explore/mlx/blob/main/mlx/io/gguf_quants.cpp
|
||||
|
||||
void unpack_32_4(uint8_t* data, int8_t* dst) {
|
||||
memset(dst, 0, 16);
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
uint8_t x = (data[j + 2] & 0x0F); // j+2 to skip scale bytes.
|
||||
if (j % 2 != 0) {
|
||||
x <<= 4;
|
||||
}
|
||||
dst[j / 2] += x;
|
||||
}
|
||||
// Last 16 weights are in the higher bits
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
uint8_t x = (data[j + 2] >> 4);
|
||||
if (j % 2 != 0) {
|
||||
x <<= 4;
|
||||
}
|
||||
dst[8 + j / 2] += x;
|
||||
}
|
||||
}
|
||||
|
||||
// Extracts (weight, scales, biases) from Q4_0 tensors.
|
||||
// Data layout is: |16 bit scale|32 x 4bit weights|.
|
||||
void extract_q4_0_data(
|
||||
uint8_t* data,
|
||||
mlx_array* weights_arr,
|
||||
mlx_array* scales_arr,
|
||||
mlx_array* biases_arr) {
|
||||
const uint64_t bytes_per_block = 18; // 2 bytes scale, 32x0.5 byte weights
|
||||
uint8_t* weights = mlx_array_data_uint8(*weights_arr);
|
||||
float16_t* scales = mlx_array_data_float16(*scales_arr);
|
||||
float16_t* biases = mlx_array_data_float16(*biases_arr);
|
||||
for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
|
||||
scales[i] = *((float16_t*)data);
|
||||
biases[i] = -8 * scales[i];
|
||||
unpack_32_4(data, weights);
|
||||
weights += 16;
|
||||
data += bytes_per_block;
|
||||
}
|
||||
}
|
||||
|
||||
// Extracts (weight, scales, biases) from Q4_1 tensors.
|
||||
// Data layout is: |16 bit scale|16 bit bias|32 x 4bit weights|.
|
||||
void extract_q4_1_data(
|
||||
uint8_t* data,
|
||||
mlx_array* weights_arr,
|
||||
mlx_array* scales_arr,
|
||||
mlx_array* biases_arr) {
|
||||
const uint64_t bytes_per_block = 20; // 2 bytes scale, 2 bytes bias, 32x0.5 byte weights
|
||||
uint8_t* weights = mlx_array_data_uint8(*weights_arr);
|
||||
float16_t* scales = mlx_array_data_float16(*scales_arr);
|
||||
float16_t* biases = mlx_array_data_float16(*biases_arr);
|
||||
for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
|
||||
scales[i] = *((float16_t*)data);
|
||||
biases[i] = *((float16_t*)(data) + 1);
|
||||
unpack_32_4(data, weights);
|
||||
weights += 16;
|
||||
data += bytes_per_block;
|
||||
}
|
||||
}
|
||||
|
||||
// Extracts (weight, scales, biases) from Q8_0 tensors.
|
||||
// Data layout is: |16 bit scale|32 x 8bit weights|.
|
||||
void extract_q8_0_data(
|
||||
uint8_t* data,
|
||||
mlx_array* weights_arr,
|
||||
mlx_array* scales_arr,
|
||||
mlx_array* biases_arr) {
|
||||
const uint64_t weights_per_block = 32;
|
||||
const uint64_t bytes_per_block = 34; // 2 bytes scale, 32x1 byte weights
|
||||
uint8_t* weights = mlx_array_data_uint8(*weights_arr);
|
||||
float16_t* scales = mlx_array_data_float16(*scales_arr);
|
||||
float16_t* biases = mlx_array_data_float16(*biases_arr);
|
||||
for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
|
||||
uint8_t* block_data = data + i * bytes_per_block;
|
||||
scales[i] = *((float16_t*)block_data);
|
||||
biases[i] = -128 * scales[i];
|
||||
for (int64_t j = 0; j < weights_per_block; ++j) {
|
||||
uint8_t x = block_data[j + 2]; // j+2 to skip the scale bytes.
|
||||
// Original data is in int8_t, so we add a bias of -128 and invert the
|
||||
// first bit.
|
||||
x ^= 1 << 7;
|
||||
weights[i * weights_per_block + j] = x;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Drived from ggml-quants.c
|
||||
|
||||
#define QK_K 256
|
||||
|
||||
// 6-bit quantization
|
||||
// weight is represented as x = a * q
|
||||
// 16 blocks of 16 elements each
|
||||
// Effectively 6.5625 bits per weight
|
||||
typedef struct {
|
||||
uint8_t ql[QK_K/2]; // quants, lower 4 bits
|
||||
uint8_t qh[QK_K/4]; // quants, upper 2 bits
|
||||
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
|
||||
uint16_t d; // super-block scale
|
||||
} block_q6_K;
|
||||
|
||||
void dequant_row_q6_K(const void * restrict vx, void * restrict vy, int k) {
|
||||
const int64_t nb = k / QK_K;
|
||||
block_q6_K *x = (block_q6_K *)vx;
|
||||
float16_t* y = (float16_t *)vy;
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
float16_t d = 0.0;
|
||||
memcpy(&d, &x[i].d, sizeof(d));
|
||||
|
||||
const uint8_t * restrict ql = x[i].ql;
|
||||
const uint8_t * restrict qh = x[i].qh;
|
||||
const int8_t * restrict sc = x[i].scales;
|
||||
|
||||
for (int n = 0; n < QK_K; n += 128) {
|
||||
for (int l = 0; l < 32; ++l) {
|
||||
int is = l/16;
|
||||
const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
|
||||
const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
|
||||
const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
|
||||
const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
|
||||
y[l + 0] = d * sc[is + 0] * q1;
|
||||
y[l + 32] = d * sc[is + 2] * q2;
|
||||
y[l + 64] = d * sc[is + 4] * q3;
|
||||
y[l + 96] = d * sc[is + 6] * q4;
|
||||
}
|
||||
y += 128;
|
||||
ql += 64;
|
||||
qh += 32;
|
||||
sc += 8;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define K_SCALE_SIZE 12
|
||||
#define GGML_COMMON_AGGR_U
|
||||
#define GGML_COMMON_AGGR_S
|
||||
|
||||
// 4-bit quantization
|
||||
// 8 blocks of 32 elements each
|
||||
// weight is represented as x = a * q + b
|
||||
// Effectively 4.5 bits per weight
|
||||
typedef struct {
|
||||
union {
|
||||
struct {
|
||||
uint16_t d; // super-block scale for quantized scales
|
||||
uint16_t dmin; // super-block scale for quantized mins
|
||||
} GGML_COMMON_AGGR_S;
|
||||
uint16_t dm;
|
||||
} GGML_COMMON_AGGR_U;
|
||||
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
|
||||
uint8_t qs[QK_K/2]; // 4--bit quants
|
||||
} block_q4_K;
|
||||
|
||||
static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
|
||||
if (j < 4) {
|
||||
*d = q[j] & 63; *m = q[j + 4] & 63;
|
||||
} else {
|
||||
*d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
|
||||
*m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
|
||||
}
|
||||
}
|
||||
|
||||
void dequant_row_q4_K(const void * restrict vx, void * restrict vy, int k) {
|
||||
block_q4_K *x = (block_q4_K *)vx;
|
||||
float16_t* y = (float16_t *)vy;
|
||||
const int nb = k / QK_K;
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const uint8_t * q = x[i].qs;
|
||||
float16_t d = 0.0;
|
||||
memcpy(&d, &x[i].d, sizeof(d));
|
||||
float16_t min = 0.0;
|
||||
memcpy(&min, &x[i].dmin, sizeof(d));
|
||||
|
||||
int is = 0;
|
||||
uint8_t sc, m;
|
||||
for (int j = 0; j < QK_K; j += 64) {
|
||||
get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
|
||||
const float16_t d1 = d * sc; const float16_t m1 = min * m;
|
||||
get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
|
||||
const float16_t d2 = d * sc; const float16_t m2 = min * m;
|
||||
for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
|
||||
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
|
||||
q += 32; is += 2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
|
||||
"github.com/x448/float16"
|
||||
)
|
||||
|
||||
func gguf_load_quantized(data unsafe.Pointer, name string, final_shape []C.int, dtype uint32, stream C.mlx_stream) (r C.mlx_array, err error) {
|
||||
shape := append([]C.int{}, final_shape...)
|
||||
var weights_per_byte C.int
|
||||
if dtype == 2 || dtype == 3 {
|
||||
weights_per_byte = 2
|
||||
} else if dtype == 8 {
|
||||
weights_per_byte = 1
|
||||
} else {
|
||||
return r, fmt.Errorf("unsupported tensor type %d", dtype)
|
||||
}
|
||||
|
||||
weights_per_block := C.int(32)
|
||||
if shape[len(shape)-1]%weights_per_block != 0 {
|
||||
return r, fmt.Errorf("[load_gguf] tensor has incompatible last dim shape: %d", shape[len(shape)-1])
|
||||
}
|
||||
|
||||
weights_shape := append([]C.int{}, shape...)
|
||||
weights_shape[len(weights_shape)-1] /= (weights_per_byte * 4)
|
||||
w_nbytes := C.int(unsafe.Sizeof(uint32(0)))
|
||||
for i := range weights_shape {
|
||||
w_nbytes *= weights_shape[i]
|
||||
}
|
||||
w_data := make([]byte, w_nbytes)
|
||||
cbytes := C.CBytes(w_data)
|
||||
defer C.free(cbytes)
|
||||
weights := C.mlx_array_new_data(
|
||||
cbytes,
|
||||
&weights_shape[0],
|
||||
C.int(len(weights_shape)),
|
||||
C.MLX_UINT32,
|
||||
)
|
||||
|
||||
// For scales and bias
|
||||
shape[len(shape)-1] = shape[len(shape)-1] / weights_per_block
|
||||
sb_nbytes := C.int(unsafe.Sizeof(float16.Float16(0)))
|
||||
for i := range shape {
|
||||
sb_nbytes *= shape[i]
|
||||
}
|
||||
|
||||
s_data := make([]byte, sb_nbytes)
|
||||
cbytes = C.CBytes(s_data)
|
||||
defer C.free(cbytes)
|
||||
scales := C.mlx_array_new_data(
|
||||
cbytes,
|
||||
&shape[0],
|
||||
C.int(len(shape)),
|
||||
C.MLX_FLOAT16,
|
||||
)
|
||||
b_data := make([]byte, sb_nbytes)
|
||||
cbytes = C.CBytes(b_data)
|
||||
defer C.free(cbytes)
|
||||
biases := C.mlx_array_new_data(
|
||||
cbytes,
|
||||
&shape[0],
|
||||
C.int(len(shape)),
|
||||
C.MLX_FLOAT16,
|
||||
)
|
||||
var bits C.int
|
||||
switch dtype {
|
||||
case 2:
|
||||
C.extract_q4_0_data((*C.uint8_t)(data), &weights, &scales, &biases)
|
||||
bits = 4
|
||||
case 3:
|
||||
C.extract_q4_1_data((*C.uint8_t)(data), &weights, &scales, &biases)
|
||||
bits = 4
|
||||
case 8:
|
||||
C.extract_q8_0_data((*C.uint8_t)(data), &weights, &scales, &biases)
|
||||
bits = 8
|
||||
}
|
||||
groupSize := C.mlx_optional_int{value: 32, has_value: true}
|
||||
bitsOpt := C.mlx_optional_int{value: bits, has_value: true}
|
||||
var dtypeOpt C.mlx_optional_dtype // has_value defaults to false
|
||||
C.mlx_dequantize(
|
||||
&r,
|
||||
weights,
|
||||
scales,
|
||||
biases,
|
||||
groupSize,
|
||||
bitsOpt,
|
||||
nil, // TODO mode
|
||||
dtypeOpt,
|
||||
stream,
|
||||
)
|
||||
C.mlx_array_free(weights)
|
||||
C.mlx_array_free(scales)
|
||||
C.mlx_array_free(biases)
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func load_k_quantized(data unsafe.Pointer, name string, shape []C.int, dtype uint32, stream C.mlx_stream) (r C.mlx_array, err error) {
|
||||
size := 1
|
||||
for _, d := range shape {
|
||||
size *= int(d)
|
||||
}
|
||||
fdata := make([]float16.Float16, size)
|
||||
switch dtype {
|
||||
case 14:
|
||||
C.dequant_row_q6_K(
|
||||
data,
|
||||
unsafe.Pointer(&fdata[0]),
|
||||
C.int(size),
|
||||
)
|
||||
|
||||
case 12:
|
||||
C.dequant_row_q4_K(
|
||||
data,
|
||||
unsafe.Pointer(&fdata[0]),
|
||||
C.int(size),
|
||||
)
|
||||
default:
|
||||
return r, fmt.Errorf("unsupported K quant")
|
||||
}
|
||||
|
||||
r = C.mlx_array_new_data(
|
||||
unsafe.Pointer(&fdata[0]),
|
||||
&shape[0],
|
||||
C.int(len(shape)),
|
||||
C.MLX_FLOAT16,
|
||||
)
|
||||
return r, nil
|
||||
}
|
||||
643
x/ml/device.go
643
x/ml/device.go
@@ -1,643 +0,0 @@
|
||||
package ml
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash/maphash"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
// GPULayers is a set of layers to be allocated on a single GPU
|
||||
type GPULayers struct {
|
||||
DeviceID
|
||||
|
||||
// Layers is a set of layer indicies to load
|
||||
Layers []int
|
||||
}
|
||||
|
||||
// FirstLayer returns the smallest layer index scheduled on this GPU, or MaxInt when empty.
|
||||
func (g GPULayers) FirstLayer() int {
|
||||
if len(g.Layers) == 0 {
|
||||
return math.MaxInt
|
||||
}
|
||||
|
||||
first := g.Layers[0]
|
||||
for i := 1; i < len(g.Layers); i++ {
|
||||
if g.Layers[i] < first {
|
||||
first = g.Layers[i]
|
||||
}
|
||||
}
|
||||
|
||||
return first
|
||||
}
|
||||
|
||||
func (g GPULayers) String() string {
|
||||
if len(g.Layers) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
slices.Sort(g.Layers)
|
||||
|
||||
contiguous := true
|
||||
base := g.Layers[0]
|
||||
for i := range g.Layers {
|
||||
if g.Layers[i] != base+i {
|
||||
contiguous = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if contiguous {
|
||||
return fmt.Sprintf("ID:%v Layers:%v(%v..%v)", g.ID, len(g.Layers), g.Layers[0], g.Layers[len(g.Layers)-1])
|
||||
} else {
|
||||
return fmt.Sprintf("ID:%v Layers:%v%v", g.ID, len(g.Layers), g.Layers)
|
||||
}
|
||||
}
|
||||
|
||||
// GPULayersList is a set of layer allocations across multiple GPUs
|
||||
type GPULayersList []GPULayers
|
||||
|
||||
func (l GPULayersList) Len() int { return len(l) }
|
||||
func (l GPULayersList) Swap(i, j int) { l[i], l[j] = l[j], l[i] }
|
||||
|
||||
// Sort by the ordering of the layers offloaded
|
||||
func (l GPULayersList) Less(i, j int) bool {
|
||||
li := l[i].FirstLayer()
|
||||
lj := l[j].FirstLayer()
|
||||
|
||||
return li < lj
|
||||
}
|
||||
|
||||
func (l GPULayersList) String() string {
|
||||
if l.Sum() > 0 {
|
||||
return fmt.Sprintf("%v%v", l.Sum(), []GPULayers(l))
|
||||
} else {
|
||||
return fmt.Sprintf("%v", []GPULayers(l))
|
||||
}
|
||||
}
|
||||
|
||||
// Sum is the total number of layers assigned across all GPUs
|
||||
func (l GPULayersList) Sum() int {
|
||||
var sum int
|
||||
|
||||
for _, g := range l {
|
||||
sum += len(g.Layers)
|
||||
}
|
||||
|
||||
return sum
|
||||
}
|
||||
|
||||
var h maphash.Hash
|
||||
|
||||
// Hash is an identifier of this layer assignment
|
||||
func (l GPULayersList) Hash() uint64 {
|
||||
h.Reset()
|
||||
for _, g := range l {
|
||||
if len(g.Layers) > 0 {
|
||||
h.WriteString(g.ID + g.Library)
|
||||
for _, l := range g.Layers {
|
||||
binary.Write(&h, binary.NativeEndian, int64(l))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return h.Sum64()
|
||||
}
|
||||
|
||||
// ErrNoMem is returned when panicing due to insufficient memory. It includes
|
||||
// the attempted memory allocation.
|
||||
type ErrNoMem struct {
|
||||
BackendMemory
|
||||
}
|
||||
|
||||
func (e ErrNoMem) Error() string {
|
||||
return fmt.Sprintf("insufficient memory - required allocations: %+v", e.BackendMemory)
|
||||
}
|
||||
|
||||
// Minimal unique device identification
|
||||
type DeviceID struct {
|
||||
// ID is an identifier for the device for matching with system
|
||||
// management libraries. The ID is only unique for other devices
|
||||
// using the same Library.
|
||||
// This ID represents a "post filtered" view of the enumerated devices
|
||||
// if the ID is numeric
|
||||
ID string `json:"id"`
|
||||
|
||||
// Library identifies which library is used for the device (e.g. CUDA, ROCm, etc.)
|
||||
Library string `json:"backend,omitempty"`
|
||||
}
|
||||
|
||||
// DeviceMemory provides a breakdown of the memory needed
|
||||
// per device, such as a CPU or GPU.
|
||||
type DeviceMemory struct {
|
||||
DeviceID
|
||||
|
||||
// Name is the name of the device as labeled by the backend. It
|
||||
// may not be persistent across instances of the runner.
|
||||
Name string
|
||||
|
||||
// Weights is the per-layer memory needed for the model weights.
|
||||
Weights []uint64
|
||||
|
||||
// Cache is the per-layer memory needed for the KV cache.
|
||||
Cache []uint64
|
||||
|
||||
// Graph is the size of the compute graph. It is not per-layer.
|
||||
Graph uint64
|
||||
}
|
||||
|
||||
func sumMemory(mem []uint64) uint64 {
|
||||
var sum uint64
|
||||
|
||||
for _, m := range mem {
|
||||
sum += m
|
||||
}
|
||||
|
||||
return sum
|
||||
}
|
||||
|
||||
// Size returns the total size of the memory required by this device
|
||||
func (m DeviceMemory) Size() uint64 {
|
||||
return sumMemory(m.Weights) + sumMemory(m.Cache) + m.Graph
|
||||
}
|
||||
|
||||
func memoryPresent(mem []uint64) bool {
|
||||
return slices.ContainsFunc(mem, func(m uint64) bool { return m != 0 })
|
||||
}
|
||||
|
||||
func (m DeviceMemory) LogValue() slog.Value {
|
||||
var attrs []slog.Attr
|
||||
if memoryPresent(m.Weights) {
|
||||
attrs = append(attrs, slog.Any("Weights", m.Weights))
|
||||
}
|
||||
|
||||
if memoryPresent(m.Cache) {
|
||||
attrs = append(attrs, slog.Any("Cache", m.Cache))
|
||||
}
|
||||
|
||||
if m.Graph != 0 {
|
||||
attrs = append(attrs, slog.Any("Graph", m.Graph))
|
||||
}
|
||||
|
||||
if len(attrs) > 0 && m.ID != "" {
|
||||
attrs = append([]slog.Attr{slog.String("ID", m.ID)}, attrs...)
|
||||
}
|
||||
|
||||
return slog.GroupValue(attrs...)
|
||||
}
|
||||
|
||||
// BackendMemory provides the amount of memory required to load the model
|
||||
// per device based on the BackendParams. In some cases, not all required
|
||||
// allocations will be known at this point. However, the size of the most recent
|
||||
// allocation is guaranteed to be provided so that if it failed, the caller can
|
||||
// accommodate that to make forward progress.
|
||||
type BackendMemory struct {
|
||||
// InputWeights are always located on the CPU and cannot be moved
|
||||
InputWeights uint64
|
||||
|
||||
// CPU model components are located in system memory. This does not
|
||||
// include unified memory allocated through the GPU.
|
||||
CPU DeviceMemory
|
||||
|
||||
// GPU model components are located on one or more GPUs.
|
||||
GPUs []DeviceMemory
|
||||
}
|
||||
|
||||
func (m BackendMemory) LogValue() slog.Value {
|
||||
var attrs []slog.Attr
|
||||
if m.InputWeights != 0 {
|
||||
attrs = append(attrs, slog.Any("InputWeights", m.InputWeights))
|
||||
}
|
||||
|
||||
attrs = append(attrs, slog.Any(m.CPU.Name, m.CPU))
|
||||
for _, g := range m.GPUs {
|
||||
attrs = append(attrs, slog.Any(g.Name, g))
|
||||
}
|
||||
|
||||
return slog.GroupValue(attrs...)
|
||||
}
|
||||
|
||||
// Log prints a high level summary of the memory
|
||||
func (m BackendMemory) Log(level slog.Level) {
|
||||
var total uint64
|
||||
|
||||
for _, gpu := range m.GPUs {
|
||||
if sum := sumMemory(gpu.Weights); sum > 0 {
|
||||
slog.Log(context.TODO(), level, "model weights", "device", gpu.Name, "size", format.HumanBytes2(sum))
|
||||
total += sum
|
||||
}
|
||||
}
|
||||
if sum := m.InputWeights + sumMemory(m.CPU.Weights); sum > 0 {
|
||||
slog.Log(context.TODO(), level, "model weights", "device", m.CPU.Name, "size", format.HumanBytes2(sum))
|
||||
total += sum
|
||||
}
|
||||
|
||||
for _, gpu := range m.GPUs {
|
||||
if sum := sumMemory(gpu.Cache); sum > 0 {
|
||||
slog.Log(context.TODO(), level, "kv cache", "device", gpu.Name, "size", format.HumanBytes2(sum))
|
||||
total += sum
|
||||
}
|
||||
}
|
||||
if sum := sumMemory(m.CPU.Cache); sum > 0 {
|
||||
slog.Log(context.TODO(), level, "kv cache", "device", m.CPU.Name, "size", format.HumanBytes2(sum))
|
||||
total += sum
|
||||
}
|
||||
|
||||
for _, gpu := range m.GPUs {
|
||||
if sum := gpu.Graph; sum > 0 {
|
||||
slog.Log(context.TODO(), level, "compute graph", "device", gpu.Name, "size", format.HumanBytes2(sum))
|
||||
total += sum
|
||||
}
|
||||
}
|
||||
if sum := m.CPU.Graph; sum > 0 {
|
||||
slog.Log(context.TODO(), level, "compute graph", "device", m.CPU.Name, "size", format.HumanBytes2(sum))
|
||||
total += sum
|
||||
}
|
||||
|
||||
if total > 0 {
|
||||
slog.Log(context.TODO(), level, "total memory", "size", format.HumanBytes2(total))
|
||||
}
|
||||
}
|
||||
|
||||
type DeviceInfo struct {
|
||||
DeviceID
|
||||
|
||||
// Name is the name of the device as labeled by the backend. It
|
||||
// may not be persistent across instances of the runner.
|
||||
Name string `json:"name"`
|
||||
|
||||
// Description is the longer user-friendly identification of the device
|
||||
Description string `json:"description"`
|
||||
|
||||
// FilterID is populated with the unfiltered device ID if a numeric ID is used
|
||||
// so the device can be included.
|
||||
FilterID string `json:"filter_id,omitempty"`
|
||||
|
||||
// Integrated is set true for integrated GPUs, false for Discrete GPUs
|
||||
Integrated bool `json:"integration,omitempty"`
|
||||
|
||||
// PCIID is the bus, device and domain ID of the device for deduplication
|
||||
// when discovered by multiple backends
|
||||
PCIID string `json:"pci_id,omitempty"`
|
||||
|
||||
// TotalMemory is the total amount of memory the device can use for loading models
|
||||
TotalMemory uint64 `json:"total_memory"`
|
||||
|
||||
// FreeMemory is the amount of memory currently available on the device for loading models
|
||||
FreeMemory uint64 `json:"free_memory,omitempty"`
|
||||
|
||||
// ComputeMajor is the major version of capabilities of the device
|
||||
// if unsupported by the backend, -1 will be returned
|
||||
ComputeMajor int
|
||||
|
||||
// ComputeMinor is the minor version of capabilities of the device
|
||||
// if unsupported by the backend, -1 will be returned
|
||||
ComputeMinor int
|
||||
|
||||
// Driver Information
|
||||
DriverMajor int `json:"driver_major,omitempty"`
|
||||
DriverMinor int `json:"driver_minor,omitempty"`
|
||||
|
||||
// Where backends were loaded from
|
||||
LibraryPath []string
|
||||
}
|
||||
|
||||
type SystemInfo struct {
|
||||
// ThreadCount is the optimal number of threads to use for inference
|
||||
ThreadCount int `json:"threads,omitempty"`
|
||||
|
||||
// TotalMemory is the total amount of system memory
|
||||
TotalMemory uint64 `json:"total_memory,omitempty"`
|
||||
|
||||
// FreeMemory is the amount of memory currently available on the system for loading models
|
||||
FreeMemory uint64 `json:"free_memory,omitempty"`
|
||||
|
||||
// FreeSwap is the amount of system swap space reported as available
|
||||
FreeSwap uint64 `json:"free_swap,omitempty"`
|
||||
}
|
||||
|
||||
func (d DeviceInfo) Compute() string {
|
||||
// AMD gfx is encoded into the major minor in hex form
|
||||
if strings.EqualFold(d.Library, "ROCm") {
|
||||
return fmt.Sprintf("gfx%x%02x", d.ComputeMajor, d.ComputeMinor)
|
||||
}
|
||||
return strconv.Itoa(d.ComputeMajor) + "." + strconv.Itoa(d.ComputeMinor)
|
||||
}
|
||||
|
||||
func (d DeviceInfo) Driver() string {
|
||||
return strconv.Itoa(d.DriverMajor) + "." + strconv.Itoa(d.DriverMinor)
|
||||
}
|
||||
|
||||
// MinimumMemory reports the amount of memory that should be set aside
|
||||
// on the device for overhead (e.g. VRAM consumed by context structures independent
|
||||
// of model allocations)
|
||||
func (d DeviceInfo) MinimumMemory() uint64 {
|
||||
if d.Library == "Metal" {
|
||||
return 512 * format.MebiByte
|
||||
}
|
||||
return 457 * format.MebiByte
|
||||
}
|
||||
|
||||
// Sort by Free Space.
|
||||
// iGPUs are reported first, thus Reverse() yields the largest discrete GPU first
|
||||
type ByFreeMemory []DeviceInfo
|
||||
|
||||
func (a ByFreeMemory) Len() int { return len(a) }
|
||||
func (a ByFreeMemory) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||
func (a ByFreeMemory) Less(i, j int) bool {
|
||||
if a[i].Integrated && !a[j].Integrated {
|
||||
return true
|
||||
} else if !a[i].Integrated && a[j].Integrated {
|
||||
return false
|
||||
}
|
||||
return a[i].FreeMemory < a[j].FreeMemory
|
||||
}
|
||||
|
||||
// ByPerformance groups devices by similar speed
|
||||
func ByPerformance(l []DeviceInfo) [][]DeviceInfo {
|
||||
resp := [][]DeviceInfo{}
|
||||
scores := []bool{}
|
||||
for _, info := range l {
|
||||
found := false
|
||||
requested := info.Integrated
|
||||
for i, score := range scores {
|
||||
if score == requested {
|
||||
resp[i] = append(resp[i], info)
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
scores = append(scores, requested)
|
||||
resp = append(resp, []DeviceInfo{info})
|
||||
}
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func ByLibrary(l []DeviceInfo) [][]DeviceInfo {
|
||||
resp := [][]DeviceInfo{}
|
||||
libs := []string{}
|
||||
for _, info := range l {
|
||||
found := false
|
||||
requested := info.Library
|
||||
for i, lib := range libs {
|
||||
if lib == requested {
|
||||
resp[i] = append(resp[i], info)
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
libs = append(libs, requested)
|
||||
resp = append(resp, []DeviceInfo{info})
|
||||
}
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func LibraryPaths(l []DeviceInfo) []string {
|
||||
gpuLibs := []string{LibOllamaPath}
|
||||
for _, gpu := range l {
|
||||
for _, dir := range gpu.LibraryPath {
|
||||
needed := true
|
||||
for _, existing := range gpuLibs {
|
||||
if dir == existing {
|
||||
needed = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if needed {
|
||||
gpuLibs = append(gpuLibs, dir)
|
||||
}
|
||||
}
|
||||
}
|
||||
return gpuLibs
|
||||
}
|
||||
|
||||
type DeviceComparison int
|
||||
|
||||
const (
|
||||
UniqueDevice DeviceComparison = iota
|
||||
SameBackendDevice // The device is the same, and the library/backend is the same
|
||||
DuplicateDevice // The same physical device but different library/backend (overlapping device)
|
||||
)
|
||||
|
||||
func (a DeviceInfo) Compare(b DeviceInfo) DeviceComparison {
|
||||
if a.PCIID != b.PCIID {
|
||||
return UniqueDevice
|
||||
}
|
||||
// If PCIID is empty, we have to use ID + library for uniqueness
|
||||
if a.PCIID == "" && a.DeviceID != b.DeviceID {
|
||||
return UniqueDevice
|
||||
}
|
||||
if a.Library == b.Library {
|
||||
return SameBackendDevice
|
||||
}
|
||||
return DuplicateDevice
|
||||
}
|
||||
|
||||
// For a SameBackendDevice, return true if b is better than a
|
||||
// e.g. newer GPU library version
|
||||
func (a DeviceInfo) IsBetter(b DeviceInfo) bool {
|
||||
aLib := a.LibraryPath[len(a.LibraryPath)-1]
|
||||
bLib := b.LibraryPath[len(b.LibraryPath)-1]
|
||||
if aLib == bLib {
|
||||
return false
|
||||
}
|
||||
aLibSplit := strings.SplitN(aLib, "_", 2)
|
||||
bLibSplit := strings.SplitN(bLib, "_", 2)
|
||||
if len(aLibSplit) < 2 || len(bLibSplit) < 2 {
|
||||
return false
|
||||
}
|
||||
if aLibSplit[0] != bLibSplit[0] {
|
||||
slog.Debug("unexpected libraries", "a", aLib, "b", bLib)
|
||||
return false
|
||||
}
|
||||
if aLibSplit[1] == bLibSplit[1] {
|
||||
return false
|
||||
}
|
||||
cmp := []string{aLibSplit[1], bLibSplit[1]}
|
||||
sort.Sort(sort.Reverse(sort.StringSlice(cmp)))
|
||||
return cmp[0] == bLibSplit[1]
|
||||
}
|
||||
|
||||
// For each GPU, check if it does NOT support flash attention
|
||||
func FlashAttentionSupported(l []DeviceInfo) bool {
|
||||
for _, gpu := range l {
|
||||
supportsFA := gpu.Library == "cpu" ||
|
||||
gpu.Name == "Metal" || gpu.Library == "Metal" ||
|
||||
(gpu.Library == "CUDA" && gpu.DriverMajor >= 7 && !(gpu.ComputeMajor == 7 && gpu.ComputeMinor == 2)) ||
|
||||
gpu.Library == "ROCm" ||
|
||||
gpu.Library == "Vulkan"
|
||||
|
||||
if !supportsFA {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Given the list of GPUs this instantiation is targeted for,
|
||||
// figure out the visible devices environment variables
|
||||
// Set mustFilter true to enable filtering of CUDA devices
|
||||
func GetVisibleDevicesEnv(l []DeviceInfo, mustFilter bool) map[string]string {
|
||||
if len(l) == 0 {
|
||||
return nil
|
||||
}
|
||||
env := map[string]string{}
|
||||
for _, d := range l {
|
||||
d.updateVisibleDevicesEnv(env, mustFilter)
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
// NeedsInitValidation returns true if the device in question has the potential
|
||||
// to crash at inference time and requires deeper validation before we include
|
||||
// it in the supported devices list.
|
||||
func (d DeviceInfo) NeedsInitValidation() bool {
|
||||
// ROCm: rocblas will crash on unsupported devices.
|
||||
// CUDA: verify CC is supported by the version of the library
|
||||
return d.Library == "ROCm" || d.Library == "CUDA"
|
||||
}
|
||||
|
||||
// Set the init validation environment variable
|
||||
func (d DeviceInfo) AddInitValidation(env map[string]string) {
|
||||
env["GGML_CUDA_INIT"] = "1" // force deep initialization to trigger crash on unsupported GPUs
|
||||
}
|
||||
|
||||
// PreferredLibrary returns true if this library is preferred over the other input
|
||||
// library
|
||||
// Used to filter out Vulkan in favor of CUDA or ROCm
|
||||
func (d DeviceInfo) PreferredLibrary(other DeviceInfo) bool {
|
||||
// TODO in the future if we find Vulkan is better than ROCm on some devices
|
||||
// that implementation can live here.
|
||||
|
||||
if d.Library == "CUDA" || d.Library == "ROCm" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (d DeviceInfo) updateVisibleDevicesEnv(env map[string]string, mustFilter bool) {
|
||||
var envVar string
|
||||
switch d.Library {
|
||||
case "ROCm":
|
||||
// ROCm must be filtered as it can crash the runner on unsupported devices
|
||||
envVar = "ROCR_VISIBLE_DEVICES"
|
||||
if runtime.GOOS != "linux" {
|
||||
envVar = "HIP_VISIBLE_DEVICES"
|
||||
}
|
||||
case "CUDA":
|
||||
if !mustFilter {
|
||||
// By default we try to avoid filtering CUDA devices because ROCm also
|
||||
// looks at the CUDA env var, and gets confused in mixed vendor environments.
|
||||
return
|
||||
}
|
||||
envVar = "CUDA_VISIBLE_DEVICES"
|
||||
default:
|
||||
// Vulkan is not filtered via env var, but via scheduling decisions
|
||||
return
|
||||
}
|
||||
v, existing := env[envVar]
|
||||
if existing {
|
||||
v = v + ","
|
||||
}
|
||||
if d.FilterID != "" {
|
||||
v = v + d.FilterID
|
||||
} else {
|
||||
v = v + d.ID
|
||||
}
|
||||
env[envVar] = v
|
||||
}
|
||||
|
||||
type BaseRunner interface {
|
||||
// GetPort returns the localhost port number the runner is running on
|
||||
GetPort() int
|
||||
|
||||
// HasExited indicates if the runner is no longer running. This can be used during
|
||||
// bootstrap to detect if a given filtered device is incompatible and triggered an assert
|
||||
HasExited() bool
|
||||
}
|
||||
|
||||
type RunnerDiscovery interface {
|
||||
BaseRunner
|
||||
|
||||
// GetDeviceInfos will perform a query of the underlying device libraries
|
||||
// for device identification and free VRAM information
|
||||
// During bootstrap scenarios, this routine may take seconds to complete
|
||||
GetDeviceInfos(ctx context.Context) []DeviceInfo
|
||||
}
|
||||
|
||||
type FilteredRunnerDiscovery interface {
|
||||
RunnerDiscovery
|
||||
|
||||
// GetActiveDeviceIDs returns the filtered set of devices actively in
|
||||
// use by this runner for running models. If the runner is a bootstrap runner, no devices
|
||||
// will be active yet so no device IDs are returned.
|
||||
// This routine will not query the underlying device and will return immediately
|
||||
GetActiveDeviceIDs() []DeviceID
|
||||
}
|
||||
|
||||
func GetDevicesFromRunner(ctx context.Context, runner BaseRunner) ([]DeviceInfo, error) {
|
||||
var moreDevices []DeviceInfo
|
||||
port := runner.GetPort()
|
||||
tick := time.Tick(10 * time.Millisecond)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("failed to finish discovery before timeout")
|
||||
case <-tick:
|
||||
r, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/info", port), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(r)
|
||||
if err != nil {
|
||||
// slog.Warn("failed to send request", "error", err)
|
||||
if runner.HasExited() {
|
||||
return nil, fmt.Errorf("runner crashed")
|
||||
}
|
||||
continue
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
// old runner, fall back to bootstrapping model
|
||||
return nil, fmt.Errorf("llamarunner free vram reporting not supported")
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
slog.Warn("failed to read response", "error", err)
|
||||
continue
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
logutil.Trace("runner failed to discover free VRAM", "status", resp.StatusCode, "response", body)
|
||||
return nil, fmt.Errorf("runner error: %s", string(body))
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &moreDevices); err != nil {
|
||||
slog.Warn("unmarshal encode response", "error", err)
|
||||
continue
|
||||
}
|
||||
return moreDevices, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,103 +0,0 @@
|
||||
package nn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ollama/ollama/x/kvcache"
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
)
|
||||
|
||||
// Attention implements scaled dot-product attention for transformer models:
|
||||
// Attention(Q, K, V) = softmax(QK^T/√d_k)V
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for tensor operations
|
||||
// - query: Query tensor (Q) with shape [d_k, heads, seq_len_q]
|
||||
// - key: Key tensor (K) with shape [d_k, kv_heads, seq_len_k], can be nil to read from cache only
|
||||
// - value: Value tensor (V) with shape [d_v, kv_heads, seq_len_k], can be nil to read from cache only
|
||||
// - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
|
||||
// - cache: KV cache to store key/value and get past history, can be nil to only use provided key/value
|
||||
//
|
||||
// Returns:
|
||||
//
|
||||
// Attention output with shape [d_v, heads, seq_len_q]
|
||||
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||
return AttentionWithVMLA(ctx, query, key, value, nil, nil, scale, cache)
|
||||
}
|
||||
|
||||
func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||
return AttentionWithVMLA(ctx, query, key, value, sinks, nil, scale, cache)
|
||||
}
|
||||
|
||||
func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||
ctx.Forward(query)
|
||||
|
||||
if key != nil && value != nil {
|
||||
if query.Dim(0) != key.Dim(0) {
|
||||
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
|
||||
}
|
||||
|
||||
if key.Dim(1) != value.Dim(1) {
|
||||
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1)))
|
||||
}
|
||||
|
||||
if key.Dim(2) != value.Dim(2) {
|
||||
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
|
||||
}
|
||||
|
||||
ctx.Forward(key, value)
|
||||
if cache != nil {
|
||||
cache.Put(ctx, key, value)
|
||||
}
|
||||
} else if cache == nil {
|
||||
panic("key & value tensors must be provided if cache is nil")
|
||||
}
|
||||
|
||||
// ctx.CompareWith("/tmp/test", map[string]ml.Tensor{"q": query, "k": key, "v": value}, true)
|
||||
// panic("after cache get") //
|
||||
// 2025/12/10 16:02:33 INFO XXX tensors are similar q=0.9999869465827942 shape="[1 8 13 256]" min_difference=[-0.07926178] max_difference=[0.07012844]
|
||||
// 2025/12/10 16:02:33 INFO XXX tensors are similar k=0.9999891519546509 shape="[1 4 13 256]" min_difference=[-0.21365738] max_difference=[0.19916534]
|
||||
// 2025/12/10 16:02:33 INFO XXX tensors are similar v=0.9999960660934448 shape="[1 4 13 256]" min_difference=[-0.32923126] max_difference=[0.32646942]
|
||||
|
||||
// var mask ml.Tensor
|
||||
if cache != nil {
|
||||
key, value, _ = cache.Get(ctx)
|
||||
}
|
||||
// ctx.CompareWith("/tmp/test", map[string]ml.Tensor{"q": query.Contiguous(ctx, false), "k": key.Contiguous(ctx, false), "v": value.Contiguous(ctx, false)}, true)
|
||||
// panic("after cache get") //
|
||||
// 2025/12/10 15:34:03 INFO XXX tensors are similar q=0.9999869465827942 shape="[1 8 13 256]" min_difference=[-0.07926178] max_difference=[0.07012844]
|
||||
// 2025/12/10 15:34:03 INFO XXX tensors are similar k=0.9999881982803345 shape="[1 4 13 256]" min_difference=[-0.25] max_difference=[0.25]
|
||||
// 2025/12/10 15:34:03 INFO XXX tensors are similar v=0.9999913573265076 shape="[1 4 13 256]" min_difference=[-0.5] max_difference=[0.5]
|
||||
|
||||
// Only use the fast SDPA implementation if we have a cache, since that's what
|
||||
// will do any expected backend-specific transformations for us
|
||||
|
||||
if cache != nil {
|
||||
// TODO what to do with vmla?
|
||||
// return query.Transpose(ctx, 0, 2, 1, 3).ScaledDotProductAttention(ctx, key.Transpose(ctx, 0, 2, 1, 3), value.Transpose(ctx, 0, 2, 1, 3), scale, "array", mask, sinks)
|
||||
return query.ScaledDotProductAttention(ctx, key, value, scale, "causal", nil, sinks)
|
||||
|
||||
// TODO these two produce identical output, but not similar enough - 92.9% - should be 99.999%
|
||||
} else {
|
||||
panic("else case not supported")
|
||||
// TODO transpose shapes are wrong
|
||||
// key = key.Transpose(ctx, 0, 2, 1, 3)
|
||||
// value = value.Transpose(ctx, 1, 2, 0, 3).Contiguous(ctx, false)
|
||||
|
||||
// kq := query.Matmul(ctx, key)
|
||||
|
||||
// kq = kq.Scale(ctx, scale)
|
||||
// if mask != nil {
|
||||
// kq = kq.Add(ctx, mask)
|
||||
// }
|
||||
// kq = kq.Softmax(ctx)
|
||||
|
||||
// kqv := kq.Matmul(ctx, value)
|
||||
|
||||
// if vmla != nil {
|
||||
// kqv = kqv.Matmul(ctx, vmla)
|
||||
// }
|
||||
|
||||
// return kqv.Transpose(ctx, 0, 2, 1, 3).Contiguous(ctx, false)
|
||||
}
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
package nn
|
||||
|
||||
import "github.com/ollama/ollama/x/ml"
|
||||
|
||||
type Conv2D struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Bias ml.Tensor `gguf:"bias"`
|
||||
}
|
||||
|
||||
func (m *Conv2D) Forward(ctx ml.Context, t ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
|
||||
t = m.Weight.Conv2D(ctx, t, s0, s1, p0, p1, d0, d1, 1)
|
||||
if m.Bias != nil {
|
||||
// Bias shape is (out_channels,) while t shape is (width, height, out_channels, batch)
|
||||
t = t.Add(ctx, m.Bias.Reshape(ctx, 1, 1, -1))
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
type Conv3D struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Bias ml.Tensor `gguf:"bias"`
|
||||
}
|
||||
|
||||
func (m *Conv3D) Forward(ctx ml.Context, t ml.Tensor, s0, s1, s2, p0, p1, p2, d0, d1, d2, g int) ml.Tensor {
|
||||
t = m.Weight.Conv3D(ctx, t, s0, s1, s2, p0, p1, p2, d0, d1, d2, g)
|
||||
if m.Bias != nil {
|
||||
t = t.Add(ctx, m.Bias)
|
||||
}
|
||||
return t
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
package nn
|
||||
|
||||
import "github.com/ollama/ollama/x/ml"
|
||||
|
||||
type Embedding struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
}
|
||||
|
||||
func (m *Embedding) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor {
|
||||
return m.Weight.TakeAxes(ctx, hiddenState, 0)
|
||||
}
|
||||
@@ -1,32 +0,0 @@
|
||||
package nn
|
||||
|
||||
import "github.com/ollama/ollama/x/ml"
|
||||
|
||||
type Linear struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Bias ml.Tensor `gguf:"bias"`
|
||||
}
|
||||
|
||||
func (m *Linear) Forward(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
||||
t = t.Matmul(ctx, m.Weight.Transpose(ctx))
|
||||
if m.Bias != nil {
|
||||
t = t.Add(ctx, m.Bias)
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
type LinearBatch struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Bias ml.Tensor `gguf:"bias"`
|
||||
}
|
||||
|
||||
func (m *LinearBatch) Forward(ctx ml.Context, t, indices ml.Tensor) ml.Tensor {
|
||||
panic("not yet ported")
|
||||
// t = m.Weight.MulmatID(ctx, t, indices)
|
||||
// if m.Bias != nil {
|
||||
// t = t.AddID(ctx, m.Bias, indices)
|
||||
// }
|
||||
|
||||
// return t
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
package nn
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
)
|
||||
|
||||
type LayerNorm struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Bias ml.Tensor `gguf:"bias"`
|
||||
}
|
||||
|
||||
func (m *LayerNorm) Forward(ctx ml.Context, t ml.Tensor, eps float32) ml.Tensor {
|
||||
return t.LayerNorm(ctx, m.Weight, m.Bias, eps)
|
||||
}
|
||||
|
||||
type RMSNorm struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
}
|
||||
|
||||
func (m *RMSNorm) Forward(ctx ml.Context, t ml.Tensor, eps float32) ml.Tensor {
|
||||
// slog.Info("RMSNorm", "eps", eps)
|
||||
// fmt.Fprintln(os.Stderr, t.ToString())
|
||||
// fmt.Fprintln(os.Stderr, m.Weight.ToString())
|
||||
|
||||
// TODO this is probably model specific, not generalized...
|
||||
w := m.Weight.Add(ctx, ctx.FromFloats([]float32{1.0}, 1))
|
||||
|
||||
return t.RMSNorm(ctx, w, eps)
|
||||
}
|
||||
@@ -1,41 +0,0 @@
|
||||
package pooling
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
)
|
||||
|
||||
type Type uint32
|
||||
|
||||
const (
|
||||
TypeNone Type = iota
|
||||
TypeMean
|
||||
TypeCLS
|
||||
TypeLast
|
||||
)
|
||||
|
||||
func (t Type) String() string {
|
||||
switch t {
|
||||
case TypeMean:
|
||||
return "Mean"
|
||||
case TypeCLS:
|
||||
return "CLS"
|
||||
case TypeLast:
|
||||
return "Last"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
|
||||
switch t {
|
||||
// case TypeMean:
|
||||
// hiddenStates = hiddenStates.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false).Mean(ctx)
|
||||
// return hiddenStates.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false)
|
||||
// case TypeCLS:
|
||||
// return hiddenStates.Slice(ctx, 1, 0, 1, 1)
|
||||
// case TypeLast:
|
||||
// return hiddenStates.Slice(ctx, 1, hiddenStates.Dim(1)-1, hiddenStates.Dim(1), 1)
|
||||
default:
|
||||
panic("unknown pooling type")
|
||||
}
|
||||
}
|
||||
@@ -1,72 +0,0 @@
|
||||
package rope
|
||||
|
||||
import "github.com/ollama/ollama/x/ml"
|
||||
|
||||
// Options contains optional parameters for RoPE function
|
||||
type Options struct {
|
||||
Type int
|
||||
Factors ml.Tensor
|
||||
|
||||
// YaRN options
|
||||
YaRN struct {
|
||||
OriginalContextLength int
|
||||
ExtrapolationFactor,
|
||||
AttentionFactor,
|
||||
BetaFast,
|
||||
BetaSlow float32
|
||||
}
|
||||
|
||||
// MRoPE options
|
||||
MRoPE struct {
|
||||
Sections []int
|
||||
}
|
||||
}
|
||||
|
||||
// WithTypeNeoX sets RoPE type to NeoX
|
||||
func WithTypeNeoX() func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.Type = 2
|
||||
}
|
||||
}
|
||||
|
||||
// WithFactors sets custom rope factors
|
||||
func WithFactors(factors ml.Tensor) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
if factors != nil {
|
||||
opts.Factors = factors
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithOriginalContextLength sets a custom context length
|
||||
func WithOriginalContextLength(n int) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.YaRN.OriginalContextLength = n
|
||||
}
|
||||
}
|
||||
|
||||
func WithExtrapolationFactor(extrapolationFactor float32) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.YaRN.ExtrapolationFactor = extrapolationFactor
|
||||
}
|
||||
}
|
||||
|
||||
func WithAttentionFactor(attentionFactor float32) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.YaRN.AttentionFactor = attentionFactor
|
||||
}
|
||||
}
|
||||
|
||||
func WithMRoPE(sections []int) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.Type |= 1 << 3
|
||||
opts.MRoPE.Sections = sections
|
||||
}
|
||||
}
|
||||
|
||||
func WithInterleaveMRoPE(sections []int) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.Type |= 1<<3 | 1<<5
|
||||
opts.MRoPE.Sections = sections
|
||||
}
|
||||
}
|
||||
56
x/ml/path.go
56
x/ml/path.go
@@ -1,56 +0,0 @@
|
||||
package ml
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// LibPath is a path to lookup dynamic libraries
|
||||
// in development it's usually 'build/lib/ollama'
|
||||
// in distribution builds it's 'lib/ollama' on Windows
|
||||
// '../lib/ollama' on Linux and the executable's directory on macOS
|
||||
// note: distribution builds, additional GPU-specific libraries are
|
||||
// found in subdirectories of the returned path, such as
|
||||
// 'cuda_v12', 'rocm', etc.
|
||||
var LibOllamaPath string = func() string {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if eval, err := filepath.EvalSymlinks(exe); err == nil {
|
||||
exe = eval
|
||||
}
|
||||
|
||||
var libPath string
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
libPath = filepath.Join(filepath.Dir(exe), "lib", "ollama")
|
||||
case "linux":
|
||||
libPath = filepath.Join(filepath.Dir(exe), "..", "lib", "ollama")
|
||||
case "darwin":
|
||||
libPath = filepath.Dir(exe)
|
||||
}
|
||||
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
paths := []string{
|
||||
libPath,
|
||||
|
||||
// build paths for development
|
||||
filepath.Join(filepath.Dir(exe), "build", "lib", "ollama"),
|
||||
filepath.Join(cwd, "build", "lib", "ollama"),
|
||||
}
|
||||
|
||||
for _, p := range paths {
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return p
|
||||
}
|
||||
}
|
||||
|
||||
return filepath.Dir(exe)
|
||||
}()
|
||||
@@ -1,322 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func llama(t testing.TB) BytePairEncoding {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.Open(filepath.Join("..", "..", "model", "testdata", "llama3.2", "encoder.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
vocab := make(map[string]int32)
|
||||
if err := json.NewDecoder(f).Decode(&vocab); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
types := make([]int32, len(vocab))
|
||||
tokens := make([]string, len(vocab))
|
||||
for token, id := range vocab {
|
||||
tokens[id] = token
|
||||
types[id] = 1
|
||||
}
|
||||
|
||||
for _, token := range []string{"<|begin_of_text|>", "<|end_of_text|>"} {
|
||||
if _, ok := vocab[token]; !ok {
|
||||
tokens = append(tokens, token) //nolint:makezero
|
||||
types = append(types, 3) //nolint:makezero
|
||||
vocab[token] = int32(len(vocab))
|
||||
}
|
||||
}
|
||||
|
||||
f, err = os.Open(filepath.Join("..", "..", "model", "testdata", "llama3.2", "vocab.bpe"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
merges := make([]string, 0, 50000)
|
||||
|
||||
scanner := bufio.NewScanner(f)
|
||||
for scanner.Scan() {
|
||||
if !strings.HasPrefix(scanner.Text(), "#") {
|
||||
merges = append(merges, scanner.Text())
|
||||
}
|
||||
}
|
||||
|
||||
return NewBytePairEncoding(
|
||||
&Vocabulary{
|
||||
Values: tokens,
|
||||
Types: types,
|
||||
Merges: merges,
|
||||
},
|
||||
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
)
|
||||
}
|
||||
|
||||
func TestLlama(t *testing.T) {
|
||||
tokenizer := llama(t)
|
||||
|
||||
t.Run("simple", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ids, err := tokenizer.Encode("hello world", true)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff([]int32{15339, 1917}, ids); diff != "" {
|
||||
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||||
}
|
||||
|
||||
s, err := tokenizer.Decode([]int32{15339, 1917})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if s != "hello world" {
|
||||
t.Errorf("got %q, want hello world", s)
|
||||
}
|
||||
|
||||
ids, err = tokenizer.Encode("hello <|end_of_text|>", true)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff([]int32{15339, 220, 128001}, ids); diff != "" {
|
||||
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("simple repeated", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := map[string][]int32{
|
||||
strings.Repeat("0", 1): {15},
|
||||
strings.Repeat("0", 2): {410},
|
||||
strings.Repeat("0", 3): {931},
|
||||
strings.Repeat("0", 4): {931, 15},
|
||||
strings.Repeat("0", 5): {931, 410},
|
||||
strings.Repeat("0", 6): {931, 931},
|
||||
strings.Repeat("0", 7): {931, 931, 15},
|
||||
strings.Repeat("0", 8): {931, 931, 410},
|
||||
strings.Repeat("0", 9): {931, 931, 931},
|
||||
strings.Repeat("0", 10): {931, 931, 931, 15},
|
||||
strings.Repeat("0", 11): {931, 931, 931, 410},
|
||||
strings.Repeat("0", 12): {931, 931, 931, 931},
|
||||
strings.Repeat("0", 13): {931, 931, 931, 931, 15},
|
||||
strings.Repeat("0", 14): {931, 931, 931, 931, 410},
|
||||
strings.Repeat("0", 15): {931, 931, 931, 931, 931},
|
||||
strings.Repeat("0", 16): {931, 931, 931, 931, 931, 15},
|
||||
strings.Repeat("0", 17): {931, 931, 931, 931, 931, 410},
|
||||
}
|
||||
|
||||
for s, want := range cases {
|
||||
ids, err := tokenizer.Encode(s, true)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(want, ids); diff != "" {
|
||||
t.Errorf("%q no match (-theirs +ours):\n%s", s, diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("basic roundtrip", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []string{
|
||||
"hello",
|
||||
"hello ",
|
||||
"hello ",
|
||||
" hello",
|
||||
" hello ",
|
||||
" hello ",
|
||||
"hello world",
|
||||
"请考试我的软件!12345",
|
||||
}
|
||||
|
||||
for _, want := range cases {
|
||||
ids, err := tokenizer.Encode(want, true)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if got, err := tokenizer.Decode(ids); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("special", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := map[string][]int32{
|
||||
"<|begin_of_text|>A B!": {128000, 32, 426, 0},
|
||||
"<|begin_of_text|>A<|end_of_text|>B!": {128000, 32, 128001, 33, 0},
|
||||
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!": {128000, 32, 128001, 33, 128000, 0},
|
||||
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!<|end_of_text|>": {128000, 32, 128001, 33, 128000, 0, 128001},
|
||||
}
|
||||
|
||||
for s, want := range cases {
|
||||
ids, err := tokenizer.Encode(s, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(want, ids); diff != "" {
|
||||
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("split", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := map[string][]string{
|
||||
"Hello World!": {"Hello", " World", "!"},
|
||||
"I'm don't won't": {"I", "'m", " don", "'t", " won", "'t"},
|
||||
"In 2024 there are 366 days": {"In", " ", "202", "4", " there", " are", " ", "366", " days"},
|
||||
"Hello!! ...world": {"Hello", "!!", " ...", "world"},
|
||||
"Hello World": {"Hello", " ", " World"},
|
||||
"Hello\nWorld": {"Hello", "\n", "World"},
|
||||
"Hello, WORLD!! How's it going?": {"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?"},
|
||||
}
|
||||
|
||||
for s, want := range cases {
|
||||
got := slices.Collect(tokenizer.split(s))
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("roundtriping 0x00-0xFF", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for b := 0x00; b <= 0xFF; b++ {
|
||||
input := string(rune(b))
|
||||
ids, err := tokenizer.Encode(input, false)
|
||||
if err != nil {
|
||||
t.Errorf("failed to encode rune 0x%02X: %v", b, err)
|
||||
continue
|
||||
}
|
||||
|
||||
decoded, err := tokenizer.Decode(ids)
|
||||
if err != nil {
|
||||
t.Errorf("failed to decode rune 0x%02X: %v", b, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if b == 0x00 {
|
||||
if len(decoded) != 0 {
|
||||
t.Errorf("Decode(Encode(0x00)) should be empty, got %v", ids)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if decoded != input {
|
||||
t.Errorf("rune 0x%02X failed roundtrip: got %q, want %q", b, decoded, input)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkBytePairEncoding(b *testing.B) {
|
||||
tokenizer := llama(b)
|
||||
bts, err := os.ReadFile(filepath.Join("testdata", "war-and-peace.txt"))
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
for i := range 8 {
|
||||
n := min(int(math.Pow10(i)), len(bts))
|
||||
bts := bts[:n]
|
||||
b.Run("encode"+strconv.Itoa(n), func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
_, err := tokenizer.Encode(string(bts), true)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("decode"+strconv.Itoa(n), func(b *testing.B) {
|
||||
ids, err := tokenizer.Encode(string(bts), true)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
_, err := tokenizer.Decode(ids)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("split"+strconv.Itoa(n), func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
slices.Collect(tokenizer.split(string(bts)))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplit(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
patterns,
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
name: "default",
|
||||
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " 123", " 一二三"},
|
||||
},
|
||||
{
|
||||
name: "unicode",
|
||||
patterns: []string{
|
||||
"\\p{N}{1,3}",
|
||||
`[一-龥-ゟ゠-ヿ]+`,
|
||||
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
|
||||
},
|
||||
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "123", " ", "一二三"},
|
||||
},
|
||||
{
|
||||
name: "individual digits",
|
||||
patterns: []string{
|
||||
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
},
|
||||
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "1", "2", "3", " 一二三"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tokenizer := NewBytePairEncoding(nil, tt.patterns...)
|
||||
if diff := cmp.Diff(tt.want, slices.Collect(tokenizer.split("Hello, WORLD!! How's it going? 123 一二三"))); diff != "" {
|
||||
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,76 +0,0 @@
|
||||
package input
|
||||
|
||||
import "github.com/ollama/ollama/x/ml"
|
||||
|
||||
// Multimodal is a multimodal embedding or a component of one.
|
||||
// For example, it could be a row of an image that can be processed
|
||||
// independently.
|
||||
type Multimodal struct {
|
||||
// Tensor is the embedding data. Implementations may chose what to
|
||||
// store here or it may be nil if not needed. However, any ml.Tensor
|
||||
// objects must be stored here and not in Data.
|
||||
Tensor ml.Tensor
|
||||
|
||||
// Data is implementation-specific opaque data, such as metadata on how
|
||||
// to layout Tensor. It may be nil if not needed. It may also store larger
|
||||
// objects such as complete images if they are to be processed later.
|
||||
Data any
|
||||
}
|
||||
|
||||
// Input represents one token in the input stream
|
||||
type Input struct {
|
||||
// Token is a single element of text.
|
||||
Token int32
|
||||
|
||||
// Multimodal is represents a non-text element such as an
|
||||
// image (or part of one if the image can be processed in pieces).
|
||||
// It may be used either together with Token or on its own.
|
||||
Multimodal []Multimodal
|
||||
|
||||
// MultimodalHash is a unique representation of the data
|
||||
// stored in Multimodal, used for caching and comparing
|
||||
// equality.
|
||||
MultimodalHash uint64
|
||||
|
||||
// SameBatch forces the following number of tokens to be processed
|
||||
// in a single batch, breaking and extending batches as needed.
|
||||
// Useful for things like images that must be processed in one
|
||||
// shot.
|
||||
SameBatch int
|
||||
}
|
||||
|
||||
// MultimodalIndex is a multimodal element (such as an image)
|
||||
// together with an index into the slice of Inputs with the
|
||||
// corresponding token. Note that the index is not the same
|
||||
// as the position - to find that use the index with the
|
||||
// Positions slice.
|
||||
type MultimodalIndex struct {
|
||||
Index int
|
||||
Multimodal []Multimodal
|
||||
}
|
||||
|
||||
// Batch contains the inputs for a model forward pass
|
||||
type Batch struct {
|
||||
// Inputs is the input tokens, including placeholders for multimodal inputs.
|
||||
Inputs ml.Tensor
|
||||
|
||||
// Outputs are the set of indicies into Inputs for which output data should
|
||||
// be returned.
|
||||
Outputs ml.Tensor
|
||||
|
||||
// TODO maybe not the optimal way to handle this
|
||||
// Offset of final tensor in the final batch
|
||||
Offset int
|
||||
|
||||
// Positions is the position for each Input, relative to its sequence. Equal
|
||||
// in length to Inputs.
|
||||
Positions []int32
|
||||
|
||||
// Sequences is the sequence for each Input. Equal in length to Inputs.
|
||||
Sequences []int
|
||||
|
||||
// Multimodal is a set of multimodal embeddings previously created by
|
||||
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
|
||||
// models or for batches without multimodal elements.
|
||||
Multimodal []MultimodalIndex
|
||||
}
|
||||
333
x/model/model.go
333
x/model/model.go
@@ -1,333 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"log/slog"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
_ "golang.org/x/image/bmp"
|
||||
_ "golang.org/x/image/tiff"
|
||||
_ "golang.org/x/image/webp"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
fsggml "github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/x/kvcache"
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
_ "github.com/ollama/ollama/x/ml/backend"
|
||||
"github.com/ollama/ollama/x/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/x/model/input"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNoVisionModel = errors.New("this model is missing data required for image input")
|
||||
ErrUnsupportedModel = errors.New("model not supported")
|
||||
ErrUnsupportedTokenizer = errors.New("tokenizer not supported")
|
||||
)
|
||||
|
||||
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
|
||||
type Model interface {
|
||||
Forward(ml.Context, input.Batch) (ml.Tensor, error)
|
||||
|
||||
Backend() ml.Backend
|
||||
Config() config
|
||||
}
|
||||
|
||||
// MultimodalProcessor must be implemented by multimodal models.
|
||||
type MultimodalProcessor interface {
|
||||
// EncodeMultimodal processes a single input (such as an image) and
|
||||
// generates an output (typically an embedding) that can be used by the model.
|
||||
//
|
||||
// The return value is one or more tensors, each with optional model-specific
|
||||
// opaque metadata. Typically, the tensors might be views into an embedding
|
||||
// with each view representing a chunk of data that can be processed independently
|
||||
// in different batches.
|
||||
//
|
||||
// The result may be cached by the runner.
|
||||
EncodeMultimodal(ml.Context, []byte) ([]input.Multimodal, error)
|
||||
|
||||
// PostTokenize is called after tokenization to allow the model to edit the
|
||||
// input stream to correctly arrange multimodal elements.
|
||||
//
|
||||
// The input is a slice of tokens with the results of EncodeMultimodal interleaved
|
||||
// in the order that the user provided them. Each element of the slice will be
|
||||
// either a single token or single multimodal object.
|
||||
//
|
||||
// The model must ensure that inputs are stored according to how they will be
|
||||
// processed and stored in the cache. For example, Llava-style models should insert
|
||||
// placeholder tokens equal to the feature size of the corresponding image with
|
||||
// the image itself attached to and split across these tokens. When Forward is called
|
||||
// a partial subset of these tokens may be submitted according to the batch size.
|
||||
//
|
||||
// This function is also responsible for updating MultimodalHash for any Multimodal
|
||||
// that is modified to ensure that there is a unique hash value that accurately
|
||||
// represents the contents.
|
||||
PostTokenize([]*input.Input) ([]*input.Input, error)
|
||||
}
|
||||
|
||||
// Base implements the common fields and methods for all models
|
||||
type Base struct {
|
||||
b ml.Backend
|
||||
config
|
||||
}
|
||||
|
||||
type config struct {
|
||||
Cache kvcache.Cache
|
||||
}
|
||||
|
||||
// Backend returns the underlying backend that will run the model
|
||||
func (m *Base) Backend() ml.Backend {
|
||||
return m.b
|
||||
}
|
||||
|
||||
func (m *Base) Config() config {
|
||||
return m.config
|
||||
}
|
||||
|
||||
var models = make(map[string]func(fs.Config) (Model, error))
|
||||
|
||||
// Register registers a model constructor for the given architecture
|
||||
func Register(name string, f func(fs.Config) (Model, error)) {
|
||||
if _, ok := models[name]; ok {
|
||||
panic("model: model already registered")
|
||||
}
|
||||
|
||||
models[name] = f
|
||||
}
|
||||
|
||||
// New initializes a new model instance with the provided configuration based on the metadata in the model file
|
||||
func New(modelPath string, params ml.BackendParams) (Model, error) {
|
||||
b, err := ml.NewBackend(modelPath, params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m, err := modelForArch(b.Config())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
base := Base{b: b, config: m.Config()}
|
||||
v := reflect.ValueOf(m)
|
||||
v.Elem().Set(populateFields(base, v.Elem()))
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func NewTextProcessor(s string) (TextProcessor, error) {
|
||||
r, err := os.Open(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
meta, err := fsggml.Decode(r, -1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m, err := modelForArch(meta.KV())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tp, ok := m.(TextProcessor)
|
||||
if !ok {
|
||||
return nil, ErrUnsupportedTokenizer
|
||||
}
|
||||
return tp, nil
|
||||
}
|
||||
|
||||
func modelForArch(c fs.Config) (Model, error) {
|
||||
arch := c.Architecture()
|
||||
if pooling.Type(c.Uint("pooling_type")) != pooling.TypeNone {
|
||||
arch = arch + "_embed"
|
||||
}
|
||||
|
||||
f, ok := models[arch]
|
||||
if !ok {
|
||||
return nil, ErrUnsupportedModel
|
||||
}
|
||||
|
||||
return f(c)
|
||||
}
|
||||
|
||||
func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
|
||||
t := v.Type()
|
||||
|
||||
if t.Kind() == reflect.Struct {
|
||||
allNil := true
|
||||
for i := range t.NumField() {
|
||||
tt := t.Field(i).Type
|
||||
vv := v.Field(i)
|
||||
if !vv.CanSet() {
|
||||
continue
|
||||
}
|
||||
|
||||
// make a copy
|
||||
tagsCopy := tags
|
||||
if tag := t.Field(i).Tag.Get("gguf"); tag != "" {
|
||||
tagsCopy = append(tagsCopy, parseTag(tag))
|
||||
}
|
||||
|
||||
if tt == reflect.TypeOf((*Base)(nil)).Elem() {
|
||||
vv.Set(reflect.ValueOf(base))
|
||||
} else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() {
|
||||
var fn func([]Tag, string, string) [][]string
|
||||
fn = func(tags []Tag, prefix, suffix string) (fullNames [][]string) {
|
||||
if len(tags) > 0 {
|
||||
var names []string
|
||||
if tags[0].name != "" {
|
||||
for _, n := range append([]string{tags[0].name}, tags[0].alternatives...) {
|
||||
names = append(names, prefix+n+suffix)
|
||||
}
|
||||
}
|
||||
childNames := fn(tags[1:], tags[0].prefix, tags[0].suffix)
|
||||
if len(names) == 0 {
|
||||
// current tag has no name, use child names only
|
||||
fullNames = append(fullNames, childNames...)
|
||||
} else if len(childNames) == 0 {
|
||||
// current tag has names but no children, create branches for each name
|
||||
for _, name := range names {
|
||||
fullNames = append(fullNames, []string{name})
|
||||
}
|
||||
} else {
|
||||
// merge each name with each child
|
||||
for _, name := range names {
|
||||
for _, childName := range childNames {
|
||||
fullNames = append(fullNames, append([]string{name}, childName...))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fullNames
|
||||
}
|
||||
|
||||
names := fn(tagsCopy, "", "")
|
||||
for _, name := range names {
|
||||
if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil {
|
||||
logutil.Trace("found tensor", "", tensor)
|
||||
vv.Set(reflect.ValueOf(tensor))
|
||||
break
|
||||
}
|
||||
}
|
||||
} else if tt.Kind() == reflect.Pointer || tt.Kind() == reflect.Interface {
|
||||
setPointer(base, vv, tagsCopy)
|
||||
} else if tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array {
|
||||
for i := range vv.Len() {
|
||||
vvv := vv.Index(i)
|
||||
if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface {
|
||||
setPointer(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)}))
|
||||
} else {
|
||||
vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)})...))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !canNil(tt) || !vv.IsNil() {
|
||||
allNil = false
|
||||
}
|
||||
}
|
||||
|
||||
if allNil {
|
||||
return reflect.Zero(t)
|
||||
}
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
func setPointer(base Base, v reflect.Value, tags []Tag) {
|
||||
vv := v
|
||||
if v.Kind() == reflect.Interface {
|
||||
if v.IsNil() {
|
||||
return
|
||||
}
|
||||
|
||||
vv = vv.Elem()
|
||||
}
|
||||
|
||||
vv = reflect.Indirect(vv)
|
||||
if v.IsNil() {
|
||||
vv = reflect.New(v.Type().Elem()).Elem()
|
||||
}
|
||||
|
||||
if f := populateFields(base, vv, tags...); f.CanAddr() {
|
||||
v.Set(f.Addr())
|
||||
}
|
||||
}
|
||||
|
||||
type Tag struct {
|
||||
name,
|
||||
// prefix and suffix are applied to child tags
|
||||
prefix,
|
||||
suffix string
|
||||
alternatives []string
|
||||
}
|
||||
|
||||
func parseTag(s string) (tag Tag) {
|
||||
parts := strings.Split(s, ",")
|
||||
if len(parts) > 0 {
|
||||
tag.name = parts[0]
|
||||
|
||||
for _, part := range parts[1:] {
|
||||
if value, ok := strings.CutPrefix(part, "alt:"); ok && tag.name == "" {
|
||||
// elevate alternative to primary if no primary given
|
||||
tag.name = value
|
||||
slog.Warn("gguf tag has alt: but no primary name", "tag", s)
|
||||
} else if ok {
|
||||
tag.alternatives = append(tag.alternatives, value)
|
||||
}
|
||||
if value, ok := strings.CutPrefix(part, "pre:"); ok {
|
||||
tag.prefix = value
|
||||
}
|
||||
if value, ok := strings.CutPrefix(part, "suf:"); ok {
|
||||
tag.suffix = value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func canNil(t reflect.Type) bool {
|
||||
return t.Kind() == reflect.Chan ||
|
||||
t.Kind() == reflect.Func ||
|
||||
t.Kind() == reflect.Interface ||
|
||||
t.Kind() == reflect.Map ||
|
||||
t.Kind() == reflect.Pointer ||
|
||||
t.Kind() == reflect.Slice
|
||||
}
|
||||
|
||||
func Forward(ctx ml.Context, m Model, batch input.Batch) (ml.Tensor, error) {
|
||||
if len(batch.Positions) != len(batch.Sequences) {
|
||||
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
|
||||
}
|
||||
|
||||
if len(batch.Positions) < 1 {
|
||||
return nil, errors.New("batch size cannot be less than 1")
|
||||
}
|
||||
|
||||
cache := m.Config().Cache
|
||||
if cache != nil {
|
||||
err := cache.StartForward(ctx, batch, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
t, err := m.Forward(ctx, batch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx.Forward(t)
|
||||
|
||||
return t, nil
|
||||
}
|
||||
@@ -1,58 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package gemma3
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
"github.com/ollama/ollama/x/ml/nn"
|
||||
"github.com/ollama/ollama/x/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/x/model"
|
||||
"github.com/ollama/ollama/x/model/input"
|
||||
)
|
||||
|
||||
type embedModel struct {
|
||||
model.Base
|
||||
model.SentencePiece
|
||||
|
||||
*TextModel
|
||||
poolingType pooling.Type
|
||||
|
||||
Dense [2]*nn.Linear `gguf:"dense"`
|
||||
}
|
||||
|
||||
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
|
||||
hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
|
||||
for _, dense := range m.Dense {
|
||||
hiddenStates = dense.Forward(ctx, hiddenStates)
|
||||
}
|
||||
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
|
||||
return hiddenStates, nil
|
||||
}
|
||||
|
||||
func newEmbedModel(c fs.Config) (model.Model, error) {
|
||||
m := &embedModel{
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{
|
||||
int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
||||
int32(c.Uint("tokenizer.ggml.eot_token_id", 106)),
|
||||
},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
),
|
||||
TextModel: newTextModel(c),
|
||||
poolingType: pooling.Type(c.Uint("pooling_type", 0)),
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
@@ -1,157 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package gemma3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"image"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/x/kvcache"
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
"github.com/ollama/ollama/x/ml/nn"
|
||||
"github.com/ollama/ollama/x/model"
|
||||
"github.com/ollama/ollama/x/model/input"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.SentencePiece
|
||||
|
||||
*VisionModel `gguf:"vision_tower.vision_model"`
|
||||
*TextModel `gguf:"language_model.model"`
|
||||
|
||||
*MultiModalProjector `gguf:"multi_modal_projector"`
|
||||
|
||||
ImageProcessor
|
||||
}
|
||||
|
||||
var _ model.MultimodalProcessor = (*Model)(nil)
|
||||
|
||||
type MultiModalProjector struct {
|
||||
SoftEmbNorm *nn.RMSNorm `gguf:"mm_soft_emb_norm"`
|
||||
InputProjection *nn.Linear `gguf:"mm_input_projection_weight"` // TODO .weight vs _weight
|
||||
|
||||
tokensPerImage int
|
||||
}
|
||||
|
||||
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, imageSize, patchSize int, eps float32) ml.Tensor {
|
||||
l := visionOutputs.Dim(0)
|
||||
|
||||
visionOutputs = visionOutputs.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false)
|
||||
patchesPerImage := imageSize / patchSize
|
||||
visionOutputs = visionOutputs.Reshape(ctx, patchesPerImage, patchesPerImage, l)
|
||||
|
||||
kernelSize := patchesPerImage / int(math.Sqrt(float64(p.tokensPerImage)))
|
||||
visionOutputs = visionOutputs.AvgPool2D(ctx, kernelSize, kernelSize, 0)
|
||||
visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0)*visionOutputs.Dim(1), l)
|
||||
visionOutputs = visionOutputs.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false)
|
||||
visionOutputs = p.SoftEmbNorm.Forward(ctx, visionOutputs, eps)
|
||||
|
||||
// TODO: inputProjection must be transposed since they're incompatible with visionOutputs
|
||||
visionOutputs = visionOutputs.Matmul(ctx, p.InputProjection.Weight.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false))
|
||||
return visionOutputs
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
// slog.Info("XXX Config", "c", c)
|
||||
m := Model{
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{
|
||||
int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
||||
int32(c.Uint("tokenizer.ggml.eot_token_id", 106)),
|
||||
},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
),
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
TextModel: newTextModel(c),
|
||||
MultiModalProjector: &MultiModalProjector{
|
||||
tokensPerImage: int(c.Uint("mm_tokens_per_image", 256)),
|
||||
},
|
||||
}
|
||||
|
||||
// slidingWindowLen := int32(c.Uint("attention.sliding_window"))
|
||||
// m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
|
||||
|
||||
// TODO need to implement sliding window...
|
||||
m.Cache = kvcache.NewCausalCache()
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
|
||||
if len(m.VisionModel.Layers) == 0 {
|
||||
return nil, model.ErrNoVisionModel
|
||||
}
|
||||
|
||||
image, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f32s, err := m.ImageProcessor.ProcessImage(image)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pixelValues := ctx.Input().FromFloats(f32s,
|
||||
m.ImageProcessor.imageSize,
|
||||
m.ImageProcessor.imageSize,
|
||||
m.ImageProcessor.numChannels,
|
||||
)
|
||||
|
||||
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
|
||||
visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.imageSize, m.patchSize, m.VisionModel.eps)
|
||||
return []input.Multimodal{{Tensor: visionOutputs}}, nil
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
var result []*input.Input
|
||||
|
||||
for _, inp := range inputs {
|
||||
if len(inp.Multimodal) == 0 {
|
||||
result = append(result, inp)
|
||||
} else {
|
||||
inputMultimodal := inp.Multimodal[0].Tensor
|
||||
|
||||
result = append(result,
|
||||
&input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
|
||||
&input.Input{Token: 255999}, // "<start_of_image>""
|
||||
&input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
|
||||
)
|
||||
|
||||
// add image token placeholders
|
||||
result = append(result, slices.Repeat([]*input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
|
||||
|
||||
result = append(result,
|
||||
&input.Input{Token: 256000}, // <end_of_image>
|
||||
&input.Input{Token: 108}, // "\n\n"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("gemma3", New)
|
||||
model.Register("gemma3_embed", newEmbedModel)
|
||||
}
|
||||
@@ -1,211 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package gemma3
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/x/kvcache"
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
"github.com/ollama/ollama/x/ml/nn"
|
||||
"github.com/ollama/ollama/x/model/input"
|
||||
)
|
||||
|
||||
type TextConfig struct {
|
||||
hiddenSize, numHeads, numKVHeads int
|
||||
attnKeyLen int
|
||||
eps, ropeScale float32
|
||||
ropeLocalBase, ropeGlobalBase float32
|
||||
largeModelScaling bool
|
||||
}
|
||||
|
||||
type TextModel struct {
|
||||
TokenEmbedding *nn.Embedding `gguf:"embed_tokens"`
|
||||
Layers []TextLayer `gguf:"layers"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"norm"`
|
||||
Output *nn.Linear `gguf:"embed_tokens"`
|
||||
|
||||
*TextConfig
|
||||
}
|
||||
|
||||
const (
|
||||
gemmaGlobalCacheCount = 6
|
||||
gemma27BLayerCount = 62
|
||||
)
|
||||
|
||||
// const (
|
||||
// cacheTypeSWA = iota
|
||||
// cacheTypeCausal
|
||||
// )
|
||||
|
||||
func newTextModel(c fs.Config) *TextModel {
|
||||
numBlocks := int(c.Uint("block_count"))
|
||||
|
||||
m := TextModel{
|
||||
Layers: make([]TextLayer, numBlocks),
|
||||
TextConfig: &TextConfig{
|
||||
hiddenSize: int(c.Uint("embedding_length")), // 2560 -- config.json: text_config.hidden_size
|
||||
numHeads: int(c.Uint("attention.head_count")), // 8 -- hard coded in python implementation for the model, 4 in some places, then overridden as 8
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")), // 4 -- same as above
|
||||
attnKeyLen: int(c.Uint("attention.key_length", 256)), //256 -- rope settings, hardcoded in model definition python
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), // 1e-06 - hardcoded in model definition python
|
||||
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0), // 10000 - hardcoded in python
|
||||
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0), // 1e+06 - hardcoded in python
|
||||
ropeScale: 1, // 1 - default is 1, implied in python code
|
||||
// vocabSize: vocabSize, // 262144
|
||||
// attnValLen: int(c.Uint("attention.value_length", 256)), //256
|
||||
// NOTE: the rope.scaling.factor is set incorrectly in the official QAT weights
|
||||
// (8 instead of 1)
|
||||
// ropeScale: c.Float("rope.scaling.factor", 1.0),
|
||||
},
|
||||
}
|
||||
if numBlocks == gemma27BLayerCount {
|
||||
m.largeModelScaling = true
|
||||
}
|
||||
|
||||
return &m
|
||||
}
|
||||
|
||||
type TextSelfAttention struct {
|
||||
Query *nn.Linear `gguf:"q_proj"`
|
||||
QueryNorm *nn.RMSNorm `gguf:"q_norm"`
|
||||
Key *nn.Linear `gguf:"k_proj"`
|
||||
KeyNorm *nn.RMSNorm `gguf:"k_norm"`
|
||||
Value *nn.Linear `gguf:"v_proj"`
|
||||
Output *nn.Linear `gguf:"o_proj"`
|
||||
}
|
||||
|
||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState ml.Tensor, offset int, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
|
||||
B := hiddenState.Dim(0)
|
||||
L := hiddenState.Dim(1)
|
||||
ropeBase := opts.ropeLocalBase
|
||||
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
||||
ropeBase = opts.ropeGlobalBase
|
||||
}
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, B, L, opts.numHeads, -1).Transpose(ctx, 0, 2, 1, 3)
|
||||
k = k.Reshape(ctx, B, L, opts.numKVHeads, -1).Transpose(ctx, 0, 2, 1, 3)
|
||||
v = v.Reshape(ctx, B, L, opts.numKVHeads, -1).Transpose(ctx, 0, 2, 1, 3).Contiguous(ctx, false)
|
||||
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
|
||||
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
|
||||
traditional := false
|
||||
q = q.RoPE(ctx, opts.attnKeyLen, traditional, opts.ropeScale, offset, ml.WithRoPEBase(ropeBase))
|
||||
k = k.RoPE(ctx, opts.attnKeyLen, traditional, opts.ropeScale, offset, ml.WithRoPEBase(ropeBase))
|
||||
|
||||
// TODO - this is wrong somehow so commenting out
|
||||
// if opts.largeModelScaling {
|
||||
// q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
||||
// } else {
|
||||
// q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
|
||||
// }
|
||||
|
||||
scaleFactor := math.Pow(256, -0.5)
|
||||
|
||||
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
|
||||
kqv = kqv.Transpose(ctx, 0, 2, 1, 3).Reshape(ctx, B, L, -1)
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
}
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
// ropeBase := m.TextConfig.ropeLocalBase
|
||||
// if (layer+1)%gemmaGlobalCacheCount == 0 {
|
||||
// ropeBase = m.TextConfig.ropeGlobalBase
|
||||
// }
|
||||
// q = q.RoPE(ctx, opts.attnKeyLen, traditional, opts.ropeScale, offset, ml.WithRoPEBase(ropeBase))
|
||||
panic("not yet implemented")
|
||||
// return key.RoPE(ctx, shift, m.TextConfig.attnKeyLen, ropeBase, 1/m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
|
||||
}
|
||||
|
||||
type TextMLP struct {
|
||||
Up *nn.Linear `gguf:"up_proj"`
|
||||
Down *nn.Linear `gguf:"down_proj"`
|
||||
Gate *nn.Linear `gguf:"gate_proj"`
|
||||
}
|
||||
|
||||
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
type TextLayer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"input_layernorm"`
|
||||
SelfAttention *TextSelfAttention `gguf:"self_attn"`
|
||||
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_layernorm"`
|
||||
MLPNorm *nn.RMSNorm `gguf:"pre_feedforward_layernorm"`
|
||||
MLP *TextMLP `gguf:"mlp"`
|
||||
PostMLPNorm *nn.RMSNorm `gguf:"post_feedforward_layernorm"`
|
||||
}
|
||||
|
||||
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, outputs ml.Tensor, offset int, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
|
||||
residual := hiddenState
|
||||
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, offset, cache, opts)
|
||||
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
|
||||
// In the final layer (outputs != nil), optimize by pruning to just the token positions
|
||||
// we need logits for.
|
||||
if outputs != nil {
|
||||
hiddenState = hiddenState.TakeAxes(ctx, outputs, 1)
|
||||
residual = residual.TakeAxes(ctx, outputs, 1)
|
||||
}
|
||||
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
residual = hiddenState
|
||||
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.MLP.Forward(ctx, hiddenState, opts) // TODO this is where it goes bad most likely...
|
||||
hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
|
||||
|
||||
// set image embeddings
|
||||
// var except []int
|
||||
// for _, image := range batch.Multimodal {
|
||||
// visionOutputs := image.Multimodal[0].Tensor
|
||||
// ctx.Forward(visionOutputs.Copy(ctx, hiddenState.AsStrided(ctx,
|
||||
// []int{visionOutputs.Dim(0) * visionOutputs.Dim(1)},
|
||||
// []int{image.Index * hiddenState.Stride(1)}, 0)))
|
||||
|
||||
// for i := range visionOutputs.Dim(1) {
|
||||
// except = append(except, image.Index+i)
|
||||
// }
|
||||
// }
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
// gemma alternates between the sliding window (local) and causal (global)
|
||||
// kv cache every 6 layers
|
||||
if cache != nil {
|
||||
// cacheType := cacheTypeSWA
|
||||
// if (i+1)%gemmaGlobalCacheCount == 0 {
|
||||
// cacheType = cacheTypeCausal
|
||||
// }
|
||||
cache.SetLayer(i)
|
||||
|
||||
// TODO this needs to come back
|
||||
// wc := cache.(*kvcache.WrapperCache)
|
||||
// wc.SetLayerType(cacheType)
|
||||
|
||||
// if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
|
||||
// causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
|
||||
// }
|
||||
}
|
||||
|
||||
var offset int
|
||||
var lastLayerOutputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
offset = batch.Offset
|
||||
lastLayerOutputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, i, hiddenState, lastLayerOutputs, offset, cache, m.TextConfig)
|
||||
}
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
return hiddenState
|
||||
}
|
||||
@@ -1,121 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package gemma3
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
"github.com/ollama/ollama/x/ml/nn"
|
||||
)
|
||||
|
||||
var batchSize int = 1
|
||||
|
||||
type VisionSelfAttention struct {
|
||||
Query *nn.Linear `gguf:"self_attn.q_proj"`
|
||||
Key *nn.Linear `gguf:"self_attn.k_proj"`
|
||||
Value *nn.Linear `gguf:"self_attn.v_proj"`
|
||||
Output *nn.Linear `gguf:"self_attn.out_proj"`
|
||||
}
|
||||
|
||||
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
|
||||
query := sa.Query.Forward(ctx, hiddenState)
|
||||
key := sa.Key.Forward(ctx, hiddenState)
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize)
|
||||
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
|
||||
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
||||
|
||||
hiddenState = sa.Output.Forward(ctx, attention)
|
||||
return hiddenState
|
||||
}
|
||||
|
||||
type VisionMLP struct {
|
||||
FC1 *nn.Linear `gguf:"fc1"`
|
||||
FC2 *nn.Linear `gguf:"fc2"`
|
||||
}
|
||||
|
||||
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
hiddenState = mlp.FC1.Forward(ctx, hiddenState).GELU(ctx)
|
||||
hiddenState = mlp.FC2.Forward(ctx, hiddenState)
|
||||
return hiddenState
|
||||
}
|
||||
|
||||
type VisionEncoderLayer struct {
|
||||
LayerNorm1 *nn.LayerNorm `gguf:"layer_norm1"`
|
||||
SelfAttention *VisionSelfAttention
|
||||
|
||||
LayerNorm2 *nn.LayerNorm `gguf:"layer_norm2"`
|
||||
MLP *VisionMLP `gguf:"mlp"`
|
||||
}
|
||||
|
||||
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
// self attention
|
||||
hiddenState = e.LayerNorm1.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = e.SelfAttention.Forward(ctx, hiddenState, opts)
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
residual = hiddenState
|
||||
|
||||
// feed forward
|
||||
hiddenState = e.LayerNorm2.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
|
||||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
type VisionModelOptions struct {
|
||||
hiddenSize, numHeads int
|
||||
imageSize, patchSize int
|
||||
eps float32
|
||||
}
|
||||
|
||||
type VisionModel struct {
|
||||
PatchEmbedding *nn.Conv2D `gguf:"embeddings.patch_embedding"`
|
||||
PositionEmbedding *nn.Embedding `gguf:"embeddings.position_embedding"`
|
||||
PostLayerNorm *nn.LayerNorm `gguf:"post_layernorm"`
|
||||
|
||||
Layers []VisionEncoderLayer `gguf:"encoder.layers"`
|
||||
|
||||
*VisionModelOptions
|
||||
}
|
||||
|
||||
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
|
||||
numPatches := (m.imageSize / m.patchSize) * (m.imageSize / m.patchSize)
|
||||
|
||||
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
|
||||
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
|
||||
hiddenState = hiddenState.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false)
|
||||
|
||||
positionIDs := ctx.Arange(0, float32(numPatches), 1, ml.DTypeInt32)
|
||||
hiddenState = hiddenState.Add(ctx, m.PositionEmbedding.Forward(ctx, positionIDs))
|
||||
|
||||
for _, layer := range m.Layers {
|
||||
hiddenState = layer.Forward(ctx, hiddenState, m.VisionModelOptions)
|
||||
}
|
||||
|
||||
hiddenState = m.PostLayerNorm.Forward(ctx, hiddenState, m.eps)
|
||||
return hiddenState
|
||||
}
|
||||
|
||||
func newVisionModel(c fs.Config) *VisionModel {
|
||||
return &VisionModel{
|
||||
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")),
|
||||
VisionModelOptions: &VisionModelOptions{
|
||||
hiddenSize: int(c.Uint("vision.embedding_length")),
|
||||
numHeads: int(c.Uint("vision.attention.head_count")),
|
||||
|
||||
imageSize: int(c.Uint("vision.image_size")),
|
||||
patchSize: int(c.Uint("vision.patch_size")),
|
||||
|
||||
eps: c.Float("vision.attention.layer_norm_epsilon"),
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -1,60 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package gemma3
|
||||
|
||||
import (
|
||||
"image"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/model/imageproc"
|
||||
)
|
||||
|
||||
type ImageProcessor struct {
|
||||
imageSize, patchSize, numChannels int
|
||||
}
|
||||
|
||||
func newImageProcessor(c fs.Config) ImageProcessor {
|
||||
return ImageProcessor{
|
||||
imageSize: int(c.Uint("vision.image_size")),
|
||||
patchSize: int(c.Uint("vision.patch_size")),
|
||||
numChannels: int(c.Uint("vision.num_channels")),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ImageProcessor) pack(img image.Image, mean, std [3]float32) []float32 {
|
||||
var pixelVals, rVals, gVals, bVals []float32
|
||||
|
||||
bounds := img.Bounds()
|
||||
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
|
||||
for x := bounds.Min.X; x < bounds.Max.X; x++ {
|
||||
c := img.At(x, y)
|
||||
r, g, b, _ := c.RGBA()
|
||||
rVal := float32(r>>8) / 255.0
|
||||
gVal := float32(g>>8) / 255.0
|
||||
bVal := float32(b>>8) / 255.0
|
||||
|
||||
rVal = (rVal - mean[0]) / std[0]
|
||||
gVal = (gVal - mean[1]) / std[1]
|
||||
bVal = (bVal - mean[2]) / std[2]
|
||||
|
||||
rVals = append(rVals, rVal)
|
||||
gVals = append(gVals, gVal)
|
||||
bVals = append(bVals, bVal)
|
||||
}
|
||||
}
|
||||
|
||||
pixelVals = append(pixelVals, rVals...)
|
||||
pixelVals = append(pixelVals, gVals...)
|
||||
pixelVals = append(pixelVals, bVals...)
|
||||
|
||||
return pixelVals
|
||||
}
|
||||
|
||||
func (p ImageProcessor) ProcessImage(img image.Image) ([]float32, error) {
|
||||
outputSize := image.Point{p.imageSize, p.imageSize}
|
||||
newImage := imageproc.Composite(img)
|
||||
newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear)
|
||||
|
||||
data := p.pack(newImage, imageproc.ImageNetStandardMean, imageproc.ImageNetStandardSTD)
|
||||
return data, nil
|
||||
}
|
||||
@@ -1,3 +0,0 @@
|
||||
package models
|
||||
|
||||
// _ "github.com/ollama/ollama/x/model/models/gemma3"
|
||||
@@ -1,249 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
const spmWhitespaceSep = "▁"
|
||||
|
||||
type SentencePiece struct {
|
||||
maxTokenLen int
|
||||
vocab *Vocabulary
|
||||
}
|
||||
|
||||
var _ TextProcessor = (*SentencePiece)(nil)
|
||||
|
||||
func (spm SentencePiece) Vocabulary() *Vocabulary {
|
||||
return spm.vocab
|
||||
}
|
||||
|
||||
func NewSentencePiece(vocab *Vocabulary) SentencePiece {
|
||||
logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
|
||||
|
||||
counter := map[int]int{}
|
||||
var maxTokenLen int
|
||||
for cnt := range vocab.Types {
|
||||
switch vocab.Types[cnt] {
|
||||
case TOKEN_TYPE_NORMAL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_UNUSED:
|
||||
maxTokenLen = max(maxTokenLen, len(vocab.Values[cnt]))
|
||||
fallthrough
|
||||
default:
|
||||
counter[int(vocab.Types[cnt])] += 1
|
||||
}
|
||||
}
|
||||
|
||||
logutil.Trace("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
|
||||
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
|
||||
"max token len", maxTokenLen)
|
||||
|
||||
return SentencePiece{
|
||||
maxTokenLen: maxTokenLen,
|
||||
vocab: vocab,
|
||||
}
|
||||
}
|
||||
|
||||
func (spm SentencePiece) Is(id int32, special Special) bool {
|
||||
return spm.vocab.Is(id, special)
|
||||
}
|
||||
|
||||
func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
fragments := []fragment{{value: s}}
|
||||
for _, special := range spm.vocab.SpecialVocabulary() {
|
||||
id := spm.vocab.Encode(special)
|
||||
for i := 0; i < len(fragments); i++ {
|
||||
frag := fragments[i]
|
||||
if len(frag.ids) > 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var middle []fragment
|
||||
switch i := strings.Index(frag.value, special); {
|
||||
case i < 0:
|
||||
middle = append(middle, frag)
|
||||
case i > 0:
|
||||
middle = append(middle, fragment{value: frag.value[:i]})
|
||||
fallthrough
|
||||
default:
|
||||
middle = append(middle, fragment{value: special, ids: []int32{id}})
|
||||
if rest := frag.value[i+len(special):]; rest != "" {
|
||||
middle = append(middle, fragment{value: rest})
|
||||
}
|
||||
}
|
||||
|
||||
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
|
||||
}
|
||||
}
|
||||
|
||||
var ids []int32
|
||||
for _, frag := range fragments {
|
||||
if len(frag.ids) > 0 {
|
||||
ids = append(ids, frag.ids...)
|
||||
continue
|
||||
}
|
||||
|
||||
text := strings.ReplaceAll(frag.value, " ", spmWhitespaceSep)
|
||||
|
||||
if id := spm.vocab.Encode(text); id >= 0 {
|
||||
ids = append(ids, id)
|
||||
continue
|
||||
}
|
||||
|
||||
q := &queue{}
|
||||
heap.Init(q)
|
||||
|
||||
runes := []rune(text)
|
||||
merges := make([]merge, len(runes))
|
||||
for r := range runes {
|
||||
merges[r] = merge{
|
||||
p: r - 1,
|
||||
n: r + 1,
|
||||
runes: []rune{runes[r]},
|
||||
}
|
||||
}
|
||||
|
||||
pairwise := func(a, b int) *candidate {
|
||||
if a < 0 || b >= len(runes) {
|
||||
return nil
|
||||
}
|
||||
|
||||
left, right := string(merges[a].runes), string(merges[b].runes)
|
||||
if id := spm.vocab.Encode(left + right); id >= 0 {
|
||||
return &candidate{
|
||||
a: a,
|
||||
b: b,
|
||||
score: spm.vocab.Scores[id],
|
||||
size: len(left) + len(right),
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
for i := range len(runes) - 1 {
|
||||
if pair := pairwise(i, i+1); pair != nil {
|
||||
heap.Push(q, pair)
|
||||
}
|
||||
}
|
||||
|
||||
for q.Len() > 0 {
|
||||
pair := heap.Pop(q).(*candidate)
|
||||
left, right := merges[pair.a], merges[pair.b]
|
||||
|
||||
if string(left.runes) == "" || string(right.runes) == "" || len(string(left.runes))+len(string(right.runes)) != pair.size {
|
||||
continue
|
||||
}
|
||||
|
||||
merges[pair.a].runes = append(left.runes, right.runes...)
|
||||
merges[pair.b].runes = nil
|
||||
merges[pair.a].n = right.n
|
||||
if right.n < len(merges) {
|
||||
merges[right.n].p = pair.a
|
||||
}
|
||||
|
||||
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
|
||||
heap.Push(q, pair)
|
||||
}
|
||||
|
||||
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
|
||||
heap.Push(q, pair)
|
||||
}
|
||||
}
|
||||
|
||||
for _, merge := range merges {
|
||||
if token := string(merge.runes); token != "" {
|
||||
id := spm.vocab.Encode(token)
|
||||
|
||||
if id >= 0 {
|
||||
ids = append(ids, id)
|
||||
continue
|
||||
}
|
||||
|
||||
// Fallback to byte tokenization
|
||||
var result []int32
|
||||
for _, b := range []byte(token) {
|
||||
byteToken := fmt.Sprintf("<0x%02X>", b)
|
||||
unknownID := spm.vocab.Encode(byteToken)
|
||||
if unknownID >= 0 {
|
||||
result = append(result, unknownID)
|
||||
} else {
|
||||
slog.Debug("unknown byte token", "byte", b, "token", byteToken)
|
||||
}
|
||||
}
|
||||
|
||||
ids = append(ids, result...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if addSpecial {
|
||||
ids = spm.vocab.addSpecials(ids)
|
||||
}
|
||||
|
||||
logutil.Trace("encoded", "string", s, "ids", ids)
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
type candidate struct {
|
||||
a, b int
|
||||
score float32
|
||||
size int
|
||||
}
|
||||
|
||||
type queue []*candidate
|
||||
|
||||
func (q queue) Len() int { return len(q) }
|
||||
|
||||
func (q queue) Less(i, j int) bool {
|
||||
return (q[i].score > q[j].score) || (q[i].score == q[j].score && q[i].a < q[j].a)
|
||||
}
|
||||
|
||||
func (q queue) Swap(i, j int) { q[i], q[j] = q[j], q[i] }
|
||||
|
||||
func (q *queue) Push(x interface{}) {
|
||||
item := x.(*candidate)
|
||||
*q = append(*q, item)
|
||||
}
|
||||
|
||||
func (q *queue) Pop() interface{} {
|
||||
old := *q
|
||||
n := len(old)
|
||||
item := old[n-1]
|
||||
*q = old[0 : n-1]
|
||||
return item
|
||||
}
|
||||
|
||||
func (spm SentencePiece) Decode(ids []int32) (string, error) {
|
||||
var sb strings.Builder
|
||||
for _, id := range ids {
|
||||
data := spm.vocab.Decode(id)
|
||||
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
|
||||
|
||||
// For tokenizers that use byte tokens like "<0xEA>"
|
||||
// convert them to the partial unicode character
|
||||
// so they are buffered correctly by the runner instead
|
||||
// of being sent back to the api as "<0xEA>"
|
||||
if len(data) == 6 && strings.HasPrefix(data, "<0x") && strings.HasSuffix(data, ">") {
|
||||
byteVal, err := strconv.ParseUint(data[1:5], 0, 8)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse hex byte: %v", err)
|
||||
}
|
||||
|
||||
if err := sb.WriteByte(byte(byteVal)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
} else {
|
||||
if _, err := sb.WriteString(data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logutil.Trace("decoded", "ids", ids, "string", sb.String())
|
||||
return sb.String(), nil
|
||||
}
|
||||
@@ -1,172 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/ollama/ollama/convert/sentencepiece"
|
||||
)
|
||||
|
||||
func loadSentencePieceVocab(t *testing.T) SentencePiece {
|
||||
t.Helper()
|
||||
|
||||
bts, err := os.ReadFile(filepath.Join("..", "..", "model", "testdata", "gemma2", "tokenizer.model"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var spm sentencepiece.ModelProto
|
||||
if err := proto.Unmarshal(bts, &spm); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var v Vocabulary
|
||||
|
||||
for _, piece := range spm.GetPieces() {
|
||||
v.Values = append(v.Values, piece.GetPiece())
|
||||
v.Scores = append(v.Scores, piece.GetScore())
|
||||
switch t := piece.GetType(); t {
|
||||
case sentencepiece.ModelProto_SentencePiece_UNKNOWN,
|
||||
sentencepiece.ModelProto_SentencePiece_CONTROL,
|
||||
sentencepiece.ModelProto_SentencePiece_UNUSED,
|
||||
sentencepiece.ModelProto_SentencePiece_BYTE:
|
||||
v.Types = append(v.Types, int32(t))
|
||||
default:
|
||||
tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
|
||||
// todo parse the special tokens file
|
||||
// - this will roundtrip correctly but the <start_of_turn> and
|
||||
// <end_of_turn> tokens aren't processed
|
||||
v.Types = append(v.Types, tt)
|
||||
}
|
||||
}
|
||||
|
||||
return NewSentencePiece(&v)
|
||||
}
|
||||
|
||||
func TestSentencePieceEncode(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
slog.SetDefault(logger)
|
||||
|
||||
tokenizer := loadSentencePieceVocab(t)
|
||||
|
||||
t.Run("basic roundtrip", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []string{
|
||||
"hello",
|
||||
"hello ",
|
||||
"hello ",
|
||||
" hello",
|
||||
" hello ",
|
||||
" hello ",
|
||||
"hello world",
|
||||
"请考试我的软件!12345",
|
||||
"你好",
|
||||
"Hello 你好 world!",
|
||||
"Special characters: !@#$%^&*()_+-=[]{}|;':\",./<>?",
|
||||
"Multilingual: 你好 こんにちは Привет Hola مرحبا",
|
||||
"Numbers and symbols: 123456789 +- */",
|
||||
"Special tokens: <bos> text <eos>",
|
||||
"Code snippets: func main() { fmt.Println(\"Hello World\") }",
|
||||
"Long text: " + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " +
|
||||
"Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " +
|
||||
"Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris.",
|
||||
}
|
||||
|
||||
for _, want := range cases {
|
||||
ids, err := tokenizer.Encode(want, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if got, err := tokenizer.Decode(ids); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if got != want {
|
||||
t.Errorf("got %q, want %q [%#v]", got, want, ids)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("special tokens", func(t *testing.T) {
|
||||
type candidate struct {
|
||||
token string
|
||||
ids []int32
|
||||
}
|
||||
|
||||
cases := []candidate{
|
||||
{"<bos>", []int32{2}},
|
||||
{"<eos>", []int32{1}},
|
||||
}
|
||||
|
||||
for _, want := range cases {
|
||||
ids, err := tokenizer.Encode(want.token, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !slices.Equal(ids, want.ids) {
|
||||
t.Errorf("got %#v, want %#v", ids, want.ids)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSentencePieceDecodeByteTokens(t *testing.T) {
|
||||
vocab := &Vocabulary{
|
||||
Values: []string{
|
||||
"normal",
|
||||
"<0xEA>",
|
||||
"<0x41>",
|
||||
"<0xC3>",
|
||||
"<0xA3>",
|
||||
},
|
||||
Types: []int32{
|
||||
TOKEN_TYPE_NORMAL,
|
||||
TOKEN_TYPE_BYTE,
|
||||
TOKEN_TYPE_BYTE,
|
||||
TOKEN_TYPE_BYTE,
|
||||
TOKEN_TYPE_BYTE,
|
||||
},
|
||||
Scores: []float32{0, 0, 0, 0, 0},
|
||||
}
|
||||
|
||||
spm := NewSentencePiece(vocab)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ids []int32
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "single byte token",
|
||||
ids: []int32{1},
|
||||
expected: "\xea",
|
||||
},
|
||||
{
|
||||
name: "ASCII byte token",
|
||||
ids: []int32{2},
|
||||
expected: "A",
|
||||
},
|
||||
{
|
||||
name: "multiple byte tokens forming UTF-8 character",
|
||||
ids: []int32{3, 4},
|
||||
expected: "ã",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := spm.Decode(tt.ids)
|
||||
if err != nil {
|
||||
t.Errorf("failed to decode token IDs %v: %v", tt.ids, err)
|
||||
}
|
||||
if result != tt.expected {
|
||||
t.Errorf("got %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
package model
|
||||
|
||||
const (
|
||||
TOKEN_TYPE_NORMAL = iota + 1
|
||||
TOKEN_TYPE_UNKNOWN
|
||||
TOKEN_TYPE_CONTROL
|
||||
TOKEN_TYPE_USER_DEFINED
|
||||
TOKEN_TYPE_UNUSED
|
||||
TOKEN_TYPE_BYTE
|
||||
)
|
||||
|
||||
type TextProcessor interface {
|
||||
Encode(s string, addSpecial bool) ([]int32, error)
|
||||
Decode([]int32) (string, error)
|
||||
Is(int32, Special) bool
|
||||
Vocabulary() *Vocabulary
|
||||
}
|
||||
@@ -1,112 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"slices"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Special int32
|
||||
|
||||
const (
|
||||
SpecialBOS Special = iota
|
||||
SpecialEOS
|
||||
)
|
||||
|
||||
type Vocabulary struct {
|
||||
Values []string
|
||||
Types []int32
|
||||
Scores []float32
|
||||
Merges []string
|
||||
|
||||
BOS, EOS []int32
|
||||
AddBOS, AddEOS bool
|
||||
|
||||
specialOnce sync.Once
|
||||
special []string
|
||||
|
||||
valuesOnce sync.Once
|
||||
values map[string]int32
|
||||
|
||||
mergeOnce sync.Once
|
||||
merge map[string]int32
|
||||
}
|
||||
|
||||
func (v *Vocabulary) Is(id int32, special Special) bool {
|
||||
switch special {
|
||||
case SpecialBOS:
|
||||
return slices.Contains(v.BOS, id)
|
||||
case SpecialEOS:
|
||||
return slices.Contains(v.EOS, id)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (v *Vocabulary) addSpecials(ids []int32) []int32 {
|
||||
if v.AddBOS && len(v.BOS) > 0 {
|
||||
if len(ids) > 0 && slices.Contains(v.BOS, ids[0]) {
|
||||
slog.Warn("adding bos token to prompt which already has it", "id", v.BOS)
|
||||
}
|
||||
|
||||
slog.Debug("adding bos token to prompt", "id", v.BOS[0])
|
||||
ids = append([]int32{v.BOS[0]}, ids...)
|
||||
}
|
||||
|
||||
if v.AddEOS && len(v.EOS) > 0 {
|
||||
if len(ids) > 0 && slices.Contains(v.BOS, ids[len(ids)-1]) {
|
||||
slog.Warn("adding eos token to prompt which already has it", "id", v.EOS)
|
||||
}
|
||||
|
||||
slog.Debug("adding eos token to prompt", "id", v.EOS[0])
|
||||
ids = append(ids, v.EOS[0])
|
||||
}
|
||||
|
||||
return ids
|
||||
}
|
||||
|
||||
func (v *Vocabulary) Encode(s string) int32 {
|
||||
v.valuesOnce.Do(func() {
|
||||
v.values = make(map[string]int32, len(v.Values))
|
||||
for i, value := range v.Values {
|
||||
v.values[value] = int32(i)
|
||||
}
|
||||
})
|
||||
|
||||
if id, ok := v.values[s]; ok {
|
||||
return id
|
||||
}
|
||||
|
||||
return -1
|
||||
}
|
||||
|
||||
func (v *Vocabulary) Decode(id int32) string {
|
||||
return v.Values[id]
|
||||
}
|
||||
|
||||
func (v *Vocabulary) SpecialVocabulary() []string {
|
||||
v.specialOnce.Do(func() {
|
||||
for i := range v.Values {
|
||||
if v.Types[i] == TOKEN_TYPE_CONTROL || v.Types[i] == TOKEN_TYPE_USER_DEFINED {
|
||||
v.special = append(v.special, v.Values[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return v.special
|
||||
}
|
||||
|
||||
func (v *Vocabulary) Merge(left, right string) int {
|
||||
v.mergeOnce.Do(func() {
|
||||
v.merge = make(map[string]int32, len(v.Merges))
|
||||
for i, merge := range v.Merges {
|
||||
v.merge[merge] = int32(i)
|
||||
}
|
||||
})
|
||||
|
||||
if id, ok := v.merge[left+" "+right]; ok {
|
||||
return int(id)
|
||||
}
|
||||
|
||||
return -1
|
||||
}
|
||||
@@ -1,107 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestSpecialVocabulary(t *testing.T) {
|
||||
vocab := &Vocabulary{
|
||||
Values: []string{"<|startoftext|>", "<|endoftext|>", "<|tool_call_start|>", "<|tool_call_end|>", "hi"},
|
||||
Types: []int32{TOKEN_TYPE_CONTROL, TOKEN_TYPE_CONTROL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_NORMAL},
|
||||
}
|
||||
|
||||
specialVocab := vocab.SpecialVocabulary()
|
||||
|
||||
if len(specialVocab) != 4 {
|
||||
t.Errorf("expected 4 special tokens, got %d", len(specialVocab))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddSpecialVocabulary(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
vocab *Vocabulary
|
||||
input []int32
|
||||
want []int32
|
||||
}{
|
||||
{
|
||||
name: "add bos",
|
||||
vocab: &Vocabulary{
|
||||
BOS: []int32{0},
|
||||
EOS: []int32{1},
|
||||
AddBOS: true,
|
||||
AddEOS: false,
|
||||
},
|
||||
input: []int32{2, 3, 4},
|
||||
want: []int32{0, 2, 3, 4},
|
||||
},
|
||||
{
|
||||
// TODO(mxyng): this is to match previous behaviour
|
||||
name: "add bos when already present",
|
||||
vocab: &Vocabulary{
|
||||
BOS: []int32{0},
|
||||
EOS: []int32{1},
|
||||
AddBOS: true,
|
||||
AddEOS: false,
|
||||
},
|
||||
input: []int32{0, 2, 3, 4},
|
||||
want: []int32{0, 0, 2, 3, 4},
|
||||
},
|
||||
{
|
||||
name: "add eos",
|
||||
vocab: &Vocabulary{
|
||||
BOS: []int32{0},
|
||||
EOS: []int32{1},
|
||||
AddBOS: false,
|
||||
AddEOS: true,
|
||||
},
|
||||
input: []int32{2, 3, 4},
|
||||
want: []int32{2, 3, 4, 1},
|
||||
},
|
||||
{
|
||||
// TODO(mxyng): this is to match previous behaviour
|
||||
name: "add eos when already present",
|
||||
vocab: &Vocabulary{
|
||||
BOS: []int32{0},
|
||||
EOS: []int32{1},
|
||||
AddBOS: false,
|
||||
AddEOS: true,
|
||||
},
|
||||
input: []int32{2, 3, 4, 1},
|
||||
want: []int32{2, 3, 4, 1, 1},
|
||||
},
|
||||
{
|
||||
name: "add both",
|
||||
vocab: &Vocabulary{
|
||||
BOS: []int32{0},
|
||||
EOS: []int32{1},
|
||||
AddBOS: true,
|
||||
AddEOS: true,
|
||||
},
|
||||
input: []int32{2, 3, 4},
|
||||
want: []int32{0, 2, 3, 4, 1},
|
||||
},
|
||||
{
|
||||
name: "add bos to empty inputs",
|
||||
vocab: &Vocabulary{
|
||||
BOS: []int32{0},
|
||||
EOS: []int32{1},
|
||||
AddBOS: true,
|
||||
AddEOS: false,
|
||||
},
|
||||
input: []int32{},
|
||||
want: []int32{0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tt.vocab.addSpecials(tt.input)
|
||||
if diff := cmp.Diff(tt.want, got); diff != "" {
|
||||
t.Errorf("no match (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,171 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"iter"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type WordPiece struct {
|
||||
vocab *Vocabulary
|
||||
lowercase bool
|
||||
}
|
||||
|
||||
// ggmlPrefix is the prefix used by GGML vocabularies to indicate word boundaries.
|
||||
// this differs from original word piece which uses "##" to indicate subwords.
|
||||
const ggmlPrefix = "▁"
|
||||
|
||||
var wordPieceReplacer = strings.NewReplacer(
|
||||
" .", ".",
|
||||
" ?", "?",
|
||||
" !", "!",
|
||||
" ,", ",",
|
||||
" ' ", "'",
|
||||
" n't", "n't",
|
||||
" 'm", "'m",
|
||||
" do not", " don't",
|
||||
" 's", "'s",
|
||||
" 've", "'ve",
|
||||
" 're", "'re",
|
||||
)
|
||||
|
||||
// Decode implements TextProcessor.
|
||||
func (wpm WordPiece) Decode(ids []int32) (string, error) {
|
||||
var sb strings.Builder
|
||||
for i, id := range ids {
|
||||
if id < 0 || int(id) >= len(wpm.vocab.Values) {
|
||||
return "", fmt.Errorf("invalid token id: %d", id)
|
||||
}
|
||||
|
||||
var separator string
|
||||
piece := wpm.vocab.Values[id]
|
||||
if i > 0 &&
|
||||
(strings.HasPrefix(piece, ggmlPrefix) ||
|
||||
(strings.HasPrefix(piece, "[") && strings.HasSuffix(piece, "]"))) {
|
||||
separator = " "
|
||||
}
|
||||
|
||||
sb.WriteString(wordPieceReplacer.Replace(separator + strings.TrimPrefix(piece, ggmlPrefix)))
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
// words splits a string into words, treating CJK characters as separate words.
|
||||
// TODO: this is specifically for BERT and may need to be adjusted or refactored for other models.
|
||||
func (wpm WordPiece) words(s string) iter.Seq[string] {
|
||||
return func(yield func(string) bool) {
|
||||
runes := make([]rune, 0, len(s)*3)
|
||||
for _, r := range s {
|
||||
switch {
|
||||
case r >= 0x4E00 && r <= 0x9FFF,
|
||||
r >= 0x3400 && r <= 0x4DBF,
|
||||
r >= 0x20000 && r <= 0x2A6DF,
|
||||
r >= 0x2A700 && r <= 0x2B73F,
|
||||
r >= 0x2B740 && r <= 0x2B81F,
|
||||
r >= 0x2B820 && r <= 0x2CEAF,
|
||||
r >= 0xF900 && r <= 0xFAFF,
|
||||
r >= 0x2F800 && r <= 0x2FA1F:
|
||||
runes = append(runes, ' ', r, ' ')
|
||||
default:
|
||||
runes = append(runes, r)
|
||||
}
|
||||
}
|
||||
|
||||
for w := range strings.FieldsFuncSeq(string(runes), unicode.IsSpace) {
|
||||
// split on but keep punctuation
|
||||
var start int
|
||||
for start < len(w) {
|
||||
end := strings.IndexFunc(w[start:], unicode.IsPunct)
|
||||
if end < 0 {
|
||||
end = len(w) - start
|
||||
} else if end == 0 {
|
||||
end = 1
|
||||
}
|
||||
|
||||
if !yield(w[start : start+end]) {
|
||||
return
|
||||
}
|
||||
|
||||
start += end
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Encode implements TextProcessor.
|
||||
func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
var ids []int32
|
||||
|
||||
// TODO: use [UNK] from config
|
||||
unk := wpm.vocab.Encode("[UNK]")
|
||||
for word := range wpm.words(s) {
|
||||
var start int
|
||||
var pieces []int32
|
||||
for start < len(word) {
|
||||
end := len(word)
|
||||
|
||||
var piece int32
|
||||
for start < end {
|
||||
subword := word[start:end]
|
||||
if start == 0 {
|
||||
subword = ggmlPrefix + subword
|
||||
}
|
||||
|
||||
if wpm.lowercase {
|
||||
subword = strings.ToLower(subword)
|
||||
}
|
||||
piece = wpm.vocab.Encode(subword)
|
||||
if piece >= 0 {
|
||||
break
|
||||
}
|
||||
|
||||
end--
|
||||
}
|
||||
|
||||
if piece < 0 {
|
||||
// Unknown token
|
||||
pieces = pieces[:0]
|
||||
break
|
||||
}
|
||||
|
||||
pieces = append(pieces, piece)
|
||||
start = end
|
||||
}
|
||||
|
||||
if len(pieces) > 0 {
|
||||
ids = append(ids, pieces...)
|
||||
} else {
|
||||
ids = append(ids, unk)
|
||||
}
|
||||
}
|
||||
|
||||
if addSpecial {
|
||||
ids = wpm.vocab.addSpecials(ids)
|
||||
}
|
||||
|
||||
logutil.Trace("encoded", "string", s, "ids", ids)
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
// Is implements TextProcessor.
|
||||
func (wpm WordPiece) Is(id int32, special Special) bool {
|
||||
return wpm.vocab.Is(id, special)
|
||||
}
|
||||
|
||||
// Vocabulary implements TextProcessor.
|
||||
func (wpm WordPiece) Vocabulary() *Vocabulary {
|
||||
return wpm.vocab
|
||||
}
|
||||
|
||||
var _ TextProcessor = (*WordPiece)(nil)
|
||||
|
||||
func NewWordPiece(vocab *Vocabulary, lowercase bool) WordPiece {
|
||||
return WordPiece{
|
||||
vocab: vocab,
|
||||
lowercase: lowercase,
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user