Compare commits

...

4 Commits

Author SHA1 Message Date
Michael Yang
89637ae43b gemma2: enable flash attention 2025-12-16 09:45:05 -08:00
Michael Yang
f49797fbdb generic pad 2025-12-16 09:45:05 -08:00
Michael Yang
baae175ebe cleanup 2025-12-16 09:45:05 -08:00
Michael Yang
de82b1f9a3 cleanup attention interface
the updated interface supports variadic attention options which
removes the need for individual `AttentionWith...` functions. it means
more models can use the attention interface, e.g. models with
custom masks, logit softcapping, etc.

additionally, this interface should be less error prone since there are
now reasonable defaults for all optional parameters
2025-12-16 09:45:04 -08:00
30 changed files with 228 additions and 277 deletions

View File

@@ -826,10 +826,6 @@ func (f GGML) SupportsFlashAttention() bool {
return false
}
if arch := f.KV().Architecture(); slices.Contains([]string{"gemma2"}, arch) {
return false
}
// Check head counts match and are non-zero
headCountK := f.KV().EmbeddingHeadCountK()
headCountV := f.KV().EmbeddingHeadCountV()

View File

@@ -33,7 +33,7 @@ type Backend interface {
// BackendCacheConfig should be implemented by backends that need special output
// from the cache to meet specific requirements. It is frequently implemented in
// conjunction with ScaledDotProductAttention.
// conjunction with [nn.fastAttention].
type BackendCacheConfig interface {
CacheConfig() CacheConfig
}
@@ -152,7 +152,6 @@ type Tensor interface {
Div(ctx Context, t2 Tensor) Tensor
Mulmat(ctx Context, t2 Tensor) Tensor
MulmatFullPrec(ctx Context, t2 Tensor) Tensor
MulmatID(ctx Context, t2, ids Tensor) Tensor
AddID(ctx Context, t2, ids Tensor) Tensor
@@ -213,32 +212,6 @@ type Tensor interface {
Interpolate(ctx Context, dims [4]int, samplingMode SamplingMode) Tensor
}
// ScaledDotProductAttention implements a fused attention
// operation equivalent to following code on a tensor named
// query:
//
// query = query.Permute(ctx, 0, 2, 1, 3)
// key = key.Permute(ctx, 0, 2, 1, 3)
// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
//
// kq := key.MulmatFullPrec(ctx, query)
//
// kq = kq.Scale(ctx, scale)
//
// if mask != nil {
// kq = kq.Add(ctx, mask)
// }
//
// kq = kq.Softmax(ctx)
//
// kqv := value.Mulmat(ctx, kq)
// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
//
// cacheConfigApplied indicates whether the optimizations requested through CacheConfig have been performed
type ScaledDotProductAttention interface {
ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64, cacheConfigApplied bool) Tensor
}
type number interface {
~int | ~int8 | ~int16 | ~int32 | ~int64 |
~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |

View File

@@ -19,6 +19,7 @@ import (
"io"
"log/slog"
"maps"
"math"
"os"
"runtime"
"slices"
@@ -35,6 +36,7 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
"github.com/ollama/ollama/ml/nn/attention"
"github.com/ollama/ollama/ml/nn/rope"
"golang.org/x/sync/errgroup"
)
@@ -882,7 +884,7 @@ func shapeToGGML(shape []int) *C.int64_t {
return &sh[0]
}
func pad(length, pad C.size_t) C.size_t {
func pad[T C.size_t | int](length, pad T) T {
return ((length + pad - 1) / pad) * pad
}
@@ -1248,16 +1250,6 @@ func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
}
}
func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
mul := C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t)
C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32)
return &Tensor{
b: t.b,
t: mul,
}
}
func (t *Tensor) MulmatID(ctx ml.Context, t2, ids ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
@@ -1648,75 +1640,6 @@ func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
}
}
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, vmla ml.Tensor, scale float64, cacheConfigApplied bool) ml.Tensor {
// If the cache didn't help us with required transformations, do them here
if !cacheConfigApplied {
cacheConfig := t.b.CacheConfig()
// Padding key and value to CachePadding is a performance optimization, not a requirement, so we don't do it if it wasn't done by the caller
if cacheConfig.PermutedV {
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
}
if mask != nil {
padSize := int(pad(C.size_t(mask.Dim(1)), C.size_t(cacheConfig.MaskBatchPadding))) - mask.Dim(1)
if padSize > 0 {
mask = mask.Pad(ctx, 0, padSize, 0, 0)
}
if mask.DType() != cacheConfig.MaskDType {
mask = mask.Cast(ctx, cacheConfig.MaskDType)
}
}
}
var kqMask *C.struct_ggml_tensor
if mask != nil {
kqMask = mask.(*Tensor).t
}
query := t.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
if t.b.flashAttention == ml.FlashAttentionEnabled {
value = value.Permute(ctx, 0, 2, 1, 3)
kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)
if sinks != nil {
C.ggml_flash_attn_ext_add_sinks(kqv, sinks.(*Tensor).t)
}
C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32)
if vmla != nil {
var cur ml.Tensor = &Tensor{b: t.b, t: kqv}
cur = cur.Permute(ctx, 0, 2, 1, 3)
cur = vmla.Mulmat(ctx, cur)
cur = cur.Permute(ctx, 0, 2, 1, 3)
cur = cur.Contiguous(ctx)
kqv = cur.(*Tensor).t
}
return &Tensor{b: t.b, t: kqv}
} else {
kq := key.MulmatFullPrec(ctx, query)
kq = &Tensor{
b: t.b,
t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
}
if sinks != nil {
C.ggml_soft_max_add_sinks(kq.(*Tensor).t, sinks.(*Tensor).t)
}
kqv := value.Mulmat(ctx, kq)
if vmla != nil {
kqv = vmla.Mulmat(ctx, kqv)
}
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
}
}
func (t *Tensor) Duplicate(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
@@ -1849,3 +1772,89 @@ func (t *Tensor) ChunkSections(ctx ml.Context, dim int, sections ...int) []ml.Te
}
return s
}
func (t *Tensor) SDPA(ctx ml.Context, key, value ml.Tensor, fns ...func(*attention.Options)) ml.Tensor {
opts := attention.Options{
Scale: 1 / math.Sqrt(float64(t.Dim(0))),
}
for _, fn := range fns {
fn(&opts)
}
if !opts.Cached {
config := t.b.CacheConfig()
if config.PermutedV {
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
}
if opts.Mask != nil {
if padSize := pad(opts.Mask.Dim(1), config.MaskBatchPadding) - opts.Mask.Dim(1); padSize > 0 {
opts.Mask = opts.Mask.Pad(ctx, 0, padSize, 0, 0)
}
if opts.Mask.DType() != config.MaskDType {
opts.Mask = opts.Mask.Cast(ctx, config.MaskDType)
}
}
}
query := t.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
var mask *C.struct_ggml_tensor
if opts.Mask != nil {
mask = opts.Mask.(*Tensor).t
}
if t.b.flashAttention == ml.FlashAttentionEnabled {
value = value.Permute(ctx, 0, 2, 1, 3)
tt := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, mask, C.float(opts.Scale), 0, C.float(opts.LogitSoftcap))
C.ggml_flash_attn_ext_set_prec(tt, C.GGML_PREC_F32)
if opts.Sinks != nil {
C.ggml_flash_attn_ext_add_sinks(tt, opts.Sinks.(*Tensor).t)
}
var attention ml.Tensor = &Tensor{b: t.b, t: tt}
if opts.MLA != nil {
attention = attention.Permute(ctx, 0, 2, 1, 3)
attention = opts.MLA.Mulmat(ctx, attention)
attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
}
return attention
}
scores := key.Mulmat(ctx, query)
C.ggml_mul_mat_set_prec(scores.(*Tensor).t, C.GGML_PREC_F32)
if opts.LogitSoftcap > 0 {
scores = scores.Scale(ctx, 1/float64(opts.LogitSoftcap)).Tanh(ctx).Scale(ctx, float64(opts.LogitSoftcap))
}
if opts.Cached {
scores = &Tensor{b: t.b, t: C.ggml_soft_max_ext(ctx.(*Context).ctx, scores.(*Tensor).t, mask, C.float(opts.Scale), 0)}
} else {
scores = scores.Scale(ctx, opts.Scale)
if opts.Mask != nil {
scores = scores.Add(ctx, opts.Mask)
}
scores = scores.Softmax(ctx)
}
if opts.Sinks != nil {
C.ggml_soft_max_add_sinks(scores.(*Tensor).t, opts.Sinks.(*Tensor).t)
}
if key.Dim(1) == value.Dim(2) && key.Dim(2) == value.Dim(1) {
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
}
attention := value.Mulmat(ctx, scores)
if opts.MLA != nil {
attention = opts.MLA.Mulmat(ctx, attention)
}
return attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
}

