Compare commits

...

2 Commits

Author SHA1 Message Date
jmorganca
c330ea33ed qwen3next: handle mixed recurrent batches
Allow mixed token-count batches by tracking per-seq indices

and falling back to per-seq recurrent processing when layouts

differ.

Add per-slot conv/delta state access with checkpoint capture,

relax attention layout handling, and reuse projections in mixed

batches to reduce overhead.
2026-02-05 11:50:00 -08:00
Jesse Gross
c61023f554 ollamarunner: Fix off by one error with numPredict
When numPredict is set, the user will receive one less token
than the requested limit. In addition, the stats will incorrectly
show the number of tokens returned as the limit. In cases where
numPredict is not set, the number of tokens is reported correctly.

This occurs because numPredict is checked when setting up the next
batch but hitting the limit will terminate the current batch as well.
Instead, is is better to check the limit as we actually predict them.
2026-02-04 17:14:24 -08:00
5 changed files with 326 additions and 48 deletions

View File

@@ -144,3 +144,47 @@ func TestUnicodeModelDir(t *testing.T) {
}
ChatTestHelper(ctx, t, req, blueSkyExpected)
}
// TestNumPredict verifies that when num_predict is set, the model generates
// exactly that many tokens. It uses logprobs to count the actual tokens output.
func TestNumPredict(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, "qwen3:0.6b"); err != nil {
t.Fatalf("failed to pull model: %v", err)
}
req := api.GenerateRequest{
Model: "qwen3:0.6b",
Prompt: "Write a long story.",
Stream: &stream,
Logprobs: true,
Options: map[string]any{
"num_predict": 10,
"temperature": 0,
"seed": 123,
},
}
logprobCount := 0
var finalResponse api.GenerateResponse
err := client.Generate(ctx, &req, func(resp api.GenerateResponse) error {
logprobCount += len(resp.Logprobs)
if resp.Done {
finalResponse = resp
}
return nil
})
if err != nil {
t.Fatalf("generate failed: %v", err)
}
if logprobCount != 10 {
t.Errorf("expected 10 tokens (logprobs), got %d (EvalCount=%d, DoneReason=%s)",
logprobCount, finalResponse.EvalCount, finalResponse.DoneReason)
}
}

View File

@@ -39,12 +39,15 @@ func (sa *FullAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tens
if nSeqs > 0 {
// 3D tensor: [hiddenDim, seqTokens, nSeqs]
if batchSize != seqTokens || nSeqs != seqs {
return nil, ErrUnsupportedBatchLayout
// Fallback: treat as flat batch if layout doesn't match.
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, batchSize*nSeqs)
batchSize = batchSize * nSeqs
} else {
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, seqTokens*seqs)
batchSize = seqTokens * seqs
}
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, seqTokens*seqs)
batchSize = seqTokens * seqs
} else if batchSize != seqTokens*seqs {
return nil, ErrUnsupportedBatchLayout
// Layout mismatch; proceed with flat batch.
}
}
}

View File

