mirror of
https://github.com/ollama/ollama.git
synced 2026-01-13 18:09:59 -05:00
Compare commits
7 Commits
jmorganca/
...
pdevine/bf
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c75b428249 | ||
|
|
021dcf089d | ||
|
|
bf24498b1e | ||
|
|
95e271d98f | ||
|
|
364629b8d6 | ||
|
|
108fe02165 | ||
|
|
4561fff36e |
@@ -11,9 +11,10 @@ import (
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/d4l3k/go-bfloat16"
|
||||
"github.com/x448/float16"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/ollama/ollama/types/bfloat16"
|
||||
)
|
||||
|
||||
type safetensorMetadata struct {
|
||||
|
||||
1
go.mod
1
go.mod
@@ -16,7 +16,6 @@ require (
|
||||
|
||||
require (
|
||||
github.com/agnivade/levenshtein v1.1.1
|
||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
|
||||
github.com/dlclark/regexp2 v1.11.4
|
||||
github.com/emirpasic/gods/v2 v2.0.0-alpha
|
||||
github.com/google/go-cmp v0.6.0
|
||||
|
||||
2
go.sum
2
go.sum
@@ -35,8 +35,6 @@ github.com/containerd/console v1.0.3 h1:lIr7SlA5PxZyMV30bDW0MGbiOPXwc63yRuCP0ARu
|
||||
github.com/containerd/console v1.0.3/go.mod h1:7LqA/THxQ86k76b8c/EMSiaJ3h1eZkMkXar0TQ1gf3U=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1 h1:cBzrdJPAFBsgCrDPnZxlp1dF2+k4r1kVpD7+1S1PVjY=
|
||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1/go.mod h1:uw2gLcxEuYUlAd/EXyjc/v55nd3+47YAgWbSXVxPrNI=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
|
||||
@@ -312,17 +312,19 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
||||
return fmt.Errorf("unassigned tensor: %s", t.Name)
|
||||
}
|
||||
|
||||
bts := make([]byte, t.Size())
|
||||
n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), bts)
|
||||
if err != nil {
|
||||
return err
|
||||
bts := C.malloc(C.size_t(t.Size()))
|
||||
if bts == nil {
|
||||
return errors.New("failed to allocate tensor buffer")
|
||||
}
|
||||
defer C.free(bts)
|
||||
|
||||
buf := unsafe.Slice((*byte)(bts), t.Size())
|
||||
n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), buf)
|
||||
if err != nil || n != len(buf) {
|
||||
return errors.New("read failed")
|
||||
}
|
||||
|
||||
if n != len(bts) {
|
||||
return errors.New("short read")
|
||||
}
|
||||
|
||||
C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), 0, C.size_t(t.Size()))
|
||||
C.ggml_backend_tensor_set(tt, bts, 0, C.size_t(t.Size()))
|
||||
return nil
|
||||
})
|
||||
}
|
||||
@@ -371,7 +373,7 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
||||
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])),
|
||||
C.int(len(schedBackends)),
|
||||
C.size_t(maxGraphNodes),
|
||||
true,
|
||||
C._Bool(len(gpus) > 1 && slices.Contains(gpus, output.d)),
|
||||
),
|
||||
input: deviceBufferTypes[input.d],
|
||||
output: deviceBufferTypes[output.d],
|
||||
|
||||
@@ -89,7 +89,7 @@ type InputCacheSlot struct {
|
||||
lastUsed time.Time
|
||||
}
|
||||
|
||||
func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*InputCacheSlot, []input.Input, error) {
|
||||
func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []input.Input, error) {
|
||||
var slot *InputCacheSlot
|
||||
var numPast int32
|
||||
var err error
|
||||
@@ -107,11 +107,6 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*Inp
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// TODO (brucemacd): cachePrompt is always true for completion, but false for embedding, can this be improved?
|
||||
if !cachePrompt {
|
||||
numPast = 0
|
||||
}
|
||||
|
||||
slot.InUse = true
|
||||
slot.lastUsed = time.Now()
|
||||
|
||||
|
||||
@@ -297,3 +297,131 @@ func TestShiftDiscard(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCacheSlot(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cache InputCache
|
||||
prompt []input.Input
|
||||
wantErr bool
|
||||
expectedSlotId int
|
||||
expectedPrompt int // expected length of remaining prompt
|
||||
}{
|
||||
{
|
||||
name: "Basic cache hit - single user",
|
||||
cache: InputCache{
|
||||
multiUserCache: false,
|
||||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
wantErr: false,
|
||||
expectedSlotId: 0,
|
||||
expectedPrompt: 1, // Only token 3 remains
|
||||
},
|
||||
{
|
||||
name: "Basic cache hit - multi user",
|
||||
cache: InputCache{
|
||||
multiUserCache: true,
|
||||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
wantErr: false,
|
||||
expectedSlotId: 0,
|
||||
expectedPrompt: 1, // Only token 3 remains
|
||||
},
|
||||
{
|
||||
name: "Exact match - leave one input",
|
||||
cache: InputCache{
|
||||
multiUserCache: false,
|
||||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}},
|
||||
wantErr: false,
|
||||
expectedSlotId: 0,
|
||||
expectedPrompt: 1, // Should leave 1 token for sampling
|
||||
},
|
||||
{
|
||||
name: "No available slots",
|
||||
cache: InputCache{
|
||||
multiUserCache: false,
|
||||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: true,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
wantErr: true,
|
||||
expectedSlotId: -1,
|
||||
expectedPrompt: -1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt)
|
||||
|
||||
// Check error state
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("LoadCacheSlot() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
return // Skip further checks if we expected an error
|
||||
}
|
||||
|
||||
// Verify slot ID
|
||||
if slot.Id != tt.expectedSlotId {
|
||||
t.Errorf("LoadCacheSlot() slot ID = %v, expected %v", slot.Id, tt.expectedSlotId)
|
||||
}
|
||||
|
||||
// Verify slot is now marked in use
|
||||
if !slot.InUse {
|
||||
t.Errorf("LoadCacheSlot() slot not marked InUse")
|
||||
}
|
||||
|
||||
// Verify remaining prompt length
|
||||
if len(remainingPrompt) != tt.expectedPrompt {
|
||||
t.Errorf("LoadCacheSlot() remaining prompt length = %v, expected %v",
|
||||
len(remainingPrompt), tt.expectedPrompt)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -115,6 +115,9 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
params.numKeep = int32(len(inputs))
|
||||
}
|
||||
|
||||
// TODO(jessegross): We should ensure that we always leave minBatch of context space to shift,
|
||||
// otherwise we might truncate or split the batch against the model's wishes
|
||||
|
||||
// Ensure that at least 1 input can be discarded during shift
|
||||
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
|
||||
|
||||
@@ -366,17 +369,6 @@ func (s *Server) processBatch() error {
|
||||
batchSize := s.batchSize
|
||||
|
||||
for j, inp := range seq.inputs {
|
||||
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx {
|
||||
if len(seq.pendingInputs) == 0 {
|
||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If we are required to put following inputs into a single batch then extend the
|
||||
// batch size. Since we are only extending the size the minimum amount possible, this
|
||||
// will cause a break if we have pending inputs.
|
||||
@@ -389,6 +381,20 @@ func (s *Server) processBatch() error {
|
||||
break
|
||||
}
|
||||
|
||||
// If the sum of our working set (already processed tokens, tokens we added to this
|
||||
// batch, required following tokens) exceeds the context size, then trigger a shift
|
||||
// now so we don't have to do one later when we can't break the batch.
|
||||
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+minBatch) > s.cache.numCtx {
|
||||
if len(seq.pendingInputs) != 0 {
|
||||
break
|
||||
}
|
||||
|
||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
options.Inputs = append(options.Inputs, inp.Token)
|
||||
if inp.Multimodal != nil {
|
||||
options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal})
|
||||
@@ -590,7 +596,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
found := false
|
||||
for i, sq := range s.seqs {
|
||||
if sq == nil {
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||
|
||||
@@ -87,8 +87,9 @@ func (s *Sampler) sample(tokens []token) (token, error) {
|
||||
// topK also sorts the tokens in descending order of logits
|
||||
tokens = topK(tokens, s.topK)
|
||||
|
||||
tokens = temperature(tokens, s.temperature)
|
||||
tokens = softmax(tokens)
|
||||
// scale and normalize the tokens in place
|
||||
temperature(tokens, s.temperature)
|
||||
softmax(tokens)
|
||||
|
||||
tokens = topP(tokens, s.topP)
|
||||
tokens = minP(tokens, s.minP)
|
||||
|
||||
@@ -26,17 +26,16 @@ func (h *tokenHeap) Pop() any {
|
||||
}
|
||||
|
||||
// temperature applies scaling to the logits
|
||||
func temperature(ts []token, temp float32) []token {
|
||||
func temperature(ts []token, temp float32) {
|
||||
// Ensure temperature clipping near 0 to avoid numerical instability
|
||||
temp = max(temp, 1e-7)
|
||||
for i := range ts {
|
||||
ts[i].value = ts[i].value / temp
|
||||
}
|
||||
return ts
|
||||
}
|
||||
|
||||
// softmax applies normalization to the logits
|
||||
func softmax(ts []token) []token {
|
||||
func softmax(ts []token) {
|
||||
// Find max logit for numerical stability
|
||||
maxLogit := float32(math.Inf(-1))
|
||||
for _, t := range ts {
|
||||
@@ -56,8 +55,6 @@ func softmax(ts []token) []token {
|
||||
for i := range ts {
|
||||
ts[i].value /= sum
|
||||
}
|
||||
|
||||
return ts
|
||||
}
|
||||
|
||||
// topK limits the number of tokens considered to the k highest logits
|
||||
@@ -99,6 +96,7 @@ func topK(ts []token, k int) []token {
|
||||
}
|
||||
|
||||
// topP limits tokens to those with cumulative probability p
|
||||
// requires ts to be sorted in descending order of probabilities
|
||||
func topP(ts []token, p float32) []token {
|
||||
if p == 1.0 {
|
||||
return ts
|
||||
@@ -109,37 +107,24 @@ func topP(ts []token, p float32) []token {
|
||||
for i, t := range ts {
|
||||
sum += t.value
|
||||
if sum > float32(p) {
|
||||
ts = ts[:i+1]
|
||||
return ts
|
||||
return ts[:i+1]
|
||||
}
|
||||
}
|
||||
|
||||
return ts
|
||||
}
|
||||
|
||||
// minP limits tokens to those with cumulative probability p
|
||||
// minP filters tokens with probabilities >= p * max_prob
|
||||
// requires ts to be sorted in descending order of probabilities
|
||||
func minP(ts []token, p float32) []token {
|
||||
if p == 1.0 {
|
||||
return ts
|
||||
}
|
||||
maxProb := ts[0].value
|
||||
|
||||
maxProb := float32(math.Inf(-1))
|
||||
for _, token := range ts {
|
||||
if token.value > maxProb {
|
||||
maxProb = token.value
|
||||
threshold := maxProb * p
|
||||
|
||||
for i, t := range ts {
|
||||
if t.value < threshold {
|
||||
return ts[:i]
|
||||
}
|
||||
}
|
||||
|
||||
threshold := maxProb * float32(p)
|
||||
|
||||
// Filter tokens in-place
|
||||
validTokens := ts[:0]
|
||||
for i, token := range ts {
|
||||
if token.value >= threshold {
|
||||
validTokens = append(validTokens, ts[i])
|
||||
}
|
||||
}
|
||||
|
||||
ts = validTokens
|
||||
return ts
|
||||
}
|
||||
|
||||
@@ -34,17 +34,22 @@ func compareLogits(t *testing.T, name string, want []float32, got []token) {
|
||||
|
||||
func TestTemperature(t *testing.T) {
|
||||
input := []float32{1.0, 4.0, -2.0, 0.0}
|
||||
got := temperature(toTokens(input), 0.5)
|
||||
tokens := toTokens(input)
|
||||
temperature(tokens, 0.5)
|
||||
want := []float32{2.0, 8.0, -4.0, 0.0}
|
||||
compareLogits(t, "temperature(0.5)", want, got)
|
||||
compareLogits(t, "temperature(0.5)", want, tokens)
|
||||
|
||||
got = temperature(toTokens(input), 1.0)
|
||||
input = []float32{1.0, 4.0, -2.0, 0.0}
|
||||
tokens = toTokens(input)
|
||||
temperature(tokens, 1.0)
|
||||
want = []float32{1.0, 4.0, -2.0, 0.0}
|
||||
compareLogits(t, "temperature(1)", want, got)
|
||||
compareLogits(t, "temperature(1)", want, tokens)
|
||||
|
||||
got = temperature(toTokens(input), 0.0)
|
||||
input = []float32{1.0, 4.0, -2.0, 0.0}
|
||||
tokens = toTokens(input)
|
||||
temperature(tokens, 0.0)
|
||||
want = []float32{1e7, 4e7, -2e7, 0.0}
|
||||
compareLogits(t, "temperature(0)", want, got)
|
||||
compareLogits(t, "temperature(0)", want, tokens)
|
||||
}
|
||||
|
||||
func TestSoftmax(t *testing.T) {
|
||||
@@ -90,16 +95,17 @@ func TestSoftmax(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := softmax(toTokens(tt.input))
|
||||
tokens := toTokens(tt.input)
|
||||
softmax(tokens)
|
||||
|
||||
if tt.expected != nil {
|
||||
compareLogits(t, tt.name, tt.expected, got)
|
||||
compareLogits(t, tt.name, tt.expected, tokens)
|
||||
return
|
||||
}
|
||||
|
||||
// Check probabilities sum to 1
|
||||
var sum float32
|
||||
for _, token := range got {
|
||||
for _, token := range tokens {
|
||||
sum += token.value
|
||||
if token.value < 0 || token.value > 1 {
|
||||
t.Errorf("probability out of range [0,1]: got %f", token.value)
|
||||
@@ -114,38 +120,44 @@ func TestSoftmax(t *testing.T) {
|
||||
|
||||
func TestTopK(t *testing.T) {
|
||||
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
||||
|
||||
// Test k=5
|
||||
got := topK(toTokens(input), 5)
|
||||
if len(got) != 5 {
|
||||
t.Errorf("topK(5): wrong length: want 5, got %d", len(got))
|
||||
tokens := toTokens(input)
|
||||
tokens = topK(tokens, 5)
|
||||
if len(tokens) != 5 {
|
||||
t.Errorf("topK(5): wrong length: want 5, got %d", len(tokens))
|
||||
}
|
||||
// Should keep highest 3 values in descending order
|
||||
want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154}
|
||||
compareLogits(t, "topK(3)", want, got)
|
||||
compareLogits(t, "topK(3)", want, tokens)
|
||||
|
||||
got = topK(toTokens(input), 20)
|
||||
if len(got) != len(input) {
|
||||
t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(got))
|
||||
tokens = toTokens(input)
|
||||
tokens = topK(tokens, 20)
|
||||
if len(tokens) != len(input) {
|
||||
t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(tokens))
|
||||
}
|
||||
|
||||
// Test k=-1
|
||||
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
||||
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
|
||||
got = topK(toTokens(input), -1)
|
||||
if len(got) != len(input) {
|
||||
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got))
|
||||
tokens = toTokens(input)
|
||||
tokens = topK(tokens, -1)
|
||||
if len(tokens) != len(input) {
|
||||
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(tokens))
|
||||
}
|
||||
compareLogits(t, "topK(-1)", want, got)
|
||||
compareLogits(t, "topK(-1)", want, tokens)
|
||||
|
||||
// Test k=0
|
||||
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
||||
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
|
||||
got = topK(toTokens(input), 0)
|
||||
if len(got) != len(input) {
|
||||
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got))
|
||||
tokens = toTokens(input)
|
||||
tokens = topK(tokens, 0)
|
||||
if len(tokens) != len(input) {
|
||||
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(tokens))
|
||||
}
|
||||
compareLogits(t, "topK(-1)", want, tokens)
|
||||
|
||||
input = []float32{-1e7, -2e7, -3e7, -4e7}
|
||||
tokens = toTokens(input)
|
||||
tokens = topK(tokens, 1)
|
||||
if len(tokens) < 1 {
|
||||
t.Error("topK should keep at least one token")
|
||||
}
|
||||
compareLogits(t, "topK(-1)", want, got)
|
||||
}
|
||||
|
||||
func TestTopP(t *testing.T) {
|
||||
@@ -153,16 +165,25 @@ func TestTopP(t *testing.T) {
|
||||
tokens := toTokens(input)
|
||||
|
||||
// First apply temperature and softmax to get probabilities
|
||||
tokens = softmax(tokens)
|
||||
softmax(tokens)
|
||||
tokens = topK(tokens, 20)
|
||||
|
||||
// Then apply topP
|
||||
got := topP(tokens, 0.95)
|
||||
tokens = topP(tokens, 0.95)
|
||||
|
||||
// Should keep tokens until cumsum > 0.95
|
||||
if len(got) > 3 {
|
||||
t.Errorf("topP(0.95): kept too many tokens: got %d", len(got))
|
||||
t.Logf("got: %v", got)
|
||||
if len(tokens) > 3 {
|
||||
t.Errorf("topP(0.95): kept too many tokens: got %d", len(tokens))
|
||||
t.Logf("got: %v", tokens)
|
||||
}
|
||||
|
||||
// Test edge case - ensure at least one token remains
|
||||
input = []float32{-1e6, -1e6, -1e6} // One dominant token
|
||||
tokens = toTokens(input)
|
||||
softmax(tokens)
|
||||
tokens = topP(tokens, 0.0) // Very small p
|
||||
if len(tokens) < 1 {
|
||||
t.Error("topP should keep at least one token")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -171,14 +192,45 @@ func TestMinP(t *testing.T) {
|
||||
tokens := toTokens(input)
|
||||
|
||||
// First apply temperature and softmax
|
||||
tokens = softmax(tokens)
|
||||
tokens = topK(tokens, 20)
|
||||
softmax(tokens)
|
||||
|
||||
// Then apply minP
|
||||
got := minP(tokens, 0.2)
|
||||
tokens = minP(tokens, 1.0)
|
||||
|
||||
if len(tokens) != 1 {
|
||||
t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(tokens), len(tokens))
|
||||
}
|
||||
|
||||
// Test with normal p value
|
||||
tokens = toTokens(input) // Reset tokens
|
||||
tokens = topK(tokens, 20)
|
||||
softmax(tokens)
|
||||
tokens = minP(tokens, 0.2)
|
||||
|
||||
// Should keep tokens with prob >= 0.2 * max_prob
|
||||
if len(got) > 3 {
|
||||
t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
|
||||
if len(tokens) > 3 {
|
||||
t.Errorf("minP(0.2): kept too many tokens: got %d", len(tokens))
|
||||
t.Logf("got: %v", tokens)
|
||||
}
|
||||
|
||||
// Test with zero p value
|
||||
tokens = toTokens(input) // Reset tokens
|
||||
tokens = topK(tokens, 20)
|
||||
softmax(tokens)
|
||||
tokens = minP(tokens, 0.0)
|
||||
|
||||
// Should keep only the highest probability token
|
||||
if len(tokens) != len(input) {
|
||||
t.Errorf("minP(0.0): should keep only one token, got %d", len(tokens))
|
||||
t.Logf("got: %v", tokens)
|
||||
}
|
||||
|
||||
input = []float32{1e-10, 1e-10, 1e-10}
|
||||
tokens = toTokens(input)
|
||||
softmax(tokens)
|
||||
tokens = minP(tokens, 1.0)
|
||||
if len(tokens) < 1 {
|
||||
t.Error("minP should keep at least one token even with extreme probabilities")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -231,7 +283,7 @@ func BenchmarkTransforms(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
copy(tokensCopy, tokens)
|
||||
topK(tokensCopy, 10)
|
||||
tokens = topK(tokensCopy, 10)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -239,7 +291,7 @@ func BenchmarkTransforms(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
copy(tokensCopy, tokens)
|
||||
topP(tokensCopy, 0.9)
|
||||
tokens = topP(tokensCopy, 0.9)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -247,7 +299,7 @@ func BenchmarkTransforms(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
copy(tokensCopy, tokens)
|
||||
minP(tokensCopy, 0.2)
|
||||
tokens = minP(tokensCopy, 0.2)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -255,7 +307,7 @@ func BenchmarkTransforms(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
copy(tokensCopy, tokens)
|
||||
topK(tokensCopy, 200000)
|
||||
tokens = topK(tokensCopy, 200000)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
21
types/bfloat16/LICENSE
Normal file
21
types/bfloat16/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2021 Tristan Rice
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
57
types/bfloat16/bfloat16.go
Normal file
57
types/bfloat16/bfloat16.go
Normal file
@@ -0,0 +1,57 @@
|
||||
// Vendored code from https://github.com/d4l3k/go-bfloat16
|
||||
// unsafe pointer replaced by "math"
|
||||
package bfloat16
|
||||
|
||||
import "math"
|
||||
|
||||
type BF16 uint16
|
||||
|
||||
func FromBytes(buf []byte) BF16 {
|
||||
return BF16(uint16(buf[0]) + uint16(buf[1])<<8)
|
||||
}
|
||||
|
||||
func ToBytes(b BF16) []byte {
|
||||
return []byte{byte(b & 0xFF), byte(b >> 8)}
|
||||
}
|
||||
|
||||
func Decode(buf []byte) []BF16 {
|
||||
var out []BF16
|
||||
for i := 0; i < len(buf); i += 2 {
|
||||
out = append(out, FromBytes(buf[i:]))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func Encode(f []BF16) []byte {
|
||||
var out []byte
|
||||
for _, a := range f {
|
||||
out = append(out, ToBytes(a)...)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func DecodeFloat32(buf []byte) []float32 {
|
||||
var out []float32
|
||||
for i := 0; i < len(buf); i += 2 {
|
||||
out = append(out, ToFloat32(FromBytes(buf[i:])))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func EncodeFloat32(f []float32) []byte {
|
||||
var out []byte
|
||||
for _, a := range f {
|
||||
out = append(out, ToBytes(FromFloat32(a))...)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func ToFloat32(b BF16) float32 {
|
||||
u32 := uint32(b) << 16
|
||||
return math.Float32frombits(u32)
|
||||
}
|
||||
|
||||
func FromFloat32(f float32) BF16 {
|
||||
u32 := math.Float32bits(f)
|
||||
return BF16(u32 >> 16)
|
||||
}
|
||||
53
types/bfloat16/bfloat16_test.go
Normal file
53
types/bfloat16/bfloat16_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package bfloat16
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func randomBytes(n int) []byte {
|
||||
out := make([]byte, n)
|
||||
if _, err := rand.Read(out); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func TestEncodeDecode(t *testing.T) {
|
||||
b := randomBytes(1024)
|
||||
bf16 := Decode(b)
|
||||
out := Encode(bf16)
|
||||
if !reflect.DeepEqual(b, out) {
|
||||
t.Fatalf("%+v != %+v", b, out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeDecodeFloat32(t *testing.T) {
|
||||
b := randomBytes(1024)
|
||||
bf16 := DecodeFloat32(b)
|
||||
out := EncodeFloat32(bf16)
|
||||
if !reflect.DeepEqual(b, out) {
|
||||
t.Fatalf("%+v != %+v", b, out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBasicFloat32(t *testing.T) {
|
||||
var in float32 = 1.0
|
||||
out := ToFloat32(FromFloat32(in))
|
||||
if !reflect.DeepEqual(in, out) {
|
||||
t.Fatalf("%+v != %+v", in, out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplexFloat32(t *testing.T) {
|
||||
var in float32 = 123456789123456789.123456789
|
||||
var want float32 = 123286039799267328.0
|
||||
out := ToFloat32(FromFloat32(in))
|
||||
if in == out {
|
||||
t.Fatalf("no loss of precision")
|
||||
}
|
||||
if out != want {
|
||||
t.Fatalf("%.16f != %.16f", want, out)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user