View File

@@ -1,12 +1,17 @@
package nn
import (
"fmt"
"log"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn/attention"
)
type fastAttention interface {
SDPA(ctx ml.Context, key, value ml.Tensor, opts ...func(*attention.Options)) ml.Tensor
}
// Attention implements scaled dot-product attention for transformer models:
// Attention(Q, K, V) = softmax(QK^T/√d_k)V
//
@@ -21,27 +26,19 @@ import (
// Returns:
//
// Attention output with shape [d_v, heads, seq_len_q]
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
return AttentionWithVMLA(ctx, query, key, value, nil, nil, scale, cache)
}
func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
return AttentionWithVMLA(ctx, query, key, value, sinks, nil, scale, cache)
}
func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
ctx.Forward(query)
func Attention(ctx ml.Context, query, key, value ml.Tensor, cache kvcache.Cache, fns ...func(*attention.Options)) ml.Tensor {
if key != nil && value != nil {
if query.Dim(0) != key.Dim(0) {
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
log.Fatalf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0))
}
if key.Dim(1) != value.Dim(1) {
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1)))
log.Fatalf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1))
}
if key.Dim(2) != value.Dim(2) {
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
log.Fatalf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2))
}
ctx.Forward(key, value)
@@ -57,28 +54,12 @@ func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla
key, value, mask = cache.Get(ctx)
}
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok {
cacheConfigApplied := cache != nil
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, sinks, vmla, scale, cacheConfigApplied)
} else {
query = query.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := key.MulmatFullPrec(ctx, query)
kq = kq.Scale(ctx, scale)
if mask != nil {
kq = kq.Add(ctx, mask)
}
kq = kq.Softmax(ctx)
kqv := value.Mulmat(ctx, kq)
if vmla != nil {
kqv = vmla.Mulmat(ctx, kqv)
}
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
if t, ok := query.(fastAttention); ok {
return t.SDPA(ctx, key, value, append([]func(*attention.Options){
attention.WithMask(mask),
func(opts *attention.Options) { opts.Cached = cache != nil },
}, fns...)...)
}
panic("Attention not implemented for this tensor type")
}

