Files
LocalAI/pkg/vram/hf_estimate.go
Richard Palethorpe 969005b2a1 feat(gallery): Speed up load times and clean gallery entries (#9211)
* feat: Rework VRAM estimation and use known_usecases in gallery

Signed-off-by: Richard Palethorpe <io@richiejp.com>
Assisted-by: Claude:claude-opus-4-7[1m] [Claude Code]

* chore(gallery): regenerate gallery index and add known_usecases to model entries

Signed-off-by: Richard Palethorpe <io@richiejp.com>

---------

Signed-off-by: Richard Palethorpe <io@richiejp.com>
2026-05-06 14:51:38 +02:00

136 lines
3.0 KiB
Go

package vram
import (
"context"
"strings"
"sync"
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
)
var (
hfSizeCacheMu sync.Mutex
hfSizeCacheData = make(map[string]hfSizeCacheEntry)
)
type hfSizeCacheEntry struct {
totalBytes uint64
err error
generation uint64
}
// ExtractHFRepoID extracts a HuggingFace repo ID from a string.
// It handles both short form ("org/model") and full URL form
// ("https://huggingface.co/org/model", "huggingface.co/org/model").
// Returns the repo ID and true if found, or empty string and false otherwise.
func ExtractHFRepoID(s string) (string, bool) {
s = strings.TrimSpace(s)
if s == "" {
return "", false
}
// Handle full URL form: https://huggingface.co/org/model or huggingface.co/org/model
for _, prefix := range []string{
"https://huggingface.co/",
"http://huggingface.co/",
"huggingface.co/",
} {
if strings.HasPrefix(strings.ToLower(s), prefix) {
rest := s[len(prefix):]
// Strip trailing slashes and path fragments beyond org/model
rest = strings.TrimRight(rest, "/")
parts := strings.SplitN(rest, "/", 3)
if len(parts) >= 2 && parts[0] != "" && parts[1] != "" {
return parts[0] + "/" + parts[1], true
}
return "", false
}
}
// Handle short form: org/model
if strings.Contains(s, "://") || strings.Contains(s, " ") {
return "", false
}
parts := strings.Split(s, "/")
if len(parts) == 2 && parts[0] != "" && parts[1] != "" {
return s, true
}
return "", false
}
// hfRepoWeightSize returns the total weight file size for a HuggingFace repo.
// Results are cached and invalidated when the gallery generation changes.
func hfRepoWeightSize(ctx context.Context, repoID string) (uint64, error) {
gen := currentGeneration()
hfSizeCacheMu.Lock()
if entry, ok := hfSizeCacheData[repoID]; ok && entry.generation == gen {
hfSizeCacheMu.Unlock()
return entry.totalBytes, entry.err
}
hfSizeCacheMu.Unlock()
totalBytes, err := hfRepoWeightSizeUncached(ctx, repoID)
hfSizeCacheMu.Lock()
hfSizeCacheData[repoID] = hfSizeCacheEntry{
totalBytes: totalBytes,
err: err,
generation: gen,
}
hfSizeCacheMu.Unlock()
return totalBytes, err
}
func hfRepoWeightSizeUncached(ctx context.Context, repoID string) (uint64, error) {
client := hfapi.NewClient()
type listResult struct {
files []hfapi.FileInfo
err error
}
ch := make(chan listResult, 1)
go func() {
files, err := client.ListFiles(repoID)
ch <- listResult{files, err}
}()
select {
case <-ctx.Done():
return 0, ctx.Err()
case res := <-ch:
if res.err != nil {
return 0, res.err
}
return sumWeightFileBytes(res.files), nil
}
}
func sumWeightFileBytes(files []hfapi.FileInfo) uint64 {
var total int64
for _, f := range files {
if f.Type != "file" {
continue
}
ext := strings.ToLower(f.Path)
if idx := strings.LastIndex(ext, "."); idx >= 0 {
ext = ext[idx:]
} else {
continue
}
if !weightExts[ext] {
continue
}
size := f.Size
if f.LFS != nil && f.LFS.Size > 0 {
size = f.LFS.Size
}
total += size
}
if total < 0 {
return 0
}
return uint64(total)
}