mirror of
https://github.com/ollama/ollama.git
synced 2026-01-02 04:29:51 -05:00
Compare commits
2 Commits
implement-
...
mxyng/toke
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f52671ecc6 | ||
|
|
05711b77da |
@@ -7,7 +7,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/dlclark/regexp2"
|
||||
heap "github.com/emirpasic/gods/v2/trees/binaryheap"
|
||||
"github.com/emirpasic/gods/v2/trees/binaryheap"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
@@ -84,16 +84,15 @@ type fragment struct {
|
||||
ids []int32
|
||||
}
|
||||
|
||||
// pair is a pair of runes and its rank
|
||||
type pair struct {
|
||||
a, b int
|
||||
rank int
|
||||
value string
|
||||
// pair is a pair of merges and its rank
|
||||
type pair[T int | float32] struct {
|
||||
a, b *merge
|
||||
rank T
|
||||
}
|
||||
|
||||
type merge struct {
|
||||
p, n int
|
||||
runes []rune
|
||||
offset, size int
|
||||
prev, next *merge
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
@@ -156,80 +155,61 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
}
|
||||
|
||||
runes := []rune(sb.String())
|
||||
merges := make([]merge, len(runes))
|
||||
for r := range runes {
|
||||
merges[r] = merge{
|
||||
p: r - 1,
|
||||
n: r + 1,
|
||||
runes: []rune{runes[r]},
|
||||
}
|
||||
|
||||
root := &merge{offset: len(runes) - 1, size: 1}
|
||||
for i := len(runes) - 2; i >= 0; i-- {
|
||||
m := &merge{offset: i, size: 1, next: root}
|
||||
root.prev = m
|
||||
root = m
|
||||
}
|
||||
|
||||
pairwise := func(a, b int) *pair {
|
||||
if a < 0 || b >= len(runes) {
|
||||
return nil
|
||||
pairwise := func(a, b *merge) *pair[int] {
|
||||
if a != nil && b != nil {
|
||||
aa := string(runes[a.offset : a.offset+a.size])
|
||||
bb := string(runes[b.offset : b.offset+b.size])
|
||||
if rank := bpe.vocab.Merge(aa, bb); rank >= 0 {
|
||||
return &pair[int]{a: a, b: b, rank: rank}
|
||||
}
|
||||
}
|
||||
|
||||
left, right := string(merges[a].runes), string(merges[b].runes)
|
||||
rank := bpe.vocab.Merge(left, right)
|
||||
if rank < 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &pair{
|
||||
a: a,
|
||||
b: b,
|
||||
rank: rank,
|
||||
value: left + right,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
pairs := heap.NewWith(func(i, j *pair) int {
|
||||
return cmp.Compare(i.rank, j.rank)
|
||||
})
|
||||
|
||||
for i := range len(runes) - 1 {
|
||||
if pair := pairwise(i, i+1); pair != nil {
|
||||
pairs := binaryheap.NewWith(func(i, j *pair[int]) int { return cmp.Compare(i.rank, j.rank) })
|
||||
for m := root; m != nil; m = m.next {
|
||||
if pair := pairwise(m, m.next); pair != nil {
|
||||
pairs.Push(pair)
|
||||
}
|
||||
}
|
||||
|
||||
for !pairs.Empty() {
|
||||
pair, _ := pairs.Pop()
|
||||
|
||||
left, right := merges[pair.a], merges[pair.b]
|
||||
if len(left.runes) == 0 || len(right.runes) == 0 ||
|
||||
string(left.runes)+string(right.runes) != pair.value {
|
||||
p, _ := pairs.Pop()
|
||||
a := string(runes[p.a.offset : p.a.offset+p.a.size])
|
||||
b := string(runes[p.b.offset : p.b.offset+p.b.size])
|
||||
if a == "" || b == "" || bpe.vocab.Merge(a, b) != p.rank {
|
||||
continue
|
||||
}
|
||||
|
||||
if id := bpe.vocab.Encode(pair.value); id < 0 {
|
||||
continue
|
||||
p.a.size += p.b.size
|
||||
p.b.size = 0
|
||||
|
||||
p.a.next = p.b.next
|
||||
if p.b.next != nil {
|
||||
p.b.next.prev = p.a
|
||||
}
|
||||
|
||||
merges[pair.a].runes = append(left.runes, right.runes...)
|
||||
merges[pair.b].runes = nil
|
||||
|
||||
merges[pair.a].n = right.n
|
||||
if right.n < len(merges) {
|
||||
merges[right.n].p = pair.a
|
||||
}
|
||||
|
||||
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
|
||||
if pair := pairwise(p.a.prev, p.a); pair != nil {
|
||||
pairs.Push(pair)
|
||||
}
|
||||
|
||||
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
|
||||
if pair := pairwise(p.a, p.a.next); pair != nil {
|
||||
pairs.Push(pair)
|
||||
}
|
||||
}
|
||||
|
||||
for _, merge := range merges {
|
||||
if len(merge.runes) > 0 {
|
||||
// TODO: handle the edge case where the rune isn't in the vocabulary
|
||||
if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
for m := root; m != nil; m = m.next {
|
||||
if id := bpe.vocab.Encode(string(runes[m.offset : m.offset+m.size])); id >= 0 {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -133,6 +133,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{
|
||||
int32(cmp.Or(
|
||||
@@ -157,6 +158,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
switch c.String("tokenizer.ggml.model", "bert") {
|
||||
case "bert":
|
||||
processor = model.NewWordPiece(vocab, true)
|
||||
case "gpt2":
|
||||
processor = model.NewBytePairEncoding(vocab)
|
||||
default:
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"cmp"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/emirpasic/gods/v2/trees/binaryheap"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
@@ -94,79 +95,68 @@ func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
continue
|
||||
}
|
||||
|
||||
q := &queue{}
|
||||
heap.Init(q)
|
||||
|
||||
runes := []rune(text)
|
||||
merges := make([]merge, len(runes))
|
||||
for r := range runes {
|
||||
merges[r] = merge{
|
||||
p: r - 1,
|
||||
n: r + 1,
|
||||
runes: []rune{runes[r]},
|
||||
}
|
||||
|
||||
root := &merge{offset: len(runes) - 1, size: 1}
|
||||
for i := len(runes) - 2; i >= 0; i-- {
|
||||
m := &merge{offset: i, size: 1, next: root}
|
||||
root.prev = m
|
||||
root = m
|
||||
}
|
||||
|
||||
pairwise := func(a, b int) *candidate {
|
||||
if a < 0 || b >= len(runes) {
|
||||
return nil
|
||||
}
|
||||
|
||||
left, right := string(merges[a].runes), string(merges[b].runes)
|
||||
if id := spm.vocab.Encode(left + right); id >= 0 {
|
||||
return &candidate{
|
||||
a: a,
|
||||
b: b,
|
||||
score: spm.vocab.Scores[id],
|
||||
size: len(left) + len(right),
|
||||
pairwise := func(a, b *merge) *pair[float32] {
|
||||
if a != nil && b != nil {
|
||||
aa := string(runes[a.offset : a.offset+a.size])
|
||||
bb := string(runes[b.offset : b.offset+b.size])
|
||||
if id := spm.vocab.Encode(aa + bb); id >= 0 {
|
||||
return &pair[float32]{a: a, b: b, rank: spm.vocab.Scores[id]}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
for i := range len(runes) - 1 {
|
||||
if pair := pairwise(i, i+1); pair != nil {
|
||||
heap.Push(q, pair)
|
||||
pairs := binaryheap.NewWith(func(i, j *pair[float32]) int { return cmp.Compare(i.rank, j.rank) })
|
||||
for m := root; m != nil; m = m.next {
|
||||
if pair := pairwise(m, m.next); pair != nil {
|
||||
pairs.Push(pair)
|
||||
}
|
||||
}
|
||||
|
||||
for q.Len() > 0 {
|
||||
pair := heap.Pop(q).(*candidate)
|
||||
left, right := merges[pair.a], merges[pair.b]
|
||||
|
||||
if string(left.runes) == "" || string(right.runes) == "" || len(string(left.runes))+len(string(right.runes)) != pair.size {
|
||||
for !pairs.Empty() {
|
||||
p, _ := pairs.Pop()
|
||||
a := string(runes[p.a.offset : p.a.offset+p.a.size])
|
||||
b := string(runes[p.b.offset : p.b.offset+p.b.size])
|
||||
if id := spm.vocab.Encode(a + b); a == "" || b == "" || id < 0 || spm.vocab.Scores[id] != p.rank {
|
||||
continue
|
||||
}
|
||||
|
||||
merges[pair.a].runes = append(left.runes, right.runes...)
|
||||
merges[pair.b].runes = nil
|
||||
merges[pair.a].n = right.n
|
||||
if right.n < len(merges) {
|
||||
merges[right.n].p = pair.a
|
||||
p.a.size += p.b.size
|
||||
p.b.size = 0
|
||||
|
||||
p.a.next = p.b.next
|
||||
if p.b.next != nil {
|
||||
p.b.next.prev = p.a
|
||||
}
|
||||
|
||||
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
|
||||
heap.Push(q, pair)
|
||||
if pair := pairwise(p.a.prev, p.a); pair != nil {
|
||||
pairs.Push(pair)
|
||||
}
|
||||
|
||||
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
|
||||
heap.Push(q, pair)
|
||||
if pair := pairwise(p.a, p.a.next); pair != nil {
|
||||
pairs.Push(pair)
|
||||
}
|
||||
}
|
||||
|
||||
for _, merge := range merges {
|
||||
if token := string(merge.runes); token != "" {
|
||||
id := spm.vocab.Encode(token)
|
||||
|
||||
if id >= 0 {
|
||||
for m := root; m != nil; m = m.next {
|
||||
if s := string(runes[m.offset : m.offset+m.size]); s != "" {
|
||||
if id := spm.vocab.Encode(s); id >= 0 {
|
||||
ids = append(ids, id)
|
||||
continue
|
||||
}
|
||||
|
||||
// Fallback to byte tokenization
|
||||
var result []int32
|
||||
for _, b := range []byte(token) {
|
||||
for _, b := range []byte(s) {
|
||||
byteToken := fmt.Sprintf("<0x%02X>", b)
|
||||
unknownID := spm.vocab.Encode(byteToken)
|
||||
if unknownID >= 0 {
|
||||
@@ -189,35 +179,6 @@ func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
type candidate struct {
|
||||
a, b int
|
||||
score float32
|
||||
size int
|
||||
}
|
||||
|
||||
type queue []*candidate
|
||||
|
||||
func (q queue) Len() int { return len(q) }
|
||||
|
||||
func (q queue) Less(i, j int) bool {
|
||||
return (q[i].score > q[j].score) || (q[i].score == q[j].score && q[i].a < q[j].a)
|
||||
}
|
||||
|
||||
func (q queue) Swap(i, j int) { q[i], q[j] = q[j], q[i] }
|
||||
|
||||
func (q *queue) Push(x interface{}) {
|
||||
item := x.(*candidate)
|
||||
*q = append(*q, item)
|
||||
}
|
||||
|
||||
func (q *queue) Pop() interface{} {
|
||||
old := *q
|
||||
n := len(old)
|
||||
item := old[n-1]
|
||||
*q = old[0 : n-1]
|
||||
return item
|
||||
}
|
||||
|
||||
func (spm SentencePiece) Decode(ids []int32) (string, error) {
|
||||
var sb strings.Builder
|
||||
for _, id := range ids {
|
||||
|
||||
@@ -2394,4 +2394,3 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
|
||||
}
|
||||
return msgs
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user