Compare commits

...

2 Commits

Author SHA1 Message Date
jmorganca
006c4bed4c update default repeat penalty 2026-03-01 18:17:33 -08:00
jmorganca
0d9e6bce00 runner: add repeat based sampling to ollama runner 2026-03-01 14:40:04 -08:00
8 changed files with 184 additions and 15 deletions

View File

@@ -1077,7 +1077,7 @@ func DefaultOptions() Options {
TopP: 0.9,
TypicalP: 1.0,
RepeatLastN: 64,
RepeatPenalty: 1.1,
RepeatPenalty: 1.0,
PresencePenalty: 0.0,
FrequencyPenalty: 0.0,
Seed: -1,

View File

@@ -152,7 +152,7 @@ PARAMETER <parameter> <parametervalue>
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.0) | float | repeat_penalty 1.0 |
| temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | float | temperature 0.7 |
| seed | Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. (Default: 0) | int | seed 42 |
| stop | Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return. Multiple stop patterns may be set by specifying multiple separate `stop` parameters in a modelfile. | string | stop "AI assistant:" |

View File

@@ -562,6 +562,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
if errors.As(err, &reprocess) {
// Prepend these inputs to the sequence's inputs queue for reprocessing
seq.inputs = append(reprocess.Inputs, seq.inputs...)
seq.sampler.Reset()
// Skip this sequence but continue processing the rest
nextBatch.seqs[seqIdx] = nil // clear this sequence for this batch
err = nil
@@ -692,6 +693,12 @@ func (s *Server) computeBatch(activeBatch batchState) {
// (unless we take down the whole runner).
if len(seq.pendingInputs) > 0 {
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
for _, inp := range seq.pendingInputs {
if len(inp.Multimodal) != 0 {
continue
}
seq.sampler.Accept(inp.Token)
}
seq.pendingInputs = []*input.Input{}
}
@@ -892,6 +899,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
req.Options.TopK,
req.Options.TopP,
req.Options.MinP,
req.Options.RepeatPenalty,
req.Options.PresencePenalty,
req.Options.FrequencyPenalty,
req.Options.Seed,
grammar,
)
@@ -938,6 +948,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
seq.sampler.Reset()
for _, inp := range seq.cache.Inputs {
if len(inp.Multimodal) != 0 {
continue
}
seq.sampler.Accept(inp.Token)
}
s.seqs[i] = seq
s.cond.Signal()
found = true

View File

@@ -16,24 +16,49 @@ type token struct {
value float32 // The raw logit or probability from the model
}
const DefaultPenaltyLookback = 64
type Sampler struct {
rng *rand.Rand
topK int
topP float32
minP float32
temperature float32
repeat float32
presence float32
frequency float32
history []int32
grammar *GrammarSampler
}
func (s *Sampler) Reset() {
s.history = s.history[:0]
}
func (s *Sampler) Accept(token int32) {
s.history = append(s.history, token)
if len(s.history) > DefaultPenaltyLookback {
copy(s.history, s.history[len(s.history)-DefaultPenaltyLookback:])
s.history = s.history[:DefaultPenaltyLookback]
}
}
func (s *Sampler) Sample(logits []float32) (int32, error) {
if len(logits) == 0 {
return -1, errors.New("sample: no logits provided to sample")
}
counts := tokenCounts(s.history, len(logits))
tokens := make([]token, len(logits))
for i := range logits {
value := logits[i]
if count := counts[int32(i)]; count > 0 {
value = applyPenalty(value, count, s.repeat, s.presence, s.frequency)
}
tokens[i].id = int32(i)
tokens[i].value = logits[i]
tokens[i].value = value
}
t, err := s.sample(tokens)
@@ -55,8 +80,12 @@ func (s *Sampler) Sample(logits []float32) (int32, error) {
// we need to reset them before applying the grammar and
// sampling again
for i := range logits {
value := logits[i]
if count := counts[int32(i)]; count > 0 {
value = applyPenalty(value, count, s.repeat, s.presence, s.frequency)
}
tokens[i].id = int32(i)
tokens[i].value = logits[i]
tokens[i].value = value
}
s.grammar.Apply(tokens)
t, err = s.sample(tokens)
@@ -127,7 +156,7 @@ func (s *Sampler) sample(tokens []token) (token, error) {
}
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *GrammarSampler) Sampler {
func NewSampler(temperature float32, topK int, topP float32, minP float32, repeatPenalty float32, presencePenalty float32, frequencyPenalty float32, seed int, grammar *GrammarSampler) Sampler {
var rng *rand.Rand
if seed != -1 {
// PCG requires two parameters: sequence and stream
@@ -154,12 +183,19 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
minP = 1.0
}
if repeatPenalty <= 0 {
repeatPenalty = 1.0
}
return Sampler{
rng: rng,
topK: topK,
topP: topP,
minP: minP,
temperature: temperature,
repeat: repeatPenalty,
presence: presencePenalty,
frequency: frequencyPenalty,
grammar: grammar,
}
}

View File

@@ -16,7 +16,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
logits[i] = float32(rand.Float64()*10 - 5)
}
sampler := NewSampler(0.8, 0, 0, 0, 42, nil)
sampler := NewSampler(0.8, 0, 0, 0, 1, 0, 0, 42, nil)
b.ResetTimer()
for b.Loop() {
sampler.Sample(logits)
@@ -49,7 +49,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
for _, tc := range configs {
b.Run("Config"+tc.name, func(b *testing.B) {
sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed, nil)
sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, 1, 0, 0, tc.seed, nil)
sampler.Sample(logits)
b.ResetTimer()
@@ -62,7 +62,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
// Test with combined transforms separately - topK influences performance greatly
b.Run("TransformCombined", func(b *testing.B) {
sampler := NewSampler(0.8, 50, 0.9, 0.05, 42, nil)
sampler := NewSampler(0.8, 50, 0.9, 0.05, 1, 0, 0, 42, nil)
b.ResetTimer()
for b.Loop() {
@@ -81,7 +81,7 @@ func BenchmarkGreedySampler(b *testing.B) {
logits[i] = float32(rand.Float64()*10 - 5)
}
sampler := NewSampler(0, -1, 0, 0, -1, nil)
sampler := NewSampler(0, -1, 0, 0, 1, 0, 0, -1, nil)
b.ResetTimer()
for b.Loop() {

View File

@@ -13,7 +13,7 @@ import (
func TestWeighted(t *testing.T) {
logits := []float32{-10, 3, -10, -10}
sampler := NewSampler(0, 0, 0, 0, 0, nil)
sampler := NewSampler(0, 0, 0, 0, 1, 0, 0, 0, nil)
got, err := sampler.Sample(logits)
if err != nil {
t.Error(err)
@@ -25,7 +25,7 @@ func TestWeighted(t *testing.T) {
}
logits = []float32{-100, -10, 0, 10}
sampler = NewSampler(0, 0, 0, 0, 0, nil)
sampler = NewSampler(0, 0, 0, 0, 1, 0, 0, 0, nil)
got, err = sampler.Sample(logits)
if err != nil {
t.Error(err)
@@ -39,7 +39,7 @@ func TestWeighted(t *testing.T) {
// Test very high p
logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1}
// Use extremely small topP to filter out all tokens
sampler = NewSampler(1.0, 0, 1e-10, 0, 0, nil)
sampler = NewSampler(1.0, 0, 1e-10, 0, 1, 0, 0, 0, nil)
got, err = sampler.Sample(logits)
if err != nil {
t.Error(err)
@@ -52,7 +52,7 @@ func TestWeighted(t *testing.T) {
}
logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())}
sampler = NewSampler(1, 0, 0.95, 0.05, 0, nil)
sampler = NewSampler(1, 0, 0.95, 0.05, 1, 0, 0, 0, nil)
got, err = sampler.Sample(logits)
if err == nil {
t.Errorf("expected error, got %d", got)
@@ -151,8 +151,8 @@ func TestGrammar(t *testing.T) {
func BenchmarkSample(b *testing.B) {
samplers := map[string]Sampler{
"Greedy": NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
"Weighted": NewSampler(0.5, 10, 0.9, 0.2, -1, nil),
"Greedy": NewSampler(0, 0, 0, 0, 1, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
"Weighted": NewSampler(0.5, 10, 0.9, 0.2, 1, 0, 0, -1, nil),
}
// Generate random logits for benchmarking

View File

@@ -25,6 +25,41 @@ func (h *tokenHeap) Pop() any {
return x
}
func tokenCounts(history []int32, vocabSize int) map[int32]int {
if len(history) == 0 {
return nil
}
start := 0
if len(history) > DefaultPenaltyLookback {
start = len(history) - DefaultPenaltyLookback
}
counts := make(map[int32]int, len(history)-start)
for _, token := range history[start:] {
if token < 0 || int(token) >= vocabSize {
continue
}
counts[token]++
}
return counts
}
func applyPenalty(logit float32, count int, repeatPenalty float32, presencePenalty float32, frequencyPenalty float32) float32 {
if repeatPenalty != 1.0 {
// Preserve ordering for negative logits when applying repeat penalty.
if logit < 0 {
logit *= repeatPenalty
} else {
logit /= repeatPenalty
}
}
logit -= float32(count)*frequencyPenalty + presencePenalty
return logit
}
// temperature applies scaling to the logits
func temperature(ts []token, temp float32) {
// Ensure temperature clipping near 0 to avoid numerical instability

View File

@@ -295,6 +295,86 @@ func TestMinP(t *testing.T) {
}
}
func TestTokenCounts(t *testing.T) {
history := make([]int32, 70)
history[0] = 7
history[69] = 7
counts := tokenCounts(history, 8)
if got := counts[7]; got != 1 {
t.Fatalf("lookback mismatch: got %d want %d", got, 1)
}
}
func TestApplyPenalty(t *testing.T) {
logit := applyPenalty(5.0, 3, 1.0, 1.5, 0.5)
if math.Abs(float64(logit-2.0)) > 1e-6 {
t.Fatalf("unexpected penalty result: got %f want %f", logit, 2.0)
}
logit = applyPenalty(4.0, 1, 2.0, 0, 0)
if math.Abs(float64(logit-2.0)) > 1e-6 {
t.Fatalf("unexpected repeat penalty result for positive logits: got %f want %f", logit, 2.0)
}
logit = applyPenalty(-4.0, 1, 2.0, 0, 0)
if math.Abs(float64(logit-(-8.0))) > 1e-6 {
t.Fatalf("unexpected repeat penalty result for negative logits: got %f want %f", logit, -8.0)
}
}
func TestSamplerPresencePenalty(t *testing.T) {
logits := []float32{0.0, 5.0, 0.0}
baseline := NewSampler(0, 0, 1, 0, 1, 0, 0, -1, nil)
baseline.Accept(1)
got, err := baseline.Sample(logits)
if err != nil {
t.Fatal(err)
}
if got != 1 {
t.Fatalf("unexpected baseline token: got %d want %d", got, 1)
}
presence := NewSampler(0, 0, 1, 0, 1, 6, 0, -1, nil)
presence.Accept(1)
got, err = presence.Sample(logits)
if err != nil {
t.Fatal(err)
}
if got == 1 {
t.Fatalf("presence penalty did not change repeated token selection")
}
}
func TestSamplerFrequencyPenalty(t *testing.T) {
logits := []float32{0.0, 5.0, 4.0}
baseline := NewSampler(0, 0, 1, 0, 1, 0, 0, -1, nil)
baseline.Accept(1)
baseline.Accept(1)
baseline.Accept(1)
got, err := baseline.Sample(logits)
if err != nil {
t.Fatal(err)
}
if got != 1 {
t.Fatalf("unexpected baseline token: got %d want %d", got, 1)
}
frequency := NewSampler(0, 0, 1, 0, 1, 0, 1.0, -1, nil)
frequency.Accept(1)
frequency.Accept(1)
frequency.Accept(1)
got, err = frequency.Sample(logits)
if err != nil {
t.Fatal(err)
}
if got != 2 {
t.Fatalf("frequency penalty did not demote repeated token as expected: got %d want %d", got, 2)
}
}
func BenchmarkTransforms(b *testing.B) {
// Generate random logits
tokens := make([]token, 1<<16)