Compare commits

...

1 Commits

Author SHA1 Message Date
Jesse Gross
1200e427f7 ollamarunner: Automatically enable flash attention
If a user hasn't explicitly either enabled or disabled flash attention,
automatically enable flash attention if the model supports it and
it would not trigger a fallback to CPU.

This supports text, vision and embedding models as well as automatic
handling of KV cache quantization (which requires flash attention). If a
model does not call the fast fused attention operation, this is detected
and disables any operations that depend on it.
2025-12-17 13:09:49 -08:00
8 changed files with 249 additions and 156 deletions

View File

@@ -813,43 +813,13 @@ func (f GGML) SupportsKVCacheType(cacheType string) bool {
}
// KVCacheTypeIsQuantized checks if the requested cache type is a quantized type
func (f GGML) KVCacheTypeIsQuantized(cacheType string) bool {
func KVCacheTypeIsQuantized(cacheType string) bool {
if cacheType == "" || cacheType == "f16" || cacheType == "f32" || cacheType == "bf16" {
return false
}
return true
}
// SupportsFlashAttention checks if the model supports flash attention
func (f GGML) SupportsFlashAttention() bool {
_, isEmbedding := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]
if isEmbedding {
return false
}
if arch := f.KV().Architecture(); slices.Contains([]string{"gemma2"}, arch) {
return false
}
// Check head counts match and are non-zero
headCountK := f.KV().EmbeddingHeadCountK()
headCountV := f.KV().EmbeddingHeadCountV()
return headCountK != 0 && headCountV != 0 && headCountK == headCountV
}
// FlashAttention checks if the model should enable flash attention
func (f GGML) FlashAttention() bool {
return slices.Contains([]string{
"bert",
"gemma3",
"gptoss", "gpt-oss",
"mistral3",
"olmo3",
"qwen3", "qwen3moe",
"qwen3vl", "qwen3vlmoe",
}, f.KV().String("general.architecture"))
}
// kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type
func kvCacheBytesPerElement(cacheType string) float64 {
switch cacheType {

View File

@@ -696,7 +696,7 @@ func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
func (c *testContext) Compute(...ml.Tensor) {}
func (c *testContext) Reserve() {}
func (c *testContext) Reserve() error { return nil }
func (c *testContext) MaxGraphNodes() int {
return 10

View File

@@ -188,73 +188,26 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
if len(projectors) > 0 && llamaModel != nil {
loadRequest.ProjectorPath = projectors[0]
}
// Determine if the user has forced FA on or off
faUserSet := false
if envconfig.FlashAttention(true) == envconfig.FlashAttention(false) {
faUserSet = true
}
fa := envconfig.FlashAttention(f.FlashAttention())
// Determine if the user has forced FA on or off
if envconfig.FlashAttention(true) != envconfig.FlashAttention(false) {
loadRequest.FlashAttention = ml.FlashAttentionAuto
} else if envconfig.FlashAttention(false) {
loadRequest.FlashAttention = ml.FlashAttentionEnabled
}
// This will disable flash attention unless all GPUs on the system support it, even if we end up selecting a subset
// that can handle it.
if fa && !ml.FlashAttentionSupported(gpus) {
// that can handle it. There are still holes in GGML's hardware detection for flash attention.
if loadRequest.FlashAttention != ml.FlashAttentionDisabled && !ml.FlashAttentionSupported(gpus) {
slog.Warn("flash attention enabled but not supported by gpu")
fa = false
}
if fa && !f.SupportsFlashAttention() {
slog.Warn("flash attention enabled but not supported by model")
fa = false
loadRequest.FlashAttention = ml.FlashAttentionDisabled
}
kvct := strings.ToLower(envconfig.KvCacheType())
if textProcessor == nil {
flashAttention := ml.FlashAttentionAuto
if faUserSet {
if fa {
flashAttention = ml.FlashAttentionEnabled
} else {
flashAttention = ml.FlashAttentionDisabled
}
}
if kvct != "" {
if f.KVCacheTypeIsQuantized(kvct) {
if flashAttention != ml.FlashAttentionEnabled {
slog.Warn("OLLAMA_FLASH_ATTENTION must be enabled to use a quantized OLLAMA_KV_CACHE_TYPE", "type", kvct)
loadRequest.KvCacheType = ""
} else if f.SupportsKVCacheType(kvct) {
loadRequest.KvCacheType = kvct
} else {
slog.Warn("unsupported OLLAMA_KV_CACHE_TYPE", "type", kvct)
}
} else {
if f.SupportsKVCacheType(kvct) {
loadRequest.KvCacheType = kvct
} else {
slog.Warn("unsupported OLLAMA_KV_CACHE_TYPE", "type", kvct)
}
}
}
loadRequest.FlashAttention = flashAttention
if f.SupportsKVCacheType(kvct) {
loadRequest.KvCacheType = kvct
} else {
// For Ollama engine, use our SupportsFlashAttention logic
if fa {
slog.Info("enabling flash attention")
loadRequest.FlashAttention = ml.FlashAttentionEnabled
// Flash Attention also supports kv cache quantization
// Enable if the requested and kv cache type is supported by the model
if f.SupportsKVCacheType(kvct) {
loadRequest.KvCacheType = kvct
} else {
slog.Warn("kv cache type not supported by model", "type", kvct)
}
} else if kvct != "" && kvct != "f16" {
slog.Warn("quantized kv cache requested but flash attention disabled", "type", kvct)
}
slog.Warn("unsupported OLLAMA_KV_CACHE_TYPE", "type", kvct)
}
gpuLibs := ml.LibraryPaths(gpus)
@@ -487,6 +440,7 @@ type LoadRequest struct {
type LoadResponse struct {
Success bool
Request LoadRequest // The original request with fields updated that the runner had to modify
Memory ml.BackendMemory
}
@@ -511,6 +465,11 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, system
s.mem.GPUs[i].Cache = make([]uint64, s.totalLayers)
}
if s.loadRequest.FlashAttention != ml.FlashAttentionEnabled && ggml.KVCacheTypeIsQuantized(s.loadRequest.KvCacheType) {
slog.Warn("OLLAMA_FLASH_ATTENTION must be enabled to use a quantized OLLAMA_KV_CACHE_TYPE", "type", s.loadRequest.KvCacheType)
s.loadRequest.KvCacheType = ""
}
// Check if embedding model and adjust batch size accordingly
_, isEmbedding := s.ggml.KV()[fmt.Sprintf("%s.pooling_type", s.ggml.KV().Architecture())]
if isEmbedding && s.loadRequest.BatchSize < s.options.NumCtx {
@@ -769,6 +728,7 @@ nextOperation:
resp.Memory.Log(slog.LevelDebug)
slog.Debug("memory", "success", resp.Success, "required", resp.Memory)
s.loadRequest = resp.Request // Incorporate any adjustments from the runner to avoid needing to do them again
pastAllocations[gpuLayers.Hash()] = struct{}{}
s.mem = &resp.Memory
@@ -822,6 +782,7 @@ nextOperation:
resp.Memory.Log(slog.LevelDebug)
slog.Debug("memory", "success", resp.Success, "required", resp.Memory)
s.loadRequest = resp.Request
if resp.Success {
verifyGPULayers, err := s.createLayout(systemInfo, gpus, &resp.Memory, requireFull, backoff)

View File

@@ -118,7 +118,7 @@ type Context interface {
// graph, simply preallocates memory. Typically called with a
// worst case graph to ensure all resources are available for
// for future inference.
Reserve()
Reserve() error
MaxGraphNodes() int
Close()

View File

@@ -684,7 +684,7 @@ func (b *Backend) NewContextSize(n int) ml.Context {
}
func (b *Backend) CacheConfig() ml.CacheConfig {
if b.flashAttention == ml.FlashAttentionEnabled {
if b.flashAttention != ml.FlashAttentionDisabled {
return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD}
} else {
return ml.CacheConfig{CachePadding: 256, PermutedV: true}
@@ -842,11 +842,16 @@ func (c *Context) ComputeWithNotify(cb func(), tensors ...ml.Tensor) {
}
}
func (c *Context) Reserve() {
func (c *Context) Reserve() error {
if c.batchSize > 0 {
C.ggml_backend_sched_set_batch_size(c.b.sched, C.int(c.batchSize))
}
flashBackendAssignments, err := validateGraph(c.graph, c.b.flashAttention)
if err != nil {
return err
}
reserved := C.ggml_backend_sched_reserve(c.b.sched, c.graph)
slog.Debug("compute graph", "nodes", C.ggml_graph_n_nodes(c.graph), "splits", C.ggml_backend_sched_get_n_splits(c.b.sched))
@@ -867,6 +872,142 @@ func (c *Context) Reserve() {
if !reserved {
panic(ml.ErrNoMem{BackendMemory: *c.b.requiredMemory})
}
// If flash attention is in auto mode, ensure that the scheduler placed the flash attention on the same
// (or higher priority) backend as we originally loaded the weights. If it's a lower priority backend (i.e. CPU),
// that means that backend likely does not support flash attention for this graph.
if c.b.flashAttention == ml.FlashAttentionAuto {
flashIdx := 0
for i := range C.ggml_graph_n_nodes(c.graph) {
node := C.ggml_graph_node(c.graph, i)
if node.op != C.GGML_OP_FLASH_ATTN_EXT {
continue
}
if flashIdx >= len(flashBackendAssignments) {
slog.Debug("flash attention assignment missing",
"index", flashIdx,
"tensor", C.GoString(C.ggml_get_name(node)))
return errors.New("flash attention not supported by backend")
}
assignedBT := flashBackendAssignments[flashIdx]
flashIdx++
if node.buffer == nil || assignedBT == nil {
continue
}
bufferType := C.ggml_backend_buffer_get_type(node.buffer)
actualPriority := bufferTypePriority(bufferType, c.b.schedBufts)
expectedPriority := bufferTypePriority(assignedBT, c.b.schedBufts)
// A lower numbered priority is better here
if actualPriority > expectedPriority {
slog.Debug("flash attention not supported by backend",
"tensor", C.GoString(C.ggml_get_name(node)),
"assigned_buffer_type", C.GoString(C.ggml_backend_buft_name(bufferType)),
"assigned_priority", actualPriority,
"expected_buffer_type", C.GoString(C.ggml_backend_buft_name(assignedBT)),
"expected_priority", expectedPriority)
return errors.New("flash attention not supported by backend")
}
}
}
return nil
}
func bufferTypePriority(buft C.ggml_backend_buffer_type_t, schedBufts []C.ggml_backend_buffer_type_t) int {
for i, b := range schedBufts {
if b == buft {
return i
}
}
return len(schedBufts)
}
// Check that there are no illegal operations and build a mapping of flash attention operation locations
// from before the scheduler runs to compare to the result afterwards.
func validateGraph(graph *C.struct_ggml_cgraph, flashAttention ml.FlashAttentionType) ([]C.ggml_backend_buffer_type_t, error) {
var assignments []C.ggml_backend_buffer_type_t
for i := range C.ggml_graph_n_nodes(graph) {
node := C.ggml_graph_node(graph, i)
switch node.op {
// Only flash attention supports quantized KV cache, so if we have a matmul that uses a quantized input (other than weights),
// it means that the model is using its own implementation of attention.
case C.GGML_OP_MUL_MAT:
for srcIndex := range int(C.GGML_MAX_SRC) {
src := node.src[srcIndex]
if src == nil {
continue
}
var quantized *C.struct_ggml_tensor
for current := src; current != nil; current = current.view_src {
if C.ggml_is_quantized(current._type) {
quantized = current
break
}
}
// If matmul has a quantized input, it is only supported if it is weights (due to uniform stride)
if quantized != nil &&
!(quantized.buffer != nil && C.ggml_backend_buffer_get_usage(quantized.buffer) == C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
slog.Debug("unsupported quantized matmul input",
"tensor", C.GoString(C.ggml_get_name(node)),
"src", C.GoString(C.ggml_get_name(src)),
"type", C.GoString(C.ggml_type_name(quantized._type)))
return nil, errors.New("unsupported quantized matmul input")
}
}
// Build a mapping of flash attention operations to their most direct weight input. We do this before the scheduler runs
// because the graph is fully connected. After scheduling, the graph is hard to trace because it is broken up into splits.
// We index by flash attention number (more or less equivalent to layer) since that is persistent across scheduling.
case C.GGML_OP_FLASH_ATTN_EXT:
if flashAttention == ml.FlashAttentionAuto {
// Breadth-first search for the first ancestor that has a buffer with weights
queue := []*C.struct_ggml_tensor{node}
visited := make(map[*C.struct_ggml_tensor]struct{})
var ancestor *C.struct_ggml_tensor
for len(queue) > 0 {
current := queue[0]
queue = queue[1:]
if _, ok := visited[current]; ok {
continue
}
visited[current] = struct{}{}
// Only use weights as reference points - we don't want to use inputs like the cache mask, which are always on the CPU
if current.buffer != nil && C.ggml_backend_buffer_get_usage(current.buffer) == C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS {
ancestor = current
break
}
for srcIndex := range int(C.GGML_MAX_SRC) {
if src := current.src[srcIndex]; src != nil {
queue = append(queue, src)
}
}
}
var bufferType C.ggml_backend_buffer_type_t
if ancestor != nil {
bufferType = C.ggml_backend_buffer_get_type(ancestor.buffer)
}
assignments = append(assignments, bufferType)
}
}
}
return assignments, nil
}
func (c *Context) MaxGraphNodes() int {
@@ -1679,7 +1820,7 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sin
query := t.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
if t.b.flashAttention == ml.FlashAttentionEnabled {
if t.b.flashAttention != ml.FlashAttentionDisabled {
value = value.Permute(ctx, 0, 2, 1, 3)
kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)

View File

@@ -935,13 +935,13 @@ func (s *Server) load(w http.ResponseWriter, r *http.Request) {
case llm.LoadOperationClose:
// No-op for us
if err := json.NewEncoder(w).Encode(&llm.LoadResponse{}); err != nil {
if err := json.NewEncoder(w).Encode(&llm.LoadResponse{Request: req}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
}
return
}
resp := llm.LoadResponse{Success: true}
resp := llm.LoadResponse{Success: true, Request: req}
if err := json.NewEncoder(w).Encode(&resp); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
return

View File

@@ -98,7 +98,10 @@ func (m multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Ten
}
}
} else {
computeCtx.Reserve()
err := computeCtx.Reserve()
if err != nil {
return nil, err
}
}
}

View File

@@ -1160,22 +1160,12 @@ func (s *Server) reserveWorstCaseGraph(prompt bool) error {
}
ctx.SetBatchSize(batchSize)
ctx.Forward(t).Reserve()
return nil
return ctx.Forward(t).Reserve()
}
// allocModel pre-allocates the maximum needed memory for a model
// based on the given parameters
func (s *Server) allocModel(
mpath string,
params ml.BackendParams,
loraPath []string,
parallel int,
kvCacheType string,
kvSize int,
multiUserCache bool,
) (panicErr error) {
func (s *Server) allocModel(mpath string, req *llm.LoadRequest) (panicErr error) {
// Convert memory allocation panics to errors
defer func() {
if r := recover(); r != nil {
@@ -1192,43 +1182,73 @@ func (s *Server) allocModel(
}
}()
var err error
s.model, err = model.New(mpath, params)
if err != nil {
return err
}
// TODO(jessegross): LoRA loading
if len(loraPath) > 0 {
return errors.New("loras are not yet implemented")
}
if s.model.Config().Cache == nil {
if parallel > 1 {
parallel = 1
slog.Warn("model does not support caching, disabling parallel processing")
reload:
for range 2 {
params := ml.BackendParams{
AllocMemory: req.Operation != llm.LoadOperationFit,
NumThreads: req.NumThreads,
GPULayers: req.GPULayers,
FlashAttention: req.FlashAttention,
}
if s.batchSize < kvSize {
s.batchSize = kvSize
slog.Warn("model does not support caching, setting batch size to context length", "batch_size", kvSize)
var err error
s.model, err = model.New(mpath, params)
if err != nil {
return err
}
}
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache)
if err != nil {
return err
}
// TODO(jessegross): LoRA loading
if len(req.LoraPath) > 0 {
return errors.New("loras are not yet implemented")
}
s.parallel = parallel
s.seqs = make([]*Sequence, s.parallel)
s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
if params.FlashAttention == ml.FlashAttentionDisabled && ggml.KVCacheTypeIsQuantized(req.KvCacheType) {
slog.Warn("quantized kv cache requested but flash attention disabled", "type", req.KvCacheType)
req.KvCacheType = ""
}
if s.model.Config().Cache == nil {
if req.Parallel > 1 {
req.Parallel = 1
slog.Warn("model does not support caching, disabling parallel processing")
}
if req.BatchSize < req.KvSize {
req.BatchSize = req.KvSize
slog.Warn("model does not support caching, setting batch size to context length", "batch_size", req.KvSize)
}
}
s.cache, err = NewInputCache(s.model, req.KvCacheType, int32(req.KvSize), req.Parallel, req.BatchSize, req.MultiUserCache)
if err != nil {
return err
}
s.batchSize = req.BatchSize
s.parallel = req.Parallel
s.seqs = make([]*Sequence, s.parallel)
s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
for _, prompt := range []bool{true, false} {
if err := s.reserveWorstCaseGraph(prompt); err != nil {
if req.FlashAttention != ml.FlashAttentionDisabled {
slog.Warn("flash attention enabled but not supported by model")
req.FlashAttention = ml.FlashAttentionDisabled
s.closeModel()
continue reload
}
return err
}
}
if req.FlashAttention == ml.FlashAttentionAuto {
req.FlashAttention = ml.FlashAttentionEnabled
}
err = s.reserveWorstCaseGraph(true)
if err != nil {
return nil
}
return s.reserveWorstCaseGraph(false)
return errors.New("unable to allocate model")
}
// closeModel frees all memory associated with a model
@@ -1243,7 +1263,15 @@ func (s *Server) closeModel() {
// loadModel loads the weights for a model. The memory must already
// have been allocated with allocModel
func (s *Server) loadModel() {
func (s *Server) loadModel(req llm.LoadRequest) {
if req.FlashAttention != ml.FlashAttentionDisabled {
slog.Info("enabling flash attention")
}
if ggml.KVCacheTypeIsQuantized(req.KvCacheType) {
slog.Info("enabling kv cache quantization", "type", req.KvCacheType)
}
err := s.model.Backend().Load(context.TODO(),
func(progress float32) {
s.progress = progress
@@ -1279,7 +1307,7 @@ func (s *Server) load(w http.ResponseWriter, r *http.Request) {
if req.Operation == llm.LoadOperationClose {
s.closeModel()
if err := json.NewEncoder(w).Encode(&llm.LoadResponse{}); err != nil {
if err := json.NewEncoder(w).Encode(&llm.LoadResponse{Request: req}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
}
return
@@ -1288,27 +1316,16 @@ func (s *Server) load(w http.ResponseWriter, r *http.Request) {
s.lastLoad.Operation = req.Operation
loadModel := s.model == nil || !reflect.DeepEqual(req, s.lastLoad)
s.lastLoad = req
if loadModel {
s.closeModel()
params := ml.BackendParams{
AllocMemory: req.Operation != llm.LoadOperationFit,
NumThreads: req.NumThreads,
GPULayers: req.GPULayers,
FlashAttention: req.FlashAttention,
}
s.batchSize = req.BatchSize
err := s.allocModel(s.modelPath, params, req.LoraPath, req.Parallel, req.KvCacheType, req.KvSize, req.MultiUserCache)
err := s.allocModel(s.modelPath, &req)
if err != nil {
s.closeModel()
var noMem ml.ErrNoMem
if errors.As(err, &noMem) {
resp := llm.LoadResponse{Success: false, Memory: noMem.BackendMemory}
resp := llm.LoadResponse{Success: false, Request: req, Memory: noMem.BackendMemory}
if err := json.NewEncoder(w).Encode(&resp); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
}
@@ -1321,6 +1338,7 @@ func (s *Server) load(w http.ResponseWriter, r *http.Request) {
}
}
s.lastLoad = req
mem := s.model.Backend().BackendMemory()
switch req.Operation {
@@ -1332,10 +1350,10 @@ func (s *Server) load(w http.ResponseWriter, r *http.Request) {
case llm.LoadOperationCommit:
s.status = llm.ServerStatusLoadingModel
go s.loadModel()
go s.loadModel(req)
}
resp := llm.LoadResponse{Success: true, Memory: mem}
resp := llm.LoadResponse{Success: true, Request: req, Memory: mem}
if err := json.NewEncoder(w).Encode(&resp); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
return