mirror of
https://github.com/ollama/ollama.git
synced 2026-01-19 04:51:17 -05:00
Compare commits
2 Commits
parth/decr
...
mxyng/toke
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f52671ecc6 | ||
|
|
05711b77da |
@@ -7,7 +7,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/dlclark/regexp2"
|
"github.com/dlclark/regexp2"
|
||||||
heap "github.com/emirpasic/gods/v2/trees/binaryheap"
|
"github.com/emirpasic/gods/v2/trees/binaryheap"
|
||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -84,16 +84,15 @@ type fragment struct {
|
|||||||
ids []int32
|
ids []int32
|
||||||
}
|
}
|
||||||
|
|
||||||
// pair is a pair of runes and its rank
|
// pair is a pair of merges and its rank
|
||||||
type pair struct {
|
type pair[T int | float32] struct {
|
||||||
a, b int
|
a, b *merge
|
||||||
rank int
|
rank T
|
||||||
value string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type merge struct {
|
type merge struct {
|
||||||
p, n int
|
offset, size int
|
||||||
runes []rune
|
prev, next *merge
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||||
@@ -156,84 +155,65 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
runes := []rune(sb.String())
|
runes := []rune(sb.String())
|
||||||
merges := make([]merge, len(runes))
|
|
||||||
for r := range runes {
|
root := &merge{offset: len(runes) - 1, size: 1}
|
||||||
merges[r] = merge{
|
for i := len(runes) - 2; i >= 0; i-- {
|
||||||
p: r - 1,
|
m := &merge{offset: i, size: 1, next: root}
|
||||||
n: r + 1,
|
root.prev = m
|
||||||
runes: []rune{runes[r]},
|
root = m
|
||||||
|
}
|
||||||
|
|
||||||
|
pairwise := func(a, b *merge) *pair[int] {
|
||||||
|
if a != nil && b != nil {
|
||||||
|
aa := string(runes[a.offset : a.offset+a.size])
|
||||||
|
bb := string(runes[b.offset : b.offset+b.size])
|
||||||
|
if rank := bpe.vocab.Merge(aa, bb); rank >= 0 {
|
||||||
|
return &pair[int]{a: a, b: b, rank: rank}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pairwise := func(a, b int) *pair {
|
|
||||||
if a < 0 || b >= len(runes) {
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
left, right := string(merges[a].runes), string(merges[b].runes)
|
pairs := binaryheap.NewWith(func(i, j *pair[int]) int { return cmp.Compare(i.rank, j.rank) })
|
||||||
rank := bpe.vocab.Merge(left, right)
|
for m := root; m != nil; m = m.next {
|
||||||
if rank < 0 {
|
if pair := pairwise(m, m.next); pair != nil {
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return &pair{
|
|
||||||
a: a,
|
|
||||||
b: b,
|
|
||||||
rank: rank,
|
|
||||||
value: left + right,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pairs := heap.NewWith(func(i, j *pair) int {
|
|
||||||
return cmp.Compare(i.rank, j.rank)
|
|
||||||
})
|
|
||||||
|
|
||||||
for i := range len(runes) - 1 {
|
|
||||||
if pair := pairwise(i, i+1); pair != nil {
|
|
||||||
pairs.Push(pair)
|
pairs.Push(pair)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for !pairs.Empty() {
|
for !pairs.Empty() {
|
||||||
pair, _ := pairs.Pop()
|
p, _ := pairs.Pop()
|
||||||
|
a := string(runes[p.a.offset : p.a.offset+p.a.size])
|
||||||
left, right := merges[pair.a], merges[pair.b]
|
b := string(runes[p.b.offset : p.b.offset+p.b.size])
|
||||||
if len(left.runes) == 0 || len(right.runes) == 0 ||
|
if a == "" || b == "" || bpe.vocab.Merge(a, b) != p.rank {
|
||||||
string(left.runes)+string(right.runes) != pair.value {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if id := bpe.vocab.Encode(pair.value); id < 0 {
|
p.a.size += p.b.size
|
||||||
continue
|
p.b.size = 0
|
||||||
|
|
||||||
|
p.a.next = p.b.next
|
||||||
|
if p.b.next != nil {
|
||||||
|
p.b.next.prev = p.a
|
||||||
}
|
}
|
||||||
|
|
||||||
merges[pair.a].runes = append(left.runes, right.runes...)
|
if pair := pairwise(p.a.prev, p.a); pair != nil {
|
||||||
merges[pair.b].runes = nil
|
|
||||||
|
|
||||||
merges[pair.a].n = right.n
|
|
||||||
if right.n < len(merges) {
|
|
||||||
merges[right.n].p = pair.a
|
|
||||||
}
|
|
||||||
|
|
||||||
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
|
|
||||||
pairs.Push(pair)
|
pairs.Push(pair)
|
||||||
}
|
}
|
||||||
|
|
||||||
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
|
if pair := pairwise(p.a, p.a.next); pair != nil {
|
||||||
pairs.Push(pair)
|
pairs.Push(pair)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, merge := range merges {
|
for m := root; m != nil; m = m.next {
|
||||||
if len(merge.runes) > 0 {
|
if id := bpe.vocab.Encode(string(runes[m.offset : m.offset+m.size])); id >= 0 {
|
||||||
// TODO: handle the edge case where the rune isn't in the vocabulary
|
|
||||||
if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 {
|
|
||||||
ids = append(ids, id)
|
ids = append(ids, id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if addSpecial {
|
if addSpecial {
|
||||||
ids = bpe.vocab.addSpecials(ids)
|
ids = bpe.vocab.addSpecials(ids)
|
||||||
|
|||||||
@@ -133,6 +133,7 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||||
|
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||||
BOS: []int32{
|
BOS: []int32{
|
||||||
int32(cmp.Or(
|
int32(cmp.Or(
|
||||||
@@ -157,6 +158,8 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
switch c.String("tokenizer.ggml.model", "bert") {
|
switch c.String("tokenizer.ggml.model", "bert") {
|
||||||
case "bert":
|
case "bert":
|
||||||
processor = model.NewWordPiece(vocab, true)
|
processor = model.NewWordPiece(vocab, true)
|
||||||
|
case "gpt2":
|
||||||
|
processor = model.NewBytePairEncoding(vocab)
|
||||||
default:
|
default:
|
||||||
return nil, model.ErrUnsupportedTokenizer
|
return nil, model.ErrUnsupportedTokenizer
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"container/heap"
|
"cmp"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/emirpasic/gods/v2/trees/binaryheap"
|
||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -94,79 +95,68 @@ func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
q := &queue{}
|
|
||||||
heap.Init(q)
|
|
||||||
|
|
||||||
runes := []rune(text)
|
runes := []rune(text)
|
||||||
merges := make([]merge, len(runes))
|
|
||||||
for r := range runes {
|
root := &merge{offset: len(runes) - 1, size: 1}
|
||||||
merges[r] = merge{
|
for i := len(runes) - 2; i >= 0; i-- {
|
||||||
p: r - 1,
|
m := &merge{offset: i, size: 1, next: root}
|
||||||
n: r + 1,
|
root.prev = m
|
||||||
runes: []rune{runes[r]},
|
root = m
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pairwise := func(a, b int) *candidate {
|
pairwise := func(a, b *merge) *pair[float32] {
|
||||||
if a < 0 || b >= len(runes) {
|
if a != nil && b != nil {
|
||||||
return nil
|
aa := string(runes[a.offset : a.offset+a.size])
|
||||||
}
|
bb := string(runes[b.offset : b.offset+b.size])
|
||||||
|
if id := spm.vocab.Encode(aa + bb); id >= 0 {
|
||||||
left, right := string(merges[a].runes), string(merges[b].runes)
|
return &pair[float32]{a: a, b: b, rank: spm.vocab.Scores[id]}
|
||||||
if id := spm.vocab.Encode(left + right); id >= 0 {
|
|
||||||
return &candidate{
|
|
||||||
a: a,
|
|
||||||
b: b,
|
|
||||||
score: spm.vocab.Scores[id],
|
|
||||||
size: len(left) + len(right),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range len(runes) - 1 {
|
pairs := binaryheap.NewWith(func(i, j *pair[float32]) int { return cmp.Compare(i.rank, j.rank) })
|
||||||
if pair := pairwise(i, i+1); pair != nil {
|
for m := root; m != nil; m = m.next {
|
||||||
heap.Push(q, pair)
|
if pair := pairwise(m, m.next); pair != nil {
|
||||||
|
pairs.Push(pair)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for q.Len() > 0 {
|
for !pairs.Empty() {
|
||||||
pair := heap.Pop(q).(*candidate)
|
p, _ := pairs.Pop()
|
||||||
left, right := merges[pair.a], merges[pair.b]
|
a := string(runes[p.a.offset : p.a.offset+p.a.size])
|
||||||
|
b := string(runes[p.b.offset : p.b.offset+p.b.size])
|
||||||
if string(left.runes) == "" || string(right.runes) == "" || len(string(left.runes))+len(string(right.runes)) != pair.size {
|
if id := spm.vocab.Encode(a + b); a == "" || b == "" || id < 0 || spm.vocab.Scores[id] != p.rank {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
merges[pair.a].runes = append(left.runes, right.runes...)
|
p.a.size += p.b.size
|
||||||
merges[pair.b].runes = nil
|
p.b.size = 0
|
||||||
merges[pair.a].n = right.n
|
|
||||||
if right.n < len(merges) {
|
p.a.next = p.b.next
|
||||||
merges[right.n].p = pair.a
|
if p.b.next != nil {
|
||||||
|
p.b.next.prev = p.a
|
||||||
}
|
}
|
||||||
|
|
||||||
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
|
if pair := pairwise(p.a.prev, p.a); pair != nil {
|
||||||
heap.Push(q, pair)
|
pairs.Push(pair)
|
||||||
}
|
}
|
||||||
|
|
||||||
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
|
if pair := pairwise(p.a, p.a.next); pair != nil {
|
||||||
heap.Push(q, pair)
|
pairs.Push(pair)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, merge := range merges {
|
for m := root; m != nil; m = m.next {
|
||||||
if token := string(merge.runes); token != "" {
|
if s := string(runes[m.offset : m.offset+m.size]); s != "" {
|
||||||
id := spm.vocab.Encode(token)
|
if id := spm.vocab.Encode(s); id >= 0 {
|
||||||
|
|
||||||
if id >= 0 {
|
|
||||||
ids = append(ids, id)
|
ids = append(ids, id)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fallback to byte tokenization
|
|
||||||
var result []int32
|
var result []int32
|
||||||
for _, b := range []byte(token) {
|
for _, b := range []byte(s) {
|
||||||
byteToken := fmt.Sprintf("<0x%02X>", b)
|
byteToken := fmt.Sprintf("<0x%02X>", b)
|
||||||
unknownID := spm.vocab.Encode(byteToken)
|
unknownID := spm.vocab.Encode(byteToken)
|
||||||
if unknownID >= 0 {
|
if unknownID >= 0 {
|
||||||
@@ -189,35 +179,6 @@ func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
|||||||
return ids, nil
|
return ids, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type candidate struct {
|
|
||||||
a, b int
|
|
||||||
score float32
|
|
||||||
size int
|
|
||||||
}
|
|
||||||
|
|
||||||
type queue []*candidate
|
|
||||||
|
|
||||||
func (q queue) Len() int { return len(q) }
|
|
||||||
|
|
||||||
func (q queue) Less(i, j int) bool {
|
|
||||||
return (q[i].score > q[j].score) || (q[i].score == q[j].score && q[i].a < q[j].a)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (q queue) Swap(i, j int) { q[i], q[j] = q[j], q[i] }
|
|
||||||
|
|
||||||
func (q *queue) Push(x interface{}) {
|
|
||||||
item := x.(*candidate)
|
|
||||||
*q = append(*q, item)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (q *queue) Pop() interface{} {
|
|
||||||
old := *q
|
|
||||||
n := len(old)
|
|
||||||
item := old[n-1]
|
|
||||||
*q = old[0 : n-1]
|
|
||||||
return item
|
|
||||||
}
|
|
||||||
|
|
||||||
func (spm SentencePiece) Decode(ids []int32) (string, error) {
|
func (spm SentencePiece) Decode(ids []int32) (string, error) {
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
for _, id := range ids {
|
for _, id := range ids {
|
||||||
|
|||||||
@@ -2394,4 +2394,3 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
|
|||||||
}
|
}
|
||||||
return msgs
|
return msgs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user