Compare commits

..

1 Commits

Author SHA1 Message Date
jmorganca
bac80afe6a runner: discard compute results if sequence replaced mid-batch
If a sequence is replaced in s.seqs while a batch is computing, the old logits can be decoded into the new sequence. This change rechecks the sequence pointer after compute and skips decoding for replaced entries, preventing stale results from being applied.
2026-02-04 09:05:26 -08:00
2 changed files with 341 additions and 1 deletions

View File

@@ -740,7 +740,11 @@ func (s *Server) computeBatch(activeBatch batchState) {
if seq == nil || nextBatchTokens[i] == nil {
continue
}
// If the sequence was replaced while this batch was computing, discard results.
if activeBatch.seqs[i] != seq {
logutil.Trace("computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i)
continue
}
seq.lastUpdatedAt = t
if seq.numPredicted == 1 {
seq.processingDuration = seq.lastUpdatedAt.Sub(seq.startedAt)

View File

@@ -0,0 +1,336 @@
package ollamarunner
import (
"sync"
"testing"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type testContext struct {
onCompute func()
}
func (f *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
return &testTensor{dims: shape}
}
func (f *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
return &testTensor{dims: shape}
}
func (f *testContext) FromBytes(dtype ml.DType, _ []byte, shape ...int) ml.Tensor {
return &testTensor{dims: shape}
}
func (f *testContext) FromFloats(_ []float32, shape ...int) ml.Tensor {
return &testTensor{dims: shape}
}
func (f *testContext) FromInts(_ []int32, shape ...int) ml.Tensor {
return &testTensor{dims: shape}
}
func (f *testContext) Arange(_, _, _ float32, _ ml.DType) ml.Tensor {
return &testTensor{}
}
func (f *testContext) Forward(_ ...ml.Tensor) ml.Context {
return f
}
func (f *testContext) SetBatchSize(_ int) {}
func (f *testContext) Compute(_ ...ml.Tensor) {
if f.onCompute != nil {
f.onCompute()
}
}
func (f *testContext) ComputeWithNotify(notify func(), _ ...ml.Tensor) {
if notify != nil {
notify()
}
if f.onCompute != nil {
f.onCompute()
}
}
func (f *testContext) Reserve() {}
func (f *testContext) MaxGraphNodes() int { return 0 }
func (f *testContext) Close() {}
func (f *testContext) Input() ml.Context { return f }
func (f *testContext) Layer(_ int) ml.Context {
return f
}
type testTensor struct {
dims []int
floats []float32
}
func (f *testTensor) Dim(n int) int {
if n < 0 || n >= len(f.dims) {
return 0
}
return f.dims[n]
}
func (f *testTensor) Stride(_ int) int { return 0 }
func (f *testTensor) Shape() []int { return append([]int(nil), f.dims...) }
func (f *testTensor) DType() ml.DType { return ml.DTypeF32 }
func (f *testTensor) Cast(_ ml.Context, _ ml.DType) ml.Tensor {
return f
}
func (f *testTensor) Bytes() []byte { return nil }
func (f *testTensor) Floats() []float32 { return f.floats }
func (f *testTensor) FromBytes(_ []byte) {}
func (f *testTensor) FromFloats(v []float32) { f.floats = v }
func (f *testTensor) FromInts(_ []int32) {}
func (f *testTensor) Add(_ ml.Context, _ ml.Tensor) ml.Tensor {
return f
}
func (f *testTensor) Sub(_ ml.Context, _ ml.Tensor) ml.Tensor {
return f
}
func (f *testTensor) Mul(_ ml.Context, _ ml.Tensor) ml.Tensor {
return f
}
func (f *testTensor) Div(_ ml.Context, _ ml.Tensor) ml.Tensor {
return f
}
func (f *testTensor) Mulmat(_ ml.Context, _ ml.Tensor) ml.Tensor {
return f
}
func (f *testTensor) MulmatFullPrec(_ ml.Context, _ ml.Tensor) ml.Tensor {
return f
}
func (f *testTensor) MulmatID(_ ml.Context, _ ml.Tensor, _ ml.Tensor) ml.Tensor {
return f
}
func (f *testTensor) AddID(_ ml.Context, _ ml.Tensor, _ ml.Tensor) ml.Tensor {
return f
}
func (f *testTensor) Softmax(_ ml.Context) ml.Tensor {
return f
}
func (f *testTensor) L2Norm(_ ml.Context, _ float32) ml.Tensor {
return f
}
func (f *testTensor) LayerNorm(_ ml.Context, _ ml.Tensor, _ ml.Tensor, _ float32) ml.Tensor {
return f
}
func (f *testTensor) RMSNorm(_ ml.Context, _ ml.Tensor, _ float32) ml.Tensor {
return f
}
func (f *testTensor) Scale(_ ml.Context, _ float64) ml.Tensor {
return f
}
func (f *testTensor) SumRows(_ ml.Context) ml.Tensor {
return f
}
func (f *testTensor) AvgPool2D(_ ml.Context, _, _ int, _ float32) ml.Tensor {
return f
}
func (f *testTensor) Conv2D(_ ml.Context, _ ml.Tensor, _, _, _, _, _, _ int) ml.Tensor {
return f
}
func (f *testTensor) Conv3D(_ ml.Context, _ ml.Tensor, _, _, _, _, _, _, _, _, _, _ int) ml.Tensor {
return f
}
func (f *testTensor) SSMConv(_ ml.Context, _ ml.Tensor) ml.Tensor {
return f
}
func (f *testTensor) IM2Col(_ ml.Context, _ ml.Tensor, _, _, _, _, _, _ int) ml.Tensor {
return f
}
func (f *testTensor) Sin(_ ml.Context) ml.Tensor {
return f
}
func (f *testTensor) Cos(_ ml.Context) ml.Tensor {
return f
}
func (f *testTensor) Tanh(_ ml.Context) ml.Tensor {
return f
}
func (f *testTensor) GELU(_ ml.Context, _ ...ml.Tensor) ml.Tensor {
return f
}
func (f *testTensor) GELU_ERF(_ ml.Context) ml.Tensor {
return f
}
func (f *testTensor) QuickGELU(_ ml.Context, _ ...ml.Tensor) ml.Tensor {
return f
}
func (f *testTensor) SILU(_ ml.Context, _ ...ml.Tensor) ml.Tensor {
return f
}
func (f *testTensor) RELU(_ ml.Context, _ ...ml.Tensor) ml.Tensor {
return f
}
func (f *testTensor) Sigmoid(_ ml.Context) ml.Tensor {
return f
}
func (f *testTensor) SILUAlphaLimit(_ ml.Context, _ ml.Tensor, _, _ float32) ml.Tensor {
return f
}
func (f *testTensor) Reshape(_ ml.Context, shape ...int) ml.Tensor {
f.dims = append([]int(nil), shape...)
return f
}
func (f *testTensor) View(_ ml.Context, _ int, shape ...int) ml.Tensor {
f.dims = append([]int(nil), shape...)
return f
}
func (f *testTensor) Permute(_ ml.Context, _ ...int) ml.Tensor {
return f
}
func (f *testTensor) Contiguous(_ ml.Context, _ ...int) ml.Tensor {
return f
}
func (f *testTensor) Pad(_ ml.Context, _ ...int) ml.Tensor {
return f
}
func (f *testTensor) Stack(_ ml.Context, _ int, _ ...ml.Tensor) ml.Tensor {
return f
}
func (f *testTensor) Repeat(_ ml.Context, _, _ int) ml.Tensor {
return f
}
func (f *testTensor) Concat(_ ml.Context, _ ml.Tensor, _ int) ml.Tensor {
return f
}
func (f *testTensor) Rows(_ ml.Context, _ ml.Tensor) ml.Tensor {
return f
}
func (f *testTensor) SetRows(_ ml.Context, _ ml.Tensor, _ ml.Tensor) ml.Tensor {
return f
}
func (f *testTensor) Copy(_ ml.Context, _ ml.Tensor) ml.Tensor {
return f
}
func (f *testTensor) Duplicate(_ ml.Context) ml.Tensor {
return f
}
func (f *testTensor) Slice(_ ml.Context, _, _, _, _ int) ml.Tensor {
return f
}
func (f *testTensor) Chunk(_ ml.Context, _ int, _ int) []ml.Tensor {
return nil
}
func (f *testTensor) ChunkSections(_ ml.Context, _ int, _ ...int) []ml.Tensor {
return nil
}
func (f *testTensor) TopK(_ ml.Context, _ int) ml.Tensor {
return f
}
func (f *testTensor) Argsort(_ ml.Context) ml.Tensor {
return f
}
func (f *testTensor) Mean(_ ml.Context) ml.Tensor {
return f
}
func (f *testTensor) Variance(_ ml.Context) ml.Tensor {
return f
}
func (f *testTensor) Stddev(_ ml.Context) ml.Tensor {
return f
}
func (f *testTensor) Sqr(_ ml.Context) ml.Tensor {
return f
}
func (f *testTensor) Sqrt(_ ml.Context) ml.Tensor {
return f
}
func (f *testTensor) Exp(_ ml.Context) ml.Tensor {
return f
}
func (f *testTensor) Neg(_ ml.Context) ml.Tensor {
return f
}
func (f *testTensor) Clamp(_ ml.Context, _, _ float32) ml.Tensor {
return f
}
func (f *testTensor) Softplus(_ ml.Context) ml.Tensor {
return f
}
func (f *testTensor) CumSum(_ ml.Context) ml.Tensor {
return f
}
func (f *testTensor) Diag(_ ml.Context) ml.Tensor {
return f
}
func (f *testTensor) Tri(_ ml.Context, _ int) ml.Tensor {
return f
}
func (f *testTensor) Fill(_ ml.Context, _ float32) ml.Tensor {
return f
}
func (f *testTensor) Repeat4D(_ ml.Context, _, _, _, _ int) ml.Tensor {
return f
}
func (f *testTensor) SolveTri(_ ml.Context, _ ml.Tensor, _, _, _ bool) ml.Tensor {
return f
}
func (f *testTensor) Interpolate(_ ml.Context, _ [4]int, _ ml.SamplingMode) ml.Tensor {
return f
}
func TestComputeBatchDiscardsReplacedSequence(t *testing.T) {
server := &Server{
seqs: make([]*Sequence, 1),
}
server.cond = sync.NewCond(&server.mu)
server.model = &testTextModel{}
oldSeq := &Sequence{
cache: &InputCacheSlot{Id: 0, Inputs: []*input.Input{}},
}
newSeq := &Sequence{
cache: &InputCacheSlot{Id: 0, Inputs: []*input.Input{}},
responses: make(chan response, 1),
quit: make(chan bool, 1),
}
server.seqs[0] = oldSeq
ctx := &testContext{
onCompute: func() {
server.mu.Lock()
server.seqs[0] = newSeq
server.mu.Unlock()
},
}
active := batchState{
id: 1,
ctx: ctx,
seqs: []*Sequence{oldSeq},
batchInputs: []*input.Input{},
inputsReadyCh: make(chan struct{}, 1),
computeStartedCh: make(chan struct{}, 1),
outputsReadyCh: make(chan struct{}, 1),
batch: input.Batch{
Inputs: &testTensor{},
Outputs: &testTensor{dims: []int{1}},
},
modelOutput: &testTensor{floats: []float32{0.1, 0.9}},
}
active.inputsReadyCh <- struct{}{}
server.computeBatch(active)
if newSeq.numPredicted != 0 {
t.Fatalf("replaced sequence was sampled: numPredicted=%d", newSeq.numPredicted)
}
if got := len(newSeq.responses); got != 0 {
t.Fatalf("unexpected response emitted for replaced sequence: %d", got)
}
}
var _ ml.Context = (*testContext)(nil)
var _ ml.Tensor = (*testTensor)(nil)
type testTextModel struct {
model.Base
}
func (m *testTextModel) Forward(ml.Context, input.Batch) (ml.Tensor, error) {
return &testTensor{dims: []int{1}, floats: []float32{0}}, nil
}
func (m *testTextModel) Encode(string, bool) ([]int32, error) { return nil, nil }
func (m *testTextModel) Decode([]int32) (string, error) { return "X", nil }
func (m *testTextModel) Is(int32, model.Special) bool { return false }
func (m *testTextModel) Vocabulary() *model.Vocabulary { return &model.Vocabulary{} }