View File

@@ -0,0 +1,55 @@
package attention
import (
"github.com/ollama/ollama/ml"
)
type Options struct {
// Scale is a scaling factor applied to the attention scores. Default is 1/√d_k.
Scale float64
// LogitSoftcap is used to apply a soft cap to the logits before softmax.
LogitSoftcap float32
// Mask is used in some attention mechanisms to mask out certain positions.
Mask ml.Tensor
// Sinks is used in some attention mechanisms to store additional data.
Sinks ml.Tensor
// MLA is used in some attention mechanisms for multi-latent attention.
MLA ml.Tensor
// Cached indicates whether key/value were retrieved from cache.
Cached bool
}
func WithScale(scale float64) func(*Options) {
return func(o *Options) {
o.Scale = scale
}
}
func WithSinks(sinks ml.Tensor) func(*Options) {
return func(o *Options) {
o.Sinks = sinks
}
}
func WithMLA(mla ml.Tensor) func(*Options) {
return func(o *Options) {
o.MLA = mla
}
}
func WithMask(mask ml.Tensor) func(*Options) {
return func(o *Options) {
o.Mask = mask
}
}
func WithLogitSoftcap(softcap float32) func(*Options) {
return func(o *Options) {
o.LogitSoftcap = softcap
}
}

View File

