Compare commits

..

8 Commits

Author SHA1 Message Date
Michael Yang
fcfbb06f1b cmd: handle sigint globally
This change also updates both client.do and client.stream to return
ctx.Err(). Previously this error is skipped so canceled contexts are
silently ignored
2025-02-19 10:46:25 -08:00
Michael Yang
e8d35d0de0 cmd: fix hide cursor
hides the cursor for the entire progress rather than each render cycle
2025-02-19 09:43:44 -08:00
Michael Yang
e13e7c8d94 Merge pull request #9079 from jeremyschlatter/main
cmd: fix flickering in progress bar
2025-02-18 22:59:29 +00:00
Jeremy Schlatter
78f403ff45 address code review comments 2025-02-18 14:50:09 -08:00
Michael Yang
08a299e1d0 cmake: avoid building intel backends on linux 2025-02-18 22:17:00 +00:00
Jeremy Schlatter
f9c7ead160 cmd: eliminate flickering with synchronized output 2025-02-17 20:01:03 -08:00
Jeremy Schlatter
5930aaeb1a cmd: fix cursor flickering in progress bar
The previous commit fixed flickering in the progress bar itself. Cursor
flickering is harder to address.

Cursor flickering could be fixed by hiding the cursor altogether while
the progress bar is displayed. The downside of this is that if the
program is killed in such a way that it can't clean up its state, it
would leave the cursor invisible.

Instead, this commit introduces an output buffer. All of the escape
codes and content for a single progress update are written to a buffer,
which is then flushed to the terminal all at once. This significantly
decreases the time during which the terminal has seen the cursor-hiding
code but has not yet seen the cursor-showing code, thus minimizing (but
not 100% eliminating) cursor flickering.

For more context, see:
https://gitlab.gnome.org/GNOME/vte/-/issues/2837#note_2269501
2025-02-17 14:56:57 -08:00
Jeremy Schlatter
faf67db089 cmd: fix progress bar flickering
Previous code cleared the display before writing new content, creating a
window where the terminal could (and in some cases did) render empty lines.

Instead, we now write new content over the old content, only clearing
the trailing end of lines for cases where the new line is shorter.

Fixes #1664
2025-02-17 13:39:02 -08:00
17 changed files with 61 additions and 513 deletions

View File

@@ -24,7 +24,7 @@ set(GGML_LLAMAFILE ON)
set(GGML_CUDA_PEER_MAX_BATCH_SIZE 128)
set(GGML_CUDA_GRAPHS ON)
if((NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
OR (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_SYSTEM_PROCESSOR MATCHES "arm|aarch64|ARM64|ARMv[0-9]+"))
set(GGML_CPU_ALL_VARIANTS ON)
endif()

View File

@@ -126,7 +126,8 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
return err
}
}
return nil
return ctx.Err()
}
const maxBufferSize = 512 * format.KiloByte
@@ -189,7 +190,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
}
}
return nil
return ctx.Err()
}
// GenerateResponseFunc is a function that [Client.Generate] invokes every time

View File

@@ -15,13 +15,11 @@ import (
"net"
"net/http"
"os"
"os/signal"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync/atomic"
"syscall"
"time"
"github.com/containerd/console"
@@ -330,6 +328,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
if err := PullHandler(cmd, []string{name}); err != nil {
return nil, err
}
return client.Show(cmd.Context(), &api.ShowRequest{Name: name})
}
return info, err
@@ -858,17 +857,6 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
spinner := progress.NewSpinner("")
p.Add("", spinner)
cancelCtx, cancel := context.WithCancel(cmd.Context())
defer cancel()
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT)
go func() {
<-sigChan
cancel()
}()
var state *displayResponseState = &displayResponseState{}
var latest api.ChatResponse
var fullResponse strings.Builder
@@ -903,10 +891,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
req.KeepAlive = opts.KeepAlive
}
if err := client.Chat(cancelCtx, req, fn); err != nil {
if errors.Is(err, context.Canceled) {
return nil, nil
}
if err := client.Chat(cmd.Context(), req, fn); err != nil {
return nil, err
}
@@ -946,17 +931,6 @@ func generate(cmd *cobra.Command, opts runOptions) error {
generateContext = []int{}
}
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT)
go func() {
<-sigChan
cancel()
}()
var state *displayResponseState = &displayResponseState{}
fn := func(response api.GenerateResponse) error {
@@ -992,10 +966,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
KeepAlive: opts.KeepAlive,
}
if err := client.Generate(ctx, &request, fn); err != nil {
if errors.Is(err, context.Canceled) {
return nil
}
if err := client.Generate(cmd.Context(), &request, fn); err != nil {
return err
}
@@ -1017,8 +988,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
latest.Summary()
}
ctx = context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context)
cmd.SetContext(ctx)
cmd.SetContext(context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context))
return nil
}

