Compare commits

...

2 Commits

Author SHA1 Message Date
Michael Yang
12a7e5ec46 gemma3: scale in attention 2025-08-19 13:43:47 -07:00
Michael Yang
b323cfe731 gemma2: use fast attention 2025-08-19 13:33:12 -07:00
2 changed files with 17 additions and 39 deletions

View File

@@ -69,10 +69,10 @@ func New(c fs.Config) (model.Model, error) {
}, },
} }
slidingWindowLen := int32(c.Uint("attention.sliding_window")) m.Cache = kvcache.NewWrapperCache(
m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift)) kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift),
m.Cache.SetConfig(ml.CacheConfig{}) kvcache.NewCausalCache(m.Shift),
)
return &m, nil return &m, nil
} }
@@ -90,12 +90,6 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
if opts.largeModelScaling {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
} else {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
}
k := sa.Key.Forward(ctx, hiddenState) k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
@@ -103,28 +97,14 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
v := sa.Value.Forward(ctx, hiddenState) v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
cache.Put(ctx, k, v) scale := 1.0 / math.Sqrt(float64(opts.attnKeyLen))
k, v, mask := cache.Get(ctx) if opts.largeModelScaling {
scale = 1.0 / math.Sqrt(float64(opts.hiddenSize/opts.numHeads))
}
q = q.Permute(ctx, 0, 2, 1, 3) attn := nn.Attention(ctx, q, k, v, scale, cache)
k = k.Permute(ctx, 0, 2, 1, 3) attn = attn.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) return sa.Output.Forward(ctx, attn)
kq := k.Mulmat(ctx, q)
// logit softcap
kq = kq.Scale(ctx, 1.0/float64(opts.attnLogitSoftcap))
kq = kq.Tanh(ctx)
kq = kq.Scale(ctx, float64(opts.attnLogitSoftcap))
kq = kq.Add(ctx, mask)
kq = kq.Softmax(ctx)
kqv := v.Mulmat(ctx, kq)
kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
return sa.Output.Forward(ctx, kqv)
} }
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {

View File

@@ -86,12 +86,6 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
q = sa.QueryNorm.Forward(ctx, q, opts.eps) q = sa.QueryNorm.Forward(ctx, q, opts.eps)
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX()) q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX())
if opts.largeModelScaling {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
} else {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
}
k := sa.Key.Forward(ctx, hiddenState) k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
k = sa.KeyNorm.Forward(ctx, k, opts.eps) k = sa.KeyNorm.Forward(ctx, k, opts.eps)
@@ -100,8 +94,12 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
v := sa.Value.Forward(ctx, hiddenState) v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
scaleFactor := 1.0 scale := 1.0 / math.Sqrt(float64(opts.attnKeyLen))
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache) if opts.largeModelScaling {
scale = 1.0 / math.Sqrt(float64(opts.hiddenSize/opts.numHeads))
}
kqv := nn.Attention(ctx, q, k, v, scale, cache)
kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize) kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
return sa.Output.Forward(ctx, kqv) return sa.Output.Forward(ctx, kqv)