@@ -64,6 +64,8 @@ type HybridCache struct {
curSlots []int
curSlotsInput ml.Tensor
curSeqTokens int
// token indices per sequence in batch order
curSeqTokenIdxs [][]int32
// track if EnsureWritable has been called for this forward pass
writableEnsured bool
@@ -168,19 +170,44 @@ func (c *HybridCache) StartForward(ctx ml.Context, batch input.Batch, reserve bo
}
if len(c.curSeqs) == 0 {
c.curSeqTokenIdxs = c.curSeqTokenIdxs[:0]
return nil
}
if cap(c.curSeqTokenIdxs) < len(c.curSeqs) {
c.curSeqTokenIdxs = make([][]int32, len(c.curSeqs))
} else {
c.curSeqTokenIdxs = c.curSeqTokenIdxs[:len(c.curSeqs)]
}
for i := range c.curSeqTokenIdxs {
c.curSeqTokenIdxs[i] = c.curSeqTokenIdxs[i][:0]
}
seqIndex := make(map[int]int, len(c.curSeqs))
for i, s := range c.curSeqs {
seqIndex[s] = i
}
for i, s := range batch.Sequences {
c.curSeqTokenIdxs[seqIndex[s]] = append(c.curSeqTokenIdxs[seqIndex[s]], int32(i))
}
nTokens := len(batch.Sequences)
nSeqs := len(c.curSeqs)
want := nTokens / nSeqs
uniform := true
for _, s := range c.curSeqs {
if seqCounts[s] != want {
return kvcache.ErrNotSupported
uniform = false
break
}
}
c.curSeqTokens = want
if uniform {
c.curSeqTokens = want
} else {
// Mixed batch: recurrent layers will process sequences independently.
c.curSeqTokens = 0
}
// When reserving memory for estimation, use fake slot assignments
if reserve {
@@ -585,7 +612,101 @@ func (c *HybridCache) UpdateDeltaState(ctx ml.Context, layer int, newState ml.Te
c.captureDeltaCheckpoint(ctx, layer, srcF32)
}
// IsSupportedForBatch returns true if the current batch layout supports recurrent layers.
// convStateForSlot returns the conv state for a single slot as [convDim, convChannels, 1].
func (c *HybridCache) convStateForSlot(ctx ml.Context, layer int, slot int) (ml.Tensor, error) {
c.ensureWritableOnce(ctx)
if c.writableError != nil {
return nil, c.writableError
}
buf := c.convBuffer(ctx, layer)
slotIdx := ctx.Input().FromInts([]int32{int32(slot)}, 1)
cur := buf.Rows(ctx, slotIdx)
return cur.Reshape(ctx, c.convDim, c.convChannels, 1), nil
}
// updateConvStateForSlot writes a new conv state for a single slot.
func (c *HybridCache) updateConvStateForSlot(ctx ml.Context, layer int, slot int, seqIndex int, newState ml.Tensor) {
buf := c.convBuffer(ctx, layer)
src := newState.Reshape(ctx, c.convDim*c.convChannels, 1)
srcF32 := src.Cast(ctx, ml.DTypeF32)
slotIdx := ctx.Input().FromInts([]int32{int32(slot)}, 1)
ctx.Forward(buf.SetRows(ctx, srcF32, slotIdx))
c.captureConvCheckpointForSeq(ctx, layer, seqIndex, srcF32)
}
// deltaStateForSlot returns the delta state for a single slot as [headVDim, headVDim*numVHeads, 1].
func (c *HybridCache) deltaStateForSlot(ctx ml.Context, layer int, slot int, headVDim, numVHeads int) (ml.Tensor, error) {
c.ensureWritableOnce(ctx)
if c.writableError != nil {
return nil, c.writableError
}
buf := c.deltaBuffer(ctx, layer)
slotIdx := ctx.Input().FromInts([]int32{int32(slot)}, 1)
cur := buf.Rows(ctx, slotIdx)
return cur.Reshape(ctx, headVDim, headVDim*numVHeads, 1), nil
}
// updateDeltaStateForSlot writes a new delta state for a single slot.
func (c *HybridCache) updateDeltaStateForSlot(ctx ml.Context, layer int, slot int, seqIndex int, newState ml.Tensor) {
buf := c.deltaBuffer(ctx, layer)
src := newState.Reshape(ctx, c.deltaStateSize, 1)
srcF32 := src.Cast(ctx, ml.DTypeF32)
slotIdx := ctx.Input().FromInts([]int32{int32(slot)}, 1)
ctx.Forward(buf.SetRows(ctx, srcF32, slotIdx))
c.captureDeltaCheckpointForSeq(ctx, layer, seqIndex, srcF32)
}
func (c *HybridCache) captureConvCheckpointForSeq(ctx ml.Context, layer int, seqIndex int, src ml.Tensor) {
if c.checkpointCount == 0 {
return
}
if c.reserveCheckpoints {
c.reserveCheckpointConv(layer)
return
}
if seqIndex < 0 || seqIndex >= len(c.curCheckpointPos) {
return
}
pos := c.curCheckpointPos[seqIndex]
if pos < 0 {
return
}
slot := c.curSlots[seqIndex]
idx := c.checkpointIndexForSlot(slot, pos)
if idx < 0 {
return
}
entry := &c.checkpoints[slot].entries[idx]
dst := c.ensureCheckpointConv(layer, entry)
ctx.Forward(src.Copy(ctx, dst))
}
func (c *HybridCache) captureDeltaCheckpointForSeq(ctx ml.Context, layer int, seqIndex int, src ml.Tensor) {
if c.checkpointCount == 0 {
return
}
if c.reserveCheckpoints {
c.reserveCheckpointDelta(layer)
return
}
if seqIndex < 0 || seqIndex >= len(c.curCheckpointPos) {
return
}
pos := c.curCheckpointPos[seqIndex]
if pos < 0 {
return
}
slot := c.curSlots[seqIndex]
idx := c.checkpointIndexForSlot(slot, pos)
if idx < 0 {
return
}
entry := &c.checkpoints[slot].entries[idx]
dst := c.ensureCheckpointDelta(layer, entry)
ctx.Forward(src.Copy(ctx, dst))
}
// IsSupportedForBatch returns true if the current batch layout supports grid-style recurrent processing.
func (c *HybridCache) IsSupportedForBatch() bool {
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
}

View File

@@ -48,6 +48,13 @@ type GatedDeltaNet struct {
Layer int
}
type stateAccessors struct {
convState func() (ml.Tensor, error)
updateConv func(ml.Tensor)
deltaState func() (ml.Tensor, error)
updateDelta func(ml.Tensor)
}
// createMasks builds the constant mask tensors (called once, reused for all chunks)
func createMasks(ctx ml.Context) *Masks {
ones := ctx.Input().Zeros(ml.DTypeF32, chunkSize, chunkSize)
@@ -68,7 +75,6 @@ func createMasks(ctx ml.Context) *Masks {
}
func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
layer := gdn.Layer
nSeqTokens := hiddenStates.Dim(1)
nSeqs := hiddenStates.Dim(2)
if cache != nil && cache.IsSupportedForBatch() {
@@ -77,34 +83,140 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
if seqTokens > 0 && seqs > 0 {
if nSeqs > 1 {
if nSeqTokens != seqTokens || nSeqs != seqs {
return nil, ErrUnsupportedBatchLayout
return gdn.forwardMixed(ctx, hiddenStates, cache, opts)
}
} else {
if nSeqTokens != seqTokens*seqs {
return nil, ErrUnsupportedBatchLayout
return gdn.forwardMixed(ctx, hiddenStates, cache, opts)
}
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), seqTokens, seqs)
nSeqTokens = seqTokens
nSeqs = seqs
}
}
numVHeads := opts.ssmDtRank
headVDim := opts.ssmDInner / numVHeads
layer := gdn.Layer
access := stateAccessors{
convState: func() (ml.Tensor, error) {
return cache.ConvState(ctx, layer)
},
updateConv: func(newState ml.Tensor) {
cache.UpdateConvState(ctx, layer, newState)
},
deltaState: func() (ml.Tensor, error) {
return cache.DeltaState(ctx, layer, headVDim, numVHeads)
},
updateDelta: func(newState ml.Tensor) {
cache.UpdateDeltaState(ctx, layer, newState)
},
}
return gdn.forwardWithAccessors(ctx, hiddenStates, opts, access)
}
if cache == nil {
return nil, ErrUnsupportedBatchLayout
}
return gdn.forwardMixed(ctx, hiddenStates, cache, opts)
}
func (gdn *GatedDeltaNet) forwardMixed(ctx ml.Context, hiddenStates ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
if hiddenStates.Dim(2) > 0 {
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), hiddenStates.Dim(1)*hiddenStates.Dim(2))
}
if len(cache.curSeqs) == 0 {
return hiddenStates, nil
}
// Ensure any shared slots are detached once for this forward pass.
cache.ensureWritableOnce(ctx)
layer := gdn.Layer
numVHeads := opts.ssmDtRank
headVDim := opts.ssmDInner / numVHeads
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
return nil, errors.New("qwen3next: missing attn_qkv/attn_gate projections (legacy ssm_in is not supported)")
}
// Precompute projections for the full batch and slice per sequence.
mixedBAFull := gdn.SSMBetaAlpha.Forward(ctx, hiddenStates)
qkvMixedFull := gdn.SSMQKV.Forward(ctx, hiddenStates)
zFull := gdn.SSMQKVGate.Forward(ctx, hiddenStates)
out := hiddenStates
for seqIndex := range cache.curSeqs {
idxs := cache.curSeqTokenIdxs[seqIndex]
if len(idxs) == 0 {
continue
}
idxTensor := ctx.Input().FromInts(idxs, len(idxs))
mixedBA := mixedBAFull.Rows(ctx, idxTensor)
qkvMixed := qkvMixedFull.Rows(ctx, idxTensor)
z := zFull.Rows(ctx, idxTensor)
slot := cache.curSlots[seqIndex]
access := stateAccessors{
convState: func() (ml.Tensor, error) {
return cache.convStateForSlot(ctx, layer, slot)
},
updateConv: func(newState ml.Tensor) {
cache.updateConvStateForSlot(ctx, layer, slot, seqIndex, newState)
},
deltaState: func() (ml.Tensor, error) {
return cache.deltaStateForSlot(ctx, layer, slot, headVDim, numVHeads)
},
updateDelta: func(newState ml.Tensor) {
cache.updateDeltaStateForSlot(ctx, layer, slot, seqIndex, newState)
},
}
seqOut, err := gdn.forwardProjected(ctx, len(idxs), 1, mixedBA, qkvMixed, z, opts, access)
if err != nil {
return nil, err
}
out = out.SetRows(ctx, seqOut, idxTensor)
}
return out, nil
}
func (gdn *GatedDeltaNet) forwardWithAccessors(ctx ml.Context, hiddenStates ml.Tensor, opts *Options, access stateAccessors) (ml.Tensor, error) {
nSeqTokens := hiddenStates.Dim(1)
nSeqs := hiddenStates.Dim(2)
mixedBA := gdn.SSMBetaAlpha.Forward(ctx, hiddenStates)
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
return nil, errors.New("qwen3next: missing attn_qkv/attn_gate projections (legacy ssm_in is not supported)")
}
// Optimized path: pre-split QKV and gate
qkvMixed := gdn.SSMQKV.Forward(ctx, hiddenStates)
z := gdn.SSMQKVGate.Forward(ctx, hiddenStates)
return gdn.forwardProjected(ctx, nSeqTokens, nSeqs, mixedBA, qkvMixed, z, opts, access)
}
func (gdn *GatedDeltaNet) forwardProjected(
ctx ml.Context,
nSeqTokens, nSeqs int,
mixedBA, qkvMixed, z ml.Tensor,
opts *Options,
access stateAccessors,
) (ml.Tensor, error) {
layer := gdn.Layer
headKDim := opts.ssmDState
numKHeads := opts.ssmNGroup
numVHeads := opts.ssmDtRank
headVDim := opts.ssmDInner / numVHeads
convKernelSize := opts.convKernelSize
mixedBA := gdn.SSMBetaAlpha.Forward(ctx, hiddenStates)
qkvDim := headKDim*numKHeads*2 + headVDim*numVHeads
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
return nil, errors.New("qwen3next: missing attn_qkv/attn_gate projections (legacy ssm_in is not supported)")
}
// Optimized path: pre-split QKV and gate
qkvMixed := gdn.SSMQKV.Forward(ctx, hiddenStates).Reshape(ctx, qkvDim, nSeqTokens, nSeqs)
z := gdn.SSMQKVGate.Forward(ctx, hiddenStates)
qkvMixed = qkvMixed.Reshape(ctx, qkvDim, nSeqTokens, nSeqs)
baNewDim := 2 * numVHeads / numKHeads
mixedBAReshaped := mixedBA.Reshape(ctx, baNewDim, numKHeads, nSeqTokens, nSeqs)
@@ -127,7 +239,7 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
qkvMixed = qkvMixed.Permute(ctx, 1, 0, 2, 3)
// Get conv state from cache
convStates, err := cache.ConvState(ctx, layer)
convStates, err := access.convState()
if err != nil {
// Log this - if it happens, short-term context will be lost
slog.Warn("qwen3next: failed to get conv state, using zeros", "layer", layer, "error", err)
@@ -142,7 +254,7 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
// Save new conv state (last convKernelSize-1 tokens)
lastConvStates := convInput.Slice(ctx, 0, nSeqTokens, nSeqTokens+convKernelSize-1, 1)
cache.UpdateConvState(ctx, layer, lastConvStates)
access.updateConv(lastConvStates)
// Apply SSM convolution (kernel must be F32 for Metal)
convOutput := convInput.SSMConv(ctx, gdn.SSMConv1D.Weight)
@@ -162,7 +274,7 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
vConv = vConv.Contiguous(ctx, headVDim, numVHeads, nSeqTokens, nSeqs)
// Get delta state from cache
state, err := cache.DeltaState(ctx, layer, headVDim, numVHeads)
state, err := access.deltaState()
if err != nil {
// Log this - if it happens frequently, context will degrade
slog.Warn("qwen3next: failed to get delta state, using zeros", "layer", layer, "error", err)
@@ -185,14 +297,19 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
}
// Choose computation mode based on sequence length
var attnOut ml.Tensor
var (
attnOut ml.Tensor
newState ml.Tensor
)
if nSeqTokens == 1 {
attnOut = gdn.deltaNetAutoregressive(ctx, qConv, kConv, vConv, gate, beta, state, opts, layer, cache)
attnOut, newState = gdn.deltaNetAutoregressive(ctx, qConv, kConv, vConv, gate, beta, state, opts)
} else {
// Use pre-computed masks from opts (created once in Model.Forward)
attnOut = gdn.deltaNetChunked(ctx, qConv, kConv, vConv, gate, beta, state, opts.masks, opts, layer, cache)
attnOut, newState = gdn.deltaNetChunked(ctx, qConv, kConv, vConv, gate, beta, state, opts.masks, opts)
}
access.updateDelta(newState)
// Apply gated normalization
attnOut2D := attnOut.Contiguous(ctx, headVDim, numVHeads*nSeqTokens*nSeqs)
z2D := z.Contiguous(ctx, headVDim, numVHeads*nSeqTokens*nSeqs)
@@ -215,9 +332,7 @@ func (gdn *GatedDeltaNet) deltaNetAutoregressive(
ctx ml.Context,
q, k, v, gate, beta, state ml.Tensor,
opts *Options,
layer int,
cache *HybridCache,
) ml.Tensor {
) (ml.Tensor, ml.Tensor) {
numVHeads := v.Dim(1)
headVDim := v.Dim(0)
nSeqs := q.Dim(3)
@@ -273,10 +388,8 @@ func (gdn *GatedDeltaNet) deltaNetAutoregressive(
coreAttnOut := stateQ.SumRows(ctx)
coreAttnOut = coreAttnOut.Permute(ctx, 1, 0, 2, 3)
// Update delta state in cache
cache.UpdateDeltaState(ctx, layer, state.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs))
return coreAttnOut.Reshape(ctx, headVDim, numVHeads, 1, nSeqs)
newState := state.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs)
return coreAttnOut.Reshape(ctx, headVDim, numVHeads, 1, nSeqs), newState
}
// deltaNetChunked implements chunked computation for prefill.
@@ -286,9 +399,7 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
q, k, v, gate, beta, state ml.Tensor,
masks *Masks,
opts *Options,
layer int,
cache *HybridCache,
) ml.Tensor {
) (ml.Tensor, ml.Tensor) {
headKDim := q.Dim(0)
numVHeads := v.Dim(1)
headVDim := v.Dim(0)
@@ -465,8 +576,6 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
coreAttnOut = coreAttnOut.Slice(ctx, 1, 0, nTokens, 1)
}
// Update delta state in cache
cache.UpdateDeltaState(ctx, layer, newState.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs))
return coreAttnOut.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, numVHeads, nTokens, nSeqs)
newStateFlat := newState.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs)
return coreAttnOut.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, numVHeads, nTokens, nSeqs), newStateFlat
}

