mirror of
https://github.com/ollama/ollama.git
synced 2026-01-02 12:38:15 -05:00
Compare commits
1 Commits
implement-
...
jessegross
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1200e427f7 |
@@ -813,43 +813,13 @@ func (f GGML) SupportsKVCacheType(cacheType string) bool {
|
||||
}
|
||||
|
||||
// KVCacheTypeIsQuantized checks if the requested cache type is a quantized type
|
||||
func (f GGML) KVCacheTypeIsQuantized(cacheType string) bool {
|
||||
func KVCacheTypeIsQuantized(cacheType string) bool {
|
||||
if cacheType == "" || cacheType == "f16" || cacheType == "f32" || cacheType == "bf16" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// SupportsFlashAttention checks if the model supports flash attention
|
||||
func (f GGML) SupportsFlashAttention() bool {
|
||||
_, isEmbedding := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]
|
||||
if isEmbedding {
|
||||
return false
|
||||
}
|
||||
|
||||
if arch := f.KV().Architecture(); slices.Contains([]string{"gemma2"}, arch) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check head counts match and are non-zero
|
||||
headCountK := f.KV().EmbeddingHeadCountK()
|
||||
headCountV := f.KV().EmbeddingHeadCountV()
|
||||
return headCountK != 0 && headCountV != 0 && headCountK == headCountV
|
||||
}
|
||||
|
||||
// FlashAttention checks if the model should enable flash attention
|
||||
func (f GGML) FlashAttention() bool {
|
||||
return slices.Contains([]string{
|
||||
"bert",
|
||||
"gemma3",
|
||||
"gptoss", "gpt-oss",
|
||||
"mistral3",
|
||||
"olmo3",
|
||||
"qwen3", "qwen3moe",
|
||||
"qwen3vl", "qwen3vlmoe",
|
||||
}, f.KV().String("general.architecture"))
|
||||
}
|
||||
|
||||
// kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type
|
||||
func kvCacheBytesPerElement(cacheType string) float64 {
|
||||
switch cacheType {
|
||||
|
||||
@@ -696,7 +696,7 @@ func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
|
||||
|
||||
func (c *testContext) Compute(...ml.Tensor) {}
|
||||
|
||||
func (c *testContext) Reserve() {}
|
||||
func (c *testContext) Reserve() error { return nil }
|
||||
|
||||
func (c *testContext) MaxGraphNodes() int {
|
||||
return 10
|
||||
|
||||
@@ -188,73 +188,26 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
if len(projectors) > 0 && llamaModel != nil {
|
||||
loadRequest.ProjectorPath = projectors[0]
|
||||
}
|
||||
// Determine if the user has forced FA on or off
|
||||
faUserSet := false
|
||||
if envconfig.FlashAttention(true) == envconfig.FlashAttention(false) {
|
||||
faUserSet = true
|
||||
}
|
||||
|
||||
fa := envconfig.FlashAttention(f.FlashAttention())
|
||||
// Determine if the user has forced FA on or off
|
||||
if envconfig.FlashAttention(true) != envconfig.FlashAttention(false) {
|
||||
loadRequest.FlashAttention = ml.FlashAttentionAuto
|
||||
} else if envconfig.FlashAttention(false) {
|
||||
loadRequest.FlashAttention = ml.FlashAttentionEnabled
|
||||
}
|
||||
|
||||
// This will disable flash attention unless all GPUs on the system support it, even if we end up selecting a subset
|
||||
// that can handle it.
|
||||
if fa && !ml.FlashAttentionSupported(gpus) {
|
||||
// that can handle it. There are still holes in GGML's hardware detection for flash attention.
|
||||
if loadRequest.FlashAttention != ml.FlashAttentionDisabled && !ml.FlashAttentionSupported(gpus) {
|
||||
slog.Warn("flash attention enabled but not supported by gpu")
|
||||
fa = false
|
||||
}
|
||||
|
||||
if fa && !f.SupportsFlashAttention() {
|
||||
slog.Warn("flash attention enabled but not supported by model")
|
||||
fa = false
|
||||
loadRequest.FlashAttention = ml.FlashAttentionDisabled
|
||||
}
|
||||
|
||||
kvct := strings.ToLower(envconfig.KvCacheType())
|
||||
|
||||
if textProcessor == nil {
|
||||
flashAttention := ml.FlashAttentionAuto
|
||||
if faUserSet {
|
||||
if fa {
|
||||
flashAttention = ml.FlashAttentionEnabled
|
||||
} else {
|
||||
flashAttention = ml.FlashAttentionDisabled
|
||||
}
|
||||
}
|
||||
|
||||
if kvct != "" {
|
||||
if f.KVCacheTypeIsQuantized(kvct) {
|
||||
if flashAttention != ml.FlashAttentionEnabled {
|
||||
slog.Warn("OLLAMA_FLASH_ATTENTION must be enabled to use a quantized OLLAMA_KV_CACHE_TYPE", "type", kvct)
|
||||
loadRequest.KvCacheType = ""
|
||||
} else if f.SupportsKVCacheType(kvct) {
|
||||
loadRequest.KvCacheType = kvct
|
||||
} else {
|
||||
slog.Warn("unsupported OLLAMA_KV_CACHE_TYPE", "type", kvct)
|
||||
}
|
||||
} else {
|
||||
if f.SupportsKVCacheType(kvct) {
|
||||
loadRequest.KvCacheType = kvct
|
||||
} else {
|
||||
slog.Warn("unsupported OLLAMA_KV_CACHE_TYPE", "type", kvct)
|
||||
}
|
||||
}
|
||||
}
|
||||
loadRequest.FlashAttention = flashAttention
|
||||
if f.SupportsKVCacheType(kvct) {
|
||||
loadRequest.KvCacheType = kvct
|
||||
} else {
|
||||
// For Ollama engine, use our SupportsFlashAttention logic
|
||||
if fa {
|
||||
slog.Info("enabling flash attention")
|
||||
loadRequest.FlashAttention = ml.FlashAttentionEnabled
|
||||
|
||||
// Flash Attention also supports kv cache quantization
|
||||
// Enable if the requested and kv cache type is supported by the model
|
||||
if f.SupportsKVCacheType(kvct) {
|
||||
loadRequest.KvCacheType = kvct
|
||||
} else {
|
||||
slog.Warn("kv cache type not supported by model", "type", kvct)
|
||||
}
|
||||
} else if kvct != "" && kvct != "f16" {
|
||||
slog.Warn("quantized kv cache requested but flash attention disabled", "type", kvct)
|
||||
}
|
||||
slog.Warn("unsupported OLLAMA_KV_CACHE_TYPE", "type", kvct)
|
||||
}
|
||||
|
||||
gpuLibs := ml.LibraryPaths(gpus)
|
||||
@@ -487,6 +440,7 @@ type LoadRequest struct {
|
||||
|
||||
type LoadResponse struct {
|
||||
Success bool
|
||||
Request LoadRequest // The original request with fields updated that the runner had to modify
|
||||
Memory ml.BackendMemory
|
||||
}
|
||||
|
||||
@@ -511,6 +465,11 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, system
|
||||
s.mem.GPUs[i].Cache = make([]uint64, s.totalLayers)
|
||||
}
|
||||
|
||||
if s.loadRequest.FlashAttention != ml.FlashAttentionEnabled && ggml.KVCacheTypeIsQuantized(s.loadRequest.KvCacheType) {
|
||||
slog.Warn("OLLAMA_FLASH_ATTENTION must be enabled to use a quantized OLLAMA_KV_CACHE_TYPE", "type", s.loadRequest.KvCacheType)
|
||||
s.loadRequest.KvCacheType = ""
|
||||
}
|
||||
|
||||
// Check if embedding model and adjust batch size accordingly
|
||||
_, isEmbedding := s.ggml.KV()[fmt.Sprintf("%s.pooling_type", s.ggml.KV().Architecture())]
|
||||
if isEmbedding && s.loadRequest.BatchSize < s.options.NumCtx {
|
||||
@@ -769,6 +728,7 @@ nextOperation:
|
||||
|
||||
resp.Memory.Log(slog.LevelDebug)
|
||||
slog.Debug("memory", "success", resp.Success, "required", resp.Memory)
|
||||
s.loadRequest = resp.Request // Incorporate any adjustments from the runner to avoid needing to do them again
|
||||
|
||||
pastAllocations[gpuLayers.Hash()] = struct{}{}
|
||||
s.mem = &resp.Memory
|
||||
@@ -822,6 +782,7 @@ nextOperation:
|
||||
|
||||
resp.Memory.Log(slog.LevelDebug)
|
||||
slog.Debug("memory", "success", resp.Success, "required", resp.Memory)
|
||||
s.loadRequest = resp.Request
|
||||
|
||||
if resp.Success {
|
||||
verifyGPULayers, err := s.createLayout(systemInfo, gpus, &resp.Memory, requireFull, backoff)
|
||||
|
||||
@@ -118,7 +118,7 @@ type Context interface {
|
||||
// graph, simply preallocates memory. Typically called with a
|
||||
// worst case graph to ensure all resources are available for
|
||||
// for future inference.
|
||||
Reserve()
|
||||
Reserve() error
|
||||
|
||||
MaxGraphNodes() int
|
||||
Close()
|
||||
|
||||
@@ -684,7 +684,7 @@ func (b *Backend) NewContextSize(n int) ml.Context {
|
||||
}
|
||||
|
||||
func (b *Backend) CacheConfig() ml.CacheConfig {
|
||||
if b.flashAttention == ml.FlashAttentionEnabled {
|
||||
if b.flashAttention != ml.FlashAttentionDisabled {
|
||||
return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD}
|
||||
} else {
|
||||
return ml.CacheConfig{CachePadding: 256, PermutedV: true}
|
||||
@@ -842,11 +842,16 @@ func (c *Context) ComputeWithNotify(cb func(), tensors ...ml.Tensor) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Context) Reserve() {
|
||||
func (c *Context) Reserve() error {
|
||||
if c.batchSize > 0 {
|
||||
C.ggml_backend_sched_set_batch_size(c.b.sched, C.int(c.batchSize))
|
||||
}
|
||||
|
||||
flashBackendAssignments, err := validateGraph(c.graph, c.b.flashAttention)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reserved := C.ggml_backend_sched_reserve(c.b.sched, c.graph)
|
||||
|
||||
slog.Debug("compute graph", "nodes", C.ggml_graph_n_nodes(c.graph), "splits", C.ggml_backend_sched_get_n_splits(c.b.sched))
|
||||
@@ -867,6 +872,142 @@ func (c *Context) Reserve() {
|
||||
if !reserved {
|
||||
panic(ml.ErrNoMem{BackendMemory: *c.b.requiredMemory})
|
||||
}
|
||||
|
||||
// If flash attention is in auto mode, ensure that the scheduler placed the flash attention on the same
|
||||
// (or higher priority) backend as we originally loaded the weights. If it's a lower priority backend (i.e. CPU),
|
||||
// that means that backend likely does not support flash attention for this graph.
|
||||
if c.b.flashAttention == ml.FlashAttentionAuto {
|
||||
flashIdx := 0
|
||||
for i := range C.ggml_graph_n_nodes(c.graph) {
|
||||
node := C.ggml_graph_node(c.graph, i)
|
||||
if node.op != C.GGML_OP_FLASH_ATTN_EXT {
|
||||
continue
|
||||
}
|
||||
|
||||
if flashIdx >= len(flashBackendAssignments) {
|
||||
slog.Debug("flash attention assignment missing",
|
||||
"index", flashIdx,
|
||||
"tensor", C.GoString(C.ggml_get_name(node)))
|
||||
return errors.New("flash attention not supported by backend")
|
||||
}
|
||||
|
||||
assignedBT := flashBackendAssignments[flashIdx]
|
||||
flashIdx++
|
||||
|
||||
if node.buffer == nil || assignedBT == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
bufferType := C.ggml_backend_buffer_get_type(node.buffer)
|
||||
|
||||
actualPriority := bufferTypePriority(bufferType, c.b.schedBufts)
|
||||
expectedPriority := bufferTypePriority(assignedBT, c.b.schedBufts)
|
||||
|
||||
// A lower numbered priority is better here
|
||||
if actualPriority > expectedPriority {
|
||||
slog.Debug("flash attention not supported by backend",
|
||||
"tensor", C.GoString(C.ggml_get_name(node)),
|
||||
"assigned_buffer_type", C.GoString(C.ggml_backend_buft_name(bufferType)),
|
||||
"assigned_priority", actualPriority,
|
||||
"expected_buffer_type", C.GoString(C.ggml_backend_buft_name(assignedBT)),
|
||||
"expected_priority", expectedPriority)
|
||||
return errors.New("flash attention not supported by backend")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func bufferTypePriority(buft C.ggml_backend_buffer_type_t, schedBufts []C.ggml_backend_buffer_type_t) int {
|
||||
for i, b := range schedBufts {
|
||||
if b == buft {
|
||||
return i
|
||||
}
|
||||
}
|
||||
|
||||
return len(schedBufts)
|
||||
}
|
||||
|
||||
// Check that there are no illegal operations and build a mapping of flash attention operation locations
|
||||
// from before the scheduler runs to compare to the result afterwards.
|
||||
func validateGraph(graph *C.struct_ggml_cgraph, flashAttention ml.FlashAttentionType) ([]C.ggml_backend_buffer_type_t, error) {
|
||||
var assignments []C.ggml_backend_buffer_type_t
|
||||
|
||||
for i := range C.ggml_graph_n_nodes(graph) {
|
||||
node := C.ggml_graph_node(graph, i)
|
||||
|
||||
switch node.op {
|
||||
// Only flash attention supports quantized KV cache, so if we have a matmul that uses a quantized input (other than weights),
|
||||
// it means that the model is using its own implementation of attention.
|
||||
case C.GGML_OP_MUL_MAT:
|
||||
for srcIndex := range int(C.GGML_MAX_SRC) {
|
||||
src := node.src[srcIndex]
|
||||
if src == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var quantized *C.struct_ggml_tensor
|
||||
for current := src; current != nil; current = current.view_src {
|
||||
if C.ggml_is_quantized(current._type) {
|
||||
quantized = current
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If matmul has a quantized input, it is only supported if it is weights (due to uniform stride)
|
||||
if quantized != nil &&
|
||||
!(quantized.buffer != nil && C.ggml_backend_buffer_get_usage(quantized.buffer) == C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
|
||||
slog.Debug("unsupported quantized matmul input",
|
||||
"tensor", C.GoString(C.ggml_get_name(node)),
|
||||
"src", C.GoString(C.ggml_get_name(src)),
|
||||
"type", C.GoString(C.ggml_type_name(quantized._type)))
|
||||
return nil, errors.New("unsupported quantized matmul input")
|
||||
}
|
||||
}
|
||||
|
||||
// Build a mapping of flash attention operations to their most direct weight input. We do this before the scheduler runs
|
||||
// because the graph is fully connected. After scheduling, the graph is hard to trace because it is broken up into splits.
|
||||
// We index by flash attention number (more or less equivalent to layer) since that is persistent across scheduling.
|
||||
case C.GGML_OP_FLASH_ATTN_EXT:
|
||||
if flashAttention == ml.FlashAttentionAuto {
|
||||
// Breadth-first search for the first ancestor that has a buffer with weights
|
||||
queue := []*C.struct_ggml_tensor{node}
|
||||
visited := make(map[*C.struct_ggml_tensor]struct{})
|
||||
|
||||
var ancestor *C.struct_ggml_tensor
|
||||
for len(queue) > 0 {
|
||||
current := queue[0]
|
||||
queue = queue[1:]
|
||||
|
||||
if _, ok := visited[current]; ok {
|
||||
continue
|
||||
}
|
||||
visited[current] = struct{}{}
|
||||
|
||||
// Only use weights as reference points - we don't want to use inputs like the cache mask, which are always on the CPU
|
||||
if current.buffer != nil && C.ggml_backend_buffer_get_usage(current.buffer) == C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS {
|
||||
ancestor = current
|
||||
break
|
||||
}
|
||||
|
||||
for srcIndex := range int(C.GGML_MAX_SRC) {
|
||||
if src := current.src[srcIndex]; src != nil {
|
||||
queue = append(queue, src)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var bufferType C.ggml_backend_buffer_type_t
|
||||
if ancestor != nil {
|
||||
bufferType = C.ggml_backend_buffer_get_type(ancestor.buffer)
|
||||
}
|
||||
assignments = append(assignments, bufferType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return assignments, nil
|
||||
}
|
||||
|
||||
func (c *Context) MaxGraphNodes() int {
|
||||
@@ -1679,7 +1820,7 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sin
|
||||
query := t.Permute(ctx, 0, 2, 1, 3)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
|
||||
if t.b.flashAttention == ml.FlashAttentionEnabled {
|
||||
if t.b.flashAttention != ml.FlashAttentionDisabled {
|
||||
value = value.Permute(ctx, 0, 2, 1, 3)
|
||||
|
||||
kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)
|
||||
|
||||
@@ -935,13 +935,13 @@ func (s *Server) load(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
case llm.LoadOperationClose:
|
||||
// No-op for us
|
||||
if err := json.NewEncoder(w).Encode(&llm.LoadResponse{}); err != nil {
|
||||
if err := json.NewEncoder(w).Encode(&llm.LoadResponse{Request: req}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
resp := llm.LoadResponse{Success: true}
|
||||
resp := llm.LoadResponse{Success: true, Request: req}
|
||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
|
||||
@@ -98,7 +98,10 @@ func (m multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Ten
|
||||
}
|
||||
}
|
||||
} else {
|
||||
computeCtx.Reserve()
|
||||
err := computeCtx.Reserve()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1160,22 +1160,12 @@ func (s *Server) reserveWorstCaseGraph(prompt bool) error {
|
||||
}
|
||||
|
||||
ctx.SetBatchSize(batchSize)
|
||||
ctx.Forward(t).Reserve()
|
||||
|
||||
return nil
|
||||
return ctx.Forward(t).Reserve()
|
||||
}
|
||||
|
||||
// allocModel pre-allocates the maximum needed memory for a model
|
||||
// based on the given parameters
|
||||
func (s *Server) allocModel(
|
||||
mpath string,
|
||||
params ml.BackendParams,
|
||||
loraPath []string,
|
||||
parallel int,
|
||||
kvCacheType string,
|
||||
kvSize int,
|
||||
multiUserCache bool,
|
||||
) (panicErr error) {
|
||||
func (s *Server) allocModel(mpath string, req *llm.LoadRequest) (panicErr error) {
|
||||
// Convert memory allocation panics to errors
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
@@ -1192,43 +1182,73 @@ func (s *Server) allocModel(
|
||||
}
|
||||
}()
|
||||
|
||||
var err error
|
||||
s.model, err = model.New(mpath, params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO(jessegross): LoRA loading
|
||||
if len(loraPath) > 0 {
|
||||
return errors.New("loras are not yet implemented")
|
||||
}
|
||||
|
||||
if s.model.Config().Cache == nil {
|
||||
if parallel > 1 {
|
||||
parallel = 1
|
||||
slog.Warn("model does not support caching, disabling parallel processing")
|
||||
reload:
|
||||
for range 2 {
|
||||
params := ml.BackendParams{
|
||||
AllocMemory: req.Operation != llm.LoadOperationFit,
|
||||
NumThreads: req.NumThreads,
|
||||
GPULayers: req.GPULayers,
|
||||
FlashAttention: req.FlashAttention,
|
||||
}
|
||||
if s.batchSize < kvSize {
|
||||
s.batchSize = kvSize
|
||||
slog.Warn("model does not support caching, setting batch size to context length", "batch_size", kvSize)
|
||||
|
||||
var err error
|
||||
s.model, err = model.New(mpath, params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// TODO(jessegross): LoRA loading
|
||||
if len(req.LoraPath) > 0 {
|
||||
return errors.New("loras are not yet implemented")
|
||||
}
|
||||
|
||||
s.parallel = parallel
|
||||
s.seqs = make([]*Sequence, s.parallel)
|
||||
s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
|
||||
if params.FlashAttention == ml.FlashAttentionDisabled && ggml.KVCacheTypeIsQuantized(req.KvCacheType) {
|
||||
slog.Warn("quantized kv cache requested but flash attention disabled", "type", req.KvCacheType)
|
||||
req.KvCacheType = ""
|
||||
}
|
||||
|
||||
if s.model.Config().Cache == nil {
|
||||
if req.Parallel > 1 {
|
||||
req.Parallel = 1
|
||||
slog.Warn("model does not support caching, disabling parallel processing")
|
||||
}
|
||||
if req.BatchSize < req.KvSize {
|
||||
req.BatchSize = req.KvSize
|
||||
slog.Warn("model does not support caching, setting batch size to context length", "batch_size", req.KvSize)
|
||||
}
|
||||
}
|
||||
|
||||
s.cache, err = NewInputCache(s.model, req.KvCacheType, int32(req.KvSize), req.Parallel, req.BatchSize, req.MultiUserCache)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.batchSize = req.BatchSize
|
||||
s.parallel = req.Parallel
|
||||
s.seqs = make([]*Sequence, s.parallel)
|
||||
s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
|
||||
|
||||
for _, prompt := range []bool{true, false} {
|
||||
if err := s.reserveWorstCaseGraph(prompt); err != nil {
|
||||
if req.FlashAttention != ml.FlashAttentionDisabled {
|
||||
slog.Warn("flash attention enabled but not supported by model")
|
||||
req.FlashAttention = ml.FlashAttentionDisabled
|
||||
s.closeModel()
|
||||
continue reload
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if req.FlashAttention == ml.FlashAttentionAuto {
|
||||
req.FlashAttention = ml.FlashAttentionEnabled
|
||||
}
|
||||
|
||||
err = s.reserveWorstCaseGraph(true)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.reserveWorstCaseGraph(false)
|
||||
return errors.New("unable to allocate model")
|
||||
}
|
||||
|
||||
// closeModel frees all memory associated with a model
|
||||
@@ -1243,7 +1263,15 @@ func (s *Server) closeModel() {
|
||||
|
||||
// loadModel loads the weights for a model. The memory must already
|
||||
// have been allocated with allocModel
|
||||
func (s *Server) loadModel() {
|
||||
func (s *Server) loadModel(req llm.LoadRequest) {
|
||||
if req.FlashAttention != ml.FlashAttentionDisabled {
|
||||
slog.Info("enabling flash attention")
|
||||
}
|
||||
|
||||
if ggml.KVCacheTypeIsQuantized(req.KvCacheType) {
|
||||
slog.Info("enabling kv cache quantization", "type", req.KvCacheType)
|
||||
}
|
||||
|
||||
err := s.model.Backend().Load(context.TODO(),
|
||||
func(progress float32) {
|
||||
s.progress = progress
|
||||
@@ -1279,7 +1307,7 @@ func (s *Server) load(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if req.Operation == llm.LoadOperationClose {
|
||||
s.closeModel()
|
||||
if err := json.NewEncoder(w).Encode(&llm.LoadResponse{}); err != nil {
|
||||
if err := json.NewEncoder(w).Encode(&llm.LoadResponse{Request: req}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
@@ -1288,27 +1316,16 @@ func (s *Server) load(w http.ResponseWriter, r *http.Request) {
|
||||
s.lastLoad.Operation = req.Operation
|
||||
loadModel := s.model == nil || !reflect.DeepEqual(req, s.lastLoad)
|
||||
|
||||
s.lastLoad = req
|
||||
|
||||
if loadModel {
|
||||
s.closeModel()
|
||||
|
||||
params := ml.BackendParams{
|
||||
AllocMemory: req.Operation != llm.LoadOperationFit,
|
||||
NumThreads: req.NumThreads,
|
||||
GPULayers: req.GPULayers,
|
||||
FlashAttention: req.FlashAttention,
|
||||
}
|
||||
|
||||
s.batchSize = req.BatchSize
|
||||
|
||||
err := s.allocModel(s.modelPath, params, req.LoraPath, req.Parallel, req.KvCacheType, req.KvSize, req.MultiUserCache)
|
||||
err := s.allocModel(s.modelPath, &req)
|
||||
if err != nil {
|
||||
s.closeModel()
|
||||
|
||||
var noMem ml.ErrNoMem
|
||||
if errors.As(err, &noMem) {
|
||||
resp := llm.LoadResponse{Success: false, Memory: noMem.BackendMemory}
|
||||
resp := llm.LoadResponse{Success: false, Request: req, Memory: noMem.BackendMemory}
|
||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
@@ -1321,6 +1338,7 @@ func (s *Server) load(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
s.lastLoad = req
|
||||
mem := s.model.Backend().BackendMemory()
|
||||
|
||||
switch req.Operation {
|
||||
@@ -1332,10 +1350,10 @@ func (s *Server) load(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
case llm.LoadOperationCommit:
|
||||
s.status = llm.ServerStatusLoadingModel
|
||||
go s.loadModel()
|
||||
go s.loadModel(req)
|
||||
}
|
||||
|
||||
resp := llm.LoadResponse{Success: true, Memory: mem}
|
||||
resp := llm.LoadResponse{Success: true, Request: req, Memory: mem}
|
||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user