@@ -614,10 +636,10 @@ function modelsGallery() {
totalPages: 1,
availableModels: 0,
installedModels: 0,
- storageSize: 0,
ramTotal: 0,
ramUsed: 0,
ramUsagePercent: 0,
+ totalMemory: 0,
selectedModel: null,
jobProgress: {},
notifications: [],
@@ -626,10 +648,21 @@ function modelsGallery() {
init() {
this.fetchModels();
+ this.fetchResources();
// Poll for job progress every 600ms
setInterval(() => this.pollJobs(), 600);
},
+ async fetchResources() {
+ try {
+ const response = await fetch('/api/resources');
+ if (response.ok) {
+ const data = await response.json();
+ this.totalMemory = data.aggregate?.total_memory || 0;
+ }
+ } catch (e) {}
+ },
+
addNotification(message, type = 'error') {
const id = Date.now();
this.notifications.push({ id, message, type });
@@ -663,7 +696,6 @@ function modelsGallery() {
this.totalPages = data.totalPages || 1;
this.availableModels = data.availableModels || 0;
this.installedModels = data.installedModels || 0;
- this.storageSize = data.storageSize || 0;
this.ramTotal = data.ramTotal || 0;
this.ramUsed = data.ramUsed || 0;
this.ramUsagePercent = data.ramUsagePercent || 0;
diff --git a/core/schema/localai.go b/core/schema/localai.go
index 7ccf2bb32..ccf3e6e10 100644
--- a/core/schema/localai.go
+++ b/core/schema/localai.go
@@ -24,6 +24,11 @@ type BackendMonitorResponse struct {
type GalleryResponse struct {
ID string `json:"uuid"`
StatusURL string `json:"status"`
+
+ EstimatedVRAMBytes uint64 `json:"estimated_vram_bytes,omitempty"`
+ EstimatedVRAMDisplay string `json:"estimated_vram_display,omitempty"`
+ EstimatedSizeBytes uint64 `json:"estimated_size_bytes,omitempty"`
+ EstimatedSizeDisplay string `json:"estimated_size_display,omitempty"`
}
type VideoRequest struct {
diff --git a/docs/content/features/model-gallery.md b/docs/content/features/model-gallery.md
index 96aaf01d0..a43ff2941 100644
--- a/docs/content/features/model-gallery.md
+++ b/docs/content/features/model-gallery.md
@@ -31,6 +31,15 @@ GPT and text generation models might have a license which is not permissive for
Navigate the WebUI interface in the "Models" section from the navbar at the top. Here you can find a list of models that can be installed, and you can install them by clicking the "Install" button.
+## VRAM and download size estimates
+
+When browsing the gallery or importing a model by URI, LocalAI can show **estimated download size** and **estimated VRAM** for models.
+
+- **Where they appear**: In the model gallery table (Size / VRAM column), in the model detail modal, and after starting an import from URI (in the success message).
+- **How they are computed**: GGUF models use file size (HTTP HEAD or local stat) and optional GGUF metadata (HTTP Range) for KV cache and overhead; other formats use Hugging Face file sizes and optional config when available. If metadata is unavailable, a size-only heuristic is used.
+- **Hardware fit indicator**: When your system reports GPU or RAM capacity, the gallery shows whether the estimated VRAM fits (green) or may not fit (red) using a 95% headroom rule.
+- Estimates are best-effort and may be missing if the server does not support HEAD/Range or the request times out.
+
## Add other galleries
You can add other galleries by:
diff --git a/pkg/downloader/uri.go b/pkg/downloader/uri.go
index 5baf04dd5..5db5de4a1 100644
--- a/pkg/downloader/uri.go
+++ b/pkg/downloader/uri.go
@@ -275,6 +275,68 @@ func (uri URI) checkSeverSupportsRangeHeader() (bool, error) {
return resp.Header.Get("Accept-Ranges") == "bytes", nil
}
+// ContentLength returns the size in bytes of the resource at the URI.
+// For file:// it uses os.Stat on the resolved path; for HTTP/HTTPS it uses HEAD
+// and optionally a Range request if Content-Length is missing.
+func (u URI) ContentLength(ctx context.Context) (int64, error) {
+ urlStr := u.ResolveURL()
+ if strings.HasPrefix(string(u), LocalPrefix) {
+ info, err := os.Stat(urlStr)
+ if err != nil {
+ return 0, err
+ }
+ return info.Size(), nil
+ }
+ if !u.LooksLikeHTTPURL() {
+ return 0, fmt.Errorf("unsupported URI scheme for ContentLength: %s", string(u))
+ }
+ req, err := http.NewRequestWithContext(ctx, "HEAD", urlStr, nil)
+ if err != nil {
+ return 0, err
+ }
+ resp, err := http.DefaultClient.Do(req)
+ if err != nil {
+ return 0, err
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode >= 400 {
+ return 0, fmt.Errorf("HEAD %s: status %d", urlStr, resp.StatusCode)
+ }
+ if resp.ContentLength >= 0 {
+ return resp.ContentLength, nil
+ }
+ if resp.Header.Get("Accept-Ranges") != "bytes" {
+ return 0, fmt.Errorf("HEAD %s: no Content-Length and server does not support Range", urlStr)
+ }
+ req2, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil)
+ if err != nil {
+ return 0, err
+ }
+ req2.Header.Set("Range", "bytes=0-0")
+ resp2, err := http.DefaultClient.Do(req2)
+ if err != nil {
+ return 0, err
+ }
+ defer resp2.Body.Close()
+ if resp2.StatusCode != http.StatusPartialContent && resp2.StatusCode != http.StatusOK {
+ return 0, fmt.Errorf("Range request %s: status %d", urlStr, resp2.StatusCode)
+ }
+ cr := resp2.Header.Get("Content-Range")
+ // Content-Range: bytes 0-0/12345
+ if cr == "" {
+ return 0, fmt.Errorf("Range request %s: no Content-Range header", urlStr)
+ }
+ parts := strings.Split(cr, "/")
+ if len(parts) != 2 {
+ return 0, fmt.Errorf("invalid Content-Range: %s", cr)
+ }
+ size, err := strconv.ParseInt(strings.TrimSpace(parts[1]), 10, 64)
+ if err != nil || size < 0 {
+ return 0, fmt.Errorf("invalid Content-Range total length: %s", parts[1])
+ }
+ return size, nil
+}
+
func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStatus func(string, string, string, float64)) error {
return uri.DownloadFileWithContext(context.Background(), filePath, sha, fileN, total, downloadStatus)
}
diff --git a/pkg/downloader/uri_test.go b/pkg/downloader/uri_test.go
index 571869077..bd895defc 100644
--- a/pkg/downloader/uri_test.go
+++ b/pkg/downloader/uri_test.go
@@ -1,12 +1,15 @@
package downloader_test
import (
+ "context"
"crypto/rand"
"crypto/sha256"
+ "errors"
"fmt"
"net/http"
"net/http/httptest"
"os"
+ "path/filepath"
"regexp"
"strconv"
@@ -48,6 +51,86 @@ var _ = Describe("Gallery API tests", func() {
})
})
+var _ = Describe("ContentLength", func() {
+ Context("local file", func() {
+ It("returns file size for existing file", func() {
+ dir, err := os.MkdirTemp("", "contentlength-*")
+ Expect(err).ToNot(HaveOccurred())
+ defer os.RemoveAll(dir)
+ fpath := filepath.Join(dir, "model.gguf")
+ err = os.WriteFile(fpath, make([]byte, 1234), 0644)
+ Expect(err).ToNot(HaveOccurred())
+ uri := URI("file://" + fpath)
+ ctx := context.Background()
+ size, err := uri.ContentLength(ctx)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(size).To(Equal(int64(1234)))
+ })
+ It("returns error for missing file", func() {
+ uri := URI("file:///nonexistent/path/model.gguf")
+ ctx := context.Background()
+ _, err := uri.ContentLength(ctx)
+ Expect(err).To(HaveOccurred())
+ })
+ })
+ Context("HTTP", func() {
+ It("returns Content-Length when present", func() {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ Expect(r.Method).To(Equal("HEAD"))
+ w.Header().Set("Content-Length", "1000")
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer server.Close()
+ uri := URI(server.URL)
+ ctx := context.Background()
+ size, err := uri.ContentLength(ctx)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(size).To(Equal(int64(1000)))
+ })
+ It("returns error on 404", func() {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusNotFound)
+ }))
+ defer server.Close()
+ uri := URI(server.URL)
+ ctx := context.Background()
+ _, err := uri.ContentLength(ctx)
+ Expect(err).To(HaveOccurred())
+ })
+ It("uses Range when Content-Length missing and Accept-Ranges bytes", func() {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Method == "HEAD" {
+ w.Header().Set("Accept-Ranges", "bytes")
+ w.WriteHeader(http.StatusOK)
+ return
+ }
+ Expect(r.Header.Get("Range")).To(Equal("bytes=0-0"))
+ w.Header().Set("Content-Range", "bytes 0-0/5000")
+ w.WriteHeader(http.StatusPartialContent)
+ }))
+ defer server.Close()
+ uri := URI(server.URL)
+ ctx := context.Background()
+ size, err := uri.ContentLength(ctx)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(size).To(Equal(int64(5000)))
+ })
+ It("respects context cancellation", func() {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Length", "1000")
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer server.Close()
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ uri := URI(server.URL)
+ _, err := uri.ContentLength(ctx)
+ Expect(err).To(HaveOccurred())
+ Expect(errors.Is(err, context.Canceled)).To(BeTrue())
+ })
+ })
+})
+
type RangeHeaderError struct {
msg string
}
diff --git a/pkg/vram/cache.go b/pkg/vram/cache.go
new file mode 100644
index 000000000..be2cdbd4f
--- /dev/null
+++ b/pkg/vram/cache.go
@@ -0,0 +1,96 @@
+package vram
+
+import (
+ "context"
+ "sync"
+ "time"
+)
+
+const defaultEstimateCacheTTL = 15 * time.Minute
+
+type sizeCacheEntry struct {
+ size int64
+ err error
+ until time.Time
+}
+
+type cachedSizeResolver struct {
+ underlying SizeResolver
+ ttl time.Duration
+ mu sync.Mutex
+ cache map[string]sizeCacheEntry
+}
+
+func (c *cachedSizeResolver) ContentLength(ctx context.Context, uri string) (int64, error) {
+ c.mu.Lock()
+ e, ok := c.cache[uri]
+ c.mu.Unlock()
+ if ok && time.Now().Before(e.until) {
+ return e.size, e.err
+ }
+ size, err := c.underlying.ContentLength(ctx, uri)
+ c.mu.Lock()
+ if c.cache == nil {
+ c.cache = make(map[string]sizeCacheEntry)
+ }
+ c.cache[uri] = sizeCacheEntry{size: size, err: err, until: time.Now().Add(c.ttl)}
+ c.mu.Unlock()
+ return size, err
+}
+
+type ggufCacheEntry struct {
+ meta *GGUFMeta
+ err error
+ until time.Time
+}
+
+type cachedGGUFReader struct {
+ underlying GGUFMetadataReader
+ ttl time.Duration
+ mu sync.Mutex
+ cache map[string]ggufCacheEntry
+}
+
+func (c *cachedGGUFReader) ReadMetadata(ctx context.Context, uri string) (*GGUFMeta, error) {
+ c.mu.Lock()
+ e, ok := c.cache[uri]
+ c.mu.Unlock()
+ if ok && time.Now().Before(e.until) {
+ return e.meta, e.err
+ }
+ meta, err := c.underlying.ReadMetadata(ctx, uri)
+ c.mu.Lock()
+ if c.cache == nil {
+ c.cache = make(map[string]ggufCacheEntry)
+ }
+ c.cache[uri] = ggufCacheEntry{meta: meta, err: err, until: time.Now().Add(c.ttl)}
+ c.mu.Unlock()
+ return meta, err
+}
+
+// CachedSizeResolver returns a SizeResolver that caches ContentLength results by URI for the given TTL.
+func CachedSizeResolver(underlying SizeResolver, ttl time.Duration) SizeResolver {
+ return &cachedSizeResolver{underlying: underlying, ttl: ttl, cache: make(map[string]sizeCacheEntry)}
+}
+
+// CachedGGUFReader returns a GGUFMetadataReader that caches ReadMetadata results by URI for the given TTL.
+func CachedGGUFReader(underlying GGUFMetadataReader, ttl time.Duration) GGUFMetadataReader {
+ return &cachedGGUFReader{underlying: underlying, ttl: ttl, cache: make(map[string]ggufCacheEntry)}
+}
+
+// DefaultCachedSizeResolver returns a cached SizeResolver using the default implementation and default TTL (15 min).
+// A single shared cache is used so repeated HEAD requests for the same URI are avoided across requests.
+func DefaultCachedSizeResolver() SizeResolver {
+ return defaultCachedSizeResolver
+}
+
+// DefaultCachedGGUFReader returns a cached GGUFMetadataReader using the default implementation and default TTL (15 min).
+// A single shared cache is used so repeated GGUF metadata fetches for the same URI are avoided across requests.
+func DefaultCachedGGUFReader() GGUFMetadataReader {
+ return defaultCachedGGUFReader
+}
+
+var (
+ defaultCachedSizeResolver = CachedSizeResolver(defaultSizeResolver{}, defaultEstimateCacheTTL)
+ defaultCachedGGUFReader = CachedGGUFReader(defaultGGUFReader{}, defaultEstimateCacheTTL)
+)
diff --git a/pkg/vram/estimate.go b/pkg/vram/estimate.go
new file mode 100644
index 000000000..88f30c2ac
--- /dev/null
+++ b/pkg/vram/estimate.go
@@ -0,0 +1,152 @@
+package vram
+
+import (
+ "context"
+ "fmt"
+ "path"
+ "strings"
+
+ "github.com/mudler/LocalAI/pkg/downloader"
+)
+
+var weightExts = map[string]bool{
+ ".gguf": true, ".safetensors": true, ".bin": true, ".pt": true,
+}
+
+func isWeightFile(nameOrURI string) bool {
+ ext := strings.ToLower(path.Ext(path.Base(nameOrURI)))
+ return weightExts[ext]
+}
+
+func isGGUF(nameOrURI string) bool {
+ return strings.ToLower(path.Ext(path.Base(nameOrURI))) == ".gguf"
+}
+
+func Estimate(ctx context.Context, files []FileInput, opts EstimateOptions, sizeResolver SizeResolver, ggufReader GGUFMetadataReader) (EstimateResult, error) {
+ if opts.ContextLength == 0 {
+ opts.ContextLength = 8192
+ }
+ if opts.KVQuantBits == 0 {
+ opts.KVQuantBits = 16
+ }
+
+ var sizeBytes uint64
+ var ggufSize uint64
+ var firstGGUFURI string
+ for i := range files {
+ f := &files[i]
+ if !isWeightFile(f.URI) {
+ continue
+ }
+ sz := f.Size
+ if sz <= 0 && sizeResolver != nil {
+ var err error
+ sz, err = sizeResolver.ContentLength(ctx, f.URI)
+ if err != nil {
+ continue
+ }
+ }
+ sizeBytes += uint64(sz)
+ if isGGUF(f.URI) {
+ ggufSize += uint64(sz)
+ if firstGGUFURI == "" {
+ firstGGUFURI = f.URI
+ }
+ }
+ }
+
+ sizeDisplay := FormatBytes(sizeBytes)
+
+ var vramBytes uint64
+ if ggufSize > 0 {
+ var meta *GGUFMeta
+ if ggufReader != nil && firstGGUFURI != "" {
+ meta, _ = ggufReader.ReadMetadata(ctx, firstGGUFURI)
+ }
+ if meta != nil && (meta.BlockCount > 0 || meta.EmbeddingLength > 0) {
+ nLayers := meta.BlockCount
+ if nLayers == 0 {
+ nLayers = 32
+ }
+ dModel := meta.EmbeddingLength
+ if dModel == 0 {
+ dModel = 4096
+ }
+ headCountKV := meta.HeadCountKV
+ if headCountKV == 0 {
+ headCountKV = meta.HeadCount
+ }
+ if headCountKV == 0 {
+ headCountKV = 8
+ }
+ gpuLayers := opts.GPULayers
+ if gpuLayers <= 0 {
+ gpuLayers = int(nLayers)
+ }
+ ctxLen := opts.ContextLength
+ bKV := uint32(opts.KVQuantBits / 8)
+ if bKV == 0 {
+ bKV = 4
+ }
+ M_model := ggufSize
+ M_KV := uint64(bKV) * uint64(dModel) * uint64(nLayers) * uint64(ctxLen)
+ if headCountKV > 0 && meta.HeadCount > 0 {
+ M_KV = uint64(bKV) * uint64(dModel) * uint64(headCountKV) * uint64(ctxLen)
+ }
+ P := M_model * 2
+ M_overhead := uint64(0.02*float64(P) + 0.15*1e9)
+ vramBytes = M_model + M_KV + M_overhead
+ if nLayers > 0 && gpuLayers < int(nLayers) {
+ layerRatio := float64(gpuLayers) / float64(nLayers)
+ vramBytes = uint64(layerRatio*float64(M_model)) + M_KV + M_overhead
+ }
+ } else {
+ vramBytes = sizeOnlyVRAM(ggufSize, opts.ContextLength)
+ }
+ } else if sizeBytes > 0 {
+ vramBytes = sizeOnlyVRAM(sizeBytes, opts.ContextLength)
+ }
+
+ return EstimateResult{
+ SizeBytes: sizeBytes,
+ SizeDisplay: sizeDisplay,
+ VRAMBytes: vramBytes,
+ VRAMDisplay: FormatBytes(vramBytes),
+ }, nil
+}
+
+func sizeOnlyVRAM(sizeOnDisk uint64, ctxLen uint32) uint64 {
+ k := uint64(1024)
+ vram := sizeOnDisk + k*uint64(ctxLen)*2
+ if vram < sizeOnDisk {
+ vram = sizeOnDisk
+ }
+ return vram
+}
+
+func FormatBytes(n uint64) string {
+ const unit = 1000
+ if n < unit {
+ return fmt.Sprintf("%d B", n)
+ }
+ div, exp := uint64(unit), 0
+ for u := n / unit; u >= unit; u /= unit {
+ div *= unit
+ exp++
+ }
+ return fmt.Sprintf("%.1f %cB", float64(n)/float64(div), "KMGTPE"[exp])
+}
+
+type defaultSizeResolver struct{}
+
+func (defaultSizeResolver) ContentLength(ctx context.Context, uri string) (int64, error) {
+ return downloader.URI(uri).ContentLength(ctx)
+}
+
+func DefaultSizeResolver() SizeResolver {
+ return defaultSizeResolver{}
+}
+
+func DefaultGGUFReader() GGUFMetadataReader {
+ return defaultGGUFReader{}
+}
diff --git a/pkg/vram/estimate_test.go b/pkg/vram/estimate_test.go
new file mode 100644
index 000000000..2036c8dad
--- /dev/null
+++ b/pkg/vram/estimate_test.go
@@ -0,0 +1,137 @@
+package vram_test
+
+import (
+ "context"
+
+ . "github.com/mudler/LocalAI/pkg/vram"
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+type fakeSizeResolver map[string]int64
+
+func (f fakeSizeResolver) ContentLength(ctx context.Context, uri string) (int64, error) {
+ if n, ok := f[uri]; ok {
+ return int64(n), nil
+ }
+ return 0, nil
+}
+
+type fakeGGUFReader map[string]*GGUFMeta
+
+func (f fakeGGUFReader) ReadMetadata(ctx context.Context, uri string) (*GGUFMeta, error) {
+ return f[uri], nil
+}
+
+var _ = Describe("Estimate", func() {
+ ctx := context.Background()
+
+ Describe("empty or non-GGUF inputs", func() {
+ It("returns zero size and vram for nil files", func() {
+ opts := EstimateOptions{ContextLength: 8192}
+ res, err := Estimate(ctx, nil, opts, nil, nil)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(res.SizeBytes).To(Equal(uint64(0)))
+ Expect(res.VRAMBytes).To(Equal(uint64(0)))
+ Expect(res.SizeDisplay).To(Equal("0 B"))
+ })
+
+ It("counts only .gguf files and ignores other extensions", func() {
+ files := []FileInput{
+ {URI: "http://a/model.gguf", Size: 1_000_000_000},
+ {URI: "http://a/readme.txt", Size: 100},
+ }
+ opts := EstimateOptions{ContextLength: 8192}
+ res, err := Estimate(ctx, files, opts, nil, nil)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(res.SizeBytes).To(Equal(uint64(1_000_000_000)))
+ })
+
+ It("sums size for multiple non-GGUF weight files (e.g. safetensors)", func() {
+ files := []FileInput{
+ {URI: "http://hf.co/model/model.safetensors", Size: 2_000_000_000},
+ {URI: "http://hf.co/model/model2.safetensors", Size: 3_000_000_000},
+ }
+ opts := EstimateOptions{ContextLength: 8192}
+ res, err := Estimate(ctx, files, opts, nil, nil)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(res.SizeBytes).To(Equal(uint64(5_000_000_000)))
+ })
+ })
+
+ Describe("GGUF size and resolver", func() {
+ It("uses size resolver when file size is not set", func() {
+ sizes := fakeSizeResolver{"http://example.com/model.gguf": 1_500_000_000}
+ opts := EstimateOptions{ContextLength: 8192}
+ files := []FileInput{{URI: "http://example.com/model.gguf"}}
+
+ res, err := Estimate(ctx, files, opts, sizes, nil)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(res.SizeBytes).To(Equal(uint64(1_500_000_000)))
+ Expect(res.VRAMBytes).To(BeNumerically(">=", res.SizeBytes))
+ Expect(res.SizeDisplay).To(Equal("1.5 GB"))
+ })
+
+ It("uses size-only VRAM formula when metadata is missing and size is large", func() {
+ sizes := fakeSizeResolver{"http://a/model.gguf": 10_000_000_000}
+ opts := EstimateOptions{ContextLength: 8192}
+ files := []FileInput{{URI: "http://a/model.gguf"}}
+
+ res, err := Estimate(ctx, files, opts, sizes, nil)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(res.VRAMBytes).To(BeNumerically(">", 10_000_000_000))
+ })
+
+ It("sums size for multiple GGUF shards", func() {
+ files := []FileInput{
+ {URI: "http://a/shard1.gguf", Size: 10_000_000_000},
+ {URI: "http://a/shard2.gguf", Size: 5_000_000_000},
+ }
+ opts := EstimateOptions{ContextLength: 8192}
+
+ res, err := Estimate(ctx, files, opts, nil, nil)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(res.SizeBytes).To(Equal(uint64(15_000_000_000)))
+ })
+
+ It("formats size display correctly", func() {
+ files := []FileInput{{URI: "http://a/model.gguf", Size: 2_500_000_000}}
+ opts := EstimateOptions{ContextLength: 8192}
+
+ res, err := Estimate(ctx, files, opts, nil, nil)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(res.SizeDisplay).To(Equal("2.5 GB"))
+ })
+ })
+
+ Describe("GGUF with metadata reader", func() {
+ It("uses metadata for VRAM when reader returns meta and partial offload", func() {
+ meta := &GGUFMeta{BlockCount: 32, EmbeddingLength: 4096}
+ reader := fakeGGUFReader{"http://a/model.gguf": meta}
+ opts := EstimateOptions{ContextLength: 8192, GPULayers: 20}
+ files := []FileInput{{URI: "http://a/model.gguf", Size: 8_000_000_000}}
+
+ res, err := Estimate(ctx, files, opts, nil, reader)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(res.VRAMBytes).To(BeNumerically(">", 0))
+ })
+
+ It("uses metadata head counts for KV and yields vram > size", func() {
+ files := []FileInput{{URI: "http://a/model.gguf", Size: 15_000_000_000}}
+ meta := &GGUFMeta{BlockCount: 32, EmbeddingLength: 4096, HeadCount: 32, HeadCountKV: 8}
+ reader := fakeGGUFReader{"http://a/model.gguf": meta}
+ opts := EstimateOptions{ContextLength: 8192}
+
+ res, err := Estimate(ctx, files, opts, nil, reader)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(res.SizeBytes).To(Equal(uint64(15_000_000_000)))
+ Expect(res.VRAMBytes).To(BeNumerically(">", res.SizeBytes))
+ })
+ })
+})
+
+var _ = Describe("FormatBytes", func() {
+ It("formats 2.5e9 as 2.5 GB", func() {
+ Expect(FormatBytes(2_500_000_000)).To(Equal("2.5 GB"))
+ })
+})
diff --git a/pkg/vram/gguf_reader.go b/pkg/vram/gguf_reader.go
new file mode 100644
index 000000000..631c017f7
--- /dev/null
+++ b/pkg/vram/gguf_reader.go
@@ -0,0 +1,46 @@
+package vram
+
+import (
+ "context"
+ "strings"
+
+ "github.com/mudler/LocalAI/pkg/downloader"
+ gguf "github.com/gpustack/gguf-parser-go"
+)
+
+type defaultGGUFReader struct{}
+
+func (defaultGGUFReader) ReadMetadata(ctx context.Context, uri string) (*GGUFMeta, error) {
+ u := downloader.URI(uri)
+ urlStr := u.ResolveURL()
+
+ if strings.HasPrefix(uri, downloader.LocalPrefix) {
+ f, err := gguf.ParseGGUFFile(urlStr)
+ if err != nil {
+ return nil, err
+ }
+ return ggufFileToMeta(f), nil
+ }
+ if !u.LooksLikeHTTPURL() {
+ return nil, nil
+ }
+ f, err := gguf.ParseGGUFFileRemote(ctx, urlStr)
+ if err != nil {
+ return nil, err
+ }
+ return ggufFileToMeta(f), nil
+}
+
+func ggufFileToMeta(f *gguf.GGUFFile) *GGUFMeta {
+ arch := f.Architecture()
+ meta := &GGUFMeta{
+ BlockCount: uint32(arch.BlockCount),
+ EmbeddingLength: uint32(arch.EmbeddingLength),
+ HeadCount: uint32(arch.AttentionHeadCount),
+ HeadCountKV: uint32(arch.AttentionHeadCountKV),
+ }
+ if meta.HeadCountKV == 0 {
+ meta.HeadCountKV = meta.HeadCount
+ }
+ return meta
+}
diff --git a/pkg/vram/types.go b/pkg/vram/types.go
new file mode 100644
index 000000000..cda76aff6
--- /dev/null
+++ b/pkg/vram/types.go
@@ -0,0 +1,42 @@
+package vram
+
+import "context"
+
+// FileInput represents a single model file for estimation (URI and optional pre-known size).
+type FileInput struct {
+ URI string
+ Size int64
+}
+
+// SizeResolver returns the content length in bytes for a given URI.
+type SizeResolver interface {
+ ContentLength(ctx context.Context, uri string) (int64, error)
+}
+
+// GGUFMeta holds parsed GGUF metadata used for VRAM estimation.
+type GGUFMeta struct {
+ BlockCount uint32
+ EmbeddingLength uint32
+ HeadCount uint32
+ HeadCountKV uint32
+}
+
+// GGUFMetadataReader reads GGUF metadata from a URI (e.g. via HTTP Range).
+type GGUFMetadataReader interface {
+ ReadMetadata(ctx context.Context, uri string) (*GGUFMeta, error)
+}
+
+// EstimateOptions configures VRAM/size estimation.
+type EstimateOptions struct {
+ ContextLength uint32
+ GPULayers int
+ KVQuantBits int
+}
+
+// EstimateResult holds estimated download size and VRAM with display strings.
+type EstimateResult struct {
+ SizeBytes uint64
+ SizeDisplay string
+ VRAMBytes uint64
+ VRAMDisplay string
+}
diff --git a/pkg/vram/vram_suite_test.go b/pkg/vram/vram_suite_test.go
new file mode 100644
index 000000000..0c8116922
--- /dev/null
+++ b/pkg/vram/vram_suite_test.go
@@ -0,0 +1,13 @@
+package vram_test
+
+import (
+ "testing"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+func TestVram(t *testing.T) {
+ RegisterFailHandler(Fail)
+ RunSpecs(t, "Vram test suite")
+}