Compare commits

...

2 Commits

Author SHA1 Message Date
Patrick Devine
97323d1c68 consolidate the tokenizer (#14327)
This change adds a new x/tokenizer package which includes:
  * New BPE and SentencePiece tokenizers
  * Removing the dependency on the imagegen tokenizers
  * Fixes to multibyte decoding in the pipeline
  * Various correctness and benchmark tests

Not included in this PR is the WordPiece tokenizer for BERT models which will be
added when we add embedding models. The imagegen tokenizers will also be removed in
a follow-up PR.
2026-02-19 15:55:45 -08:00
natl-set
458dd1b9d9 mlx: try loading library via rpath before searching directories (#14322)
The existing code manually searches directories for libmlxc.* and passes
full paths to dlopen, bypassing the binary's rpath. This means MLX
libraries installed via package managers (e.g., Homebrew) aren't found
even when rpath is correctly set at link time.

This change adds a fallback that tries loading via rpath first (using
just the library name), before falling back to the existing directory
search. This follows standard Unix/macOS conventions and works with any
installation that sets rpath.

Fixes library loading on macOS with Homebrew-installed mlx-c without
requiring OLLAMA_LIBRARY_PATH environment variable.

Co-authored-by: Natl <nat@MacBook-Pro.local>
2026-02-19 10:55:02 -08:00
19 changed files with 1836 additions and 16 deletions

View File

@@ -55,6 +55,30 @@ func tryLoadFromDir(dir string) bool {
return false
}
// tryLoadByName attempts to load the library using just its name,
// allowing the system to use rpath, LD_LIBRARY_PATH, or standard search paths.
// Returns true if the library was successfully loaded.
func tryLoadByName() bool {
libraryName := "libmlxc.dylib"
if runtime.GOOS == "linux" {
libraryName = "libmlxc.so"
}
cPath := C.CString(libraryName)
defer C.free(unsafe.Pointer(cPath))
var handle C.mlx_dynamic_handle
if C.mlx_dynamic_load(&handle, cPath) != 0 {
return false
}
if C.mlx_dynamic_load_symbols(handle) != 0 {
C.mlx_dynamic_unload(&handle)
return false
}
return true
}
func init() {
switch runtime.GOOS {
case "darwin":
@@ -73,6 +97,11 @@ func init() {
}
}
// Try loading via rpath/standard library search
if tryLoadByName() {
return
}
// Build search paths: executable directory, then build directories
var searchDirs []string
if exe, err := os.Executable(); err == nil {

View File

@@ -8,10 +8,10 @@ import (
"log/slog"
"sync"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/tokenizer"
)
// Model is the interface that model implementations must satisfy.

View File

@@ -7,7 +7,6 @@ import (
"errors"
"log/slog"
"time"
"unicode/utf8"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
@@ -126,13 +125,5 @@ func (r Runner) Decode(sample int32, b *bytes.Buffer) string {
return ""
}
if text := b.String(); utf8.ValidString(text) {
b.Reset()
return text
} else if b.Len() >= utf8.UTFMax {
b.Reset()
return text
}
return ""
return flushValidUTF8Prefix(b)
}

View File

@@ -12,12 +12,12 @@ import (
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/mlxrunner/sample"
"github.com/ollama/ollama/x/tokenizer"
)
type Request struct {

View File

@@ -0,0 +1,47 @@
package mlxrunner
import (
"bytes"
"unicode/utf8"
)
// flushValidUTF8Prefix returns and consumes the longest valid UTF-8 prefix
// currently buffered, leaving any incomplete trailing bytes in place.
func flushValidUTF8Prefix(b *bytes.Buffer) string {
data := b.Bytes()
if len(data) == 0 {
return ""
}
prefix := validUTF8PrefixLen(data)
if prefix == 0 {
return ""
}
text := string(data[:prefix])
b.Next(prefix)
return text
}
func validUTF8PrefixLen(data []byte) int {
i := 0
prefix := 0
for i < len(data) {
r, size := utf8.DecodeRune(data[i:])
if r == utf8.RuneError && size == 1 {
if !utf8.FullRune(data[i:]) {
break
}
// Invalid UTF-8 byte; consume one byte to guarantee forward progress.
i++
prefix = i
continue
}
i += size
prefix = i
}
return prefix
}

View File

@@ -0,0 +1,46 @@
package mlxrunner
import (
"bytes"
"testing"
)
func TestFlushValidUTF8Prefix_PreservesIncompleteRune(t *testing.T) {
var b bytes.Buffer
b.Write([]byte{0xE3, 0x81})
if got := flushValidUTF8Prefix(&b); got != "" {
t.Fatalf("first flush = %q, want empty", got)
}
b.Write([]byte{0x93, 0xE3})
if got := flushValidUTF8Prefix(&b); got != "こ" {
t.Fatalf("second flush = %q, want %q", got, "こ")
}
if got := b.Bytes(); !bytes.Equal(got, []byte{0xE3}) {
t.Fatalf("buffer after second flush = %v, want %v", got, []byte{0xE3})
}
b.Write([]byte{0x82, 0x93})
if got := flushValidUTF8Prefix(&b); got != "ん" {
t.Fatalf("third flush = %q, want %q", got, "ん")
}
if b.Len() != 0 {
t.Fatalf("buffer not empty after third flush: %d", b.Len())
}
}
func TestFlushValidUTF8Prefix_ValidText(t *testing.T) {
var b bytes.Buffer
b.WriteString("hello 世界")
if got := flushValidUTF8Prefix(&b); got != "hello 世界" {
t.Fatalf("flush = %q, want %q", got, "hello 世界")
}
if b.Len() != 0 {
t.Fatalf("buffer not empty after flush: %d", b.Len())
}
}

View File

@@ -8,12 +8,12 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/nn"
"github.com/ollama/ollama/x/tokenizer"
)
func init() {

View File

@@ -9,12 +9,12 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/nn"
"github.com/ollama/ollama/x/tokenizer"
)
func init() {

View File

@@ -8,12 +8,12 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/nn"
"github.com/ollama/ollama/x/tokenizer"
)
func init() {

View File

@@ -8,12 +8,12 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/nn"
"github.com/ollama/ollama/x/tokenizer"
)
func init() {

108
x/tokenizer/tokenizer.go Normal file
View File

@@ -0,0 +1,108 @@
//go:build mlx
// tokenizer.go - BPE and SentencePiece tokenizer for HuggingFace models
//
// Based on standard BPE algorithm (Sennrich et al. 2015) with:
// - GPT-2 byte-level encoding (OpenAI tiktoken)
// - HuggingFace tokenizer.json pretokenizer patterns
// - SentencePiece ▁-style space handling
package tokenizer
import "regexp"
// TokenizerType identifies the tokenization algorithm
type TokenizerType int
const (
TokenizerBPE TokenizerType = iota // GPT-2 style byte-level BPE
TokenizerSentencePiece // SentencePiece with ▁ for spaces
)
// Vocabulary holds the tokenizer vocabulary and merges
type Vocabulary struct {
Values []string
Reverse map[string]int32
Merges map[string]int
BOS int32
EOS []int32 // Multiple EOS tokens supported (e.g., Gemma has <eos> and <end_of_turn>)
PAD int32 // Padding token (often <|endoftext|> or <pad>)
AddBOS bool
AddEOS bool
// Precomputed byte token IDs for <0xNN> fallback (256 entries, -1 if not found)
byteTokens [256]int32
}
// Tokenizer handles BPE and SentencePiece tokenization
type Tokenizer struct {
vocab *Vocabulary
pretokenizer *regexp.Regexp
specialTokens map[string]int32 // Special tokens for direct lookup
sortedSpecialTokens []string // Special tokens sorted by length, longest first
typ TokenizerType // Algorithm type
}
// Precomputed GPT-2 byte-level encoding table
// Maps byte values to their encoded rune equivalents
var byteToRune [256]rune
func init() {
for b := 0; b < 256; b++ {
r := rune(b)
switch {
case r == 0x00ad:
r = 0x0143
case r <= 0x0020:
r = r + 0x0100
case r >= 0x007f && r <= 0x00a0:
r = r + 0x00a2
}
byteToRune[b] = r
}
}
// VocabSize returns the vocabulary size
func (t *Tokenizer) VocabSize() int {
return len(t.vocab.Values)
}
// BOS returns the beginning of sequence token ID
func (t *Tokenizer) BOS() int32 {
return t.vocab.BOS
}
// EOS returns the first end of sequence token ID (for backwards compatibility)
func (t *Tokenizer) EOS() int32 {
if len(t.vocab.EOS) > 0 {
return t.vocab.EOS[0]
}
return -1
}
// EOSTokens returns all end of sequence token IDs
func (t *Tokenizer) EOSTokens() []int32 {
return t.vocab.EOS
}
// PAD returns the padding token ID, or -1 if not set
func (t *Tokenizer) PAD() int32 {
return t.vocab.PAD
}
// IsEOS returns true if the token ID is an end of sequence token
func (t *Tokenizer) IsEOS(id int32) bool {
for _, eos := range t.vocab.EOS {
if id == eos {
return true
}
}
return false
}
// GetSpecialToken returns the token ID for a special token string
func (t *Tokenizer) GetSpecialToken(name string) (int32, bool) {
id, ok := t.specialTokens[name]
return id, ok
}

View File

@@ -0,0 +1,251 @@
//go:build mlx
package tokenizer
import (
"os"
"path/filepath"
"runtime"
"strings"
"testing"
)
var (
benchmarkSinkIDs []int32
benchmarkSinkStr string
benchmarkSinkTok *Tokenizer
)
const benchmarkWordPieceJSON = `{
"model": {
"type": "WordPiece",
"vocab": {
"[UNK]": 0,
"hello": 1,
"##world": 2,
"##ly": 3,
"##hello": 4
}
},
"added_tokens": []
}`
const benchmarkSentencePieceJSON = `{
"model": {
"type": "BPE",
"vocab": {
"\u2581": 0,
"h": 1,
"e": 2,
"l": 3,
"o": 4,
"w": 5,
"r": 6,
"d": 7,
"<0x0A>": 8
},
"merges": []
},
"decoder": {
"type": "Sequence",
"decoders": [
{
"type": "Replace",
"pattern": {
"String": "\u2581"
}
}
]
},
"added_tokens": []
}`
func benchmarkMiniLlamaPath(tb testing.TB) string {
tb.Helper()
_, filename, _, ok := runtime.Caller(0)
if !ok {
tb.Fatal("failed to resolve benchmark file path")
}
return filepath.Join(filepath.Dir(filename), "..", "imagegen", "tokenizer", "testdata", "mini_llama.json")
}
func benchmarkLoadMiniLlama(tb testing.TB) *Tokenizer {
tb.Helper()
data := benchmarkLoadMiniLlamaBytes(tb)
tok, err := LoadFromBytes(data)
if err != nil {
tb.Fatalf("failed to load mini llama tokenizer: %v", err)
}
return tok
}
func benchmarkLoadMiniLlamaBytes(tb testing.TB) []byte {
tb.Helper()
data, err := os.ReadFile(benchmarkMiniLlamaPath(tb))
if err != nil {
tb.Fatalf("failed to read mini llama tokenizer: %v", err)
}
return data
}
func benchmarkLoadFromBytes(tb testing.TB, data []byte) *Tokenizer {
tb.Helper()
tok, err := LoadFromBytes(data)
if err != nil {
tb.Fatalf("failed to load tokenizer from bytes: %v", err)
}
return tok
}
func BenchmarkTokenizerEncodeBPE(b *testing.B) {
tok := benchmarkLoadMiniLlama(b)
inputs := []struct {
name string
text string
}{
{name: "short", text: "Hello, world!"},
{name: "medium", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 16)},
{name: "long_sequential", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 80)},
{name: "long_parallel", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 160)},
{name: "huge_parallel", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 640)},
{name: "special_tokens", text: "<|begin_of_text|>system\nYou are concise.<|end_of_text|>"},
}
for _, input := range inputs {
b.Run(input.name, func(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(len(input.text)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkSinkIDs = tok.Encode(input.text, false)
}
})
}
}
func BenchmarkTokenizerDecodeBPE(b *testing.B) {
tok := benchmarkLoadMiniLlama(b)
inputs := []struct {
name string
text string
}{
{name: "medium", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 16)},
{name: "long", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 160)},
}
for _, input := range inputs {
ids := tok.Encode(input.text, false)
b.Run(input.name, func(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(len(input.text)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkSinkStr = tok.Decode(ids)
}
})
}
}
func BenchmarkTokenizerLoadFromBytes(b *testing.B) {
data := benchmarkLoadMiniLlamaBytes(b)
config := &TokenizerConfig{
TokenizerConfigJSON: []byte(`{
"bos_token": {"content": "<|begin_of_text|>"},
"eos_token": {"content": "<|end_of_text|>"},
"add_bos_token": true
}`),
GenerationConfigJSON: []byte(`{"bos_token_id": 128000, "eos_token_id": 128001}`),
}
b.Run("without_config", func(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(len(data)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
tok, err := LoadFromBytes(data)
if err != nil {
b.Fatalf("LoadFromBytes failed: %v", err)
}
benchmarkSinkTok = tok
}
})
b.Run("with_config", func(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(len(data)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
tok, err := LoadFromBytesWithConfig(data, config)
if err != nil {
b.Fatalf("LoadFromBytesWithConfig failed: %v", err)
}
benchmarkSinkTok = tok
}
})
}
func BenchmarkTokenizerEncodeWordPiece(b *testing.B) {
tok := benchmarkLoadFromBytes(b, []byte(benchmarkWordPieceJSON))
text := strings.Repeat("helloworldly", 16)
b.ReportAllocs()
b.SetBytes(int64(len(text)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkSinkIDs = tok.Encode(text, false)
}
}
func BenchmarkTokenizerDecodeWordPiece(b *testing.B) {
tok := benchmarkLoadFromBytes(b, []byte(benchmarkWordPieceJSON))
text := strings.Repeat("helloworldly", 16)
ids := tok.Encode(text, false)
b.ReportAllocs()
b.SetBytes(int64(len(text)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkSinkStr = tok.Decode(ids)
}
}
func BenchmarkTokenizerEncodeSentencePiece(b *testing.B) {
tok := benchmarkLoadFromBytes(b, []byte(benchmarkSentencePieceJSON))
text := strings.Repeat("hello world\n", 64)
b.ReportAllocs()
b.SetBytes(int64(len(text)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkSinkIDs = tok.Encode(text, false)
}
}
func BenchmarkTokenizerDecodeSentencePiece(b *testing.B) {
tok := benchmarkLoadFromBytes(b, []byte(benchmarkSentencePieceJSON))
text := strings.Repeat("hello world\n", 64)
ids := tok.Encode(text, false)
b.ReportAllocs()
b.SetBytes(int64(len(text)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkSinkStr = tok.Decode(ids)
}
}

View File

@@ -0,0 +1,175 @@
//go:build mlx
package tokenizer
import "container/heap"
type bpeMergeNode struct {
prev int
next int
token string
}
type bpePair struct {
left int
right int
rank int
value string
}
type bpePairHeap []*bpePair
func (h bpePairHeap) Len() int { return len(h) }
func (h bpePairHeap) Less(i, j int) bool {
return h[i].rank < h[j].rank || (h[i].rank == h[j].rank && h[i].left < h[j].left)
}
func (h bpePairHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *bpePairHeap) Push(x any) {
*h = append(*h, x.(*bpePair))
}
func (h *bpePairHeap) Pop() any {
old := *h
n := len(old)
item := old[n-1]
*h = old[:n-1]
return item
}
// encodeBPEMerge encodes using BPE merge algorithm.
// Uses the heap/linked-list pair merge strategy from tokenizer/bytepairencoding.go:
// merge the lowest-rank valid pair, then only recheck adjacent pairs.
func (t *Tokenizer) encodeBPEMerge(encoded string, ids []int32) []int32 {
runes := []rune(encoded)
if len(runes) == 0 {
return ids
}
nodes := make([]bpeMergeNode, len(runes))
for i := range runes {
nodes[i] = bpeMergeNode{
prev: i - 1,
next: i + 1,
token: string(runes[i]),
}
}
pairwise := func(left, right int) *bpePair {
if left < 0 || right >= len(nodes) {
return nil
}
if nodes[left].token == "" || nodes[right].token == "" {
return nil
}
leftToken, rightToken := nodes[left].token, nodes[right].token
rank, ok := t.vocab.Merges[leftToken+" "+rightToken]
if !ok {
return nil
}
value := leftToken + rightToken
if _, ok := t.vocab.Reverse[value]; !ok {
return nil
}
return &bpePair{
left: left,
right: right,
rank: rank,
value: value,
}
}
pairs := bpePairHeap{}
heap.Init(&pairs)
for i := 0; i < len(runes)-1; i++ {
if pair := pairwise(i, i+1); pair != nil {
heap.Push(&pairs, pair)
}
}
for pairs.Len() > 0 {
pair := heap.Pop(&pairs).(*bpePair)
left, right := nodes[pair.left], nodes[pair.right]
if left.token == "" || right.token == "" {
continue
}
if left.next != pair.right || right.prev != pair.left {
continue
}
if left.token+right.token != pair.value {
continue
}
nodes[pair.left].token = pair.value
nodes[pair.right].token = ""
nodes[pair.left].next = right.next
if right.next < len(nodes) {
nodes[right.next].prev = pair.left
}
if pair := pairwise(nodes[pair.left].prev, pair.left); pair != nil {
heap.Push(&pairs, pair)
}
if pair := pairwise(pair.left, nodes[pair.left].next); pair != nil {
heap.Push(&pairs, pair)
}
}
for _, node := range nodes {
if node.token == "" {
continue
}
if id, ok := t.vocab.Reverse[node.token]; ok {
ids = append(ids, id)
continue
}
ids = t.appendByteFallback(ids, node.token)
}
return ids
}
func (t *Tokenizer) appendByteFallback(ids []int32, token string) []int32 {
if t.typ == TokenizerBPE {
for _, r := range token {
if b, ok := decodeByteLevelRune(r); ok {
if id := t.vocab.byteTokens[b]; id >= 0 {
ids = append(ids, id)
}
}
}
return ids
}
// SentencePiece fallback uses the UTF-8 bytes for <0xNN> tokens.
for _, b := range []byte(token) {
if id := t.vocab.byteTokens[b]; id >= 0 {
ids = append(ids, id)
}
}
return ids
}
func decodeByteLevelRune(r rune) (byte, bool) {
switch {
case r >= 0x00 && r <= 0xFF:
return byte(r), true
case r == 0x0100:
return 0x00, true
case r == 0x0143:
return 0x00ad, true
case r > 0x0100 && r <= 0x0120:
return byte(r - 0x0100), true
case r > 0x0120 && r <= 0x0142:
return byte(r - 0x00a2), true
default:
return 0, false
}
}

View File

@@ -0,0 +1,137 @@
//go:build mlx
package tokenizer
import (
"runtime"
"strings"
"testing"
)
func equalIDs(a, b []int32) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
func TestEncodeRoundtripMiniLlama(t *testing.T) {
tok := benchmarkLoadMiniLlama(t)
inputs := []string{
"",
"hello",
"hello world",
" hello world ",
"don't we'll they're",
"1234567890",
"こんにちは世界",
"Hello 世界",
"func main() {}",
"<|begin_of_text|>system\nYou are concise.<|end_of_text|>",
strings.Repeat("The quick brown fox jumps over the lazy dog. ", 32),
}
for _, input := range inputs {
ids := tok.Encode(input, false)
got := tok.Decode(ids)
if got != input {
t.Fatalf("roundtrip mismatch for %q: got %q", input, got)
}
}
}
func TestSplitBySpecialTokensGreedyLongest(t *testing.T) {
data := []byte(`{
"model": {
"type": "BPE",
"vocab": {"a": 0, "b": 1},
"merges": []
},
"added_tokens": [
{"id": 2, "content": "<tag>", "special": true},
{"id": 3, "content": "<tag>x", "special": true}
]
}`)
tok, err := LoadFromBytes(data)
if err != nil {
t.Fatalf("failed to load tokenizer: %v", err)
}
input := "a<tag>xb"
want := []string{"a", "<tag>x", "b"}
got := tok.splitBySpecialTokens(input)
if len(got) != len(want) {
t.Fatalf("split length mismatch: got %v want %v", got, want)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("split mismatch at %d: got %v want %v", i, got, want)
}
}
}
func TestSplitBySpecialTokensFallbackWithoutCache(t *testing.T) {
data := []byte(`{
"model": {
"type": "BPE",
"vocab": {"a": 0, "b": 1},
"merges": []
},
"added_tokens": [
{"id": 2, "content": "<tag>", "special": true},
{"id": 3, "content": "<tag>x", "special": true}
]
}`)
tok, err := LoadFromBytes(data)
if err != nil {
t.Fatalf("failed to load tokenizer: %v", err)
}
input := "a<tag>xb"
want := []string{"a", "<tag>x", "b"}
// Simulate construction outside loader path where cache is not set.
tok.sortedSpecialTokens = nil
got := tok.splitBySpecialTokens(input)
if len(got) != len(want) {
t.Fatalf("split length mismatch: got %v want %v", got, want)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("split mismatch at %d: got %v want %v", i, got, want)
}
}
}
func TestEncodeDeterministicAcrossGOMAXPROCS(t *testing.T) {
tok := benchmarkLoadMiniLlama(t)
input := strings.Repeat("The quick brown fox jumps over the lazy dog. ", 640)
prev := runtime.GOMAXPROCS(0)
defer runtime.GOMAXPROCS(prev)
runtime.GOMAXPROCS(1)
seq := tok.Encode(input, false)
if prev < 2 {
runtime.GOMAXPROCS(2)
} else {
runtime.GOMAXPROCS(prev)
}
par := tok.Encode(input, false)
if !equalIDs(seq, par) {
t.Fatalf("encode mismatch between sequential and parallel paths: seq=%d par=%d", len(seq), len(par))
}
}

View File

@@ -0,0 +1,56 @@
//go:build mlx
package tokenizer
import (
"strconv"
"strings"
)
// Decode converts token IDs back to text
func (t *Tokenizer) Decode(ids []int32) string {
var sb strings.Builder
for _, id := range ids {
if int(id) >= len(t.vocab.Values) {
continue
}
token := t.vocab.Values[id]
switch t.typ {
case TokenizerSentencePiece:
// SentencePiece style: replace ▁ with space, decode byte tokens
token = strings.ReplaceAll(token, "▁", " ")
// Handle byte fallback tokens like <0x0D>
if len(token) == 6 && token[0] == '<' && token[1] == '0' && token[2] == 'x' && token[5] == '>' {
if v, err := strconv.ParseUint(token[3:5], 16, 8); err == nil {
sb.WriteByte(byte(v))
continue
}
}
sb.WriteString(token)
default:
// GPT-2 BPE style: decode byte-level encoding
for _, r := range token {
switch {
case r == 0x0100:
// Mirror GGML tokenizer behavior for NULL byte.
// 0x00 is omitted during decode.
continue
case r == 0x0143:
r = 0x00ad
case r > 0x0100 && r <= 0x0120:
r = r - 0x0100
case r > 0x0120 && r <= 0x0142:
r = r - 0x00a2
}
// Write as byte, not UTF-8 encoded rune
sb.WriteByte(byte(r))
}
}
}
return sb.String()
}

View File

@@ -0,0 +1,289 @@
//go:build mlx
package tokenizer
import (
"runtime"
"sort"
"strings"
"sync"
"unicode"
"unicode/utf8"
)
const (
encodeParallelMinInputBytes = 4 * 1024
encodeParallelMinChunksPerWorker = 8
)
type tokenMatch struct {
start int
end int
}
type encodeChunk struct {
text string
isSpecial bool
}
// isNonNewlineWhitespace returns true if s contains only whitespace characters (no newlines)
func isNonNewlineWhitespace(s string) bool {
if s == "" {
return false
}
for _, r := range s {
if r == '\n' || r == '\r' {
return false
}
if !unicode.IsSpace(r) {
return false
}
}
return true
}
// splitBySpecialTokens splits text into parts, keeping special tokens as separate elements
func (t *Tokenizer) splitBySpecialTokens(s string) []string {
if len(t.specialTokens) == 0 {
return []string{s}
}
tokens := t.sortedSpecialTokens
if len(tokens) == 0 {
// Fallback for tokenizers constructed outside the loaders.
tokens = make([]string, 0, len(t.specialTokens))
for tok := range t.specialTokens {
tokens = append(tokens, tok)
}
sort.Slice(tokens, func(i, j int) bool {
return len(tokens[i]) > len(tokens[j])
})
}
var result []string
remaining := s
for len(remaining) > 0 {
found := false
for _, tok := range tokens {
if strings.HasPrefix(remaining, tok) {
result = append(result, tok)
remaining = remaining[len(tok):]
found = true
break
}
}
if !found {
// Find next special token position
nextPos := len(remaining)
for _, tok := range tokens {
if idx := strings.Index(remaining, tok); idx != -1 && idx < nextPos {
nextPos = idx
}
}
if nextPos > 0 {
result = append(result, remaining[:nextPos])
}
remaining = remaining[nextPos:]
}
}
return result
}
func adjustWhitespaceBoundary(part string, curr, next *tokenMatch) {
m := part[curr.start:curr.end]
nextText := part[next.start:next.end]
if !isNonNewlineWhitespace(m) || len(nextText) == 0 {
return
}
firstRune, _ := utf8.DecodeRuneInString(nextText)
if !unicode.IsLetter(firstRune) {
return
}
lastSpaceStart := curr.end
for j := curr.end; j > curr.start; {
r, size := utf8.DecodeLastRuneInString(part[curr.start:j])
if unicode.IsSpace(r) {
lastSpaceStart = j - size
break
}
j -= size
}
if lastSpaceStart > curr.start {
curr.end = lastSpaceStart
next.start = lastSpaceStart
} else {
next.start = curr.start
curr.end = curr.start
}
}
func (t *Tokenizer) forEachPartChunk(part string, fn func(encodeChunk)) {
if _, ok := t.specialTokens[part]; ok {
fn(encodeChunk{text: part, isSpecial: true})
return
}
if t.pretokenizer == nil {
fn(encodeChunk{text: part, isSpecial: false})
return
}
re := t.pretokenizer
offset := 0
loc := re.FindStringIndex(part[offset:])
if loc == nil {
return
}
curr := tokenMatch{start: offset + loc[0], end: offset + loc[1]}
offset += loc[1]
for {
loc = re.FindStringIndex(part[offset:])
if loc == nil {
if curr.end > curr.start {
fn(encodeChunk{text: part[curr.start:curr.end], isSpecial: false})
}
return
}
next := tokenMatch{start: offset + loc[0], end: offset + loc[1]}
offset += loc[1]
adjustWhitespaceBoundary(part, &curr, &next)
if curr.end > curr.start {
fn(encodeChunk{text: part[curr.start:curr.end], isSpecial: false})
}
curr = next
}
}
func (t *Tokenizer) appendEncodedChunk(ids []int32, c encodeChunk) []int32 {
if c.isSpecial {
if id, ok := t.specialTokens[c.text]; ok {
return append(ids, id)
}
return ids
}
return t.encodeChunkInto(c.text, ids)
}
// Encode tokenizes text to token IDs.
// Parallel encoding is used only for very large inputs with enough chunks per worker.
func (t *Tokenizer) Encode(s string, addBOS bool) []int32 {
// First: split by special tokens
parts := t.splitBySpecialTokens(s)
// Fast path: encode sequentially without materializing chunk slices.
if len(s) < encodeParallelMinInputBytes {
var ids []int32
for _, part := range parts {
t.forEachPartChunk(part, func(c encodeChunk) {
ids = t.appendEncodedChunk(ids, c)
})
}
if addBOS && t.vocab.BOS >= 0 {
ids = append([]int32{t.vocab.BOS}, ids...)
}
return ids
}
// For large inputs collect chunks to enable parallel processing.
var allChunks []encodeChunk
for _, part := range parts {
t.forEachPartChunk(part, func(c encodeChunk) {
allChunks = append(allChunks, c)
})
}
// Encode chunks. Use the parallel path only when the chunk count is
// large enough to amortize goroutine/synchronization overhead.
useParallel := true
numWorkers := runtime.GOMAXPROCS(0)
if numWorkers > len(allChunks) {
numWorkers = len(allChunks)
}
if numWorkers < 2 || len(allChunks) < numWorkers*encodeParallelMinChunksPerWorker {
useParallel = false
}
var ids []int32
if !useParallel {
for _, c := range allChunks {
ids = t.appendEncodedChunk(ids, c)
}
} else {
chunksPer := (len(allChunks) + numWorkers - 1) / numWorkers
results := make([][]int32, numWorkers)
var wg sync.WaitGroup
for i := 0; i < numWorkers; i++ {
start := i * chunksPer
end := start + chunksPer
if end > len(allChunks) {
end = len(allChunks)
}
if start >= end {
continue
}
wg.Add(1)
go func(i int, chunks []encodeChunk) {
defer wg.Done()
var r []int32
for _, c := range chunks {
r = t.appendEncodedChunk(r, c)
}
results[i] = r
}(i, allChunks[start:end])
}
wg.Wait()
for _, r := range results {
ids = append(ids, r...)
}
}
if addBOS && t.vocab.BOS >= 0 {
ids = append([]int32{t.vocab.BOS}, ids...)
}
return ids
}
// encodeChunkInto appends encoded tokens to ids and returns the extended slice.
// Uses BPE merge algorithm for both BPE and SentencePiece tokenization.
func (t *Tokenizer) encodeChunkInto(s string, ids []int32) []int32 {
if s == "" {
return ids
}
// Apply encoding transformation
// SentencePiece: replace space with ▁
// BPE: convert bytes using precomputed table (GPT-2 byte-level encoding)
var encoded string
if t.typ == TokenizerSentencePiece {
encoded = strings.ReplaceAll(s, " ", "▁")
} else {
var sb strings.Builder
sb.Grow(len(s) * 2)
for i := 0; i < len(s); i++ {
sb.WriteRune(byteToRune[s[i]])
}
encoded = sb.String()
}
// Fast path: check if entire chunk is a single token
if id, ok := t.vocab.Reverse[encoded]; ok {
return append(ids, id)
}
return t.encodeBPEMerge(encoded, ids)
}

View File

@@ -0,0 +1,207 @@
//go:build mlx
package tokenizer
import (
"bufio"
"encoding/json"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
)
func llama32GGMLFixturePath(tb testing.TB, file string) string {
tb.Helper()
_, filename, _, ok := runtime.Caller(0)
if !ok {
tb.Fatal("failed to resolve test file path")
}
return filepath.Join(filepath.Dir(filename), "..", "..", "tokenizer", "testdata", "llama3.2", file)
}
func loadLlama32FromGGMLFixture(tb testing.TB) *Tokenizer {
tb.Helper()
f, err := os.Open(llama32GGMLFixturePath(tb, "encoder.json"))
if err != nil {
tb.Fatalf("failed to open encoder.json: %v", err)
}
defer f.Close()
vocab := make(map[string]int32)
if err := json.NewDecoder(f).Decode(&vocab); err != nil {
tb.Fatalf("failed to decode encoder.json: %v", err)
}
type addedToken struct {
ID int32 `json:"id"`
Content string `json:"content"`
Special bool `json:"special"`
}
var addedTokens []addedToken
for _, token := range []string{"<|begin_of_text|>", "<|end_of_text|>"} {
if _, ok := vocab[token]; !ok {
id := int32(len(vocab))
vocab[token] = id
addedTokens = append(addedTokens, addedToken{ID: id, Content: token, Special: true})
}
}
mf, err := os.Open(llama32GGMLFixturePath(tb, "vocab.bpe"))
if err != nil {
tb.Fatalf("failed to open vocab.bpe: %v", err)
}
defer mf.Close()
var merges []string
scanner := bufio.NewScanner(mf)
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "#") {
continue
}
line = strings.TrimSpace(line)
if line != "" {
merges = append(merges, line)
}
}
if err := scanner.Err(); err != nil {
tb.Fatalf("failed to read vocab.bpe: %v", err)
}
payload := struct {
Model struct {
Type string `json:"type"`
Vocab map[string]int32 `json:"vocab"`
Merges []string `json:"merges"`
} `json:"model"`
PreTokenizer struct {
Type string `json:"type"`
Pretokenizers []struct {
Type string `json:"type"`
Pattern struct {
Regex string `json:"Regex"`
} `json:"pattern"`
} `json:"pretokenizers"`
} `json:"pre_tokenizer"`
AddedTokens []addedToken `json:"added_tokens"`
}{}
payload.Model.Type = "BPE"
payload.Model.Vocab = vocab
payload.Model.Merges = merges
payload.PreTokenizer.Type = "Sequence"
payload.PreTokenizer.Pretokenizers = []struct {
Type string `json:"type"`
Pattern struct {
Regex string `json:"Regex"`
} `json:"pattern"`
}{
{
Type: "Split",
Pattern: struct {
Regex string `json:"Regex"`
}{
Regex: `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
},
},
}
payload.AddedTokens = addedTokens
data, err := json.Marshal(payload)
if err != nil {
tb.Fatalf("failed to marshal synthetic tokenizer.json: %v", err)
}
tok, err := LoadFromBytes(data)
if err != nil {
tb.Fatalf("failed to load tokenizer from fixture data: %v", err)
}
return tok
}
func TestGGMLLlamaKnownEncodings(t *testing.T) {
tok := loadLlama32FromGGMLFixture(t)
cases := map[string][]int32{
"hello world": {15339, 1917},
"hello <|end_of_text|>": {15339, 220, 128001},
"<|begin_of_text|>A B!": {128000, 32, 426, 0},
"<|begin_of_text|>A<|end_of_text|>B!": {128000, 32, 128001, 33, 0},
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!": {128000, 32, 128001, 33, 128000, 0},
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!<|end_of_text|>": {128000, 32, 128001, 33, 128000, 0, 128001},
}
for input, want := range cases {
got := tok.Encode(input, false)
if !equalIDs(got, want) {
t.Fatalf("encode mismatch for %q:\n got: %v\n want: %v", input, got, want)
}
}
}
func TestGGMLLlamaRepeatedZeros(t *testing.T) {
tok := loadLlama32FromGGMLFixture(t)
cases := map[int][]int32{
1: {15},
2: {410},
3: {931},
4: {931, 15},
5: {931, 410},
6: {931, 931},
7: {931, 931, 15},
8: {931, 931, 410},
9: {931, 931, 931},
10: {931, 931, 931, 15},
11: {931, 931, 931, 410},
12: {931, 931, 931, 931},
13: {931, 931, 931, 931, 15},
14: {931, 931, 931, 931, 410},
15: {931, 931, 931, 931, 931},
16: {931, 931, 931, 931, 931, 15},
17: {931, 931, 931, 931, 931, 410},
}
for n, want := range cases {
input := strings.Repeat("0", n)
got := tok.Encode(input, false)
if !equalIDs(got, want) {
t.Fatalf("encode mismatch for %q:\n got: %v\n want: %v", input, got, want)
}
}
}
func TestGGMLLlamaRoundtripAndByteBehavior(t *testing.T) {
tok := loadLlama32FromGGMLFixture(t)
cases := []string{
"hello",
"hello ",
"hello ",
" hello",
" hello ",
" hello ",
"hello world",
"请考试我的软件12345",
}
for _, input := range cases {
ids := tok.Encode(input, false)
got := tok.Decode(ids)
if got != input {
t.Fatalf("roundtrip mismatch for %q: got %q", input, got)
}
}
// Match GGML tokenizer behavior: 0x00 is omitted when decoding.
ids := tok.Encode(string(rune(0x00)), false)
got := tok.Decode(ids)
if got != "" {
t.Fatalf("expected empty decode for 0x00, got %q (ids=%v)", got, ids)
}
}

View File

@@ -0,0 +1,458 @@
//go:build mlx
package tokenizer
import (
"encoding/json"
"fmt"
"regexp"
"sort"
"strings"
)
// TokenizerConfig holds optional configuration data that can be passed to LoadFromBytesWithConfig.
type TokenizerConfig struct {
TokenizerConfigJSON []byte // tokenizer_config.json content
GenerationConfigJSON []byte // generation_config.json content
SpecialTokensMapJSON []byte // special_tokens_map.json content
ConfigJSON []byte // config.json content
}
// LoadFromBytes loads a tokenizer from tokenizer.json bytes.
// This is useful when loading from blob storage where the file content is already in memory.
// Note: This won't load special token config from companion files. Use LoadFromBytesWithConfig
// to provide tokenizer_config.json data for proper PAD/EOS token loading.
func LoadFromBytes(data []byte) (*Tokenizer, error) {
return loadFromTokenizerJSON(data)
}
// LoadFromBytesWithConfig loads a tokenizer from tokenizer.json bytes with additional config files.
// This is useful when loading from blob storage where companion config files are also blobs.
func LoadFromBytesWithConfig(data []byte, config *TokenizerConfig) (*Tokenizer, error) {
t, err := loadFromTokenizerJSON(data)
if err != nil {
return nil, err
}
if config == nil {
return t, nil
}
// Apply special token configs from provided data
loadSpecialTokenConfigFromBytes(t, config)
return t, nil
}
// loadFromTokenizerJSON parses tokenizer.json content from bytes.
func loadFromTokenizerJSON(data []byte) (*Tokenizer, error) {
var raw struct {
Model struct {
Type string `json:"type"` // "BPE"
Vocab map[string]int32 `json:"vocab"`
Merges json.RawMessage `json:"merges"` // Can be []string or [][]string (BPE only)
} `json:"model"`
PreTokenizer json.RawMessage `json:"pre_tokenizer"`
Decoder json.RawMessage `json:"decoder"`
AddedTokens []struct {
ID int32 `json:"id"`
Content string `json:"content"`
Special bool `json:"special"`
} `json:"added_tokens"`
}
if err := json.Unmarshal(data, &raw); err != nil {
return nil, fmt.Errorf("failed to parse tokenizer: %w", err)
}
// Covers SentencePiece and BPE models
if raw.Model.Type != "BPE" {
return nil, fmt.Errorf("unsupported tokenizer type: %s", raw.Model.Type)
}
// Parse merges - can be []string (Llama) or [][]string (GPT-OSS).
var mergesStrings []string
if raw.Model.Merges != nil {
var mergesArrays [][]string
if err := json.Unmarshal(raw.Model.Merges, &mergesStrings); err != nil {
// Try array of arrays format
if err := json.Unmarshal(raw.Model.Merges, &mergesArrays); err != nil {
return nil, fmt.Errorf("failed to parse merges: %w", err)
}
// Convert [][]string to []string
mergesStrings = make([]string, len(mergesArrays))
for i, pair := range mergesArrays {
if len(pair) != 2 {
return nil, fmt.Errorf("failed to parse merges: expected merge pair of length 2, got %d", len(pair))
}
mergesStrings[i] = pair[0] + " " + pair[1]
}
}
}
// Build tokenizer
t := &Tokenizer{
vocab: &Vocabulary{
Values: make([]string, len(raw.Model.Vocab)),
Reverse: raw.Model.Vocab,
Merges: make(map[string]int, len(mergesStrings)),
BOS: -1,
PAD: -1,
},
specialTokens: make(map[string]int32),
}
// Build values array
for token, id := range raw.Model.Vocab {
if int(id) >= len(t.vocab.Values) {
newValues := make([]string, id+1)
copy(newValues, t.vocab.Values)
t.vocab.Values = newValues
}
t.vocab.Values[id] = token
}
// Build merges map
for i, merge := range mergesStrings {
t.vocab.Merges[merge] = i
}
// Add all added_tokens to vocabulary and special tokens map.
// HuggingFace treats ALL added_tokens as special for tokenization purposes -
// they bypass BPE and get their own token ID. The "special" flag just indicates
// if it's a "truly special" token like BOS/EOS/PAD, but for tokenization we need
// to treat all added_tokens as special to match HuggingFace behavior.
for _, tok := range raw.AddedTokens {
if int(tok.ID) >= len(t.vocab.Values) {
newValues := make([]string, tok.ID+1)
copy(newValues, t.vocab.Values)
t.vocab.Values = newValues
}
t.vocab.Values[tok.ID] = tok.Content
t.specialTokens[tok.Content] = tok.ID // Add ALL added_tokens to special tokens
}
// Precompute byte token IDs for <0xNN> fallback
initByteTokens(t)
// Determine tokenizer type
switch {
case detectSentencePiece(raw.Decoder):
t.typ = TokenizerSentencePiece
default:
t.typ = TokenizerBPE
}
// Parse and compile pretokenizer pattern (BPE only - SentencePiece doesn't use pretokenizer)
if t.typ == TokenizerBPE {
pattern := extractPretokenizer(raw.PreTokenizer)
if pattern == "" {
pattern = `'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`
}
re, err := regexp.Compile(rewritePatternForRE2(pattern))
if err != nil {
return nil, fmt.Errorf("failed to compile pretokenizer regex %q: %w", pattern, err)
}
t.pretokenizer = re
}
cacheSortedSpecialTokens(t)
return t, nil
}
func cacheSortedSpecialTokens(t *Tokenizer) {
if len(t.specialTokens) == 0 {
t.sortedSpecialTokens = nil
return
}
tokens := make([]string, 0, len(t.specialTokens))
for tok := range t.specialTokens {
tokens = append(tokens, tok)
}
sort.Slice(tokens, func(i, j int) bool {
return len(tokens[i]) > len(tokens[j])
})
t.sortedSpecialTokens = tokens
}
type specialTokenConfigData struct {
tokenizerConfigJSON []byte
generationConfigJSON []byte
specialTokensMapJSON []byte
configJSON []byte
}
func applySpecialTokenConfig(t *Tokenizer, config specialTokenConfigData) {
parseTokenIDs := func(v interface{}) []int32 {
switch val := v.(type) {
case float64:
return []int32{int32(val)}
case []interface{}:
ids := make([]int32, 0, len(val))
for _, id := range val {
if f, ok := id.(float64); ok {
ids = append(ids, int32(f))
}
}
return ids
}
return nil
}
// Priority 1: generation_config.json
if len(config.generationConfigJSON) > 0 {
var genConfig struct {
EOSTokenID interface{} `json:"eos_token_id"`
BOSTokenID interface{} `json:"bos_token_id"`
}
if err := json.Unmarshal(config.generationConfigJSON, &genConfig); err == nil {
if ids := parseTokenIDs(genConfig.EOSTokenID); len(ids) > 0 {
t.vocab.EOS = ids
}
if ids := parseTokenIDs(genConfig.BOSTokenID); len(ids) > 0 {
t.vocab.BOS = ids[0]
}
}
}
// Priority 2: config.json
if len(config.configJSON) > 0 && (len(t.vocab.EOS) == 0 || t.vocab.BOS < 0) {
var modelConfig struct {
EOSTokenID interface{} `json:"eos_token_id"`
BOSTokenID interface{} `json:"bos_token_id"`
}
if err := json.Unmarshal(config.configJSON, &modelConfig); err == nil {
if len(t.vocab.EOS) == 0 {
if ids := parseTokenIDs(modelConfig.EOSTokenID); len(ids) > 0 {
t.vocab.EOS = ids
}
}
if t.vocab.BOS < 0 {
if ids := parseTokenIDs(modelConfig.BOSTokenID); len(ids) > 0 {
t.vocab.BOS = ids[0]
}
}
}
}
// Priority 3: tokenizer_config.json
if len(config.tokenizerConfigJSON) > 0 {
var tokConfig struct {
BOSToken interface{} `json:"bos_token"`
EOSToken interface{} `json:"eos_token"`
PADToken interface{} `json:"pad_token"`
AddBOSToken *bool `json:"add_bos_token"`
AddEOSToken *bool `json:"add_eos_token"`
}
if err := json.Unmarshal(config.tokenizerConfigJSON, &tokConfig); err == nil {
if t.vocab.BOS < 0 {
if bosStr := extractTokenString(tokConfig.BOSToken); bosStr != "" {
if id, ok := t.specialTokens[bosStr]; ok {
t.vocab.BOS = id
}
}
}
if len(t.vocab.EOS) == 0 {
if eosStr := extractTokenString(tokConfig.EOSToken); eosStr != "" {
if id, ok := t.specialTokens[eosStr]; ok {
t.vocab.EOS = []int32{id}
}
}
}
if t.vocab.PAD < 0 {
if padStr := extractTokenString(tokConfig.PADToken); padStr != "" {
if id, ok := t.specialTokens[padStr]; ok {
t.vocab.PAD = id
}
}
}
if tokConfig.AddBOSToken != nil {
t.vocab.AddBOS = *tokConfig.AddBOSToken
}
if tokConfig.AddEOSToken != nil {
t.vocab.AddEOS = *tokConfig.AddEOSToken
}
}
}
// Priority 4: special_tokens_map.json
if len(config.specialTokensMapJSON) > 0 {
var tokensMap map[string]interface{}
if err := json.Unmarshal(config.specialTokensMapJSON, &tokensMap); err == nil {
if t.vocab.BOS < 0 {
if bosStr := extractTokenString(tokensMap["bos_token"]); bosStr != "" {
if id, ok := t.specialTokens[bosStr]; ok {
t.vocab.BOS = id
}
}
}
if len(t.vocab.EOS) == 0 {
if eosStr := extractTokenString(tokensMap["eos_token"]); eosStr != "" {
if id, ok := t.specialTokens[eosStr]; ok {
t.vocab.EOS = []int32{id}
}
}
}
if t.vocab.PAD < 0 {
if padStr := extractTokenString(tokensMap["pad_token"]); padStr != "" {
if id, ok := t.specialTokens[padStr]; ok {
t.vocab.PAD = id
}
}
}
}
}
}
// extractTokenString extracts the token string from various formats used in HuggingFace configs.
// Tokens can be represented as:
// - string: "token"
// - object: {"content": "token", ...}
func extractTokenString(v interface{}) string {
if v == nil {
return ""
}
// Direct string
if s, ok := v.(string); ok {
return s
}
// Object with content field
if m, ok := v.(map[string]interface{}); ok {
if content, ok := m["content"].(string); ok {
return content
}
}
return ""
}
// rewritePatternForRE2 rewrites HuggingFace pretokenizer regex patterns to be
// compatible with Go's regexp package (RE2). HuggingFace patterns use PCRE features:
// - (?!\S) negative lookahead - RE2 doesn't support this
// - (?i:...) inline case-insensitive groups - RE2 doesn't support this
//
// We replace \s+(?!\S)|\s+ with \s+ and fix whitespace boundaries in encodeWithRegex().
// The lookahead version splits "a b" into ["a", " ", " b"] (space prepended to word).
// Simple \s+ would give ["a", " ", "b"]. We post-process to match Python's behavior.
func rewritePatternForRE2(pattern string) string {
// Replace lookahead pattern with simple \s+ - we fix boundaries in encodeWithRegex()
pattern = strings.ReplaceAll(pattern, `\s+(?!\S)|\s+`, `\s+`)
// Handle the pattern when it appears with a ? suffix (optional contractions in GPT-4o style)
// IMPORTANT: Must be done before the non-optional version to avoid partial replacement
pattern = strings.ReplaceAll(pattern,
`(?i:'s|'t|'re|'ve|'m|'ll|'d)?`,
`(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?`)
// Expand case-insensitive contraction pattern to explicit alternations
// (?i:'s|'t|'re|'ve|'m|'ll|'d) -> '[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD]
pattern = strings.ReplaceAll(pattern,
`(?i:'s|'t|'re|'ve|'m|'ll|'d)`,
`(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])`)
return pattern
}
// loadSpecialTokenConfigFromBytes loads special token configuration from byte slices.
func loadSpecialTokenConfigFromBytes(t *Tokenizer, config *TokenizerConfig) {
applySpecialTokenConfig(t, specialTokenConfigData{
tokenizerConfigJSON: config.TokenizerConfigJSON,
generationConfigJSON: config.GenerationConfigJSON,
specialTokensMapJSON: config.SpecialTokensMapJSON,
configJSON: config.ConfigJSON,
})
}
// detectSentencePiece checks if the decoder uses SentencePiece-style (▁ for spaces)
// vs GPT-2 byte-level encoding
func detectSentencePiece(data json.RawMessage) bool {
if data == nil {
return false
}
// Check for Sequence decoder with Replace step (SentencePiece style)
var seq struct {
Type string `json:"type"`
Decoders []struct {
Type string `json:"type"`
Pattern struct {
String string `json:"String"`
} `json:"pattern"`
} `json:"decoders"`
}
if err := json.Unmarshal(data, &seq); err == nil {
if seq.Type == "Sequence" {
for _, dec := range seq.Decoders {
// Look for Replace decoder that converts ▁ to space
if dec.Type == "Replace" && dec.Pattern.String == "▁" {
return true
}
}
}
}
// Check for direct ByteLevel decoder (GPT-2 style)
var simple struct {
Type string `json:"type"`
}
if err := json.Unmarshal(data, &simple); err == nil {
if simple.Type == "ByteLevel" {
return false
}
}
return false
}
// initByteTokens precomputes byte token IDs for <0xNN> fallback encoding
func initByteTokens(t *Tokenizer) {
for i := range t.vocab.byteTokens {
t.vocab.byteTokens[i] = -1
}
for b := 0; b < 256; b++ {
token := fmt.Sprintf("<0x%02X>", b)
if id, ok := t.vocab.Reverse[token]; ok {
t.vocab.byteTokens[b] = id
}
}
}
// extractPretokenizer extracts the regex pattern from the pre_tokenizer config
func extractPretokenizer(data json.RawMessage) string {
if data == nil {
return ""
}
// Try to parse as a single Split pretokenizer
var single struct {
Type string `json:"type"`
Pattern struct {
Regex string `json:"Regex"`
} `json:"pattern"`
}
if err := json.Unmarshal(data, &single); err == nil && single.Pattern.Regex != "" {
return single.Pattern.Regex
}
// Try to parse as Sequence of pretokenizers - use first Split pattern
var seq struct {
Type string `json:"type"`
Pretokenizers []struct {
Type string `json:"type"`
Pattern struct {
Regex string `json:"Regex"`
} `json:"pattern"`
} `json:"pretokenizers"`
}
if err := json.Unmarshal(data, &seq); err == nil && seq.Type == "Sequence" {
for _, pt := range seq.Pretokenizers {
if pt.Type == "Split" && pt.Pattern.Regex != "" {
return pt.Pattern.Regex
}
}
}
return ""
}

View File

@@ -0,0 +1,26 @@
//go:build mlx
package tokenizer
import (
"strings"
"testing"
)
func TestLoadFromBytesRejectsWordPiece(t *testing.T) {
data := []byte(`{
"model": {
"type": "WordPiece",
"vocab": {"[UNK]": 0, "hello": 1}
},
"added_tokens": []
}`)
_, err := LoadFromBytes(data)
if err == nil {
t.Fatal("expected WordPiece load to fail")
}
if !strings.Contains(err.Error(), "unsupported tokenizer type: WordPiece") {
t.Fatalf("unexpected error: %v", err)
}
}