Compare commits

...

2 Commits

Author SHA1 Message Date
Michael Yang
f52671ecc6 refactor: bpe and spm tokenizer merges
- merge candidates and pairs which are essentially the same other than
  the type for rank/score
- use binaryheap in sentencepiece instead of implement custom structure
- update merging algorithm so it uses about 15% less allocations
2025-12-17 13:36:32 -08:00
Michael Yang
05711b77da re-enable new engine granite embedding 2025-12-17 13:00:51 -08:00
4 changed files with 79 additions and 136 deletions

View File

@@ -7,7 +7,7 @@ import (
"strings" "strings"
"github.com/dlclark/regexp2" "github.com/dlclark/regexp2"
heap "github.com/emirpasic/gods/v2/trees/binaryheap" "github.com/emirpasic/gods/v2/trees/binaryheap"
"github.com/ollama/ollama/logutil" "github.com/ollama/ollama/logutil"
) )
@@ -84,16 +84,15 @@ type fragment struct {
ids []int32 ids []int32
} }
// pair is a pair of runes and its rank // pair is a pair of merges and its rank
type pair struct { type pair[T int | float32] struct {
a, b int a, b *merge
rank int rank T
value string
} }
type merge struct { type merge struct {
p, n int offset, size int
runes []rune prev, next *merge
} }
func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) { func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
@@ -156,84 +155,65 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
} }
runes := []rune(sb.String()) runes := []rune(sb.String())
merges := make([]merge, len(runes))
for r := range runes { root := &merge{offset: len(runes) - 1, size: 1}
merges[r] = merge{ for i := len(runes) - 2; i >= 0; i-- {
p: r - 1, m := &merge{offset: i, size: 1, next: root}
n: r + 1, root.prev = m
runes: []rune{runes[r]}, root = m
}
pairwise := func(a, b *merge) *pair[int] {
if a != nil && b != nil {
aa := string(runes[a.offset : a.offset+a.size])
bb := string(runes[b.offset : b.offset+b.size])
if rank := bpe.vocab.Merge(aa, bb); rank >= 0 {
return &pair[int]{a: a, b: b, rank: rank}
} }
} }
pairwise := func(a, b int) *pair {
if a < 0 || b >= len(runes) {
return nil return nil
} }
left, right := string(merges[a].runes), string(merges[b].runes) pairs := binaryheap.NewWith(func(i, j *pair[int]) int { return cmp.Compare(i.rank, j.rank) })
rank := bpe.vocab.Merge(left, right) for m := root; m != nil; m = m.next {
if rank < 0 { if pair := pairwise(m, m.next); pair != nil {
return nil
}
return &pair{
a: a,
b: b,
rank: rank,
value: left + right,
}
}
pairs := heap.NewWith(func(i, j *pair) int {
return cmp.Compare(i.rank, j.rank)
})
for i := range len(runes) - 1 {
if pair := pairwise(i, i+1); pair != nil {
pairs.Push(pair) pairs.Push(pair)
} }
} }
for !pairs.Empty() { for !pairs.Empty() {
pair, _ := pairs.Pop() p, _ := pairs.Pop()
a := string(runes[p.a.offset : p.a.offset+p.a.size])
left, right := merges[pair.a], merges[pair.b] b := string(runes[p.b.offset : p.b.offset+p.b.size])
if len(left.runes) == 0 || len(right.runes) == 0 || if a == "" || b == "" || bpe.vocab.Merge(a, b) != p.rank {
string(left.runes)+string(right.runes) != pair.value {
continue continue
} }
if id := bpe.vocab.Encode(pair.value); id < 0 { p.a.size += p.b.size
continue p.b.size = 0
p.a.next = p.b.next
if p.b.next != nil {
p.b.next.prev = p.a
} }
merges[pair.a].runes = append(left.runes, right.runes...) if pair := pairwise(p.a.prev, p.a); pair != nil {
merges[pair.b].runes = nil
merges[pair.a].n = right.n
if right.n < len(merges) {
merges[right.n].p = pair.a
}
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
pairs.Push(pair) pairs.Push(pair)
} }
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil { if pair := pairwise(p.a, p.a.next); pair != nil {
pairs.Push(pair) pairs.Push(pair)
} }
} }
for _, merge := range merges { for m := root; m != nil; m = m.next {
if len(merge.runes) > 0 { if id := bpe.vocab.Encode(string(runes[m.offset : m.offset+m.size])); id >= 0 {
// TODO: handle the edge case where the rune isn't in the vocabulary
if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 {
ids = append(ids, id) ids = append(ids, id)
} }
} }
} }
} }
}
if addSpecial { if addSpecial {
ids = bpe.vocab.addSpecials(ids) ids = bpe.vocab.addSpecials(ids)

View File

@@ -133,6 +133,7 @@ func New(c fs.Config) (model.Model, error) {
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"), Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"), Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{ BOS: []int32{
int32(cmp.Or( int32(cmp.Or(
@@ -157,6 +158,8 @@ func New(c fs.Config) (model.Model, error) {
switch c.String("tokenizer.ggml.model", "bert") { switch c.String("tokenizer.ggml.model", "bert") {
case "bert": case "bert":
processor = model.NewWordPiece(vocab, true) processor = model.NewWordPiece(vocab, true)
case "gpt2":
processor = model.NewBytePairEncoding(vocab)
default: default:
return nil, model.ErrUnsupportedTokenizer return nil, model.ErrUnsupportedTokenizer
} }

View File

@@ -1,12 +1,13 @@
package model package model
import ( import (
"container/heap" "cmp"
"fmt" "fmt"
"log/slog" "log/slog"
"strconv" "strconv"
"strings" "strings"
"github.com/emirpasic/gods/v2/trees/binaryheap"
"github.com/ollama/ollama/logutil" "github.com/ollama/ollama/logutil"
) )
@@ -94,79 +95,68 @@ func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
continue continue
} }
q := &queue{}
heap.Init(q)
runes := []rune(text) runes := []rune(text)
merges := make([]merge, len(runes))
for r := range runes { root := &merge{offset: len(runes) - 1, size: 1}
merges[r] = merge{ for i := len(runes) - 2; i >= 0; i-- {
p: r - 1, m := &merge{offset: i, size: 1, next: root}
n: r + 1, root.prev = m
runes: []rune{runes[r]}, root = m
}
} }
pairwise := func(a, b int) *candidate { pairwise := func(a, b *merge) *pair[float32] {
if a < 0 || b >= len(runes) { if a != nil && b != nil {
return nil aa := string(runes[a.offset : a.offset+a.size])
} bb := string(runes[b.offset : b.offset+b.size])
if id := spm.vocab.Encode(aa + bb); id >= 0 {
left, right := string(merges[a].runes), string(merges[b].runes) return &pair[float32]{a: a, b: b, rank: spm.vocab.Scores[id]}
if id := spm.vocab.Encode(left + right); id >= 0 {
return &candidate{
a: a,
b: b,
score: spm.vocab.Scores[id],
size: len(left) + len(right),
} }
} }
return nil return nil
} }
for i := range len(runes) - 1 { pairs := binaryheap.NewWith(func(i, j *pair[float32]) int { return cmp.Compare(i.rank, j.rank) })
if pair := pairwise(i, i+1); pair != nil { for m := root; m != nil; m = m.next {
heap.Push(q, pair) if pair := pairwise(m, m.next); pair != nil {
pairs.Push(pair)
} }
} }
for q.Len() > 0 { for !pairs.Empty() {
pair := heap.Pop(q).(*candidate) p, _ := pairs.Pop()
left, right := merges[pair.a], merges[pair.b] a := string(runes[p.a.offset : p.a.offset+p.a.size])
b := string(runes[p.b.offset : p.b.offset+p.b.size])
if string(left.runes) == "" || string(right.runes) == "" || len(string(left.runes))+len(string(right.runes)) != pair.size { if id := spm.vocab.Encode(a + b); a == "" || b == "" || id < 0 || spm.vocab.Scores[id] != p.rank {
continue continue
} }
merges[pair.a].runes = append(left.runes, right.runes...) p.a.size += p.b.size
merges[pair.b].runes = nil p.b.size = 0
merges[pair.a].n = right.n
if right.n < len(merges) { p.a.next = p.b.next
merges[right.n].p = pair.a if p.b.next != nil {
p.b.next.prev = p.a
} }
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil { if pair := pairwise(p.a.prev, p.a); pair != nil {
heap.Push(q, pair) pairs.Push(pair)
} }
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil { if pair := pairwise(p.a, p.a.next); pair != nil {
heap.Push(q, pair) pairs.Push(pair)
} }
} }
for _, merge := range merges { for m := root; m != nil; m = m.next {
if token := string(merge.runes); token != "" { if s := string(runes[m.offset : m.offset+m.size]); s != "" {
id := spm.vocab.Encode(token) if id := spm.vocab.Encode(s); id >= 0 {
if id >= 0 {
ids = append(ids, id) ids = append(ids, id)
continue continue
} }
// Fallback to byte tokenization
var result []int32 var result []int32
for _, b := range []byte(token) { for _, b := range []byte(s) {
byteToken := fmt.Sprintf("<0x%02X>", b) byteToken := fmt.Sprintf("<0x%02X>", b)
unknownID := spm.vocab.Encode(byteToken) unknownID := spm.vocab.Encode(byteToken)
if unknownID >= 0 { if unknownID >= 0 {
@@ -189,35 +179,6 @@ func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
return ids, nil return ids, nil
} }
type candidate struct {
a, b int
score float32
size int
}
type queue []*candidate
func (q queue) Len() int { return len(q) }
func (q queue) Less(i, j int) bool {
return (q[i].score > q[j].score) || (q[i].score == q[j].score && q[i].a < q[j].a)
}
func (q queue) Swap(i, j int) { q[i], q[j] = q[j], q[i] }
func (q *queue) Push(x interface{}) {
item := x.(*candidate)
*q = append(*q, item)
}
func (q *queue) Pop() interface{} {
old := *q
n := len(old)
item := old[n-1]
*q = old[0 : n-1]
return item
}
func (spm SentencePiece) Decode(ids []int32) (string, error) { func (spm SentencePiece) Decode(ids []int32) (string, error) {
var sb strings.Builder var sb strings.Builder
for _, id := range ids { for _, id := range ids {

View File

@@ -2394,4 +2394,3 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
} }
return msgs return msgs
} }