@@ -2,7 +2,6 @@ package bert
import (
"cmp"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
@@ -99,7 +98,7 @@ func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Option
value := a.Value.Forward(ctx, hiddenStates)
value = value.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize)
attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil)
attention := nn.Attention(ctx, query, key, value, nil)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return a.Output.Forward(ctx, attention)
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/attention"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
@@ -66,22 +67,22 @@ type Attention struct {
Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
}
func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
func (m *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
seqLength := hiddenStates.Dim(1)
var query ml.Tensor
if opts.qLoraRank == 0 {
query = attn.Q.Forward(ctx, hiddenStates)
query = m.Q.Forward(ctx, hiddenStates)
} else {
query = attn.QA.Forward(ctx, hiddenStates)
query = attn.QANorm.Forward(ctx, query, opts.eps)
query = attn.QB.Forward(ctx, query)
query = m.QA.Forward(ctx, hiddenStates)
query = m.QANorm.Forward(ctx, query, opts.eps)
query = m.QB.Forward(ctx, query)
}
query = query.Reshape(ctx, query.Dim(0)/opts.numHeads, opts.numHeads, seqLength)
queryChunks := query.ChunkSections(ctx, 0, opts.qkNopeHeadDim, opts.qkRopeHeadDim)
compressedKV := attn.KVA.Forward(ctx, hiddenStates)
compressedKV := m.KVA.Forward(ctx, hiddenStates)
kPass := compressedKV.Slice(ctx, 0, 0, opts.kvLoraRank, 1)
kRot := compressedKV.View(ctx,
opts.kvLoraRank*compressedKV.Stride(0), opts.qkRopeHeadDim,
@@ -91,12 +92,10 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
qRot := opts.applyRotaryPositionEmbeddings(ctx, queryChunks[1], positions)
kRot = opts.applyRotaryPositionEmbeddings(ctx, kRot, positions)
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
var attention ml.Tensor
kPass = m.KVANorm.Forward(ctx, kPass, opts.eps)
if !opts.isMLA { // v3
kPass = attn.KVB.Forward(ctx, kPass)
kPass = m.KVB.Forward(ctx, kPass)
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
@@ -104,10 +103,10 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
query = qRot.Concat(ctx, queryChunks[0], 0)
key := kRot.Concat(ctx, kvChunks[0], 0)
attention = nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
hiddenStates = nn.Attention(ctx, query, key, kvChunks[1], cache, attention.WithScale(opts.kqScale))
} else { // v3.1
qPass := queryChunks[0].Permute(ctx, 0, 2, 1, 3)
qPassAbsorb := attn.KB.Forward(ctx, qPass)
qPassAbsorb := m.KB.Forward(ctx, qPass)
qPassAbsorb = qPassAbsorb.Permute(ctx, 0, 2, 1, 3)
query = qRot.Concat(ctx, qPassAbsorb, 0)
@@ -115,11 +114,14 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
key := kRot.Concat(ctx, kPass, 0)
value := kPass
attention = nn.AttentionWithVMLA(ctx, query, key, value, nil, attn.VB.Weight, opts.kqScale, cache)
hiddenStates = nn.Attention(ctx, query, key, value, cache,
attention.WithMLA(m.VB.Weight),
attention.WithScale(opts.kqScale),
)
}
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
return attn.Output.Forward(ctx, attention)
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*hiddenStates.Dim(1), seqLength)
return m.Output.Forward(ctx, hiddenStates)
}
type MLP interface {

View File

@@ -1,11 +1,11 @@
package deepseekocr
import (
"math"
"slices"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/attention"
)
type samModel struct {
@@ -166,23 +166,13 @@ func (m *samAttention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts samO
ctx.Forward(query, key, value)
query = query.Permute(ctx, 0, 2, 1, 3)
rh, rw := m.decomposedRelativePositions(ctx, query, []int{h, w}, []int{h, w})
rh, rw := m.decomposedRelativePositions(ctx, query.Permute(ctx, 0, 2, 1, 3), []int{h, w}, []int{h, w})
mask := rh.Repeat(ctx, 0, rw.Dim(0)).Add(ctx, rw)
mask = mask.Reshape(ctx, h*w, -1, opts.numHeads, b)
key = key.Permute(ctx, 0, 2, 1, 3)
scores := key.MulmatFullPrec(ctx, query)
scores = scores.Scale(ctx, 1/math.Sqrt(float64(opts.headDim())))
scores = scores.Add(ctx, mask)
scores = scores.Softmax(ctx)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
attention := value.Mulmat(ctx, scores)
attention = attention.Permute(ctx, 0, 2, 1, 3)
attention = attention.Contiguous(ctx, -1, w, h, b)
return m.Output.Forward(ctx, attention)
hiddenStates = nn.Attention(ctx, query, key, value, nil, attention.WithMask(mask))
hiddenStates = hiddenStates.Contiguous(ctx, -1, w, h, b)
return m.Output.Forward(ctx, hiddenStates)
}
type samMLP struct {

View File

@@ -1,8 +1,6 @@
package deepseekocr
import (
"math"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
@@ -85,7 +83,7 @@ func (m *textAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tenso
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
attention := nn.Attention(ctx, query, key, value, cache)
attention = attention.Reshape(ctx, -1, attention.Dim(2))
return m.Output.Forward(ctx, attention)
}

View File

@@ -102,7 +102,7 @@ func (m *visionAttention) Forward(ctx ml.Context, t ml.Tensor, opts visionOption
chunks := qkv.Chunk(ctx, 1, opts.numHeads)
query, key, value := chunks[0], chunks[1], chunks[2]
attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil)
attention := nn.Attention(ctx, query, key, value, nil)
attention = attention.Reshape(ctx, -1, attention.Dim(2), attention.Dim(3))
return m.Output.Forward(ctx, attention)
}

View File

@@ -7,6 +7,7 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/attention"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
@@ -72,9 +73,10 @@ func New(c fs.Config) (model.Model, error) {
},
}
slidingWindowLen := int32(c.Uint("attention.sliding_window"))
m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
m.Cache.SetConfig(ml.CacheConfig{})
m.Cache = kvcache.NewWrapperCache(
kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift),
kvcache.NewCausalCache(m.Shift),
)
return &m, nil
}
@@ -106,28 +108,13 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
cache.Put(ctx, k, v)
k, v, mask := cache.Get(ctx)
hiddenState = nn.Attention(ctx, q, k, v, cache,
attention.WithLogitSoftcap(opts.attnLogitSoftcap),
attention.WithScale(1),
)
q = q.Permute(ctx, 0, 2, 1, 3)
k = k.Permute(ctx, 0, 2, 1, 3)
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := k.Mulmat(ctx, q)
// logit softcap
kq = kq.Scale(ctx, 1.0/float64(opts.attnLogitSoftcap))
kq = kq.Tanh(ctx)
kq = kq.Scale(ctx, float64(opts.attnLogitSoftcap))
kq = kq.Add(ctx, mask)
kq = kq.Softmax(ctx)
kqv := v.Mulmat(ctx, kq)
kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
return sa.Output.Forward(ctx, kqv)
hiddenState = hiddenState.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
return sa.Output.Forward(ctx, hiddenState)
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {

View File

@@ -7,6 +7,7 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/attention"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model/input"
)
@@ -165,8 +166,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
scaleFactor := 1.0
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
kqv := nn.Attention(ctx, q, k, v, cache, attention.WithScale(1))
kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
return sa.Output.Forward(ctx, kqv)

