Compare commits

...

1 Commits

Author SHA1 Message Date
Jeffrey Morgan
3490e9590b model/qwen3next: avoid crash in in DeltaNet when offloading (#14541)
Co-authored-by: Yossi Ovadia <jabadia@gmail.com>
2026-03-01 18:44:04 -08:00

View File

@@ -454,6 +454,10 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
vT := v.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, chunkSize, headVDim, nChunks, numVHeads*nSeqs)
stateT := state.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, 1, numVHeads*nSeqs)
// Collect chunk outputs and concatenate at the end.
// Avoids SET on buffer-less intermediates under partial offload.
chunks := make([]ml.Tensor, nChunks)
for chunk := range nChunks {
qChunk := q.Slice(ctx, 2, chunk, chunk+1, 1)
vTChunk := vT.Slice(ctx, 2, chunk, chunk+1, 1)
@@ -475,14 +479,7 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
vAttn := vTNewChunk.Mulmat(ctx, attnChunk)
coreAttnOutChunk := attnInter.Add(ctx, vAttn)
v = v.SetInplace(
ctx,
coreAttnOutChunk,
v.Stride(1),
v.Stride(2),
v.Stride(3),
chunk*v.Stride(2),
)
chunks[chunk] = coreAttnOutChunk
// Update state for next chunk
gExpLastChunk := gLastExp.Slice(ctx, 2, chunk, chunk+1, 1)
@@ -495,6 +492,20 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
stateT = stateT.Add(ctx, kgdMulVNew)
}
// Use a balanced concat tree so concat work does not balloon on long prompts.
for len(chunks) > 1 {
merged := make([]ml.Tensor, 0, (len(chunks)+1)/2)
for i := 0; i < len(chunks); i += 2 {
if i+1 < len(chunks) {
merged = append(merged, chunks[i].Concat(ctx, chunks[i+1], 2))
} else {
merged = append(merged, chunks[i])
}
}
chunks = merged
}
v = chunks[0]
// Final reshape
coreAttnOut := v.Contiguous(ctx, headVDim, chunkSize*nChunks, numVHeads, nSeqs)