Compare commits

...

2 Commits

Author SHA1 Message Date
Michael Yang
f52671ecc6 refactor: bpe and spm tokenizer merges
- merge candidates and pairs which are essentially the same other than
  the type for rank/score
- use binaryheap in sentencepiece instead of implement custom structure
- update merging algorithm so it uses about 15% less allocations
2025-12-17 13:36:32 -08:00
Michael Yang
05711b77da re-enable new engine granite embedding 2025-12-17 13:00:51 -08:00
4 changed files with 79 additions and 136 deletions

View File

@@ -7,7 +7,7 @@ import (
"strings"
"github.com/dlclark/regexp2"
heap "github.com/emirpasic/gods/v2/trees/binaryheap"
"github.com/emirpasic/gods/v2/trees/binaryheap"
"github.com/ollama/ollama/logutil"
)
@@ -84,16 +84,15 @@ type fragment struct {
ids []int32
}
// pair is a pair of runes and its rank
type pair struct {
a, b int
rank int
value string
// pair is a pair of merges and its rank
type pair[T int | float32] struct {
a, b *merge
rank T
}
type merge struct {
p, n int
runes []rune
offset, size int
prev, next *merge
}
func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
@@ -156,80 +155,61 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
}
runes := []rune(sb.String())
merges := make([]merge, len(runes))
for r := range runes {
merges[r] = merge{
p: r - 1,
n: r + 1,
runes: []rune{runes[r]},
}
root := &merge{offset: len(runes) - 1, size: 1}
for i := len(runes) - 2; i >= 0; i-- {
m := &merge{offset: i, size: 1, next: root}
root.prev = m
root = m
}
pairwise := func(a, b int) *pair {
if a < 0 || b >= len(runes) {
return nil
pairwise := func(a, b *merge) *pair[int] {
if a != nil && b != nil {
aa := string(runes[a.offset : a.offset+a.size])
bb := string(runes[b.offset : b.offset+b.size])
if rank := bpe.vocab.Merge(aa, bb); rank >= 0 {
return &pair[int]{a: a, b: b, rank: rank}
}
}
left, right := string(merges[a].runes), string(merges[b].runes)
rank := bpe.vocab.Merge(left, right)
if rank < 0 {
return nil
}
return &pair{
a: a,
b: b,
rank: rank,
value: left + right,
}
return nil
}
pairs := heap.NewWith(func(i, j *pair) int {
return cmp.Compare(i.rank, j.rank)
})
for i := range len(runes) - 1 {
if pair := pairwise(i, i+1); pair != nil {
pairs := binaryheap.NewWith(func(i, j *pair[int]) int { return cmp.Compare(i.rank, j.rank) })
for m := root; m != nil; m = m.next {
if pair := pairwise(m, m.next); pair != nil {
pairs.Push(pair)
}
}
for !pairs.Empty() {
pair, _ := pairs.Pop()
left, right := merges[pair.a], merges[pair.b]
if len(left.runes) == 0 || len(right.runes) == 0 ||
string(left.runes)+string(right.runes) != pair.value {
p, _ := pairs.Pop()
a := string(runes[p.a.offset : p.a.offset+p.a.size])
b := string(runes[p.b.offset : p.b.offset+p.b.size])
if a == "" || b == "" || bpe.vocab.Merge(a, b) != p.rank {
continue
}
if id := bpe.vocab.Encode(pair.value); id < 0 {
continue
p.a.size += p.b.size
p.b.size = 0
p.a.next = p.b.next
if p.b.next != nil {
p.b.next.prev = p.a
}
merges[pair.a].runes = append(left.runes, right.runes...)
merges[pair.b].runes = nil
merges[pair.a].n = right.n
if right.n < len(merges) {
merges[right.n].p = pair.a
}
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
if pair := pairwise(p.a.prev, p.a); pair != nil {
pairs.Push(pair)
}
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
if pair := pairwise(p.a, p.a.next); pair != nil {
pairs.Push(pair)
}
}
for _, merge := range merges {
if len(merge.runes) > 0 {
// TODO: handle the edge case where the rune isn't in the vocabulary
if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 {
ids = append(ids, id)
}
for m := root; m != nil; m = m.next {
if id := bpe.vocab.Encode(string(runes[m.offset : m.offset+m.size])); id >= 0 {
ids = append(ids, id)
}
}
}

View File

@@ -133,6 +133,7 @@ func New(c fs.Config) (model.Model, error) {
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{
int32(cmp.Or(
@@ -157,6 +158,8 @@ func New(c fs.Config) (model.Model, error) {
switch c.String("tokenizer.ggml.model", "bert") {
case "bert":
processor = model.NewWordPiece(vocab, true)
case "gpt2":
processor = model.NewBytePairEncoding(vocab)
default:
return nil, model.ErrUnsupportedTokenizer
}

View File

@@ -1,12 +1,13 @@
package model
import (
"container/heap"
"cmp"
"fmt"
"log/slog"
"strconv"
"strings"
"github.com/emirpasic/gods/v2/trees/binaryheap"
"github.com/ollama/ollama/logutil"
)
@@ -94,79 +95,68 @@ func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
continue
}
q := &queue{}
heap.Init(q)
runes := []rune(text)
merges := make([]merge, len(runes))
for r := range runes {
merges[r] = merge{
p: r - 1,
n: r + 1,
runes: []rune{runes[r]},
}
root := &merge{offset: len(runes) - 1, size: 1}
for i := len(runes) - 2; i >= 0; i-- {
m := &merge{offset: i, size: 1, next: root}
root.prev = m
root = m
}
pairwise := func(a, b int) *candidate {
if a < 0 || b >= len(runes) {
return nil
}
left, right := string(merges[a].runes), string(merges[b].runes)
if id := spm.vocab.Encode(left + right); id >= 0 {
return &candidate{
a: a,
b: b,
score: spm.vocab.Scores[id],
size: len(left) + len(right),
pairwise := func(a, b *merge) *pair[float32] {
if a != nil && b != nil {
aa := string(runes[a.offset : a.offset+a.size])
bb := string(runes[b.offset : b.offset+b.size])
if id := spm.vocab.Encode(aa + bb); id >= 0 {
return &pair[float32]{a: a, b: b, rank: spm.vocab.Scores[id]}
}
}
return nil
}
for i := range len(runes) - 1 {
if pair := pairwise(i, i+1); pair != nil {
heap.Push(q, pair)
pairs := binaryheap.NewWith(func(i, j *pair[float32]) int { return cmp.Compare(i.rank, j.rank) })
for m := root; m != nil; m = m.next {
if pair := pairwise(m, m.next); pair != nil {
pairs.Push(pair)
}
}
for q.Len() > 0 {
pair := heap.Pop(q).(*candidate)
left, right := merges[pair.a], merges[pair.b]
if string(left.runes) == "" || string(right.runes) == "" || len(string(left.runes))+len(string(right.runes)) != pair.size {
for !pairs.Empty() {
p, _ := pairs.Pop()
a := string(runes[p.a.offset : p.a.offset+p.a.size])
b := string(runes[p.b.offset : p.b.offset+p.b.size])
if id := spm.vocab.Encode(a + b); a == "" || b == "" || id < 0 || spm.vocab.Scores[id] != p.rank {
continue
}
merges[pair.a].runes = append(left.runes, right.runes...)
merges[pair.b].runes = nil
merges[pair.a].n = right.n
if right.n < len(merges) {
merges[right.n].p = pair.a
p.a.size += p.b.size
p.b.size = 0
p.a.next = p.b.next
if p.b.next != nil {
p.b.next.prev = p.a
}
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
heap.Push(q, pair)
if pair := pairwise(p.a.prev, p.a); pair != nil {
pairs.Push(pair)
}
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
heap.Push(q, pair)
if pair := pairwise(p.a, p.a.next); pair != nil {
pairs.Push(pair)
}
}
for _, merge := range merges {
if token := string(merge.runes); token != "" {
id := spm.vocab.Encode(token)
if id >= 0 {
for m := root; m != nil; m = m.next {
if s := string(runes[m.offset : m.offset+m.size]); s != "" {
if id := spm.vocab.Encode(s); id >= 0 {
ids = append(ids, id)
continue
}
// Fallback to byte tokenization
var result []int32
for _, b := range []byte(token) {
for _, b := range []byte(s) {
byteToken := fmt.Sprintf("<0x%02X>", b)
unknownID := spm.vocab.Encode(byteToken)
if unknownID >= 0 {
@@ -189,35 +179,6 @@ func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
return ids, nil
}
type candidate struct {
a, b int
score float32
size int
}
type queue []*candidate
func (q queue) Len() int { return len(q) }
func (q queue) Less(i, j int) bool {
return (q[i].score > q[j].score) || (q[i].score == q[j].score && q[i].a < q[j].a)
}
func (q queue) Swap(i, j int) { q[i], q[j] = q[j], q[i] }
func (q *queue) Push(x interface{}) {
item := x.(*candidate)
*q = append(*q, item)
}
func (q *queue) Pop() interface{} {
old := *q
n := len(old)
item := old[n-1]
*q = old[0 : n-1]
return item
}
func (spm SentencePiece) Decode(ids []int32) (string, error) {
var sb strings.Builder
for _, id := range ids {

View File

@@ -2394,4 +2394,3 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
}
return msgs
}