View File

@@ -1,8 +1,6 @@
package gemma3
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
@@ -28,7 +26,7 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, op
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
attention := nn.Attention(ctx, query, key, value, nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
hiddenState = sa.Output.Forward(ctx, attention)

View File

@@ -8,6 +8,7 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/attention"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model/input"
)
@@ -269,9 +270,9 @@ func (attn TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Ten
value = value.RMSNorm(ctx, nil, opts.eps)
}
attention := nn.Attention(ctx, query, key, value, 1., cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
return attn.Output.Forward(ctx, attention)
hiddenStates = nn.Attention(ctx, query, key, value, cache, attention.WithScale(1))
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*hiddenStates.Dim(1), batchSize)
return attn.Output.Forward(ctx, hiddenStates)
}
type TextMLP struct {

View File

@@ -2,13 +2,13 @@ package gptoss
import (
"cmp"
"math"
"strings"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/attention"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
@@ -137,9 +137,9 @@ func (attn *AttentionBlock) Forward(ctx ml.Context, hiddenStates, positions ml.T
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
attention := nn.AttentionWithSinks(ctx, query, key, value, attn.Sinks, 1/math.Sqrt(float64(opts.headDim())), cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
return attn.Output.Forward(ctx, attention).Add(ctx, residual)
hiddenStates = nn.Attention(ctx, query, key, value, cache, attention.WithSinks(attn.Sinks))
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*hiddenStates.Dim(1), batchSize)
return attn.Output.Forward(ctx, hiddenStates).Add(ctx, residual)
}
type MLPBlock struct {

View File

@@ -2,7 +2,6 @@ package llama
import (
"cmp"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
@@ -131,7 +130,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions, sa.RopeFactors)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions, sa.RopeFactors)
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
attention := nn.Attention(ctx, query, key, value, cache)
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
return sa.Output.Forward(ctx, attention)
}

View File

@@ -45,7 +45,7 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions, attent
query = query.Mul(ctx, attentionScales)
}
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), cache)
attention := nn.Attention(ctx, query, key, value, cache)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, attention)
}