View File

@@ -120,15 +120,6 @@ func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
return s
}
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
r := keyValue(kv, key, &array{})
s := make([]float32, r.size)
for i := range r.size {
s[i] = float32(r.values[i].(float32))
}
return s
}
func keyValue[T string | uint32 | uint64 | float32 | *array](kv KV, key string, defaultValue ...T) T {
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
key = kv.Architecture() + "." + key

1
go.mod
View File

@@ -18,7 +18,6 @@ require (
github.com/agnivade/levenshtein v1.1.1
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
github.com/dlclark/regexp2 v1.11.4
github.com/emirpasic/gods v1.18.1
github.com/emirpasic/gods/v2 v2.0.0-alpha
github.com/google/go-cmp v0.6.0
github.com/mattn/go-runewidth v0.0.14

2
go.sum
View File

@@ -44,8 +44,6 @@ github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48 h1:fRzb/w+pyskVMQ+
github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48/go.mod h1:if7Fbed8SFyPtHLHbg49SI7NAdJiC5WIA09pe59rfAA=
github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo=
github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
github.com/emirpasic/gods/v2 v2.0.0-alpha h1:dwFlh8pBg1VMOXWGipNMRt8v96dKAIvBehtCt6OtunU=
github.com/emirpasic/gods/v2 v2.0.0-alpha/go.mod h1:W0y4M2dtBB9U5z3YlghmpuUhiaZT2h6yoeE+C1sCp6A=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=

View File

@@ -434,7 +434,7 @@ func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0
panic("not implemented")
}
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim, ropeType uint32, base, scale float32) ml.Tensor {
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim uint32, base, scale float32) ml.Tensor {
panic("not implemented")
}

14
main.go
View File

@@ -2,6 +2,8 @@ package main
import (
"context"
"os"
"os/signal"
"github.com/spf13/cobra"
@@ -9,5 +11,15 @@ import (
)
func main() {
cobra.CheckErr(cmd.NewCLI().ExecuteContext(context.Background()))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt)
go func() {
<-sigChan
cancel()
}()
cobra.CheckErr(cmd.NewCLI().ExecuteContext(ctx))
}

View File

