Compare commits

...

1 Commits

Author SHA1 Message Date
jmorganca
794ea02bc8 qwen3next: fix issue in delta net
g_diff_exp was being broadcast across the wrong axis when multiplying with k. This fix reshapes g_diff_exp to [1, chunkSize, nChunks, ...]
2026-02-04 08:55:35 -08:00

View File

@@ -406,8 +406,10 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
gDiff := gCumsum.Neg(ctx).Add(ctx, gLast)
gDiffExp := gDiff.Exp(ctx)
// key_gdiff = k * exp(g_diff)
keyGDiff := k.Mul(ctx, gDiffExp)
// Reshapes g_diff_exp to [1, chunkSize, nChunks, ...]
gDiffExpReshaped := gDiffExp.Reshape(ctx, 1, chunkSize, nChunks, numVHeads*nSeqs)
keyGDiff := k.Mul(ctx, gDiffExpReshaped)
keyGDiffT := keyGDiff.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
// Process chunks and update state
var coreAttnOut ml.Tensor
@@ -444,12 +446,9 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
coreAttnOut = coreAttnOut.Concat(ctx, coreAttnOutChunk, 1)
}
// Update state for next chunk using pre-computed values
// Update state for next chunk
gExpLastChunk := gLastExp.Slice(ctx, 2, chunk, chunk+1, 1)
kGDiffChunk := keyGDiff.Slice(ctx, 2, chunk, chunk+1, 1)
// kgdmulvnew = key_gdiff^T @ v_new
kGDiffChunkT := kGDiffChunk.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
kGDiffChunkT := keyGDiffT.Slice(ctx, 2, chunk, chunk+1, 1)
kgdMulVNew := vNewT.Mulmat(ctx, kGDiffChunkT)
// state = state * g_last + kgdmulvnew