feat(gallery): Speed up load times and clean gallery entries (#9211)

* feat: Rework VRAM estimation and use known_usecases in gallery

Signed-off-by: Richard Palethorpe <io@richiejp.com>
Assisted-by: Claude:claude-opus-4-7[1m] [Claude Code]

* chore(gallery): regenerate gallery index and add known_usecases to model entries

Signed-off-by: Richard Palethorpe <io@richiejp.com>

---------

Signed-off-by: Richard Palethorpe <io@richiejp.com>
This commit is contained in:
Richard Palethorpe
2026-05-06 13:51:38 +01:00
committed by GitHub
parent 6d56bf98fe
commit 969005b2a1
47 changed files with 17089 additions and 5345 deletions

View File

@@ -3,94 +3,93 @@ package vram
import (
"context"
"sync"
"time"
)
const defaultEstimateCacheTTL = 15 * time.Minute
// galleryGenFunc returns the current gallery generation counter.
// When set, cache entries are invalidated when the generation changes.
// When nil (e.g., in tests or non-gallery contexts), entries never expire.
var galleryGenFunc func() uint64
// SetGalleryGenerationFunc wires the gallery generation counter into the
// VRAM caches. Call this once at application startup.
func SetGalleryGenerationFunc(fn func() uint64) {
galleryGenFunc = fn
}
func currentGeneration() uint64 {
if galleryGenFunc != nil {
return galleryGenFunc()
}
return 0
}
type sizeCacheEntry struct {
size int64
err error
until time.Time
size int64
err error
generation uint64
}
type cachedSizeResolver struct {
underlying SizeResolver
ttl time.Duration
mu sync.Mutex
cache map[string]sizeCacheEntry
}
func (c *cachedSizeResolver) ContentLength(ctx context.Context, uri string) (int64, error) {
gen := currentGeneration()
c.mu.Lock()
e, ok := c.cache[uri]
c.mu.Unlock()
if ok && time.Now().Before(e.until) {
if ok && e.generation == gen {
return e.size, e.err
}
size, err := c.underlying.ContentLength(ctx, uri)
c.mu.Lock()
if c.cache == nil {
c.cache = make(map[string]sizeCacheEntry)
}
c.cache[uri] = sizeCacheEntry{size: size, err: err, until: time.Now().Add(c.ttl)}
c.cache[uri] = sizeCacheEntry{size: size, err: err, generation: gen}
c.mu.Unlock()
return size, err
}
type ggufCacheEntry struct {
meta *GGUFMeta
err error
until time.Time
meta *GGUFMeta
err error
generation uint64
}
type cachedGGUFReader struct {
underlying GGUFMetadataReader
ttl time.Duration
mu sync.Mutex
cache map[string]ggufCacheEntry
}
func (c *cachedGGUFReader) ReadMetadata(ctx context.Context, uri string) (*GGUFMeta, error) {
gen := currentGeneration()
c.mu.Lock()
e, ok := c.cache[uri]
c.mu.Unlock()
if ok && time.Now().Before(e.until) {
if ok && e.generation == gen {
return e.meta, e.err
}
meta, err := c.underlying.ReadMetadata(ctx, uri)
c.mu.Lock()
if c.cache == nil {
c.cache = make(map[string]ggufCacheEntry)
}
c.cache[uri] = ggufCacheEntry{meta: meta, err: err, until: time.Now().Add(c.ttl)}
c.cache[uri] = ggufCacheEntry{meta: meta, err: err, generation: gen}
c.mu.Unlock()
return meta, err
}
// CachedSizeResolver returns a SizeResolver that caches ContentLength results by URI for the given TTL.
func CachedSizeResolver(underlying SizeResolver, ttl time.Duration) SizeResolver {
return &cachedSizeResolver{underlying: underlying, ttl: ttl, cache: make(map[string]sizeCacheEntry)}
}
// CachedGGUFReader returns a GGUFMetadataReader that caches ReadMetadata results by URI for the given TTL.
func CachedGGUFReader(underlying GGUFMetadataReader, ttl time.Duration) GGUFMetadataReader {
return &cachedGGUFReader{underlying: underlying, ttl: ttl, cache: make(map[string]ggufCacheEntry)}
}
// DefaultCachedSizeResolver returns a cached SizeResolver using the default implementation and default TTL (15 min).
// A single shared cache is used so repeated HEAD requests for the same URI are avoided across requests.
// DefaultCachedSizeResolver returns a cached SizeResolver using the default implementation.
// Entries are invalidated when the gallery generation changes.
func DefaultCachedSizeResolver() SizeResolver {
return defaultCachedSizeResolver
}
// DefaultCachedGGUFReader returns a cached GGUFMetadataReader using the default implementation and default TTL (15 min).
// A single shared cache is used so repeated GGUF metadata fetches for the same URI are avoided across requests.
// DefaultCachedGGUFReader returns a cached GGUFMetadataReader using the default implementation.
// Entries are invalidated when the gallery generation changes.
func DefaultCachedGGUFReader() GGUFMetadataReader {
return defaultCachedGGUFReader
}
var (
defaultCachedSizeResolver = CachedSizeResolver(defaultSizeResolver{}, defaultEstimateCacheTTL)
defaultCachedGGUFReader = CachedGGUFReader(defaultGGUFReader{}, defaultEstimateCacheTTL)
defaultCachedSizeResolver = &cachedSizeResolver{underlying: defaultSizeResolver{}, cache: make(map[string]sizeCacheEntry)}
defaultCachedGGUFReader = &cachedGGUFReader{underlying: defaultGGUFReader{}, cache: make(map[string]ggufCacheEntry)}
)

View File

@@ -23,17 +23,19 @@ func IsGGUF(nameOrURI string) bool {
return strings.ToLower(path.Ext(path.Base(nameOrURI))) == ".gguf"
}
func Estimate(ctx context.Context, files []FileInput, opts EstimateOptions, sizeResolver SizeResolver, ggufReader GGUFMetadataReader) (EstimateResult, error) {
if opts.ContextLength == 0 {
opts.ContextLength = 8192
}
if opts.KVQuantBits == 0 {
opts.KVQuantBits = 16
}
// modelProfile captures the "fixed" properties of a model after I/O.
// Everything except context length is constant for a given model.
type modelProfile struct {
sizeBytes uint64 // total weight file size
ggufSize uint64 // GGUF file size (subset of sizeBytes)
meta *GGUFMeta // nil if no GGUF metadata available
}
var sizeBytes uint64
var ggufSize uint64
// resolveProfile does all I/O: iterates files, fetches sizes and GGUF metadata.
func resolveProfile(ctx context.Context, files []FileInput, sizeResolver SizeResolver, ggufReader GGUFMetadataReader) modelProfile {
var p modelProfile
var firstGGUFURI string
for i := range files {
f := &files[i]
if !IsWeightFile(f.URI) {
@@ -47,23 +49,32 @@ func Estimate(ctx context.Context, files []FileInput, opts EstimateOptions, size
continue
}
}
sizeBytes += uint64(sz)
p.sizeBytes += uint64(sz)
if IsGGUF(f.URI) {
ggufSize += uint64(sz)
p.ggufSize += uint64(sz)
if firstGGUFURI == "" {
firstGGUFURI = f.URI
}
}
}
sizeDisplay := FormatBytes(sizeBytes)
if p.ggufSize > 0 && ggufReader != nil && firstGGUFURI != "" {
p.meta, _ = ggufReader.ReadMetadata(ctx, firstGGUFURI)
}
var vramBytes uint64
if ggufSize > 0 {
var meta *GGUFMeta
if ggufReader != nil && firstGGUFURI != "" {
meta, _ = ggufReader.ReadMetadata(ctx, firstGGUFURI)
}
return p
}
// computeVRAM is pure arithmetic — no I/O. Returns VRAM bytes for a given
// model profile and context length.
func computeVRAM(p modelProfile, ctxLen uint32, opts EstimateOptions) uint64 {
kvQuantBits := opts.KVQuantBits
if kvQuantBits == 0 {
kvQuantBits = 16
}
if p.ggufSize > 0 {
meta := p.meta
if meta != nil && (meta.BlockCount > 0 || meta.EmbeddingLength > 0) {
nLayers := meta.BlockCount
if nLayers == 0 {
@@ -84,36 +95,29 @@ func Estimate(ctx context.Context, files []FileInput, opts EstimateOptions, size
if gpuLayers <= 0 {
gpuLayers = int(nLayers)
}
ctxLen := opts.ContextLength
bKV := uint32(opts.KVQuantBits / 8)
bKV := uint32(kvQuantBits / 8)
if bKV == 0 {
bKV = 4
}
M_model := ggufSize
M_KV := uint64(bKV) * uint64(dModel) * uint64(nLayers) * uint64(ctxLen)
if headCountKV > 0 && meta.HeadCount > 0 {
M_KV = uint64(bKV) * uint64(dModel) * uint64(headCountKV) * uint64(ctxLen)
}
M_model := p.ggufSize
M_KV := uint64(bKV) * uint64(dModel) * uint64(headCountKV) * uint64(ctxLen)
P := M_model * 2
M_overhead := uint64(0.02*float64(P) + 0.15*1e9)
vramBytes = M_model + M_KV + M_overhead
vramBytes := M_model + M_KV + M_overhead
if nLayers > 0 && gpuLayers < int(nLayers) {
layerRatio := float64(gpuLayers) / float64(nLayers)
vramBytes = uint64(layerRatio*float64(M_model)) + M_KV + M_overhead
}
} else {
vramBytes = sizeOnlyVRAM(ggufSize, opts.ContextLength)
return vramBytes
}
} else if sizeBytes > 0 {
vramBytes = sizeOnlyVRAM(sizeBytes, opts.ContextLength)
return sizeOnlyVRAM(p.ggufSize, ctxLen)
}
return EstimateResult{
SizeBytes: sizeBytes,
SizeDisplay: sizeDisplay,
VRAMBytes: vramBytes,
VRAMDisplay: FormatBytes(vramBytes),
}, nil
if p.sizeBytes > 0 {
return sizeOnlyVRAM(p.sizeBytes, ctxLen)
}
return 0
}
func sizeOnlyVRAM(sizeOnDisk uint64, ctxLen uint32) uint64 {
@@ -125,6 +129,45 @@ func sizeOnlyVRAM(sizeOnDisk uint64, ctxLen uint32) uint64 {
return vram
}
// buildEstimates computes VRAMAt entries for each context size from a profile.
func buildEstimates(p modelProfile, contextSizes []uint32, opts EstimateOptions) map[string]VRAMAt {
m := make(map[string]VRAMAt, len(contextSizes))
for _, ctxLen := range contextSizes {
vramBytes := computeVRAM(p, ctxLen, opts)
m[fmt.Sprint(ctxLen)] = VRAMAt{
ContextLength: ctxLen,
VRAMBytes: vramBytes,
VRAMDisplay: FormatBytes(vramBytes),
}
}
return m
}
// EstimateMultiContext estimates model size and VRAM at multiple context sizes.
// It performs I/O once (resolveProfile) then computes VRAM for each context size.
func EstimateMultiContext(ctx context.Context, files []FileInput, contextSizes []uint32,
opts EstimateOptions, sizeResolver SizeResolver, ggufReader GGUFMetadataReader) (MultiContextEstimate, error) {
if len(contextSizes) == 0 {
contextSizes = []uint32{8192}
}
p := resolveProfile(ctx, files, sizeResolver, ggufReader)
result := MultiContextEstimate{
SizeBytes: p.sizeBytes,
SizeDisplay: FormatBytes(p.sizeBytes),
Estimates: buildEstimates(p, contextSizes, opts),
}
if p.meta != nil && p.meta.MaximumContextLength > 0 {
result.ModelMaxContext = p.meta.MaximumContextLength
}
return result, nil
}
// ParseSizeString parses a human-readable size string (e.g. "500MB", "14.5 GB", "2tb")
// into bytes. Supports B, KB, MB, GB, TB, PB (case-insensitive, space optional).
// Uses SI units (1 KB = 1000 B).
@@ -136,7 +179,6 @@ func ParseSizeString(s string) (uint64, error) {
s = strings.ToUpper(s)
// Find where the numeric part ends
i := 0
for i < len(s) && (s[i] == '.' || (s[i] >= '0' && s[i] <= '9')) {
i++
@@ -177,17 +219,6 @@ func ParseSizeString(s string) (uint64, error) {
return uint64(num * float64(multiplier)), nil
}
// EstimateFromSize builds an EstimateResult from a raw byte count.
func EstimateFromSize(sizeBytes uint64) EstimateResult {
vramBytes := sizeOnlyVRAM(sizeBytes, 8192)
return EstimateResult{
SizeBytes: sizeBytes,
SizeDisplay: FormatBytes(sizeBytes),
VRAMBytes: vramBytes,
VRAMDisplay: FormatBytes(vramBytes),
}
}
func FormatBytes(n uint64) string {
const unit = 1000
if n < unit {
@@ -216,24 +247,29 @@ func DefaultGGUFReader() GGUFMetadataReader {
}
// ModelEstimateInput describes the inputs for a unified VRAM/size estimation.
// The estimator cascades through available data: files size string HF repo zero.
// The estimator cascades through available data: files -> size string -> HF repo -> zero.
type ModelEstimateInput struct {
Files []FileInput // weight files with optional pre-known sizes
Size string // gallery hardcoded size (e.g. "14.5GB")
HFRepo string // HF repo ID or URL
Options EstimateOptions // context length, GPU layers, KV quant bits
Options EstimateOptions // GPU layers, KV quant bits
}
// EstimateModel provides a unified VRAM estimation entry point.
// EstimateModelMultiContext provides a unified VRAM estimation entry point
// that returns estimates at multiple context sizes.
// It tries (in order):
// 1. Direct file-based estimation (GGUF metadata or file size heuristic)
// 2. ParseSizeString from Size field
// 3. EstimateFromHFRepo
// 3. HuggingFace repo file listing
// 4. Zero result
func EstimateModel(ctx context.Context, input ModelEstimateInput) (EstimateResult, error) {
func EstimateModelMultiContext(ctx context.Context, input ModelEstimateInput, contextSizes []uint32) (MultiContextEstimate, error) {
if len(contextSizes) == 0 {
contextSizes = []uint32{8192}
}
// 1. Try direct file estimation
if len(input.Files) > 0 {
result, err := Estimate(ctx, input.Files, input.Options, DefaultCachedSizeResolver(), DefaultCachedGGUFReader())
result, err := EstimateMultiContext(ctx, input.Files, contextSizes, input.Options, DefaultCachedSizeResolver(), DefaultCachedGGUFReader())
if err != nil {
xlog.Debug("VRAM estimation from files failed", "error", err)
}
@@ -247,7 +283,11 @@ func EstimateModel(ctx context.Context, input ModelEstimateInput) (EstimateResul
if sizeBytes, err := ParseSizeString(input.Size); err != nil {
xlog.Debug("VRAM estimation from size string failed", "error", err, "size", input.Size)
} else if sizeBytes > 0 {
return EstimateFromSize(sizeBytes), nil
return MultiContextEstimate{
SizeBytes: sizeBytes,
SizeDisplay: FormatBytes(sizeBytes),
Estimates: buildEstimates(modelProfile{sizeBytes: sizeBytes}, contextSizes, EstimateOptions{}),
}, nil
}
}
@@ -257,15 +297,19 @@ func EstimateModel(ctx context.Context, input ModelEstimateInput) (EstimateResul
hfRepo = repoID
}
if hfRepo != "" {
result, err := EstimateFromHFRepo(ctx, hfRepo)
totalBytes, err := hfRepoWeightSize(ctx, hfRepo)
if err != nil {
xlog.Debug("VRAM estimation from HF repo failed", "error", err, "repo", hfRepo)
}
if err == nil && result.SizeBytes > 0 {
return result, nil
if err == nil && totalBytes > 0 {
return MultiContextEstimate{
SizeBytes: totalBytes,
SizeDisplay: FormatBytes(totalBytes),
Estimates: buildEstimates(modelProfile{sizeBytes: totalBytes}, contextSizes, EstimateOptions{}),
}, nil
}
}
// 4. No estimation possible
return EstimateResult{}, nil
return MultiContextEstimate{}, nil
}

View File

@@ -23,26 +23,25 @@ func (f fakeGGUFReader) ReadMetadata(ctx context.Context, uri string) (*GGUFMeta
return f[uri], nil
}
var _ = Describe("Estimate", func() {
var _ = Describe("EstimateMultiContext", func() {
ctx := context.Background()
defaultCtx := []uint32{8192}
Describe("empty or non-GGUF inputs", func() {
It("returns zero size and vram for nil files", func() {
opts := EstimateOptions{ContextLength: 8192}
res, err := Estimate(ctx, nil, opts, nil, nil)
res, err := EstimateMultiContext(ctx, nil, defaultCtx, EstimateOptions{}, nil, nil)
Expect(err).ToNot(HaveOccurred())
Expect(res.SizeBytes).To(Equal(uint64(0)))
Expect(res.VRAMBytes).To(Equal(uint64(0)))
Expect(res.Estimates["8192"].VRAMBytes).To(Equal(uint64(0)))
Expect(res.SizeDisplay).To(Equal("0 B"))
})
It("counts only .gguf files and ignores other extensions", func() {
It("counts only weight files and ignores other extensions", func() {
files := []FileInput{
{URI: "http://a/model.gguf", Size: 1_000_000_000},
{URI: "http://a/readme.txt", Size: 100},
}
opts := EstimateOptions{ContextLength: 8192}
res, err := Estimate(ctx, files, opts, nil, nil)
res, err := EstimateMultiContext(ctx, files, defaultCtx, EstimateOptions{}, nil, nil)
Expect(err).ToNot(HaveOccurred())
Expect(res.SizeBytes).To(Equal(uint64(1_000_000_000)))
})
@@ -52,8 +51,7 @@ var _ = Describe("Estimate", func() {
{URI: "http://hf.co/model/model.safetensors", Size: 2_000_000_000},
{URI: "http://hf.co/model/model2.safetensors", Size: 3_000_000_000},
}
opts := EstimateOptions{ContextLength: 8192}
res, err := Estimate(ctx, files, opts, nil, nil)
res, err := EstimateMultiContext(ctx, files, defaultCtx, EstimateOptions{}, nil, nil)
Expect(err).ToNot(HaveOccurred())
Expect(res.SizeBytes).To(Equal(uint64(5_000_000_000)))
})
@@ -62,24 +60,22 @@ var _ = Describe("Estimate", func() {
Describe("GGUF size and resolver", func() {
It("uses size resolver when file size is not set", func() {
sizes := fakeSizeResolver{"http://example.com/model.gguf": 1_500_000_000}
opts := EstimateOptions{ContextLength: 8192}
files := []FileInput{{URI: "http://example.com/model.gguf"}}
res, err := Estimate(ctx, files, opts, sizes, nil)
res, err := EstimateMultiContext(ctx, files, defaultCtx, EstimateOptions{}, sizes, nil)
Expect(err).ToNot(HaveOccurred())
Expect(res.SizeBytes).To(Equal(uint64(1_500_000_000)))
Expect(res.VRAMBytes).To(BeNumerically(">=", res.SizeBytes))
Expect(res.Estimates["8192"].VRAMBytes).To(BeNumerically(">=", res.SizeBytes))
Expect(res.SizeDisplay).To(Equal("1.5 GB"))
})
It("uses size-only VRAM formula when metadata is missing and size is large", func() {
sizes := fakeSizeResolver{"http://a/model.gguf": 10_000_000_000}
opts := EstimateOptions{ContextLength: 8192}
files := []FileInput{{URI: "http://a/model.gguf"}}
res, err := Estimate(ctx, files, opts, sizes, nil)
res, err := EstimateMultiContext(ctx, files, defaultCtx, EstimateOptions{}, sizes, nil)
Expect(err).ToNot(HaveOccurred())
Expect(res.VRAMBytes).To(BeNumerically(">", 10_000_000_000))
Expect(res.Estimates["8192"].VRAMBytes).To(BeNumerically(">", 10_000_000_000))
})
It("sums size for multiple GGUF shards", func() {
@@ -87,18 +83,16 @@ var _ = Describe("Estimate", func() {
{URI: "http://a/shard1.gguf", Size: 10_000_000_000},
{URI: "http://a/shard2.gguf", Size: 5_000_000_000},
}
opts := EstimateOptions{ContextLength: 8192}
res, err := Estimate(ctx, files, opts, nil, nil)
res, err := EstimateMultiContext(ctx, files, defaultCtx, EstimateOptions{}, nil, nil)
Expect(err).ToNot(HaveOccurred())
Expect(res.SizeBytes).To(Equal(uint64(15_000_000_000)))
})
It("formats size display correctly", func() {
files := []FileInput{{URI: "http://a/model.gguf", Size: 2_500_000_000}}
opts := EstimateOptions{ContextLength: 8192}
res, err := Estimate(ctx, files, opts, nil, nil)
res, err := EstimateMultiContext(ctx, files, defaultCtx, EstimateOptions{}, nil, nil)
Expect(err).ToNot(HaveOccurred())
Expect(res.SizeDisplay).To(Equal("2.5 GB"))
})
@@ -108,24 +102,94 @@ var _ = Describe("Estimate", func() {
It("uses metadata for VRAM when reader returns meta and partial offload", func() {
meta := &GGUFMeta{BlockCount: 32, EmbeddingLength: 4096}
reader := fakeGGUFReader{"http://a/model.gguf": meta}
opts := EstimateOptions{ContextLength: 8192, GPULayers: 20}
opts := EstimateOptions{GPULayers: 20}
files := []FileInput{{URI: "http://a/model.gguf", Size: 8_000_000_000}}
res, err := Estimate(ctx, files, opts, nil, reader)
res, err := EstimateMultiContext(ctx, files, defaultCtx, opts, nil, reader)
Expect(err).ToNot(HaveOccurred())
Expect(res.VRAMBytes).To(BeNumerically(">", 0))
Expect(res.Estimates["8192"].VRAMBytes).To(BeNumerically(">", 0))
})
It("uses metadata head counts for KV and yields vram > size", func() {
files := []FileInput{{URI: "http://a/model.gguf", Size: 15_000_000_000}}
meta := &GGUFMeta{BlockCount: 32, EmbeddingLength: 4096, HeadCount: 32, HeadCountKV: 8}
reader := fakeGGUFReader{"http://a/model.gguf": meta}
opts := EstimateOptions{ContextLength: 8192}
res, err := Estimate(ctx, files, opts, nil, reader)
res, err := EstimateMultiContext(ctx, files, defaultCtx, EstimateOptions{}, nil, reader)
Expect(err).ToNot(HaveOccurred())
Expect(res.SizeBytes).To(Equal(uint64(15_000_000_000)))
Expect(res.VRAMBytes).To(BeNumerically(">", res.SizeBytes))
Expect(res.Estimates["8192"].VRAMBytes).To(BeNumerically(">", res.SizeBytes))
})
It("populates ModelMaxContext from GGUF metadata", func() {
meta := &GGUFMeta{BlockCount: 32, EmbeddingLength: 4096, MaximumContextLength: 131072}
reader := fakeGGUFReader{"http://a/model.gguf": meta}
files := []FileInput{{URI: "http://a/model.gguf", Size: 8_000_000_000}}
res, err := EstimateMultiContext(ctx, files, defaultCtx, EstimateOptions{}, nil, reader)
Expect(err).ToNot(HaveOccurred())
Expect(res.ModelMaxContext).To(Equal(uint64(131072)))
})
})
Describe("multi-context behavior", func() {
It("returns estimates for all requested context sizes", func() {
files := []FileInput{{URI: "http://a/model.gguf", Size: 4_000_000_000}}
sizes := []uint32{8192, 32768, 131072}
res, err := EstimateMultiContext(ctx, files, sizes, EstimateOptions{}, nil, nil)
Expect(err).ToNot(HaveOccurred())
Expect(res.Estimates).To(HaveLen(3))
Expect(res.Estimates).To(HaveKey("8192"))
Expect(res.Estimates).To(HaveKey("32768"))
Expect(res.Estimates).To(HaveKey("131072"))
})
It("VRAM increases monotonically with context size", func() {
files := []FileInput{{URI: "http://a/model.gguf", Size: 4_000_000_000}}
meta := &GGUFMeta{BlockCount: 32, EmbeddingLength: 4096, HeadCount: 32, HeadCountKV: 8}
reader := fakeGGUFReader{"http://a/model.gguf": meta}
sizes := []uint32{8192, 16384, 32768, 65536, 131072, 262144}
res, err := EstimateMultiContext(ctx, files, sizes, EstimateOptions{}, nil, reader)
Expect(err).ToNot(HaveOccurred())
prev := uint64(0)
for _, sz := range sizes {
v := res.VRAMForContext(sz)
Expect(v).To(BeNumerically(">", prev), "VRAM should increase at context %d", sz)
prev = v
}
})
It("size is constant across context sizes", func() {
files := []FileInput{{URI: "http://a/model.gguf", Size: 4_000_000_000}}
sizes := []uint32{8192, 32768}
res, err := EstimateMultiContext(ctx, files, sizes, EstimateOptions{}, nil, nil)
Expect(err).ToNot(HaveOccurred())
Expect(res.SizeBytes).To(Equal(uint64(4_000_000_000)))
})
It("defaults to [8192] when contextSizes is empty", func() {
files := []FileInput{{URI: "http://a/model.gguf", Size: 4_000_000_000}}
res, err := EstimateMultiContext(ctx, files, nil, EstimateOptions{}, nil, nil)
Expect(err).ToNot(HaveOccurred())
Expect(res.Estimates).To(HaveLen(1))
Expect(res.Estimates).To(HaveKey("8192"))
})
})
Describe("VRAMForContext helper", func() {
It("returns 0 for missing context size", func() {
res := MultiContextEstimate{
Estimates: map[string]VRAMAt{
"8192": {VRAMBytes: 5000},
},
}
Expect(res.VRAMForContext(99999)).To(Equal(uint64(0)))
Expect(res.VRAMForContext(8192)).To(Equal(uint64(5000)))
})
})
})

View File

@@ -4,7 +4,6 @@ import (
"context"
"strings"
"sync"
"time"
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
)
@@ -15,13 +14,11 @@ var (
)
type hfSizeCacheEntry struct {
result EstimateResult
err error
expiresAt time.Time
totalBytes uint64
err error
generation uint64
}
const hfSizeCacheTTL = 15 * time.Minute
// ExtractHFRepoID extracts a HuggingFace repo ID from a string.
// It handles both short form ("org/model") and full URL form
// ("https://huggingface.co/org/model", "huggingface.co/org/model").
@@ -62,30 +59,31 @@ func ExtractHFRepoID(s string) (string, bool) {
return "", false
}
// EstimateFromHFRepo estimates model size by querying the HuggingFace API for file listings.
// Results are cached for 15 minutes.
func EstimateFromHFRepo(ctx context.Context, repoID string) (EstimateResult, error) {
// hfRepoWeightSize returns the total weight file size for a HuggingFace repo.
// Results are cached and invalidated when the gallery generation changes.
func hfRepoWeightSize(ctx context.Context, repoID string) (uint64, error) {
gen := currentGeneration()
hfSizeCacheMu.Lock()
if entry, ok := hfSizeCacheData[repoID]; ok && time.Now().Before(entry.expiresAt) {
if entry, ok := hfSizeCacheData[repoID]; ok && entry.generation == gen {
hfSizeCacheMu.Unlock()
return entry.result, entry.err
return entry.totalBytes, entry.err
}
hfSizeCacheMu.Unlock()
result, err := estimateFromHFRepoUncached(ctx, repoID)
totalBytes, err := hfRepoWeightSizeUncached(ctx, repoID)
hfSizeCacheMu.Lock()
hfSizeCacheData[repoID] = hfSizeCacheEntry{
result: result,
err: err,
expiresAt: time.Now().Add(hfSizeCacheTTL),
totalBytes: totalBytes,
err: err,
generation: gen,
}
hfSizeCacheMu.Unlock()
return result, err
return totalBytes, err
}
func estimateFromHFRepoUncached(ctx context.Context, repoID string) (EstimateResult, error) {
func hfRepoWeightSizeUncached(ctx context.Context, repoID string) (uint64, error) {
client := hfapi.NewClient()
type listResult struct {
@@ -100,17 +98,17 @@ func estimateFromHFRepoUncached(ctx context.Context, repoID string) (EstimateRes
select {
case <-ctx.Done():
return EstimateResult{}, ctx.Err()
return 0, ctx.Err()
case res := <-ch:
if res.err != nil {
return EstimateResult{}, res.err
return 0, res.err
}
return estimateFromFileInfos(res.files), nil
return sumWeightFileBytes(res.files), nil
}
}
func estimateFromFileInfos(files []hfapi.FileInfo) EstimateResult {
var totalSize int64
func sumWeightFileBytes(files []hfapi.FileInfo) uint64 {
var total int64
for _, f := range files {
if f.Type != "file" {
continue
@@ -128,20 +126,10 @@ func estimateFromFileInfos(files []hfapi.FileInfo) EstimateResult {
if f.LFS != nil && f.LFS.Size > 0 {
size = f.LFS.Size
}
totalSize += size
total += size
}
if totalSize <= 0 {
return EstimateResult{}
}
sizeBytes := uint64(totalSize)
vramBytes := sizeOnlyVRAM(sizeBytes, 8192)
return EstimateResult{
SizeBytes: sizeBytes,
SizeDisplay: FormatBytes(sizeBytes),
VRAMBytes: vramBytes,
VRAMDisplay: FormatBytes(vramBytes),
if total < 0 {
return 0
}
return uint64(total)
}

View File

@@ -1,6 +1,9 @@
package vram
import "context"
import (
"context"
"fmt"
)
// FileInput represents a single model file for estimation (URI and optional pre-known size).
type FileInput struct {
@@ -28,16 +31,45 @@ type GGUFMetadataReader interface {
}
// EstimateOptions configures VRAM/size estimation.
// GPULayers and KVQuantBits apply uniformly across all context sizes.
type EstimateOptions struct {
ContextLength uint32
GPULayers int
KVQuantBits int
GPULayers int
KVQuantBits int
}
// EstimateResult holds estimated download size and VRAM with display strings.
type EstimateResult struct {
SizeBytes uint64 `json:"sizeBytes"` // total model weight size in bytes
SizeDisplay string `json:"sizeDisplay"` // human-readable size (e.g. "4.2 GB")
VRAMBytes uint64 `json:"vramBytes"` // estimated VRAM usage in bytes
VRAMDisplay string `json:"vramDisplay"` // human-readable VRAM (e.g. "6.1 GB")
// VRAMAt holds the VRAM estimate at a specific context size.
type VRAMAt struct {
ContextLength uint32 `json:"contextLength"`
VRAMBytes uint64 `json:"vramBytes"`
VRAMDisplay string `json:"vramDisplay"`
}
// EstimateResult is a flat single-context view of an estimate, suitable for
// the REST /api/models/vram-estimate response and the MCP vram_estimate tool.
// It is the legacy shape the LLM and HTTP clients expect (size_bytes /
// size_display / vram_bytes / vram_display).
type EstimateResult struct {
SizeBytes uint64 `json:"size_bytes"`
SizeDisplay string `json:"size_display"`
ContextLength uint32 `json:"context_length,omitempty"`
VRAMBytes uint64 `json:"vram_bytes"`
VRAMDisplay string `json:"vram_display"`
}
// MultiContextEstimate holds VRAM estimates for one or more context sizes,
// computed from a single metadata fetch.
type MultiContextEstimate struct {
SizeBytes uint64 `json:"sizeBytes"`
SizeDisplay string `json:"sizeDisplay"`
Estimates map[string]VRAMAt `json:"estimates"` // keys: context size as string
ModelMaxContext uint64 `json:"modelMaxContext,omitempty"` // from GGUF metadata
}
// VRAMForContext is a convenience method that returns the VRAMBytes for a
// specific context size, or 0 if not present.
func (m MultiContextEstimate) VRAMForContext(ctxLen uint32) uint64 {
if e, ok := m.Estimates[fmt.Sprint(ctxLen)]; ok {
return e.VRAMBytes
}
return 0
}