View File

@@ -514,13 +514,6 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
continue
}
// if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
s.removeSequence(seqIdx, llm.DoneReasonLength)
nextBatch.seqs[seqIdx] = nil
continue
}
if !s.cache.enabled {
seq.inputs = append(seq.cache.Inputs, seq.inputs...)
seq.cache.Inputs = []*input.Input{}
@@ -709,7 +702,6 @@ func (s *Server) computeBatch(activeBatch batchState) {
continue
}
seq.numPredicted++
nextToken := &input.Input{Token: 0} // placeholder we'll fill in after Compute/Floats
seq.inputs = []*input.Input{nextToken}
nextBatchTokens[i] = nextToken
@@ -745,7 +737,9 @@ func (s *Server) computeBatch(activeBatch batchState) {
logutil.Trace("computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i)
continue
}
seq.lastUpdatedAt = t
seq.numPredicted++
if seq.numPredicted == 1 {
seq.processingDuration = seq.lastUpdatedAt.Sub(seq.startedAt)
seq.startedAt = seq.lastUpdatedAt
@@ -791,6 +785,13 @@ func (s *Server) computeBatch(activeBatch batchState) {
}
seq.pendingResponses = append(seq.pendingResponses, piece)
// if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
s.removeSequence(i, llm.DoneReasonLength)
continue
}
sequence := strings.Join(seq.pendingResponses, "")
if ok, stop := common.FindStop(sequence, seq.stop); ok {