|
|
|
@@ -122,7 +122,7 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
|
|
|
|
|
b := mixedBAReshaped.Slice(ctx, 0, 0, betaSize, 1)
|
|
|
|
|
a := mixedBAReshaped.Slice(ctx, 0, betaSize, betaSize+alphaSize, 1)
|
|
|
|
|
|
|
|
|
|
// Keep beta layout consistent with qwen35 and llama.cpp:
|
|
|
|
|
// Keep beta layout consistent with qwen35.
|
|
|
|
|
// [1, numVHeads, nSeqTokens, nSeqs]
|
|
|
|
|
beta = b.Contiguous(ctx, 1, numVHeads, nSeqTokens, nSeqs)
|
|
|
|
|
alpha = a.Contiguous(ctx, numVHeads, nSeqTokens, nSeqs)
|
|
|
|
@@ -333,7 +333,6 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
|
|
|
|
|
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headKDim, nTokens, numVHeads, nSeqs)
|
|
|
|
|
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headKDim, nTokens, numVHeads, nSeqs)
|
|
|
|
|
v = v.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, nTokens, numVHeads, nSeqs)
|
|
|
|
|
// Match llama.cpp delta-net-base layout:
|
|
|
|
|
// gate/beta: [1, numVHeads, nTokens, nSeqs] -> [1, nTokens, numVHeads, nSeqs]
|
|
|
|
|
gate = gate.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, 1, nTokens, numVHeads, nSeqs)
|
|
|
|
|
beta = beta.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, 1, nTokens, numVHeads, nSeqs)
|
|
|
|
@@ -437,60 +436,64 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
|
|
|
|
|
keyGDiff := k.Mul(ctx, gDiffExpReshaped)
|
|
|
|
|
keyGDiffT := keyGDiff.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
|
|
|
|
|
|
|
|
|
// Process chunks and update state
|
|
|
|
|
var coreAttnOut ml.Tensor
|
|
|
|
|
newState := state
|
|
|
|
|
// Process chunks and update state.
|
|
|
|
|
// Keep a transposed view of v and recurrent state across chunks so the
|
|
|
|
|
// chunk loop does not need extra transpose+contiguous nodes.
|
|
|
|
|
vT := v.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, chunkSize, headVDim, nChunks, numVHeads*nSeqs)
|
|
|
|
|
stateT := state.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, 1, numVHeads*nSeqs)
|
|
|
|
|
|
|
|
|
|
for chunk := range nChunks {
|
|
|
|
|
qChunk := q.Slice(ctx, 2, chunk, chunk+1, 1)
|
|
|
|
|
vChunk := v.Slice(ctx, 2, chunk, chunk+1, 1)
|
|
|
|
|
vTChunk := vT.Slice(ctx, 2, chunk, chunk+1, 1)
|
|
|
|
|
gExpChunk := gExp.Slice(ctx, 2, chunk, chunk+1, 1)
|
|
|
|
|
kCumdecayChunk := kCumdecay.Slice(ctx, 2, chunk, chunk+1, 1)
|
|
|
|
|
attnChunk := attnKQ.Slice(ctx, 2, chunk, chunk+1, 1) // Pre-computed!
|
|
|
|
|
|
|
|
|
|
// state^T - permute is needed but Contiguous creates a copy
|
|
|
|
|
stateT := newState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, 1, numVHeads*nSeqs)
|
|
|
|
|
// v'_t = k_cumdecay @ state_t
|
|
|
|
|
vTPrime := kCumdecayChunk.Mulmat(ctx, stateT)
|
|
|
|
|
|
|
|
|
|
// v_prime = k_cumdecay @ state
|
|
|
|
|
vPrime := stateT.Mulmat(ctx, kCumdecayChunk)
|
|
|
|
|
|
|
|
|
|
// v_new = v - v_prime
|
|
|
|
|
vNew := vChunk.Sub(ctx, vPrime)
|
|
|
|
|
vNewT := vNew.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
|
|
|
|
// v_t_new = v_t - v'_t
|
|
|
|
|
vTNewChunk := vTChunk.Sub(ctx, vTPrime)
|
|
|
|
|
|
|
|
|
|
// attn_inter = (q * g_exp) @ state
|
|
|
|
|
qGExp := qChunk.Mul(ctx, gExpChunk)
|
|
|
|
|
attnInter := stateT.Mulmat(ctx, qGExp)
|
|
|
|
|
|
|
|
|
|
// core_attn_out = attn_inter + attn @ v_new
|
|
|
|
|
vAttn := vNewT.Mulmat(ctx, attnChunk)
|
|
|
|
|
vAttn := vTNewChunk.Mulmat(ctx, attnChunk)
|
|
|
|
|
coreAttnOutChunk := attnInter.Add(ctx, vAttn)
|
|
|
|
|
|
|
|
|
|
if coreAttnOut == nil {
|
|
|
|
|
coreAttnOut = coreAttnOutChunk
|
|
|
|
|
} else {
|
|
|
|
|
coreAttnOut = coreAttnOut.Concat(ctx, coreAttnOutChunk, 1)
|
|
|
|
|
}
|
|
|
|
|
v = v.SetInplace(
|
|
|
|
|
ctx,
|
|
|
|
|
coreAttnOutChunk,
|
|
|
|
|
v.Stride(1),
|
|
|
|
|
v.Stride(2),
|
|
|
|
|
v.Stride(3),
|
|
|
|
|
chunk*v.Stride(2),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
// Update state for next chunk
|
|
|
|
|
gExpLastChunk := gLastExp.Slice(ctx, 2, chunk, chunk+1, 1)
|
|
|
|
|
kGDiffChunkT := keyGDiffT.Slice(ctx, 2, chunk, chunk+1, 1)
|
|
|
|
|
kgdMulVNew := vNewT.Mulmat(ctx, kGDiffChunkT)
|
|
|
|
|
// kgdmulvnew = key_gdiff_t @ v_new_t
|
|
|
|
|
kgdMulVNew := kGDiffChunkT.Mulmat(ctx, vTNewChunk)
|
|
|
|
|
|
|
|
|
|
// state = state * g_last + kgdmulvnew
|
|
|
|
|
gExpLastReshaped := gExpLastChunk.Contiguous(ctx).Reshape(ctx, 1, 1, numVHeads, nSeqs)
|
|
|
|
|
newState = newState.Mul(ctx, gExpLastReshaped)
|
|
|
|
|
newState = newState.Add(ctx, kgdMulVNew.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs))
|
|
|
|
|
// stateT = stateT * g_last + kgdmulvnew
|
|
|
|
|
stateT = stateT.Mul(ctx, gExpLastChunk)
|
|
|
|
|
stateT = stateT.Add(ctx, kgdMulVNew)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Final reshape
|
|
|
|
|
coreAttnOut = coreAttnOut.Contiguous(ctx, headVDim, chunkSize*nChunks, numVHeads, nSeqs)
|
|
|
|
|
coreAttnOut := v.Contiguous(ctx, headVDim, chunkSize*nChunks, numVHeads, nSeqs)
|
|
|
|
|
|
|
|
|
|
// Slice to remove padding
|
|
|
|
|
if pad > 0 {
|
|
|
|
|
coreAttnOut = coreAttnOut.Slice(ctx, 1, 0, nTokens, 1)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Convert stateT back to cache layout [S_v, S_v, H_v, nSeqs]
|
|
|
|
|
newState := stateT.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, numVHeads, nSeqs)
|
|
|
|
|
|
|
|
|
|
// Update delta state in cache
|
|
|
|
|
cache.UpdateDeltaState(ctx, layer, newState.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs))
|
|
|
|
|
|
|
|
|
|