View File

@@ -72,7 +72,7 @@ func (sa *VisionAttention) Forward(ctx ml.Context, hiddenState, cos, sin ml.Tens
query = applyVisionRotaryEmbedding(ctx, query, cos, sin)
key = applyVisionRotaryEmbedding(ctx, key, cos, sin)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), nil)
attention := nn.Attention(ctx, query, key, value, nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), attention.Dim(3))
return sa.Output.Forward(ctx, attention)
}

View File

@@ -79,7 +79,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs, posit
q = q.Mul(ctx, positionsScale)
}
kqv := nn.Attention(ctx, q, k, v, 1.0/math.Sqrt(float64(headDim)), cache)
kqv := nn.Attention(ctx, q, k, v, cache)
kqv = kqv.Reshape(ctx, headDim*opts.numHeads, batchSize)
return sa.Output.Forward(ctx, kqv)
}

View File

@@ -39,7 +39,7 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml
query = applyRotaryPositionEmbeddings(ctx, query, cos, sin)
key = applyRotaryPositionEmbeddings(ctx, key, cos, sin)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim)), nil)
attention := nn.Attention(ctx, query, key, value, nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
return sa.Output.Forward(ctx, attention)
}

View File

@@ -1,7 +1,6 @@
package mllama
import (
"math"
"slices"
"github.com/ollama/ollama/fs"
@@ -34,8 +33,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
attention := nn.Attention(ctx, query, key, value, scaleFactor, cache)
attention := nn.Attention(ctx, query, key, value, cache)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, attention)
@@ -122,20 +120,7 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
}
key, value, _ = cache.Get(ctx)
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
query = query.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := key.MulmatFullPrec(ctx, query)
kq = kq.Scale(ctx, scaleFactor)
kq = kq.Softmax(ctx)
kqv := value.Mulmat(ctx, kq)
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
attention := nn.Attention(ctx, query, key, value, nil)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return ca.Output.Forward(ctx, attention)