@@ -17,7 +17,6 @@ type Config interface {
Strings(string, ...[]string) []string
Uints(string, ...[]uint32) []uint32
Floats(string, ...[]float32) []float32
}
type Backend interface {
@@ -77,7 +76,7 @@ type Tensor interface {
Scale(ctx Context, s float64) Tensor
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim uint32, base, scale float32) Tensor
Tanh(ctx Context) Tensor
GELU(ctx Context) Tensor

View File

@@ -596,13 +596,10 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
}
const (
ropeTypeNorm C.int = 0
ropeTypeNeox C.int = 2
ropeTypeMrope C.int = 8
ropeTypeVision C.int = 24
ropeTypeNorm C.int = iota
)
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor {
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, ropeBase, ropeScale float32) ml.Tensor {
if ropeFactors == nil {
ropeFactors = &Tensor{}
}
@@ -616,8 +613,8 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
t: C.ggml_rope_ext(
ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
C.int(ropeDim),
C.int(ropeType),
131072, // YaRN n_ctx_train
131072, // YaRN n_ctx_train
ropeTypeNorm, // ROPE_TYPE_NORM
C.float(ropeBase),
C.float(ropeScale),
0., // YaRN ext_factor

View File

@@ -1,193 +0,0 @@
package gemma2
import (
"math"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
)
type Options struct {
hiddenSize, numHeads, numKVHeads int
attnKeyLen, attnValLen int
eps, ropeBase, ropeScale float32
attnLogitSoftcap float32
finalLogitSoftcap float32
}
type Model struct {
model.Base
model.SentencePieceModel
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"` // is this supposed to be root means square?
Output *nn.Linear `gguf:"output,alt:token_embd"` // just set to token_embd?
*Options
}
func New(c ml.Config) (model.Model, error) {
m := Model{
SentencePieceModel: model.NewSentencePieceModel(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Uints("tokenizer.ggml.token_type"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
},
),
Layers: make([]Layer, c.Uint("block_count")),
Options: &Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
attnKeyLen: int(c.Uint("attention.key_length")),
attnValLen: int(c.Uint("attention.value_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base", 10000.0),
ropeScale: c.Float("rope.freq_scale", 1.0),
attnLogitSoftcap: c.Float("attn_logit_softcapping"),
finalLogitSoftcap: c.Float("final_logit_softcapping"),
},
}
slidingWindowLen := int32(c.Uint("attention.sliding_window"))
m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
return &m, nil
}
type SelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1)
ropeType := uint32(2)
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
// todo: this should be 1.0/math.Sqrt(float64(headDim)) for 27B models
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
cache.Put(ctx, k, v)
k, v, mask := cache.Get(ctx)
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := k.Mulmat(ctx, q)
// logit softcap
kq = kq.Scale(ctx, 1.0/float64(opts.attnLogitSoftcap))
kq = kq.Tanh(ctx)
kq = kq.Scale(ctx, float64(opts.attnLogitSoftcap))
kq = kq.Add(ctx, mask)
kq = kq.Softmax(ctx)
kqv := v.Mulmat(ctx, kq)
kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
return sa.Output.Forward(ctx, kqv)
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, nil, uint32(m.Options.attnKeyLen), uint32(2), m.Options.ropeBase, m.Options.ropeScale), nil
}
type MLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
SelfAttention *SelfAttention
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *MLP
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
}
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
return hiddenState.Add(ctx, residual)
}
func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil {
return nil, err
}
positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions))
if err != nil {
return nil, err
}
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
for i, layer := range m.Layers {
cacheType := i % 2
m.Cache.SetLayer(i)
wc := m.Cache.(*kvcache.WrapperCache)
wc.SetLayerType(cacheType)
hiddenState = layer.Forward(ctx, hiddenState, positions, m.Cache, m.Options)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
hiddenState = m.Output.Forward(ctx, hiddenState)
// final logit softcap
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap))
hiddenState = hiddenState.Tanh(ctx)
hiddenState = hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap))
outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}
return hiddenState.Rows(ctx, outputs), nil
}
func init() {
model.Register("gemma2", New)
}

View File

