Compare commits

...

1 Commits

Author SHA1 Message Date
jmorganca
c330ea33ed qwen3next: handle mixed recurrent batches
Allow mixed token-count batches by tracking per-seq indices

and falling back to per-seq recurrent processing when layouts

differ.

Add per-slot conv/delta state access with checkpoint capture,

relax attention layout handling, and reuse projections in mixed

batches to reduce overhead.
2026-02-05 11:50:00 -08:00
3 changed files with 273 additions and 40 deletions

View File

@@ -39,12 +39,15 @@ func (sa *FullAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tens
if nSeqs > 0 {
// 3D tensor: [hiddenDim, seqTokens, nSeqs]
if batchSize != seqTokens || nSeqs != seqs {
return nil, ErrUnsupportedBatchLayout
// Fallback: treat as flat batch if layout doesn't match.
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, batchSize*nSeqs)
batchSize = batchSize * nSeqs
} else {
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, seqTokens*seqs)
batchSize = seqTokens * seqs
}
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, seqTokens*seqs)
batchSize = seqTokens * seqs
} else if batchSize != seqTokens*seqs {
return nil, ErrUnsupportedBatchLayout
// Layout mismatch; proceed with flat batch.
}
}
}

View File

@@ -64,6 +64,8 @@ type HybridCache struct {
curSlots []int
curSlotsInput ml.Tensor
curSeqTokens int
// token indices per sequence in batch order
curSeqTokenIdxs [][]int32
// track if EnsureWritable has been called for this forward pass
writableEnsured bool
@@ -168,19 +170,44 @@ func (c *HybridCache) StartForward(ctx ml.Context, batch input.Batch, reserve bo
}
if len(c.curSeqs) == 0 {
c.curSeqTokenIdxs = c.curSeqTokenIdxs[:0]
return nil
}
if cap(c.curSeqTokenIdxs) < len(c.curSeqs) {
c.curSeqTokenIdxs = make([][]int32, len(c.curSeqs))
} else {
c.curSeqTokenIdxs = c.curSeqTokenIdxs[:len(c.curSeqs)]
}
for i := range c.curSeqTokenIdxs {
c.curSeqTokenIdxs[i] = c.curSeqTokenIdxs[i][:0]
}
seqIndex := make(map[int]int, len(c.curSeqs))
for i, s := range c.curSeqs {
seqIndex[s] = i
}
for i, s := range batch.Sequences {
c.curSeqTokenIdxs[seqIndex[s]] = append(c.curSeqTokenIdxs[seqIndex[s]], int32(i))
}
nTokens := len(batch.Sequences)
nSeqs := len(c.curSeqs)
want := nTokens / nSeqs
uniform := true
for _, s := range c.curSeqs {
if seqCounts[s] != want {
return kvcache.ErrNotSupported
uniform = false
break
}
}
c.curSeqTokens = want
if uniform {
c.curSeqTokens = want
} else {
// Mixed batch: recurrent layers will process sequences independently.
c.curSeqTokens = 0
}
// When reserving memory for estimation, use fake slot assignments
if reserve {
@@ -585,7 +612,101 @@ func (c *HybridCache) UpdateDeltaState(ctx ml.Context, layer int, newState ml.Te
c.captureDeltaCheckpoint(ctx, layer, srcF32)
}
// IsSupportedForBatch returns true if the current batch layout supports recurrent layers.
// convStateForSlot returns the conv state for a single slot as [convDim, convChannels, 1].
func (c *HybridCache) convStateForSlot(ctx ml.Context, layer int, slot int) (ml.Tensor, error) {
c.ensureWritableOnce(ctx)
if c.writableError != nil {
return nil, c.writableError
}
buf := c.convBuffer(ctx, layer)
slotIdx := ctx.Input().FromInts([]int32{int32(slot)}, 1)
cur := buf.Rows(ctx, slotIdx)
return cur.Reshape(ctx, c.convDim, c.convChannels, 1), nil
}
// updateConvStateForSlot writes a new conv state for a single slot.
func (c *HybridCache) updateConvStateForSlot(ctx ml.Context, layer int, slot int, seqIndex int, newState ml.Tensor) {
buf := c.convBuffer(ctx, layer)
src := newState.Reshape(ctx, c.convDim*c.convChannels, 1)
srcF32 := src.Cast(ctx, ml.DTypeF32)
slotIdx := ctx.Input().FromInts([]int32{int32(slot)}, 1)
ctx.Forward(buf.SetRows(ctx, srcF32, slotIdx))
c.captureConvCheckpointForSeq(ctx, layer, seqIndex, srcF32)
}
// deltaStateForSlot returns the delta state for a single slot as [headVDim, headVDim*numVHeads, 1].
func (c *HybridCache) deltaStateForSlot(ctx ml.Context, layer int, slot int, headVDim, numVHeads int) (ml.Tensor, error) {
c.ensureWritableOnce(ctx)
if c.writableError != nil {
return nil, c.writableError
}
buf := c.deltaBuffer(ctx, layer)
slotIdx := ctx.Input().FromInts([]int32{int32(slot)}, 1)
cur := buf.Rows(ctx, slotIdx)
return cur.Reshape(ctx, headVDim, headVDim*numVHeads, 1), nil
}
// updateDeltaStateForSlot writes a new delta state for a single slot.
func (c *HybridCache) updateDeltaStateForSlot(ctx ml.Context, layer int, slot int, seqIndex int, newState ml.Tensor) {
buf := c.deltaBuffer(ctx, layer)
src := newState.Reshape(ctx, c.deltaStateSize, 1)
srcF32 := src.Cast(ctx, ml.DTypeF32)
slotIdx := ctx.Input().FromInts([]int32{int32(slot)}, 1)
ctx.Forward(buf.SetRows(ctx, srcF32, slotIdx))
c.captureDeltaCheckpointForSeq(ctx, layer, seqIndex, srcF32)
}
func (c *HybridCache) captureConvCheckpointForSeq(ctx ml.Context, layer int, seqIndex int, src ml.Tensor) {
if c.checkpointCount == 0 {
return
}
if c.reserveCheckpoints {
c.reserveCheckpointConv(layer)
return
}
if seqIndex < 0 || seqIndex >= len(c.curCheckpointPos) {
return
}
pos := c.curCheckpointPos[seqIndex]
if pos < 0 {
return
}
slot := c.curSlots[seqIndex]
idx := c.checkpointIndexForSlot(slot, pos)
if idx < 0 {
return
}
entry := &c.checkpoints[slot].entries[idx]
dst := c.ensureCheckpointConv(layer, entry)
ctx.Forward(src.Copy(ctx, dst))
}
func (c *HybridCache) captureDeltaCheckpointForSeq(ctx ml.Context, layer int, seqIndex int, src ml.Tensor) {
if c.checkpointCount == 0 {
return
}
if c.reserveCheckpoints {
c.reserveCheckpointDelta(layer)
return
}
if seqIndex < 0 || seqIndex >= len(c.curCheckpointPos) {
return
}
pos := c.curCheckpointPos[seqIndex]
if pos < 0 {
return
}
slot := c.curSlots[seqIndex]
idx := c.checkpointIndexForSlot(slot, pos)
if idx < 0 {
return
}
entry := &c.checkpoints[slot].entries[idx]
dst := c.ensureCheckpointDelta(layer, entry)
ctx.Forward(src.Copy(ctx, dst))
}
// IsSupportedForBatch returns true if the current batch layout supports grid-style recurrent processing.
func (c *HybridCache) IsSupportedForBatch() bool {
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
}

View File

@@ -48,6 +48,13 @@ type GatedDeltaNet struct {
Layer int
}
type stateAccessors struct {
convState func() (ml.Tensor, error)
updateConv func(ml.Tensor)
deltaState func() (ml.Tensor, error)
updateDelta func(ml.Tensor)
}
// createMasks builds the constant mask tensors (called once, reused for all chunks)
func createMasks(ctx ml.Context) *Masks {
ones := ctx.Input().Zeros(ml.DTypeF32, chunkSize, chunkSize)
@@ -68,7 +75,6 @@ func createMasks(ctx ml.Context) *Masks {
}
func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
layer := gdn.Layer
nSeqTokens := hiddenStates.Dim(1)
nSeqs := hiddenStates.Dim(2)
if cache != nil && cache.IsSupportedForBatch() {
@@ -77,34 +83,140 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
if seqTokens > 0 && seqs > 0 {
if nSeqs > 1 {
if nSeqTokens != seqTokens || nSeqs != seqs {
return nil, ErrUnsupportedBatchLayout
return gdn.forwardMixed(ctx, hiddenStates, cache, opts)
}
} else {
if nSeqTokens != seqTokens*seqs {
return nil, ErrUnsupportedBatchLayout
return gdn.forwardMixed(ctx, hiddenStates, cache, opts)
}
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), seqTokens, seqs)
nSeqTokens = seqTokens
nSeqs = seqs
}
}
numVHeads := opts.ssmDtRank
headVDim := opts.ssmDInner / numVHeads
layer := gdn.Layer
access := stateAccessors{
convState: func() (ml.Tensor, error) {
return cache.ConvState(ctx, layer)
},
updateConv: func(newState ml.Tensor) {
cache.UpdateConvState(ctx, layer, newState)
},
deltaState: func() (ml.Tensor, error) {
return cache.DeltaState(ctx, layer, headVDim, numVHeads)
},
updateDelta: func(newState ml.Tensor) {
cache.UpdateDeltaState(ctx, layer, newState)
},
}
return gdn.forwardWithAccessors(ctx, hiddenStates, opts, access)
}
if cache == nil {
return nil, ErrUnsupportedBatchLayout
}
return gdn.forwardMixed(ctx, hiddenStates, cache, opts)
}
func (gdn *GatedDeltaNet) forwardMixed(ctx ml.Context, hiddenStates ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
if hiddenStates.Dim(2) > 0 {
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), hiddenStates.Dim(1)*hiddenStates.Dim(2))
}
if len(cache.curSeqs) == 0 {
return hiddenStates, nil
}
// Ensure any shared slots are detached once for this forward pass.
cache.ensureWritableOnce(ctx)
layer := gdn.Layer
numVHeads := opts.ssmDtRank
headVDim := opts.ssmDInner / numVHeads
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
return nil, errors.New("qwen3next: missing attn_qkv/attn_gate projections (legacy ssm_in is not supported)")
}
// Precompute projections for the full batch and slice per sequence.
mixedBAFull := gdn.SSMBetaAlpha.Forward(ctx, hiddenStates)
qkvMixedFull := gdn.SSMQKV.Forward(ctx, hiddenStates)
zFull := gdn.SSMQKVGate.Forward(ctx, hiddenStates)
out := hiddenStates
for seqIndex := range cache.curSeqs {
idxs := cache.curSeqTokenIdxs[seqIndex]
if len(idxs) == 0 {
continue
}
idxTensor := ctx.Input().FromInts(idxs, len(idxs))
mixedBA := mixedBAFull.Rows(ctx, idxTensor)
qkvMixed := qkvMixedFull.Rows(ctx, idxTensor)
z := zFull.Rows(ctx, idxTensor)
slot := cache.curSlots[seqIndex]
access := stateAccessors{
convState: func() (ml.Tensor, error) {
return cache.convStateForSlot(ctx, layer, slot)
},
updateConv: func(newState ml.Tensor) {
cache.updateConvStateForSlot(ctx, layer, slot, seqIndex, newState)
},
deltaState: func() (ml.Tensor, error) {
return cache.deltaStateForSlot(ctx, layer, slot, headVDim, numVHeads)
},
updateDelta: func(newState ml.Tensor) {
cache.updateDeltaStateForSlot(ctx, layer, slot, seqIndex, newState)
},
}
seqOut, err := gdn.forwardProjected(ctx, len(idxs), 1, mixedBA, qkvMixed, z, opts, access)
if err != nil {
return nil, err
}
out = out.SetRows(ctx, seqOut, idxTensor)
}
return out, nil
}
func (gdn *GatedDeltaNet) forwardWithAccessors(ctx ml.Context, hiddenStates ml.Tensor, opts *Options, access stateAccessors) (ml.Tensor, error) {
nSeqTokens := hiddenStates.Dim(1)
nSeqs := hiddenStates.Dim(2)
mixedBA := gdn.SSMBetaAlpha.Forward(ctx, hiddenStates)
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
return nil, errors.New("qwen3next: missing attn_qkv/attn_gate projections (legacy ssm_in is not supported)")
}
// Optimized path: pre-split QKV and gate
qkvMixed := gdn.SSMQKV.Forward(ctx, hiddenStates)
z := gdn.SSMQKVGate.Forward(ctx, hiddenStates)
return gdn.forwardProjected(ctx, nSeqTokens, nSeqs, mixedBA, qkvMixed, z, opts, access)
}
func (gdn *GatedDeltaNet) forwardProjected(
ctx ml.Context,
nSeqTokens, nSeqs int,
mixedBA, qkvMixed, z ml.Tensor,
opts *Options,
access stateAccessors,
) (ml.Tensor, error) {
layer := gdn.Layer
headKDim := opts.ssmDState
numKHeads := opts.ssmNGroup
numVHeads := opts.ssmDtRank
headVDim := opts.ssmDInner / numVHeads
convKernelSize := opts.convKernelSize
mixedBA := gdn.SSMBetaAlpha.Forward(ctx, hiddenStates)
qkvDim := headKDim*numKHeads*2 + headVDim*numVHeads
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
return nil, errors.New("qwen3next: missing attn_qkv/attn_gate projections (legacy ssm_in is not supported)")
}
// Optimized path: pre-split QKV and gate
qkvMixed := gdn.SSMQKV.Forward(ctx, hiddenStates).Reshape(ctx, qkvDim, nSeqTokens, nSeqs)
z := gdn.SSMQKVGate.Forward(ctx, hiddenStates)
qkvMixed = qkvMixed.Reshape(ctx, qkvDim, nSeqTokens, nSeqs)
baNewDim := 2 * numVHeads / numKHeads
mixedBAReshaped := mixedBA.Reshape(ctx, baNewDim, numKHeads, nSeqTokens, nSeqs)
@@ -127,7 +239,7 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
qkvMixed = qkvMixed.Permute(ctx, 1, 0, 2, 3)
// Get conv state from cache
convStates, err := cache.ConvState(ctx, layer)
convStates, err := access.convState()
if err != nil {
// Log this - if it happens, short-term context will be lost
slog.Warn("qwen3next: failed to get conv state, using zeros", "layer", layer, "error", err)
@@ -142,7 +254,7 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
// Save new conv state (last convKernelSize-1 tokens)
lastConvStates := convInput.Slice(ctx, 0, nSeqTokens, nSeqTokens+convKernelSize-1, 1)
cache.UpdateConvState(ctx, layer, lastConvStates)
access.updateConv(lastConvStates)
// Apply SSM convolution (kernel must be F32 for Metal)
convOutput := convInput.SSMConv(ctx, gdn.SSMConv1D.Weight)
@@ -162,7 +274,7 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
vConv = vConv.Contiguous(ctx, headVDim, numVHeads, nSeqTokens, nSeqs)
// Get delta state from cache
state, err := cache.DeltaState(ctx, layer, headVDim, numVHeads)
state, err := access.deltaState()
if err != nil {
// Log this - if it happens frequently, context will degrade
slog.Warn("qwen3next: failed to get delta state, using zeros", "layer", layer, "error", err)
@@ -185,14 +297,19 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
}
// Choose computation mode based on sequence length
var attnOut ml.Tensor
var (
attnOut ml.Tensor
newState ml.Tensor
)
if nSeqTokens == 1 {
attnOut = gdn.deltaNetAutoregressive(ctx, qConv, kConv, vConv, gate, beta, state, opts, layer, cache)
attnOut, newState = gdn.deltaNetAutoregressive(ctx, qConv, kConv, vConv, gate, beta, state, opts)
} else {
// Use pre-computed masks from opts (created once in Model.Forward)
attnOut = gdn.deltaNetChunked(ctx, qConv, kConv, vConv, gate, beta, state, opts.masks, opts, layer, cache)
attnOut, newState = gdn.deltaNetChunked(ctx, qConv, kConv, vConv, gate, beta, state, opts.masks, opts)
}
access.updateDelta(newState)
// Apply gated normalization
attnOut2D := attnOut.Contiguous(ctx, headVDim, numVHeads*nSeqTokens*nSeqs)
z2D := z.Contiguous(ctx, headVDim, numVHeads*nSeqTokens*nSeqs)
@@ -215,9 +332,7 @@ func (gdn *GatedDeltaNet) deltaNetAutoregressive(
ctx ml.Context,
q, k, v, gate, beta, state ml.Tensor,
opts *Options,
layer int,
cache *HybridCache,
) ml.Tensor {
) (ml.Tensor, ml.Tensor) {
numVHeads := v.Dim(1)
headVDim := v.Dim(0)
nSeqs := q.Dim(3)
@@ -273,10 +388,8 @@ func (gdn *GatedDeltaNet) deltaNetAutoregressive(
coreAttnOut := stateQ.SumRows(ctx)
coreAttnOut = coreAttnOut.Permute(ctx, 1, 0, 2, 3)
// Update delta state in cache
cache.UpdateDeltaState(ctx, layer, state.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs))
return coreAttnOut.Reshape(ctx, headVDim, numVHeads, 1, nSeqs)
newState := state.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs)
return coreAttnOut.Reshape(ctx, headVDim, numVHeads, 1, nSeqs), newState
}
// deltaNetChunked implements chunked computation for prefill.
@@ -286,9 +399,7 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
q, k, v, gate, beta, state ml.Tensor,
masks *Masks,
opts *Options,
layer int,
cache *HybridCache,
) ml.Tensor {
) (ml.Tensor, ml.Tensor) {
headKDim := q.Dim(0)
numVHeads := v.Dim(1)
headVDim := v.Dim(0)
@@ -465,8 +576,6 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
coreAttnOut = coreAttnOut.Slice(ctx, 1, 0, nTokens, 1)
}
// Update delta state in cache
cache.UpdateDeltaState(ctx, layer, newState.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs))
return coreAttnOut.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, numVHeads, nTokens, nSeqs)
newStateFlat := newState.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs)
return coreAttnOut.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, numVHeads, nTokens, nSeqs), newStateFlat
}