View File

@@ -1,7 +1,6 @@
package mllama
import (
"math"
"slices"
"github.com/ollama/ollama/fs"
@@ -30,7 +29,7 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, op
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), nil)
attention := nn.Attention(ctx, query, key, value, nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
return sa.Output.Forward(ctx, attention)
}

View File

@@ -2,7 +2,6 @@ package nomicbert
import (
"cmp"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
@@ -166,7 +165,7 @@ func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, positions ml
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(opts.headDim)), nil)
attention := nn.Attention(ctx, query, key, value, nil)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)

View File

@@ -2,7 +2,6 @@ package olmo3
import (
"fmt"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
@@ -132,7 +131,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, m.numKVHeads, batchSize)
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
attention := nn.Attention(ctx, query, key, value, cache)
attention = attention.Reshape(ctx, m.hiddenSize, batchSize)
return sa.Output.Forward(ctx, attention)

View File

@@ -3,7 +3,6 @@ package qwen2
import (
"cmp"
"fmt"
"math"
"strings"
"github.com/ollama/ollama/fs"
@@ -48,7 +47,7 @@ func (attn Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor,
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
attention := nn.Attention(ctx, query, key, value, cache)
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
return attn.Output.Forward(ctx, attention)

View File

@@ -1,8 +1,6 @@
package qwen25vl
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
@@ -81,8 +79,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
kqv := nn.Attention(ctx, q, k, v, cache)
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, kqv)

View File

@@ -8,6 +8,7 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/ml/nn/attention"
)
func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int) ml.Tensor {
@@ -50,25 +51,9 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, positions,
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
// Scale factor for scaled dot-product attention
scale := 1.0 / math.Sqrt(float64(opts.headDim))
// Scaled dot-product attention
query = query.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := key.MulmatFullPrec(ctx, query)
kq = kq.Scale(ctx, scale)
if mask != nil {
kq = kq.Add(ctx, mask)
}
kq = kq.Softmax(ctx)
kqv := value.Mulmat(ctx, kq)
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2))
return sa.Output.Forward(ctx, attention)
hiddenStates = nn.Attention(ctx, query, key, value, nil, attention.WithMask(mask))
hiddenStates = hiddenStates.Reshape(ctx, opts.hiddenSize, hiddenStates.Dim(2))
return sa.Output.Forward(ctx, hiddenStates)
}
// VisionMLP implements the multi-layer perceptron

View File

@@ -74,7 +74,7 @@ func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor,
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
attention := nn.Attention(ctx, query, key, value, cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
return sa.Output.Forward(ctx, attention)
}

View File

@@ -66,7 +66,7 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tens
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
attention := nn.Attention(ctx, query, key, value, cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
return sa.Output.Forward(ctx, attention)
}

View File

@@ -39,7 +39,7 @@ func (sa *VisionAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Ten
value := sa.Value.Forward(ctx, hiddenStates)
value = value.Reshape(ctx, opts.headDim(), opts.numHeads, value.Dim(1))
attention := nn.Attention(ctx, query, key, value, math.Pow(float64(opts.headDim()), -0.5), nil)
attention := nn.Attention(ctx, query, key, value, nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2))
return sa.Output.Forward(ctx, attention)
}