@@ -67,15 +67,14 @@ type SelfAttention struct {
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads
ropeType := uint32(0)
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
q = q.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
k = k.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@@ -100,7 +99,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, m.Options.RopeFactors, m.Options.ropeDim, uint32(0), m.Options.ropeBase, m.Options.ropeScale), nil
return key.RoPE(ctx, shift, m.Options.RopeFactors, m.Options.ropeDim, m.Options.ropeBase, m.Options.ropeScale), nil
}
type MLP struct {

View File

@@ -19,15 +19,14 @@ type TextSelfAttention struct {
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads
ropeType := uint32(0)
query := sa.Query.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
query = query.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
query = query.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
key := sa.Key.Forward(ctx, hiddenState)
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
key = key.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
key = key.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@@ -53,7 +52,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// This will only get called for layers in the cache, which are just the self attention layers
return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, uint32(0), m.ropeBase, m.ropeScale), nil
return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
}
type TextMLP struct {

View File

@@ -1,7 +1,6 @@
package models
import (
_ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/llama"
_ "github.com/ollama/ollama/model/models/mllama"
)

View File

@@ -18,15 +18,6 @@ const (
SpecialEOS
)
const (
TOKEN_TYPE_NORMAL = iota + 1
TOKEN_TYPE_UNKNOWN
TOKEN_TYPE_CONTROL
TOKEN_TYPE_USER_DEFINED
TOKEN_TYPE_UNUSED
TOKEN_TYPE_BYTE
)
type TextProcessor interface {
Encode(string) ([]int32, error)
Decode([]int32) (string, error)
@@ -36,7 +27,7 @@ type TextProcessor interface {
type Vocabulary struct {
Values []string
Types []uint32
Scores []float32
Scores []uint32
Merges []string
BOS, EOS int32
@@ -84,7 +75,7 @@ func (v *Vocabulary) Decode(id int32) string {
func (v *Vocabulary) SpecialVocabulary() []string {
v.specialOnce.Do(func() {
for i := range v.Values {
if v.Types[i] == TOKEN_TYPE_CONTROL {
if v.Types[i] == 3 {
v.special = append(v.special, v.Values[i])
}
}

View File

@@ -1,220 +0,0 @@
package model
import (
"iter"
"log/slog"
"strings"
"github.com/dlclark/regexp2"
queue "github.com/emirpasic/gods/queues/priorityqueue"
)
const spmWhitespaceSep = "▁"
func replaceWhitespaceBySeperator(s string) string {
return strings.ReplaceAll(s, " ", spmWhitespaceSep)
}
type SentencePieceModel struct {
maxTokenLen int
pre *regexp2.Regexp
vocab *Vocabulary
}
func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
slog.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:3], "scores", vocab.Scores[:3], "types", vocab.Types[:3])
counter := map[int]int{}
var maxTokenLen int
for cnt := range vocab.Types {
switch vocab.Types[cnt] {
case TOKEN_TYPE_NORMAL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_UNUSED:
maxTokenLen = max(maxTokenLen, len(vocab.Values[cnt]))
fallthrough
default:
counter[int(vocab.Types[cnt])] += 1
}
}
slog.Debug("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
"max token len", maxTokenLen)
return SentencePieceModel{
maxTokenLen: maxTokenLen,
pre: regexp2.MustCompile(pre, regexp2.Unicode|regexp2.RE2),
vocab: vocab,
}
}
func (spm SentencePieceModel) Is(id int32, special Special) bool {
return spm.vocab.Is(id, special)
}
func (spm *SentencePieceModel) split(s string) iter.Seq[string] {
return func(yield func(string) bool) {
for m, _ := spm.pre.FindStringMatch(s); m != nil; m, _ = spm.pre.FindNextMatch(m) {
if !yield(m.String()) {
break
}
}
}
}
func (spm SentencePieceModel) Encode(s string) ([]int32, error) {
fragments := []fragment{{value: s}}
for _, special := range spm.vocab.SpecialVocabulary() {
// TODO: process special tokens concurrently
id := spm.vocab.Encode(special)
for i := 0; i < len(fragments); i++ {
frag := fragments[i]
if len(frag.ids) > 0 {
continue
}
var middle []fragment
switch i := strings.Index(frag.value, special); {
case i < 0:
middle = append(middle, frag)
case i > 0:
middle = append(middle, fragment{value: frag.value[:i]})
fallthrough
default:
middle = append(middle, fragment{value: special, ids: []int32{id}})
if rest := frag.value[i+len(special):]; rest != "" {
middle = append(middle, fragment{value: rest})
}
}
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
}
}
slog.Debug("fragments", "frags", fragments)
var ids []int32
for _, frag := range fragments {
if len(frag.ids) > 0 {
ids = append(ids, frag.ids...)
continue
}
for split := range spm.split(frag.value) {
split = replaceWhitespaceBySeperator(split)
var sb strings.Builder
sb.Write([]byte(split))
if id := spm.vocab.Encode(sb.String()); id >= 0 {
ids = append(ids, id)
continue
}
runes := []rune(sb.String())
pq := queue.NewWith(func(a, b any) int {
priA := a.(*candidate)
priB := b.(*candidate)
if priA.score > priB.score || (priA.score == priB.score && priA.a < priB.a) {
return 1
}
return -1
})
merges := make([]merge, len(runes))
for r := range runes {
merges[r] = merge{
p: r - 1,
n: r + 1,
runes: []rune{runes[r]},
}
}
pairwise := func(a, b int) *candidate {
if a < 0 || b >= len(runes) {
return nil
}
left, right := string(merges[a].runes), string(merges[b].runes)
if id := spm.vocab.Encode(left + right); id >= 0 {
return &candidate{
a: a,
b: b,
length: len(left + " " + right),
score: spm.vocab.Scores[id],
}
}
return nil
}
for i := range len(runes) - 1 {
if pair := pairwise(i, i+1); pair != nil {
pq.Enqueue(pair)
}
}
pqv := pq.Values()
for _, v := range pqv {
e := v.(*candidate)
slog.Debug("candidate", "candidate", e)
}
for !pq.Empty() {
v, _ := pq.Dequeue()
pair := v.(*candidate)
left, right := merges[pair.a], merges[pair.b]
if len(left.runes) == 0 || len(right.runes) == 0 {
continue
}
merges[pair.a].runes = append(left.runes, right.runes...)
merges[pair.b].runes = nil
merges[pair.a].n = right.n
if right.n < len(merges) {
merges[right.n].p = pair.a
}
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
pq.Enqueue(pair)
}
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
pq.Enqueue(pair)
}
}
slog.Debug("merges", "merges", merges)
for _, merge := range merges {
if len(merge.runes) > 0 {
if id := spm.vocab.Encode(string(merge.runes)); id >= 0 {
ids = append(ids, id)
} else {
slog.Debug("missing token", "token", string(merge.runes))
}
}
}
}
}
slog.Debug("encoded", "ids", ids)
return ids, nil
}
type candidate struct {
a, b int
score float32
length int
}
func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
var sb strings.Builder
for _, id := range ids {
data := spm.vocab.Decode(id)
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
if _, err := sb.WriteString(data); err != nil {
return "", err
}
}
slog.Debug("decoded", "ids", ids, "text", sb.String())
return sb.String(), nil
}

View File

@@ -1,6 +1,7 @@
package progress
import (
"bufio"
"fmt"
"io"
"sync"
@@ -13,7 +14,8 @@ type State interface {
type Progress struct {
mu sync.Mutex
w io.Writer
// buffer output to minimize flickering on all terminals
w *bufio.Writer
pos int
@@ -22,7 +24,7 @@ type Progress struct {
}
func NewProgress(w io.Writer) *Progress {
p := &Progress{w: w}
p := &Progress{w: bufio.NewWriter(w)}
go p.start()
return p
}
@@ -47,26 +49,29 @@ func (p *Progress) stop() bool {
func (p *Progress) Stop() bool {
stopped := p.stop()
if stopped {
fmt.Fprint(p.w, "\n")
fmt.Fprintln(p.w)
}
// show cursor
fmt.Fprint(p.w, "\033[?25h")
p.w.Flush()
return stopped
}
func (p *Progress) StopAndClear() bool {
fmt.Fprint(p.w, "\033[?25l")
defer fmt.Fprint(p.w, "\033[?25h")
stopped := p.stop()
if stopped {
// clear all progress lines
for i := range p.pos {
if i > 0 {
fmt.Fprint(p.w, "\033[A")
}
fmt.Fprint(p.w, "\033[2K\033[1G")
for range p.pos - 1 {
fmt.Fprint(p.w, "\033[A")
}
fmt.Fprint(p.w, "\033[2K", "\033[1G")
}
// show cursor
fmt.Fprint(p.w, "\033[?25h")
p.w.Flush()
return stopped
}
@@ -81,30 +86,31 @@ func (p *Progress) render() {
p.mu.Lock()
defer p.mu.Unlock()
fmt.Fprint(p.w, "\033[?25l")
defer fmt.Fprint(p.w, "\033[?25h")
fmt.Fprint(p.w, "\033[?2026h")
defer fmt.Fprint(p.w, "\033[?2026l")
// clear already rendered progress lines
for i := range p.pos {
if i > 0 {
fmt.Fprint(p.w, "\033[A")
}
fmt.Fprint(p.w, "\033[2K\033[1G")
for range p.pos - 1 {
fmt.Fprint(p.w, "\033[A")
}
fmt.Fprint(p.w, "\033[1G")
// render progress lines
for i, state := range p.states {
fmt.Fprint(p.w, state.String())
fmt.Fprint(p.w, state.String(), "\033[K")
if i < len(p.states)-1 {
fmt.Fprint(p.w, "\n")
}
}
p.pos = len(p.states)
p.w.Flush()
}
func (p *Progress) start() {
p.ticker = time.NewTicker(100 * time.Millisecond)
// hide cursor
fmt.Fprint(p.w, "\033[?25l")
for range p.ticker.C {
p.render()
}