mirror of
https://github.com/ollama/ollama.git
synced 2026-01-19 04:51:17 -05:00
Compare commits
1 Commits
parth/decr
...
jessegross
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1200e427f7 |
@@ -813,43 +813,13 @@ func (f GGML) SupportsKVCacheType(cacheType string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// KVCacheTypeIsQuantized checks if the requested cache type is a quantized type
|
// KVCacheTypeIsQuantized checks if the requested cache type is a quantized type
|
||||||
func (f GGML) KVCacheTypeIsQuantized(cacheType string) bool {
|
func KVCacheTypeIsQuantized(cacheType string) bool {
|
||||||
if cacheType == "" || cacheType == "f16" || cacheType == "f32" || cacheType == "bf16" {
|
if cacheType == "" || cacheType == "f16" || cacheType == "f32" || cacheType == "bf16" {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// SupportsFlashAttention checks if the model supports flash attention
|
|
||||||
func (f GGML) SupportsFlashAttention() bool {
|
|
||||||
_, isEmbedding := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]
|
|
||||||
if isEmbedding {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if arch := f.KV().Architecture(); slices.Contains([]string{"gemma2"}, arch) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check head counts match and are non-zero
|
|
||||||
headCountK := f.KV().EmbeddingHeadCountK()
|
|
||||||
headCountV := f.KV().EmbeddingHeadCountV()
|
|
||||||
return headCountK != 0 && headCountV != 0 && headCountK == headCountV
|
|
||||||
}
|
|
||||||
|
|
||||||
// FlashAttention checks if the model should enable flash attention
|
|
||||||
func (f GGML) FlashAttention() bool {
|
|
||||||
return slices.Contains([]string{
|
|
||||||
"bert",
|
|
||||||
"gemma3",
|
|
||||||
"gptoss", "gpt-oss",
|
|
||||||
"mistral3",
|
|
||||||
"olmo3",
|
|
||||||
"qwen3", "qwen3moe",
|
|
||||||
"qwen3vl", "qwen3vlmoe",
|
|
||||||
}, f.KV().String("general.architecture"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type
|
// kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type
|
||||||
func kvCacheBytesPerElement(cacheType string) float64 {
|
func kvCacheBytesPerElement(cacheType string) float64 {
|
||||||
switch cacheType {
|
switch cacheType {
|
||||||
|
|||||||
@@ -696,7 +696,7 @@ func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
|
|||||||
|
|
||||||
func (c *testContext) Compute(...ml.Tensor) {}
|
func (c *testContext) Compute(...ml.Tensor) {}
|
||||||
|
|
||||||
func (c *testContext) Reserve() {}
|
func (c *testContext) Reserve() error { return nil }
|
||||||
|
|
||||||
func (c *testContext) MaxGraphNodes() int {
|
func (c *testContext) MaxGraphNodes() int {
|
||||||
return 10
|
return 10
|
||||||
|
|||||||
@@ -188,73 +188,26 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
|||||||
if len(projectors) > 0 && llamaModel != nil {
|
if len(projectors) > 0 && llamaModel != nil {
|
||||||
loadRequest.ProjectorPath = projectors[0]
|
loadRequest.ProjectorPath = projectors[0]
|
||||||
}
|
}
|
||||||
// Determine if the user has forced FA on or off
|
|
||||||
faUserSet := false
|
|
||||||
if envconfig.FlashAttention(true) == envconfig.FlashAttention(false) {
|
|
||||||
faUserSet = true
|
|
||||||
}
|
|
||||||
|
|
||||||
fa := envconfig.FlashAttention(f.FlashAttention())
|
// Determine if the user has forced FA on or off
|
||||||
|
if envconfig.FlashAttention(true) != envconfig.FlashAttention(false) {
|
||||||
|
loadRequest.FlashAttention = ml.FlashAttentionAuto
|
||||||
|
} else if envconfig.FlashAttention(false) {
|
||||||
|
loadRequest.FlashAttention = ml.FlashAttentionEnabled
|
||||||
|
}
|
||||||
|
|
||||||
// This will disable flash attention unless all GPUs on the system support it, even if we end up selecting a subset
|
// This will disable flash attention unless all GPUs on the system support it, even if we end up selecting a subset
|
||||||
// that can handle it.
|
// that can handle it. There are still holes in GGML's hardware detection for flash attention.
|
||||||
if fa && !ml.FlashAttentionSupported(gpus) {
|
if loadRequest.FlashAttention != ml.FlashAttentionDisabled && !ml.FlashAttentionSupported(gpus) {
|
||||||
slog.Warn("flash attention enabled but not supported by gpu")
|
slog.Warn("flash attention enabled but not supported by gpu")
|
||||||
fa = false
|
loadRequest.FlashAttention = ml.FlashAttentionDisabled
|
||||||
}
|
|
||||||
|
|
||||||
if fa && !f.SupportsFlashAttention() {
|
|
||||||
slog.Warn("flash attention enabled but not supported by model")
|
|
||||||
fa = false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
kvct := strings.ToLower(envconfig.KvCacheType())
|
kvct := strings.ToLower(envconfig.KvCacheType())
|
||||||
|
if f.SupportsKVCacheType(kvct) {
|
||||||
if textProcessor == nil {
|
loadRequest.KvCacheType = kvct
|
||||||
flashAttention := ml.FlashAttentionAuto
|
|
||||||
if faUserSet {
|
|
||||||
if fa {
|
|
||||||
flashAttention = ml.FlashAttentionEnabled
|
|
||||||
} else {
|
|
||||||
flashAttention = ml.FlashAttentionDisabled
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if kvct != "" {
|
|
||||||
if f.KVCacheTypeIsQuantized(kvct) {
|
|
||||||
if flashAttention != ml.FlashAttentionEnabled {
|
|
||||||
slog.Warn("OLLAMA_FLASH_ATTENTION must be enabled to use a quantized OLLAMA_KV_CACHE_TYPE", "type", kvct)
|
|
||||||
loadRequest.KvCacheType = ""
|
|
||||||
} else if f.SupportsKVCacheType(kvct) {
|
|
||||||
loadRequest.KvCacheType = kvct
|
|
||||||
} else {
|
|
||||||
slog.Warn("unsupported OLLAMA_KV_CACHE_TYPE", "type", kvct)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if f.SupportsKVCacheType(kvct) {
|
|
||||||
loadRequest.KvCacheType = kvct
|
|
||||||
} else {
|
|
||||||
slog.Warn("unsupported OLLAMA_KV_CACHE_TYPE", "type", kvct)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
loadRequest.FlashAttention = flashAttention
|
|
||||||
} else {
|
} else {
|
||||||
// For Ollama engine, use our SupportsFlashAttention logic
|
slog.Warn("unsupported OLLAMA_KV_CACHE_TYPE", "type", kvct)
|
||||||
if fa {
|
|
||||||
slog.Info("enabling flash attention")
|
|
||||||
loadRequest.FlashAttention = ml.FlashAttentionEnabled
|
|
||||||
|
|
||||||
// Flash Attention also supports kv cache quantization
|
|
||||||
// Enable if the requested and kv cache type is supported by the model
|
|
||||||
if f.SupportsKVCacheType(kvct) {
|
|
||||||
loadRequest.KvCacheType = kvct
|
|
||||||
} else {
|
|
||||||
slog.Warn("kv cache type not supported by model", "type", kvct)
|
|
||||||
}
|
|
||||||
} else if kvct != "" && kvct != "f16" {
|
|
||||||
slog.Warn("quantized kv cache requested but flash attention disabled", "type", kvct)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
gpuLibs := ml.LibraryPaths(gpus)
|
gpuLibs := ml.LibraryPaths(gpus)
|
||||||
@@ -487,6 +440,7 @@ type LoadRequest struct {
|
|||||||
|
|
||||||
type LoadResponse struct {
|
type LoadResponse struct {
|
||||||
Success bool
|
Success bool
|
||||||
|
Request LoadRequest // The original request with fields updated that the runner had to modify
|
||||||
Memory ml.BackendMemory
|
Memory ml.BackendMemory
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -511,6 +465,11 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, system
|
|||||||
s.mem.GPUs[i].Cache = make([]uint64, s.totalLayers)
|
s.mem.GPUs[i].Cache = make([]uint64, s.totalLayers)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.loadRequest.FlashAttention != ml.FlashAttentionEnabled && ggml.KVCacheTypeIsQuantized(s.loadRequest.KvCacheType) {
|
||||||
|
slog.Warn("OLLAMA_FLASH_ATTENTION must be enabled to use a quantized OLLAMA_KV_CACHE_TYPE", "type", s.loadRequest.KvCacheType)
|
||||||
|
s.loadRequest.KvCacheType = ""
|
||||||
|
}
|
||||||
|
|
||||||
// Check if embedding model and adjust batch size accordingly
|
// Check if embedding model and adjust batch size accordingly
|
||||||
_, isEmbedding := s.ggml.KV()[fmt.Sprintf("%s.pooling_type", s.ggml.KV().Architecture())]
|
_, isEmbedding := s.ggml.KV()[fmt.Sprintf("%s.pooling_type", s.ggml.KV().Architecture())]
|
||||||
if isEmbedding && s.loadRequest.BatchSize < s.options.NumCtx {
|
if isEmbedding && s.loadRequest.BatchSize < s.options.NumCtx {
|
||||||
@@ -769,6 +728,7 @@ nextOperation:
|
|||||||
|
|
||||||
resp.Memory.Log(slog.LevelDebug)
|
resp.Memory.Log(slog.LevelDebug)
|
||||||
slog.Debug("memory", "success", resp.Success, "required", resp.Memory)
|
slog.Debug("memory", "success", resp.Success, "required", resp.Memory)
|
||||||
|
s.loadRequest = resp.Request // Incorporate any adjustments from the runner to avoid needing to do them again
|
||||||
|
|
||||||
pastAllocations[gpuLayers.Hash()] = struct{}{}
|
pastAllocations[gpuLayers.Hash()] = struct{}{}
|
||||||
s.mem = &resp.Memory
|
s.mem = &resp.Memory
|
||||||
@@ -822,6 +782,7 @@ nextOperation:
|
|||||||
|
|
||||||
resp.Memory.Log(slog.LevelDebug)
|
resp.Memory.Log(slog.LevelDebug)
|
||||||
slog.Debug("memory", "success", resp.Success, "required", resp.Memory)
|
slog.Debug("memory", "success", resp.Success, "required", resp.Memory)
|
||||||
|
s.loadRequest = resp.Request
|
||||||
|
|
||||||
if resp.Success {
|
if resp.Success {
|
||||||
verifyGPULayers, err := s.createLayout(systemInfo, gpus, &resp.Memory, requireFull, backoff)
|
verifyGPULayers, err := s.createLayout(systemInfo, gpus, &resp.Memory, requireFull, backoff)
|
||||||
|
|||||||
@@ -118,7 +118,7 @@ type Context interface {
|
|||||||
// graph, simply preallocates memory. Typically called with a
|
// graph, simply preallocates memory. Typically called with a
|
||||||
// worst case graph to ensure all resources are available for
|
// worst case graph to ensure all resources are available for
|
||||||
// for future inference.
|
// for future inference.
|
||||||
Reserve()
|
Reserve() error
|
||||||
|
|
||||||
MaxGraphNodes() int
|
MaxGraphNodes() int
|
||||||
Close()
|
Close()
|
||||||
|
|||||||
@@ -684,7 +684,7 @@ func (b *Backend) NewContextSize(n int) ml.Context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *Backend) CacheConfig() ml.CacheConfig {
|
func (b *Backend) CacheConfig() ml.CacheConfig {
|
||||||
if b.flashAttention == ml.FlashAttentionEnabled {
|
if b.flashAttention != ml.FlashAttentionDisabled {
|
||||||
return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD}
|
return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD}
|
||||||
} else {
|
} else {
|
||||||
return ml.CacheConfig{CachePadding: 256, PermutedV: true}
|
return ml.CacheConfig{CachePadding: 256, PermutedV: true}
|
||||||
@@ -842,11 +842,16 @@ func (c *Context) ComputeWithNotify(cb func(), tensors ...ml.Tensor) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) Reserve() {
|
func (c *Context) Reserve() error {
|
||||||
if c.batchSize > 0 {
|
if c.batchSize > 0 {
|
||||||
C.ggml_backend_sched_set_batch_size(c.b.sched, C.int(c.batchSize))
|
C.ggml_backend_sched_set_batch_size(c.b.sched, C.int(c.batchSize))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
flashBackendAssignments, err := validateGraph(c.graph, c.b.flashAttention)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
reserved := C.ggml_backend_sched_reserve(c.b.sched, c.graph)
|
reserved := C.ggml_backend_sched_reserve(c.b.sched, c.graph)
|
||||||
|
|
||||||
slog.Debug("compute graph", "nodes", C.ggml_graph_n_nodes(c.graph), "splits", C.ggml_backend_sched_get_n_splits(c.b.sched))
|
slog.Debug("compute graph", "nodes", C.ggml_graph_n_nodes(c.graph), "splits", C.ggml_backend_sched_get_n_splits(c.b.sched))
|
||||||
@@ -867,6 +872,142 @@ func (c *Context) Reserve() {
|
|||||||
if !reserved {
|
if !reserved {
|
||||||
panic(ml.ErrNoMem{BackendMemory: *c.b.requiredMemory})
|
panic(ml.ErrNoMem{BackendMemory: *c.b.requiredMemory})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If flash attention is in auto mode, ensure that the scheduler placed the flash attention on the same
|
||||||
|
// (or higher priority) backend as we originally loaded the weights. If it's a lower priority backend (i.e. CPU),
|
||||||
|
// that means that backend likely does not support flash attention for this graph.
|
||||||
|
if c.b.flashAttention == ml.FlashAttentionAuto {
|
||||||
|
flashIdx := 0
|
||||||
|
for i := range C.ggml_graph_n_nodes(c.graph) {
|
||||||
|
node := C.ggml_graph_node(c.graph, i)
|
||||||
|
if node.op != C.GGML_OP_FLASH_ATTN_EXT {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if flashIdx >= len(flashBackendAssignments) {
|
||||||
|
slog.Debug("flash attention assignment missing",
|
||||||
|
"index", flashIdx,
|
||||||
|
"tensor", C.GoString(C.ggml_get_name(node)))
|
||||||
|
return errors.New("flash attention not supported by backend")
|
||||||
|
}
|
||||||
|
|
||||||
|
assignedBT := flashBackendAssignments[flashIdx]
|
||||||
|
flashIdx++
|
||||||
|
|
||||||
|
if node.buffer == nil || assignedBT == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
bufferType := C.ggml_backend_buffer_get_type(node.buffer)
|
||||||
|
|
||||||
|
actualPriority := bufferTypePriority(bufferType, c.b.schedBufts)
|
||||||
|
expectedPriority := bufferTypePriority(assignedBT, c.b.schedBufts)
|
||||||
|
|
||||||
|
// A lower numbered priority is better here
|
||||||
|
if actualPriority > expectedPriority {
|
||||||
|
slog.Debug("flash attention not supported by backend",
|
||||||
|
"tensor", C.GoString(C.ggml_get_name(node)),
|
||||||
|
"assigned_buffer_type", C.GoString(C.ggml_backend_buft_name(bufferType)),
|
||||||
|
"assigned_priority", actualPriority,
|
||||||
|
"expected_buffer_type", C.GoString(C.ggml_backend_buft_name(assignedBT)),
|
||||||
|
"expected_priority", expectedPriority)
|
||||||
|
return errors.New("flash attention not supported by backend")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func bufferTypePriority(buft C.ggml_backend_buffer_type_t, schedBufts []C.ggml_backend_buffer_type_t) int {
|
||||||
|
for i, b := range schedBufts {
|
||||||
|
if b == buft {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(schedBufts)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that there are no illegal operations and build a mapping of flash attention operation locations
|
||||||
|
// from before the scheduler runs to compare to the result afterwards.
|
||||||
|
func validateGraph(graph *C.struct_ggml_cgraph, flashAttention ml.FlashAttentionType) ([]C.ggml_backend_buffer_type_t, error) {
|
||||||
|
var assignments []C.ggml_backend_buffer_type_t
|
||||||
|
|
||||||
|
for i := range C.ggml_graph_n_nodes(graph) {
|
||||||
|
node := C.ggml_graph_node(graph, i)
|
||||||
|
|
||||||
|
switch node.op {
|
||||||
|
// Only flash attention supports quantized KV cache, so if we have a matmul that uses a quantized input (other than weights),
|
||||||
|
// it means that the model is using its own implementation of attention.
|
||||||
|
case C.GGML_OP_MUL_MAT:
|
||||||
|
for srcIndex := range int(C.GGML_MAX_SRC) {
|
||||||
|
src := node.src[srcIndex]
|
||||||
|
if src == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var quantized *C.struct_ggml_tensor
|
||||||
|
for current := src; current != nil; current = current.view_src {
|
||||||
|
if C.ggml_is_quantized(current._type) {
|
||||||
|
quantized = current
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If matmul has a quantized input, it is only supported if it is weights (due to uniform stride)
|
||||||
|
if quantized != nil &&
|
||||||
|
!(quantized.buffer != nil && C.ggml_backend_buffer_get_usage(quantized.buffer) == C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
|
||||||
|
slog.Debug("unsupported quantized matmul input",
|
||||||
|
"tensor", C.GoString(C.ggml_get_name(node)),
|
||||||
|
"src", C.GoString(C.ggml_get_name(src)),
|
||||||
|
"type", C.GoString(C.ggml_type_name(quantized._type)))
|
||||||
|
return nil, errors.New("unsupported quantized matmul input")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build a mapping of flash attention operations to their most direct weight input. We do this before the scheduler runs
|
||||||
|
// because the graph is fully connected. After scheduling, the graph is hard to trace because it is broken up into splits.
|
||||||
|
// We index by flash attention number (more or less equivalent to layer) since that is persistent across scheduling.
|
||||||
|
case C.GGML_OP_FLASH_ATTN_EXT:
|
||||||
|
if flashAttention == ml.FlashAttentionAuto {
|
||||||
|
// Breadth-first search for the first ancestor that has a buffer with weights
|
||||||
|
queue := []*C.struct_ggml_tensor{node}
|
||||||
|
visited := make(map[*C.struct_ggml_tensor]struct{})
|
||||||
|
|
||||||
|
var ancestor *C.struct_ggml_tensor
|
||||||
|
for len(queue) > 0 {
|
||||||
|
current := queue[0]
|
||||||
|
queue = queue[1:]
|
||||||
|
|
||||||
|
if _, ok := visited[current]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
visited[current] = struct{}{}
|
||||||
|
|
||||||
|
// Only use weights as reference points - we don't want to use inputs like the cache mask, which are always on the CPU
|
||||||
|
if current.buffer != nil && C.ggml_backend_buffer_get_usage(current.buffer) == C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS {
|
||||||
|
ancestor = current
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
for srcIndex := range int(C.GGML_MAX_SRC) {
|
||||||
|
if src := current.src[srcIndex]; src != nil {
|
||||||
|
queue = append(queue, src)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var bufferType C.ggml_backend_buffer_type_t
|
||||||
|
if ancestor != nil {
|
||||||
|
bufferType = C.ggml_backend_buffer_get_type(ancestor.buffer)
|
||||||
|
}
|
||||||
|
assignments = append(assignments, bufferType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return assignments, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) MaxGraphNodes() int {
|
func (c *Context) MaxGraphNodes() int {
|
||||||
@@ -1679,7 +1820,7 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sin
|
|||||||
query := t.Permute(ctx, 0, 2, 1, 3)
|
query := t.Permute(ctx, 0, 2, 1, 3)
|
||||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||||
|
|
||||||
if t.b.flashAttention == ml.FlashAttentionEnabled {
|
if t.b.flashAttention != ml.FlashAttentionDisabled {
|
||||||
value = value.Permute(ctx, 0, 2, 1, 3)
|
value = value.Permute(ctx, 0, 2, 1, 3)
|
||||||
|
|
||||||
kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)
|
kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)
|
||||||
|
|||||||
@@ -935,13 +935,13 @@ func (s *Server) load(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
case llm.LoadOperationClose:
|
case llm.LoadOperationClose:
|
||||||
// No-op for us
|
// No-op for us
|
||||||
if err := json.NewEncoder(w).Encode(&llm.LoadResponse{}); err != nil {
|
if err := json.NewEncoder(w).Encode(&llm.LoadResponse{Request: req}); err != nil {
|
||||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := llm.LoadResponse{Success: true}
|
resp := llm.LoadResponse{Success: true, Request: req}
|
||||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -98,7 +98,10 @@ func (m multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Ten
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
computeCtx.Reserve()
|
err := computeCtx.Reserve()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1160,22 +1160,12 @@ func (s *Server) reserveWorstCaseGraph(prompt bool) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ctx.SetBatchSize(batchSize)
|
ctx.SetBatchSize(batchSize)
|
||||||
ctx.Forward(t).Reserve()
|
return ctx.Forward(t).Reserve()
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// allocModel pre-allocates the maximum needed memory for a model
|
// allocModel pre-allocates the maximum needed memory for a model
|
||||||
// based on the given parameters
|
// based on the given parameters
|
||||||
func (s *Server) allocModel(
|
func (s *Server) allocModel(mpath string, req *llm.LoadRequest) (panicErr error) {
|
||||||
mpath string,
|
|
||||||
params ml.BackendParams,
|
|
||||||
loraPath []string,
|
|
||||||
parallel int,
|
|
||||||
kvCacheType string,
|
|
||||||
kvSize int,
|
|
||||||
multiUserCache bool,
|
|
||||||
) (panicErr error) {
|
|
||||||
// Convert memory allocation panics to errors
|
// Convert memory allocation panics to errors
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
@@ -1192,43 +1182,73 @@ func (s *Server) allocModel(
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var err error
|
reload:
|
||||||
s.model, err = model.New(mpath, params)
|
for range 2 {
|
||||||
if err != nil {
|
params := ml.BackendParams{
|
||||||
return err
|
AllocMemory: req.Operation != llm.LoadOperationFit,
|
||||||
}
|
NumThreads: req.NumThreads,
|
||||||
|
GPULayers: req.GPULayers,
|
||||||
// TODO(jessegross): LoRA loading
|
FlashAttention: req.FlashAttention,
|
||||||
if len(loraPath) > 0 {
|
|
||||||
return errors.New("loras are not yet implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.model.Config().Cache == nil {
|
|
||||||
if parallel > 1 {
|
|
||||||
parallel = 1
|
|
||||||
slog.Warn("model does not support caching, disabling parallel processing")
|
|
||||||
}
|
}
|
||||||
if s.batchSize < kvSize {
|
|
||||||
s.batchSize = kvSize
|
var err error
|
||||||
slog.Warn("model does not support caching, setting batch size to context length", "batch_size", kvSize)
|
s.model, err = model.New(mpath, params)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache)
|
// TODO(jessegross): LoRA loading
|
||||||
if err != nil {
|
if len(req.LoraPath) > 0 {
|
||||||
return err
|
return errors.New("loras are not yet implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
s.parallel = parallel
|
if params.FlashAttention == ml.FlashAttentionDisabled && ggml.KVCacheTypeIsQuantized(req.KvCacheType) {
|
||||||
s.seqs = make([]*Sequence, s.parallel)
|
slog.Warn("quantized kv cache requested but flash attention disabled", "type", req.KvCacheType)
|
||||||
s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
|
req.KvCacheType = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.model.Config().Cache == nil {
|
||||||
|
if req.Parallel > 1 {
|
||||||
|
req.Parallel = 1
|
||||||
|
slog.Warn("model does not support caching, disabling parallel processing")
|
||||||
|
}
|
||||||
|
if req.BatchSize < req.KvSize {
|
||||||
|
req.BatchSize = req.KvSize
|
||||||
|
slog.Warn("model does not support caching, setting batch size to context length", "batch_size", req.KvSize)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.cache, err = NewInputCache(s.model, req.KvCacheType, int32(req.KvSize), req.Parallel, req.BatchSize, req.MultiUserCache)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.batchSize = req.BatchSize
|
||||||
|
s.parallel = req.Parallel
|
||||||
|
s.seqs = make([]*Sequence, s.parallel)
|
||||||
|
s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
|
||||||
|
|
||||||
|
for _, prompt := range []bool{true, false} {
|
||||||
|
if err := s.reserveWorstCaseGraph(prompt); err != nil {
|
||||||
|
if req.FlashAttention != ml.FlashAttentionDisabled {
|
||||||
|
slog.Warn("flash attention enabled but not supported by model")
|
||||||
|
req.FlashAttention = ml.FlashAttentionDisabled
|
||||||
|
s.closeModel()
|
||||||
|
continue reload
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.FlashAttention == ml.FlashAttentionAuto {
|
||||||
|
req.FlashAttention = ml.FlashAttentionEnabled
|
||||||
|
}
|
||||||
|
|
||||||
err = s.reserveWorstCaseGraph(true)
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.reserveWorstCaseGraph(false)
|
return errors.New("unable to allocate model")
|
||||||
}
|
}
|
||||||
|
|
||||||
// closeModel frees all memory associated with a model
|
// closeModel frees all memory associated with a model
|
||||||
@@ -1243,7 +1263,15 @@ func (s *Server) closeModel() {
|
|||||||
|
|
||||||
// loadModel loads the weights for a model. The memory must already
|
// loadModel loads the weights for a model. The memory must already
|
||||||
// have been allocated with allocModel
|
// have been allocated with allocModel
|
||||||
func (s *Server) loadModel() {
|
func (s *Server) loadModel(req llm.LoadRequest) {
|
||||||
|
if req.FlashAttention != ml.FlashAttentionDisabled {
|
||||||
|
slog.Info("enabling flash attention")
|
||||||
|
}
|
||||||
|
|
||||||
|
if ggml.KVCacheTypeIsQuantized(req.KvCacheType) {
|
||||||
|
slog.Info("enabling kv cache quantization", "type", req.KvCacheType)
|
||||||
|
}
|
||||||
|
|
||||||
err := s.model.Backend().Load(context.TODO(),
|
err := s.model.Backend().Load(context.TODO(),
|
||||||
func(progress float32) {
|
func(progress float32) {
|
||||||
s.progress = progress
|
s.progress = progress
|
||||||
@@ -1279,7 +1307,7 @@ func (s *Server) load(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
if req.Operation == llm.LoadOperationClose {
|
if req.Operation == llm.LoadOperationClose {
|
||||||
s.closeModel()
|
s.closeModel()
|
||||||
if err := json.NewEncoder(w).Encode(&llm.LoadResponse{}); err != nil {
|
if err := json.NewEncoder(w).Encode(&llm.LoadResponse{Request: req}); err != nil {
|
||||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@@ -1288,27 +1316,16 @@ func (s *Server) load(w http.ResponseWriter, r *http.Request) {
|
|||||||
s.lastLoad.Operation = req.Operation
|
s.lastLoad.Operation = req.Operation
|
||||||
loadModel := s.model == nil || !reflect.DeepEqual(req, s.lastLoad)
|
loadModel := s.model == nil || !reflect.DeepEqual(req, s.lastLoad)
|
||||||
|
|
||||||
s.lastLoad = req
|
|
||||||
|
|
||||||
if loadModel {
|
if loadModel {
|
||||||
s.closeModel()
|
s.closeModel()
|
||||||
|
|
||||||
params := ml.BackendParams{
|
err := s.allocModel(s.modelPath, &req)
|
||||||
AllocMemory: req.Operation != llm.LoadOperationFit,
|
|
||||||
NumThreads: req.NumThreads,
|
|
||||||
GPULayers: req.GPULayers,
|
|
||||||
FlashAttention: req.FlashAttention,
|
|
||||||
}
|
|
||||||
|
|
||||||
s.batchSize = req.BatchSize
|
|
||||||
|
|
||||||
err := s.allocModel(s.modelPath, params, req.LoraPath, req.Parallel, req.KvCacheType, req.KvSize, req.MultiUserCache)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.closeModel()
|
s.closeModel()
|
||||||
|
|
||||||
var noMem ml.ErrNoMem
|
var noMem ml.ErrNoMem
|
||||||
if errors.As(err, &noMem) {
|
if errors.As(err, &noMem) {
|
||||||
resp := llm.LoadResponse{Success: false, Memory: noMem.BackendMemory}
|
resp := llm.LoadResponse{Success: false, Request: req, Memory: noMem.BackendMemory}
|
||||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
@@ -1321,6 +1338,7 @@ func (s *Server) load(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.lastLoad = req
|
||||||
mem := s.model.Backend().BackendMemory()
|
mem := s.model.Backend().BackendMemory()
|
||||||
|
|
||||||
switch req.Operation {
|
switch req.Operation {
|
||||||
@@ -1332,10 +1350,10 @@ func (s *Server) load(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
case llm.LoadOperationCommit:
|
case llm.LoadOperationCommit:
|
||||||
s.status = llm.ServerStatusLoadingModel
|
s.status = llm.ServerStatusLoadingModel
|
||||||
go s.loadModel()
|
go s.loadModel(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := llm.LoadResponse{Success: true, Memory: mem}
|
resp := llm.LoadResponse{Success: true, Request: req, Memory: mem}
|
||||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
|
|||||||
Reference in New Issue
Block a user