mirror of
https://github.com/mudler/LocalAI.git
synced 2026-03-31 21:25:59 -04:00
feat(ui): add model size estimation (#8684)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
b260378694
commit
983db7bedc
@@ -1,6 +1,7 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -8,6 +9,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
@@ -18,6 +20,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/mudler/LocalAI/pkg/vram"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
@@ -37,6 +40,31 @@ func ImportModelURIEndpoint(cl *config.ModelConfigLoader, appConfig *config.Appl
|
||||
return fmt.Errorf("failed to discover model config: %w", err)
|
||||
}
|
||||
|
||||
resp := schema.GalleryResponse{
|
||||
StatusURL: fmt.Sprintf("%smodels/jobs/%s", httpUtils.BaseURL(c), ""),
|
||||
}
|
||||
|
||||
if len(modelConfig.Files) > 0 {
|
||||
files := make([]vram.FileInput, 0, len(modelConfig.Files))
|
||||
for _, f := range modelConfig.Files {
|
||||
files = append(files, vram.FileInput{URI: f.URI, Size: 0})
|
||||
}
|
||||
estCtx, cancel := context.WithTimeout(c.Request().Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
opts := vram.EstimateOptions{ContextLength: 8192}
|
||||
result, err := vram.Estimate(estCtx, files, opts, vram.DefaultCachedSizeResolver(), vram.DefaultCachedGGUFReader())
|
||||
if err == nil {
|
||||
if result.SizeBytes > 0 {
|
||||
resp.EstimatedSizeBytes = result.SizeBytes
|
||||
resp.EstimatedSizeDisplay = result.SizeDisplay
|
||||
}
|
||||
if result.VRAMBytes > 0 {
|
||||
resp.EstimatedVRAMBytes = result.VRAMBytes
|
||||
resp.EstimatedVRAMDisplay = result.VRAMDisplay
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
uuid, err := uuid.NewUUID()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -63,10 +91,9 @@ func ImportModelURIEndpoint(cl *config.ModelConfigLoader, appConfig *config.Appl
|
||||
BackendGalleries: appConfig.BackendGalleries,
|
||||
}
|
||||
|
||||
return c.JSON(200, schema.GalleryResponse{
|
||||
ID: uuid.String(),
|
||||
StatusURL: fmt.Sprintf("%smodels/jobs/%s", httpUtils.BaseURL(c), uuid.String()),
|
||||
})
|
||||
resp.ID = uuid.String()
|
||||
resp.StatusURL = fmt.Sprintf("%smodels/jobs/%s", httpUtils.BaseURL(c), uuid.String())
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,9 +8,12 @@ import (
|
||||
"math"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
@@ -22,6 +25,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/p2p"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/vram"
|
||||
"github.com/mudler/LocalAI/pkg/xsysinfo"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
@@ -263,6 +267,22 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model
|
||||
modelsJSON := make([]map[string]interface{}, 0, len(models))
|
||||
seenIDs := make(map[string]bool)
|
||||
|
||||
weightExts := map[string]bool{".gguf": true, ".safetensors": true, ".bin": true, ".pt": true}
|
||||
hasWeightFiles := func(files []gallery.File) bool {
|
||||
for _, f := range files {
|
||||
ext := strings.ToLower(path.Ext(path.Base(f.URI)))
|
||||
if weightExts[ext] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
const estimateTimeout = 3 * time.Second
|
||||
const estimateConcurrency = 3
|
||||
sem := make(chan struct{}, estimateConcurrency)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for _, m := range models {
|
||||
modelID := m.ID()
|
||||
|
||||
@@ -286,7 +306,7 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model
|
||||
|
||||
_, trustRemoteCodeExists := m.Overrides["trust_remote_code"]
|
||||
|
||||
modelsJSON = append(modelsJSON, map[string]interface{}{
|
||||
obj := map[string]interface{}{
|
||||
"id": modelID,
|
||||
"name": m.Name,
|
||||
"description": m.Description,
|
||||
@@ -301,9 +321,48 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model
|
||||
"isDeletion": isDeletionOp,
|
||||
"trustRemoteCode": trustRemoteCodeExists,
|
||||
"additionalFiles": m.AdditionalFiles,
|
||||
})
|
||||
}
|
||||
|
||||
if hasWeightFiles(m.AdditionalFiles) {
|
||||
files := make([]gallery.File, len(m.AdditionalFiles))
|
||||
copy(files, m.AdditionalFiles)
|
||||
wg.Add(1)
|
||||
go func(files []gallery.File, out map[string]interface{}) {
|
||||
defer wg.Done()
|
||||
sem <- struct{}{}
|
||||
defer func() { <-sem }()
|
||||
inputs := make([]vram.FileInput, 0, len(files))
|
||||
for _, f := range files {
|
||||
ext := strings.ToLower(path.Ext(path.Base(f.URI)))
|
||||
if weightExts[ext] {
|
||||
inputs = append(inputs, vram.FileInput{URI: f.URI, Size: 0})
|
||||
}
|
||||
}
|
||||
if len(inputs) == 0 {
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), estimateTimeout)
|
||||
defer cancel()
|
||||
opts := vram.EstimateOptions{ContextLength: 8192}
|
||||
result, err := vram.Estimate(ctx, inputs, opts, vram.DefaultCachedSizeResolver(), vram.DefaultCachedGGUFReader())
|
||||
if err == nil {
|
||||
if result.SizeBytes > 0 {
|
||||
out["estimated_size_bytes"] = result.SizeBytes
|
||||
out["estimated_size_display"] = result.SizeDisplay
|
||||
}
|
||||
if result.VRAMBytes > 0 {
|
||||
out["estimated_vram_bytes"] = result.VRAMBytes
|
||||
out["estimated_vram_display"] = result.VRAMDisplay
|
||||
}
|
||||
}
|
||||
}(files, obj)
|
||||
}
|
||||
|
||||
modelsJSON = append(modelsJSON, obj)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
prevPage := pageNum - 1
|
||||
nextPage := pageNum + 1
|
||||
if prevPage < 1 {
|
||||
@@ -318,10 +377,6 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model
|
||||
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
|
||||
installedModelsCount := len(modelConfigs) + len(modelsWithoutConfig)
|
||||
|
||||
// Calculate storage size and RAM info
|
||||
modelsPath := appConfig.SystemState.Model.ModelsPath
|
||||
storageSize, _ := getDirectorySize(modelsPath)
|
||||
|
||||
ramInfo, _ := xsysinfo.GetSystemRAMInfo()
|
||||
|
||||
return c.JSON(200, map[string]interface{}{
|
||||
@@ -332,7 +387,6 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model
|
||||
"taskTypes": taskTypes,
|
||||
"availableModels": totalModels,
|
||||
"installedModels": installedModelsCount,
|
||||
"storageSize": storageSize,
|
||||
"ramTotal": ramInfo.Total,
|
||||
"ramUsed": ramInfo.Used,
|
||||
"ramUsagePercent": ramInfo.UsagePercent,
|
||||
@@ -967,12 +1021,15 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model
|
||||
watchdogInterval = appConfig.WatchDogInterval.String()
|
||||
}
|
||||
|
||||
storageSize, _ := getDirectorySize(appConfig.SystemState.Model.ModelsPath)
|
||||
|
||||
response := map[string]interface{}{
|
||||
"type": resourceInfo.Type, // "gpu" or "ram"
|
||||
"available": resourceInfo.Available,
|
||||
"gpus": resourceInfo.GPUs,
|
||||
"ram": resourceInfo.RAM,
|
||||
"aggregate": resourceInfo.Aggregate,
|
||||
"storage_size": storageSize,
|
||||
"reclaimer_enabled": appConfig.MemoryReclaimerEnabled,
|
||||
"reclaimer_threshold": appConfig.MemoryReclaimerThreshold,
|
||||
"watchdog_interval": watchdogInterval,
|
||||
|
||||
@@ -141,6 +141,15 @@
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
<!-- Models storage (disk usage) -->
|
||||
<template x-if="resourceData.storage_size != null">
|
||||
<div class="mt-3 pt-3 border-t border-[var(--color-primary-border)]/20">
|
||||
<div class="flex justify-between text-xs">
|
||||
<span class="text-[var(--color-text-secondary)]">Models storage</span>
|
||||
<span class="font-mono text-[var(--color-text-primary)]" x-text="formatBytes(resourceData.storage_size)"></span>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
</div>
|
||||
</template>
|
||||
</div>
|
||||
|
||||
@@ -59,6 +59,26 @@
|
||||
<!-- Alert Messages -->
|
||||
<div id="alertContainer" class="mb-6"></div>
|
||||
|
||||
<!-- Persistent estimate (stays visible so user can see size/VRAM even if alert is replaced) -->
|
||||
<div x-show="!isAdvancedMode && !isEditMode && lastEstimate && ((lastEstimate.sizeDisplay && lastEstimate.sizeDisplay !== '0 B') || (lastEstimate.vramDisplay && lastEstimate.vramDisplay !== '0 B'))"
|
||||
x-transition
|
||||
class="mb-6 p-4 rounded-xl border border-[var(--color-primary)]/30 bg-[var(--color-primary-light)]/30">
|
||||
<h3 class="text-sm font-semibold text-[var(--color-text-primary)] mb-2 flex items-center gap-2">
|
||||
<i class="fas fa-memory text-[var(--color-primary)]"></i>
|
||||
Estimated requirements
|
||||
</h3>
|
||||
<div class="flex flex-wrap gap-4 text-sm text-[var(--color-text-secondary)]">
|
||||
<span x-show="lastEstimate && lastEstimate.sizeDisplay && lastEstimate.sizeDisplay !== '0 B'">
|
||||
<i class="fas fa-download mr-1.5 text-[var(--color-primary)]"></i>
|
||||
Download size: <span class="font-medium text-[var(--color-text-primary)]" x-text="lastEstimate?.sizeDisplay"></span>
|
||||
</span>
|
||||
<span x-show="lastEstimate && lastEstimate.vramDisplay && lastEstimate.vramDisplay !== '0 B'">
|
||||
<i class="fas fa-microchip mr-1.5 text-[var(--color-primary)]"></i>
|
||||
VRAM: <span class="font-medium text-[var(--color-text-primary)]" x-text="lastEstimate?.vramDisplay"></span>
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Simple Import Mode -->
|
||||
<div x-show="!isAdvancedMode && !isEditMode"
|
||||
x-transition:enter="transition ease-out duration-200"
|
||||
@@ -731,6 +751,7 @@ function importModel() {
|
||||
jobPollInterval: null,
|
||||
yamlEditor: null,
|
||||
modelEditor: null,
|
||||
lastEstimate: null,
|
||||
|
||||
init() {
|
||||
// If in edit mode, always show advanced mode
|
||||
@@ -854,15 +875,36 @@ function importModel() {
|
||||
}
|
||||
|
||||
const result = await response.json();
|
||||
|
||||
|
||||
const hasSize = result.estimated_size_display && result.estimated_size_display !== '0 B';
|
||||
const hasVram = result.estimated_vram_display && result.estimated_vram_display !== '0 B';
|
||||
if (hasSize || hasVram) {
|
||||
this.lastEstimate = {
|
||||
sizeDisplay: result.estimated_size_display || '',
|
||||
vramDisplay: result.estimated_vram_display || '',
|
||||
sizeBytes: result.estimated_size_bytes || 0,
|
||||
vramBytes: result.estimated_vram_bytes || 0
|
||||
};
|
||||
} else {
|
||||
this.lastEstimate = null;
|
||||
}
|
||||
|
||||
let successMsg = 'Import started! Tracking progress...';
|
||||
if (hasSize || hasVram) {
|
||||
const parts = [];
|
||||
if (hasSize) parts.push('Size: ' + result.estimated_size_display);
|
||||
if (hasVram) parts.push('VRAM: ' + result.estimated_vram_display);
|
||||
successMsg += ' (' + parts.join(' · ') + ')';
|
||||
}
|
||||
|
||||
if (result.uuid) {
|
||||
this.currentJobId = result.uuid;
|
||||
this.showAlert('success', 'Import started! Tracking progress...');
|
||||
this.showAlert('success', successMsg);
|
||||
this.startJobPolling();
|
||||
} else if (result.ID) {
|
||||
// Fallback for different response format
|
||||
this.currentJobId = result.ID;
|
||||
this.showAlert('success', 'Import started! Tracking progress...');
|
||||
this.showAlert('success', successMsg);
|
||||
this.startJobPolling();
|
||||
} else {
|
||||
throw new Error('No job ID returned from server');
|
||||
|
||||
@@ -61,15 +61,6 @@
|
||||
<span class="font-semibold text-purple-300" x-text="repositories.length"></span>
|
||||
<span class="text-[var(--color-text-secondary)] ml-1">repositories</span>
|
||||
</div>
|
||||
<div class="flex items-center bg-[var(--color-bg-primary)] rounded-lg px-4 py-2">
|
||||
<div class="w-2 h-2 bg-blue-500 rounded-full mr-2"></div>
|
||||
<span class="font-semibold text-blue-300" x-text="formatBytes(storageSize)"></span>
|
||||
<span class="text-[var(--color-text-secondary)] ml-1">storage</span>
|
||||
</div>
|
||||
<div x-show="storageSize > ramTotal" class="flex items-center bg-red-500/20 rounded-lg px-4 py-2 border border-red-500/30">
|
||||
<i class="fas fa-exclamation-triangle text-red-400 mr-2"></i>
|
||||
<span class="text-red-300 text-sm">Storage exceeds RAM!</span>
|
||||
</div>
|
||||
<a href="/import-model" class="inline-flex items-center gap-1.5 text-xs text-[var(--color-text-secondary)] hover:text-[var(--color-primary)] bg-transparent hover:bg-[var(--color-primary)]/10 border border-[var(--color-border-subtle)] hover:border-[var(--color-primary)]/30 rounded-md py-1.5 px-2.5 transition-colors">
|
||||
<i class="fas fa-upload"></i>
|
||||
<span>Import Model</span>
|
||||
@@ -186,7 +177,7 @@
|
||||
</div>
|
||||
|
||||
<!-- Results Section -->
|
||||
<div id="search-results" class="transition-all duration-300">
|
||||
<div id="search-results" class="transition-all duration-300 relative">
|
||||
<div x-show="loading && models.length === 0" class="text-center py-12">
|
||||
<svg class="animate-spin h-12 w-12 text-[var(--color-primary)] mx-auto mb-4" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24">
|
||||
<circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle>
|
||||
@@ -200,6 +191,21 @@
|
||||
<p class="text-[var(--color-text-secondary)]">No models found matching your criteria</p>
|
||||
</div>
|
||||
|
||||
<!-- Loading overlay when switching pages (we have models but loading) -->
|
||||
<div x-show="loading && models.length > 0"
|
||||
x-transition:enter="transition ease-out duration-150"
|
||||
x-transition:enter-start="opacity-0"
|
||||
x-transition:enter-end="opacity-100"
|
||||
class="absolute inset-0 z-10 flex items-center justify-center rounded-2xl bg-[var(--color-bg-secondary)]/80 backdrop-blur-sm">
|
||||
<div class="flex flex-col items-center gap-3">
|
||||
<svg class="animate-spin h-12 w-12 text-[var(--color-primary)]" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24">
|
||||
<circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle>
|
||||
<path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path>
|
||||
</svg>
|
||||
<p class="text-sm text-[var(--color-text-secondary)]">Loading page...</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Table View -->
|
||||
<div x-show="models.length > 0" class="bg-[var(--color-bg-secondary)] rounded-2xl border border-[var(--color-border-subtle)] overflow-hidden shadow-xl backdrop-blur-sm">
|
||||
<div class="overflow-x-auto">
|
||||
@@ -218,26 +224,7 @@
|
||||
</div>
|
||||
</th>
|
||||
<th class="px-6 py-4 text-left text-xs font-semibold text-[var(--color-text-primary)] uppercase tracking-wider">Description</th>
|
||||
<th @click="setSort('repository')"
|
||||
:class="sortBy === 'repository' ? 'bg-[var(--color-primary-light)]' : ''"
|
||||
class="px-6 py-4 text-left text-xs font-semibold text-[var(--color-text-primary)] uppercase tracking-wider cursor-pointer hover:bg-[var(--color-bg-primary)] transition-colors">
|
||||
<div class="flex items-center gap-2">
|
||||
<span>Repository</span>
|
||||
<i :class="sortBy === 'repository' ? (sortOrder === 'asc' ? 'fas fa-sort-up' : 'fas fa-sort-down') : 'fas fa-sort'"
|
||||
:class="sortBy === 'repository' ? 'text-[var(--color-primary)]' : 'text-[var(--color-text-secondary)]'"
|
||||
class="text-xs"></i>
|
||||
</div>
|
||||
</th>
|
||||
<th @click="setSort('license')"
|
||||
:class="sortBy === 'license' ? 'bg-[var(--color-primary-light)]' : ''"
|
||||
class="px-6 py-4 text-left text-xs font-semibold text-[var(--color-text-primary)] uppercase tracking-wider cursor-pointer hover:bg-[var(--color-bg-primary)] transition-colors">
|
||||
<div class="flex items-center gap-2">
|
||||
<span>License</span>
|
||||
<i :class="sortBy === 'license' ? (sortOrder === 'asc' ? 'fas fa-sort-up' : 'fas fa-sort-down') : 'fas fa-sort'"
|
||||
:class="sortBy === 'license' ? 'text-[var(--color-primary)]' : 'text-[var(--color-text-secondary)]'"
|
||||
class="text-xs"></i>
|
||||
</div>
|
||||
</th>
|
||||
<th class="px-6 py-4 text-left text-xs font-semibold text-[var(--color-text-primary)] uppercase tracking-wider">Size / VRAM</th>
|
||||
<th @click="setSort('status')"
|
||||
:class="sortBy === 'status' ? 'bg-[var(--color-primary-light)]' : ''"
|
||||
class="px-6 py-4 text-left text-xs font-semibold text-[var(--color-text-primary)] uppercase tracking-wider cursor-pointer hover:bg-[var(--color-bg-primary)] transition-colors">
|
||||
@@ -284,21 +271,26 @@
|
||||
<div class="text-sm text-[var(--color-text-secondary)] max-w-xs truncate" x-text="model.description" :title="model.description"></div>
|
||||
</td>
|
||||
|
||||
<!-- Repository -->
|
||||
<!-- Size / VRAM -->
|
||||
<td class="px-6 py-4">
|
||||
<span class="inline-flex items-center text-xs px-2 py-1 rounded bg-[var(--color-primary-light)] text-[var(--color-text-primary)] border border-[var(--color-primary-border)]">
|
||||
<i class="fa-brands fa-git-alt mr-1"></i>
|
||||
<span x-text="model.gallery"></span>
|
||||
</span>
|
||||
</td>
|
||||
|
||||
<!-- License -->
|
||||
<td class="px-6 py-4">
|
||||
<span x-show="model.license" class="inline-flex items-center text-xs px-2 py-1 rounded bg-[var(--color-accent-light)] text-[var(--color-text-primary)] border border-[var(--color-accent)]/30">
|
||||
<i class="fas fa-book mr-1"></i>
|
||||
<span x-text="model.license"></span>
|
||||
</span>
|
||||
<span x-show="!model.license" class="text-xs text-[var(--color-text-secondary)]">-</span>
|
||||
<div class="flex flex-col gap-0.5">
|
||||
<template x-if="(model.estimated_size_display && model.estimated_size_display !== '0 B') || (model.estimated_vram_display && model.estimated_vram_display !== '0 B')">
|
||||
<div class="text-xs text-[var(--color-text-secondary)]">
|
||||
<span x-show="model.estimated_size_display && model.estimated_size_display !== '0 B'" x-text="'Size: ' + model.estimated_size_display"></span>
|
||||
<span x-show="(model.estimated_size_display && model.estimated_size_display !== '0 B') && (model.estimated_vram_display && model.estimated_vram_display !== '0 B')"> · </span>
|
||||
<span x-show="model.estimated_vram_display && model.estimated_vram_display !== '0 B'" x-text="'VRAM: ' + model.estimated_vram_display"></span>
|
||||
</div>
|
||||
</template>
|
||||
<template x-if="model.estimated_vram_bytes && totalMemory > 0">
|
||||
<span :title="(model.estimated_vram_bytes <= totalMemory * 0.95 ? 'Fits your GPU' : 'May not fit your GPU')"
|
||||
class="inline-flex items-center text-xs">
|
||||
<i class="fas fa-microchip mr-1"
|
||||
:class="model.estimated_vram_bytes <= totalMemory * 0.95 ? 'text-[var(--color-success)]' : 'text-[var(--color-error)]'"></i>
|
||||
<span x-text="model.estimated_vram_bytes <= totalMemory * 0.95 ? 'Fits' : 'May not fit'"></span>
|
||||
</span>
|
||||
</template>
|
||||
<span x-show="(!model.estimated_size_display || model.estimated_size_display === '0 B') && (!model.estimated_vram_display || model.estimated_vram_display === '0 B')" class="text-xs text-[var(--color-text-muted)]">-</span>
|
||||
</div>
|
||||
</td>
|
||||
|
||||
<!-- Status -->
|
||||
@@ -414,6 +406,36 @@
|
||||
</div>
|
||||
</div>
|
||||
<div class="text-base leading-relaxed text-[var(--color-text-secondary)] break-words max-w-full markdown-content" x-html="renderMarkdown(selectedModel?.description)"></div>
|
||||
<template x-if="(selectedModel?.estimated_size_display && selectedModel.estimated_size_display !== '0 B') || (selectedModel?.estimated_vram_display && selectedModel.estimated_vram_display !== '0 B')">
|
||||
<div class="space-y-1">
|
||||
<p x-show="selectedModel?.estimated_size_display && selectedModel.estimated_size_display !== '0 B'" class="text-sm text-[var(--color-text-secondary)]">
|
||||
<i class="fas fa-download mr-2 text-[var(--color-primary)]"></i>
|
||||
Estimated download size: <span x-text="selectedModel?.estimated_size_display" class="font-medium text-[var(--color-text-primary)]"></span>
|
||||
</p>
|
||||
<p x-show="selectedModel?.estimated_vram_display && selectedModel.estimated_vram_display !== '0 B'" class="text-sm text-[var(--color-text-secondary)]">
|
||||
<i class="fas fa-memory mr-2 text-[var(--color-primary)]"></i>
|
||||
Estimated VRAM: <span x-text="selectedModel?.estimated_vram_display" class="font-medium text-[var(--color-text-primary)]"></span>
|
||||
</p>
|
||||
<p x-show="selectedModel?.estimated_vram_bytes && totalMemory > 0" class="text-sm">
|
||||
<i class="fas fa-microchip mr-2"
|
||||
:class="selectedModel?.estimated_vram_bytes <= totalMemory * 0.95 ? 'text-[var(--color-success)]' : 'text-[var(--color-error)]'"></i>
|
||||
<span x-text="selectedModel?.estimated_vram_bytes <= totalMemory * 0.95 ? 'Fits your GPU' : 'May not fit your GPU'"
|
||||
:class="selectedModel?.estimated_vram_bytes <= totalMemory * 0.95 ? 'text-[var(--color-success)]' : 'text-[var(--color-error)]'"></span>
|
||||
</p>
|
||||
</div>
|
||||
</template>
|
||||
<template x-if="selectedModel?.gallery || selectedModel?.license">
|
||||
<div class="space-y-1">
|
||||
<p x-show="selectedModel?.gallery" class="text-sm text-[var(--color-text-secondary)]">
|
||||
<i class="fa-brands fa-git-alt mr-2 text-[var(--color-primary)]"></i>
|
||||
Repository: <span x-text="selectedModel?.gallery" class="font-medium text-[var(--color-text-primary)]"></span>
|
||||
</p>
|
||||
<p x-show="selectedModel?.license" class="text-sm text-[var(--color-text-secondary)]">
|
||||
<i class="fas fa-book mr-2 text-[var(--color-primary)]"></i>
|
||||
License: <span x-text="selectedModel?.license" class="font-medium text-[var(--color-text-primary)]"></span>
|
||||
</p>
|
||||
</div>
|
||||
</template>
|
||||
<hr>
|
||||
<template x-if="selectedModel?.urls && selectedModel.urls.length > 0">
|
||||
<div>
|
||||
@@ -614,10 +636,10 @@ function modelsGallery() {
|
||||
totalPages: 1,
|
||||
availableModels: 0,
|
||||
installedModels: 0,
|
||||
storageSize: 0,
|
||||
ramTotal: 0,
|
||||
ramUsed: 0,
|
||||
ramUsagePercent: 0,
|
||||
totalMemory: 0,
|
||||
selectedModel: null,
|
||||
jobProgress: {},
|
||||
notifications: [],
|
||||
@@ -626,10 +648,21 @@ function modelsGallery() {
|
||||
|
||||
init() {
|
||||
this.fetchModels();
|
||||
this.fetchResources();
|
||||
// Poll for job progress every 600ms
|
||||
setInterval(() => this.pollJobs(), 600);
|
||||
},
|
||||
|
||||
async fetchResources() {
|
||||
try {
|
||||
const response = await fetch('/api/resources');
|
||||
if (response.ok) {
|
||||
const data = await response.json();
|
||||
this.totalMemory = data.aggregate?.total_memory || 0;
|
||||
}
|
||||
} catch (e) {}
|
||||
},
|
||||
|
||||
addNotification(message, type = 'error') {
|
||||
const id = Date.now();
|
||||
this.notifications.push({ id, message, type });
|
||||
@@ -663,7 +696,6 @@ function modelsGallery() {
|
||||
this.totalPages = data.totalPages || 1;
|
||||
this.availableModels = data.availableModels || 0;
|
||||
this.installedModels = data.installedModels || 0;
|
||||
this.storageSize = data.storageSize || 0;
|
||||
this.ramTotal = data.ramTotal || 0;
|
||||
this.ramUsed = data.ramUsed || 0;
|
||||
this.ramUsagePercent = data.ramUsagePercent || 0;
|
||||
|
||||
@@ -24,6 +24,11 @@ type BackendMonitorResponse struct {
|
||||
type GalleryResponse struct {
|
||||
ID string `json:"uuid"`
|
||||
StatusURL string `json:"status"`
|
||||
|
||||
EstimatedVRAMBytes uint64 `json:"estimated_vram_bytes,omitempty"`
|
||||
EstimatedVRAMDisplay string `json:"estimated_vram_display,omitempty"`
|
||||
EstimatedSizeBytes uint64 `json:"estimated_size_bytes,omitempty"`
|
||||
EstimatedSizeDisplay string `json:"estimated_size_display,omitempty"`
|
||||
}
|
||||
|
||||
type VideoRequest struct {
|
||||
|
||||
@@ -31,6 +31,15 @@ GPT and text generation models might have a license which is not permissive for
|
||||
|
||||
Navigate the WebUI interface in the "Models" section from the navbar at the top. Here you can find a list of models that can be installed, and you can install them by clicking the "Install" button.
|
||||
|
||||
## VRAM and download size estimates
|
||||
|
||||
When browsing the gallery or importing a model by URI, LocalAI can show **estimated download size** and **estimated VRAM** for models.
|
||||
|
||||
- **Where they appear**: In the model gallery table (Size / VRAM column), in the model detail modal, and after starting an import from URI (in the success message).
|
||||
- **How they are computed**: GGUF models use file size (HTTP HEAD or local stat) and optional GGUF metadata (HTTP Range) for KV cache and overhead; other formats use Hugging Face file sizes and optional config when available. If metadata is unavailable, a size-only heuristic is used.
|
||||
- **Hardware fit indicator**: When your system reports GPU or RAM capacity, the gallery shows whether the estimated VRAM fits (green) or may not fit (red) using a 95% headroom rule.
|
||||
- Estimates are best-effort and may be missing if the server does not support HEAD/Range or the request times out.
|
||||
|
||||
## Add other galleries
|
||||
|
||||
You can add other galleries by:
|
||||
|
||||
@@ -275,6 +275,68 @@ func (uri URI) checkSeverSupportsRangeHeader() (bool, error) {
|
||||
return resp.Header.Get("Accept-Ranges") == "bytes", nil
|
||||
}
|
||||
|
||||
// ContentLength returns the size in bytes of the resource at the URI.
|
||||
// For file:// it uses os.Stat on the resolved path; for HTTP/HTTPS it uses HEAD
|
||||
// and optionally a Range request if Content-Length is missing.
|
||||
func (u URI) ContentLength(ctx context.Context) (int64, error) {
|
||||
urlStr := u.ResolveURL()
|
||||
if strings.HasPrefix(string(u), LocalPrefix) {
|
||||
info, err := os.Stat(urlStr)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return info.Size(), nil
|
||||
}
|
||||
if !u.LooksLikeHTTPURL() {
|
||||
return 0, fmt.Errorf("unsupported URI scheme for ContentLength: %s", string(u))
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, "HEAD", urlStr, nil)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 400 {
|
||||
return 0, fmt.Errorf("HEAD %s: status %d", urlStr, resp.StatusCode)
|
||||
}
|
||||
if resp.ContentLength >= 0 {
|
||||
return resp.ContentLength, nil
|
||||
}
|
||||
if resp.Header.Get("Accept-Ranges") != "bytes" {
|
||||
return 0, fmt.Errorf("HEAD %s: no Content-Length and server does not support Range", urlStr)
|
||||
}
|
||||
req2, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
req2.Header.Set("Range", "bytes=0-0")
|
||||
resp2, err := http.DefaultClient.Do(req2)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer resp2.Body.Close()
|
||||
if resp2.StatusCode != http.StatusPartialContent && resp2.StatusCode != http.StatusOK {
|
||||
return 0, fmt.Errorf("Range request %s: status %d", urlStr, resp2.StatusCode)
|
||||
}
|
||||
cr := resp2.Header.Get("Content-Range")
|
||||
// Content-Range: bytes 0-0/12345
|
||||
if cr == "" {
|
||||
return 0, fmt.Errorf("Range request %s: no Content-Range header", urlStr)
|
||||
}
|
||||
parts := strings.Split(cr, "/")
|
||||
if len(parts) != 2 {
|
||||
return 0, fmt.Errorf("invalid Content-Range: %s", cr)
|
||||
}
|
||||
size, err := strconv.ParseInt(strings.TrimSpace(parts[1]), 10, 64)
|
||||
if err != nil || size < 0 {
|
||||
return 0, fmt.Errorf("invalid Content-Range total length: %s", parts[1])
|
||||
}
|
||||
return size, nil
|
||||
}
|
||||
|
||||
func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStatus func(string, string, string, float64)) error {
|
||||
return uri.DownloadFileWithContext(context.Background(), filePath, sha, fileN, total, downloadStatus)
|
||||
}
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
package downloader_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
|
||||
@@ -48,6 +51,86 @@ var _ = Describe("Gallery API tests", func() {
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("ContentLength", func() {
|
||||
Context("local file", func() {
|
||||
It("returns file size for existing file", func() {
|
||||
dir, err := os.MkdirTemp("", "contentlength-*")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer os.RemoveAll(dir)
|
||||
fpath := filepath.Join(dir, "model.gguf")
|
||||
err = os.WriteFile(fpath, make([]byte, 1234), 0644)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
uri := URI("file://" + fpath)
|
||||
ctx := context.Background()
|
||||
size, err := uri.ContentLength(ctx)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(size).To(Equal(int64(1234)))
|
||||
})
|
||||
It("returns error for missing file", func() {
|
||||
uri := URI("file:///nonexistent/path/model.gguf")
|
||||
ctx := context.Background()
|
||||
_, err := uri.ContentLength(ctx)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
Context("HTTP", func() {
|
||||
It("returns Content-Length when present", func() {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Expect(r.Method).To(Equal("HEAD"))
|
||||
w.Header().Set("Content-Length", "1000")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
uri := URI(server.URL)
|
||||
ctx := context.Background()
|
||||
size, err := uri.ContentLength(ctx)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(size).To(Equal(int64(1000)))
|
||||
})
|
||||
It("returns error on 404", func() {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer server.Close()
|
||||
uri := URI(server.URL)
|
||||
ctx := context.Background()
|
||||
_, err := uri.ContentLength(ctx)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
It("uses Range when Content-Length missing and Accept-Ranges bytes", func() {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "HEAD" {
|
||||
w.Header().Set("Accept-Ranges", "bytes")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
Expect(r.Header.Get("Range")).To(Equal("bytes=0-0"))
|
||||
w.Header().Set("Content-Range", "bytes 0-0/5000")
|
||||
w.WriteHeader(http.StatusPartialContent)
|
||||
}))
|
||||
defer server.Close()
|
||||
uri := URI(server.URL)
|
||||
ctx := context.Background()
|
||||
size, err := uri.ContentLength(ctx)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(size).To(Equal(int64(5000)))
|
||||
})
|
||||
It("respects context cancellation", func() {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Length", "1000")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
uri := URI(server.URL)
|
||||
_, err := uri.ContentLength(ctx)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(errors.Is(err, context.Canceled)).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
type RangeHeaderError struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
96
pkg/vram/cache.go
Normal file
96
pkg/vram/cache.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package vram
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const defaultEstimateCacheTTL = 15 * time.Minute
|
||||
|
||||
type sizeCacheEntry struct {
|
||||
size int64
|
||||
err error
|
||||
until time.Time
|
||||
}
|
||||
|
||||
type cachedSizeResolver struct {
|
||||
underlying SizeResolver
|
||||
ttl time.Duration
|
||||
mu sync.Mutex
|
||||
cache map[string]sizeCacheEntry
|
||||
}
|
||||
|
||||
func (c *cachedSizeResolver) ContentLength(ctx context.Context, uri string) (int64, error) {
|
||||
c.mu.Lock()
|
||||
e, ok := c.cache[uri]
|
||||
c.mu.Unlock()
|
||||
if ok && time.Now().Before(e.until) {
|
||||
return e.size, e.err
|
||||
}
|
||||
size, err := c.underlying.ContentLength(ctx, uri)
|
||||
c.mu.Lock()
|
||||
if c.cache == nil {
|
||||
c.cache = make(map[string]sizeCacheEntry)
|
||||
}
|
||||
c.cache[uri] = sizeCacheEntry{size: size, err: err, until: time.Now().Add(c.ttl)}
|
||||
c.mu.Unlock()
|
||||
return size, err
|
||||
}
|
||||
|
||||
type ggufCacheEntry struct {
|
||||
meta *GGUFMeta
|
||||
err error
|
||||
until time.Time
|
||||
}
|
||||
|
||||
type cachedGGUFReader struct {
|
||||
underlying GGUFMetadataReader
|
||||
ttl time.Duration
|
||||
mu sync.Mutex
|
||||
cache map[string]ggufCacheEntry
|
||||
}
|
||||
|
||||
func (c *cachedGGUFReader) ReadMetadata(ctx context.Context, uri string) (*GGUFMeta, error) {
|
||||
c.mu.Lock()
|
||||
e, ok := c.cache[uri]
|
||||
c.mu.Unlock()
|
||||
if ok && time.Now().Before(e.until) {
|
||||
return e.meta, e.err
|
||||
}
|
||||
meta, err := c.underlying.ReadMetadata(ctx, uri)
|
||||
c.mu.Lock()
|
||||
if c.cache == nil {
|
||||
c.cache = make(map[string]ggufCacheEntry)
|
||||
}
|
||||
c.cache[uri] = ggufCacheEntry{meta: meta, err: err, until: time.Now().Add(c.ttl)}
|
||||
c.mu.Unlock()
|
||||
return meta, err
|
||||
}
|
||||
|
||||
// CachedSizeResolver returns a SizeResolver that caches ContentLength results by URI for the given TTL.
|
||||
func CachedSizeResolver(underlying SizeResolver, ttl time.Duration) SizeResolver {
|
||||
return &cachedSizeResolver{underlying: underlying, ttl: ttl, cache: make(map[string]sizeCacheEntry)}
|
||||
}
|
||||
|
||||
// CachedGGUFReader returns a GGUFMetadataReader that caches ReadMetadata results by URI for the given TTL.
|
||||
func CachedGGUFReader(underlying GGUFMetadataReader, ttl time.Duration) GGUFMetadataReader {
|
||||
return &cachedGGUFReader{underlying: underlying, ttl: ttl, cache: make(map[string]ggufCacheEntry)}
|
||||
}
|
||||
|
||||
// DefaultCachedSizeResolver returns a cached SizeResolver using the default implementation and default TTL (15 min).
|
||||
// A single shared cache is used so repeated HEAD requests for the same URI are avoided across requests.
|
||||
func DefaultCachedSizeResolver() SizeResolver {
|
||||
return defaultCachedSizeResolver
|
||||
}
|
||||
|
||||
// DefaultCachedGGUFReader returns a cached GGUFMetadataReader using the default implementation and default TTL (15 min).
|
||||
// A single shared cache is used so repeated GGUF metadata fetches for the same URI are avoided across requests.
|
||||
func DefaultCachedGGUFReader() GGUFMetadataReader {
|
||||
return defaultCachedGGUFReader
|
||||
}
|
||||
|
||||
var (
|
||||
defaultCachedSizeResolver = CachedSizeResolver(defaultSizeResolver{}, defaultEstimateCacheTTL)
|
||||
defaultCachedGGUFReader = CachedGGUFReader(defaultGGUFReader{}, defaultEstimateCacheTTL)
|
||||
)
|
||||
152
pkg/vram/estimate.go
Normal file
152
pkg/vram/estimate.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package vram
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
)
|
||||
|
||||
var weightExts = map[string]bool{
|
||||
".gguf": true, ".safetensors": true, ".bin": true, ".pt": true,
|
||||
}
|
||||
|
||||
func isWeightFile(nameOrURI string) bool {
|
||||
ext := strings.ToLower(path.Ext(path.Base(nameOrURI)))
|
||||
return weightExts[ext]
|
||||
}
|
||||
|
||||
func isGGUF(nameOrURI string) bool {
|
||||
return strings.ToLower(path.Ext(path.Base(nameOrURI))) == ".gguf"
|
||||
}
|
||||
|
||||
func Estimate(ctx context.Context, files []FileInput, opts EstimateOptions, sizeResolver SizeResolver, ggufReader GGUFMetadataReader) (EstimateResult, error) {
|
||||
if opts.ContextLength == 0 {
|
||||
opts.ContextLength = 8192
|
||||
}
|
||||
if opts.KVQuantBits == 0 {
|
||||
opts.KVQuantBits = 16
|
||||
}
|
||||
|
||||
var sizeBytes uint64
|
||||
var ggufSize uint64
|
||||
var firstGGUFURI string
|
||||
for i := range files {
|
||||
f := &files[i]
|
||||
if !isWeightFile(f.URI) {
|
||||
continue
|
||||
}
|
||||
sz := f.Size
|
||||
if sz <= 0 && sizeResolver != nil {
|
||||
var err error
|
||||
sz, err = sizeResolver.ContentLength(ctx, f.URI)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
sizeBytes += uint64(sz)
|
||||
if isGGUF(f.URI) {
|
||||
ggufSize += uint64(sz)
|
||||
if firstGGUFURI == "" {
|
||||
firstGGUFURI = f.URI
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sizeDisplay := FormatBytes(sizeBytes)
|
||||
|
||||
var vramBytes uint64
|
||||
if ggufSize > 0 {
|
||||
var meta *GGUFMeta
|
||||
if ggufReader != nil && firstGGUFURI != "" {
|
||||
meta, _ = ggufReader.ReadMetadata(ctx, firstGGUFURI)
|
||||
}
|
||||
if meta != nil && (meta.BlockCount > 0 || meta.EmbeddingLength > 0) {
|
||||
nLayers := meta.BlockCount
|
||||
if nLayers == 0 {
|
||||
nLayers = 32
|
||||
}
|
||||
dModel := meta.EmbeddingLength
|
||||
if dModel == 0 {
|
||||
dModel = 4096
|
||||
}
|
||||
headCountKV := meta.HeadCountKV
|
||||
if headCountKV == 0 {
|
||||
headCountKV = meta.HeadCount
|
||||
}
|
||||
if headCountKV == 0 {
|
||||
headCountKV = 8
|
||||
}
|
||||
gpuLayers := opts.GPULayers
|
||||
if gpuLayers <= 0 {
|
||||
gpuLayers = int(nLayers)
|
||||
}
|
||||
ctxLen := opts.ContextLength
|
||||
bKV := uint32(opts.KVQuantBits / 8)
|
||||
if bKV == 0 {
|
||||
bKV = 4
|
||||
}
|
||||
M_model := ggufSize
|
||||
M_KV := uint64(bKV) * uint64(dModel) * uint64(nLayers) * uint64(ctxLen)
|
||||
if headCountKV > 0 && meta.HeadCount > 0 {
|
||||
M_KV = uint64(bKV) * uint64(dModel) * uint64(headCountKV) * uint64(ctxLen)
|
||||
}
|
||||
P := M_model * 2
|
||||
M_overhead := uint64(0.02*float64(P) + 0.15*1e9)
|
||||
vramBytes = M_model + M_KV + M_overhead
|
||||
if nLayers > 0 && gpuLayers < int(nLayers) {
|
||||
layerRatio := float64(gpuLayers) / float64(nLayers)
|
||||
vramBytes = uint64(layerRatio*float64(M_model)) + M_KV + M_overhead
|
||||
}
|
||||
} else {
|
||||
vramBytes = sizeOnlyVRAM(ggufSize, opts.ContextLength)
|
||||
}
|
||||
} else if sizeBytes > 0 {
|
||||
vramBytes = sizeOnlyVRAM(sizeBytes, opts.ContextLength)
|
||||
}
|
||||
|
||||
return EstimateResult{
|
||||
SizeBytes: sizeBytes,
|
||||
SizeDisplay: sizeDisplay,
|
||||
VRAMBytes: vramBytes,
|
||||
VRAMDisplay: FormatBytes(vramBytes),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func sizeOnlyVRAM(sizeOnDisk uint64, ctxLen uint32) uint64 {
|
||||
k := uint64(1024)
|
||||
vram := sizeOnDisk + k*uint64(ctxLen)*2
|
||||
if vram < sizeOnDisk {
|
||||
vram = sizeOnDisk
|
||||
}
|
||||
return vram
|
||||
}
|
||||
|
||||
func FormatBytes(n uint64) string {
|
||||
const unit = 1000
|
||||
if n < unit {
|
||||
return fmt.Sprintf("%d B", n)
|
||||
}
|
||||
div, exp := uint64(unit), 0
|
||||
for u := n / unit; u >= unit; u /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %cB", float64(n)/float64(div), "KMGTPE"[exp])
|
||||
}
|
||||
|
||||
type defaultSizeResolver struct{}
|
||||
|
||||
func (defaultSizeResolver) ContentLength(ctx context.Context, uri string) (int64, error) {
|
||||
return downloader.URI(uri).ContentLength(ctx)
|
||||
}
|
||||
|
||||
func DefaultSizeResolver() SizeResolver {
|
||||
return defaultSizeResolver{}
|
||||
}
|
||||
|
||||
func DefaultGGUFReader() GGUFMetadataReader {
|
||||
return defaultGGUFReader{}
|
||||
}
|
||||
137
pkg/vram/estimate_test.go
Normal file
137
pkg/vram/estimate_test.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package vram_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
. "github.com/mudler/LocalAI/pkg/vram"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type fakeSizeResolver map[string]int64
|
||||
|
||||
func (f fakeSizeResolver) ContentLength(ctx context.Context, uri string) (int64, error) {
|
||||
if n, ok := f[uri]; ok {
|
||||
return int64(n), nil
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
type fakeGGUFReader map[string]*GGUFMeta
|
||||
|
||||
func (f fakeGGUFReader) ReadMetadata(ctx context.Context, uri string) (*GGUFMeta, error) {
|
||||
return f[uri], nil
|
||||
}
|
||||
|
||||
var _ = Describe("Estimate", func() {
|
||||
ctx := context.Background()
|
||||
|
||||
Describe("empty or non-GGUF inputs", func() {
|
||||
It("returns zero size and vram for nil files", func() {
|
||||
opts := EstimateOptions{ContextLength: 8192}
|
||||
res, err := Estimate(ctx, nil, opts, nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(res.SizeBytes).To(Equal(uint64(0)))
|
||||
Expect(res.VRAMBytes).To(Equal(uint64(0)))
|
||||
Expect(res.SizeDisplay).To(Equal("0 B"))
|
||||
})
|
||||
|
||||
It("counts only .gguf files and ignores other extensions", func() {
|
||||
files := []FileInput{
|
||||
{URI: "http://a/model.gguf", Size: 1_000_000_000},
|
||||
{URI: "http://a/readme.txt", Size: 100},
|
||||
}
|
||||
opts := EstimateOptions{ContextLength: 8192}
|
||||
res, err := Estimate(ctx, files, opts, nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(res.SizeBytes).To(Equal(uint64(1_000_000_000)))
|
||||
})
|
||||
|
||||
It("sums size for multiple non-GGUF weight files (e.g. safetensors)", func() {
|
||||
files := []FileInput{
|
||||
{URI: "http://hf.co/model/model.safetensors", Size: 2_000_000_000},
|
||||
{URI: "http://hf.co/model/model2.safetensors", Size: 3_000_000_000},
|
||||
}
|
||||
opts := EstimateOptions{ContextLength: 8192}
|
||||
res, err := Estimate(ctx, files, opts, nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(res.SizeBytes).To(Equal(uint64(5_000_000_000)))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("GGUF size and resolver", func() {
|
||||
It("uses size resolver when file size is not set", func() {
|
||||
sizes := fakeSizeResolver{"http://example.com/model.gguf": 1_500_000_000}
|
||||
opts := EstimateOptions{ContextLength: 8192}
|
||||
files := []FileInput{{URI: "http://example.com/model.gguf"}}
|
||||
|
||||
res, err := Estimate(ctx, files, opts, sizes, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(res.SizeBytes).To(Equal(uint64(1_500_000_000)))
|
||||
Expect(res.VRAMBytes).To(BeNumerically(">=", res.SizeBytes))
|
||||
Expect(res.SizeDisplay).To(Equal("1.5 GB"))
|
||||
})
|
||||
|
||||
It("uses size-only VRAM formula when metadata is missing and size is large", func() {
|
||||
sizes := fakeSizeResolver{"http://a/model.gguf": 10_000_000_000}
|
||||
opts := EstimateOptions{ContextLength: 8192}
|
||||
files := []FileInput{{URI: "http://a/model.gguf"}}
|
||||
|
||||
res, err := Estimate(ctx, files, opts, sizes, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(res.VRAMBytes).To(BeNumerically(">", 10_000_000_000))
|
||||
})
|
||||
|
||||
It("sums size for multiple GGUF shards", func() {
|
||||
files := []FileInput{
|
||||
{URI: "http://a/shard1.gguf", Size: 10_000_000_000},
|
||||
{URI: "http://a/shard2.gguf", Size: 5_000_000_000},
|
||||
}
|
||||
opts := EstimateOptions{ContextLength: 8192}
|
||||
|
||||
res, err := Estimate(ctx, files, opts, nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(res.SizeBytes).To(Equal(uint64(15_000_000_000)))
|
||||
})
|
||||
|
||||
It("formats size display correctly", func() {
|
||||
files := []FileInput{{URI: "http://a/model.gguf", Size: 2_500_000_000}}
|
||||
opts := EstimateOptions{ContextLength: 8192}
|
||||
|
||||
res, err := Estimate(ctx, files, opts, nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(res.SizeDisplay).To(Equal("2.5 GB"))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("GGUF with metadata reader", func() {
|
||||
It("uses metadata for VRAM when reader returns meta and partial offload", func() {
|
||||
meta := &GGUFMeta{BlockCount: 32, EmbeddingLength: 4096}
|
||||
reader := fakeGGUFReader{"http://a/model.gguf": meta}
|
||||
opts := EstimateOptions{ContextLength: 8192, GPULayers: 20}
|
||||
files := []FileInput{{URI: "http://a/model.gguf", Size: 8_000_000_000}}
|
||||
|
||||
res, err := Estimate(ctx, files, opts, nil, reader)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(res.VRAMBytes).To(BeNumerically(">", 0))
|
||||
})
|
||||
|
||||
It("uses metadata head counts for KV and yields vram > size", func() {
|
||||
files := []FileInput{{URI: "http://a/model.gguf", Size: 15_000_000_000}}
|
||||
meta := &GGUFMeta{BlockCount: 32, EmbeddingLength: 4096, HeadCount: 32, HeadCountKV: 8}
|
||||
reader := fakeGGUFReader{"http://a/model.gguf": meta}
|
||||
opts := EstimateOptions{ContextLength: 8192}
|
||||
|
||||
res, err := Estimate(ctx, files, opts, nil, reader)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(res.SizeBytes).To(Equal(uint64(15_000_000_000)))
|
||||
Expect(res.VRAMBytes).To(BeNumerically(">", res.SizeBytes))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("FormatBytes", func() {
|
||||
It("formats 2.5e9 as 2.5 GB", func() {
|
||||
Expect(FormatBytes(2_500_000_000)).To(Equal("2.5 GB"))
|
||||
})
|
||||
})
|
||||
46
pkg/vram/gguf_reader.go
Normal file
46
pkg/vram/gguf_reader.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package vram
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
gguf "github.com/gpustack/gguf-parser-go"
|
||||
)
|
||||
|
||||
type defaultGGUFReader struct{}
|
||||
|
||||
func (defaultGGUFReader) ReadMetadata(ctx context.Context, uri string) (*GGUFMeta, error) {
|
||||
u := downloader.URI(uri)
|
||||
urlStr := u.ResolveURL()
|
||||
|
||||
if strings.HasPrefix(uri, downloader.LocalPrefix) {
|
||||
f, err := gguf.ParseGGUFFile(urlStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ggufFileToMeta(f), nil
|
||||
}
|
||||
if !u.LooksLikeHTTPURL() {
|
||||
return nil, nil
|
||||
}
|
||||
f, err := gguf.ParseGGUFFileRemote(ctx, urlStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ggufFileToMeta(f), nil
|
||||
}
|
||||
|
||||
func ggufFileToMeta(f *gguf.GGUFFile) *GGUFMeta {
|
||||
arch := f.Architecture()
|
||||
meta := &GGUFMeta{
|
||||
BlockCount: uint32(arch.BlockCount),
|
||||
EmbeddingLength: uint32(arch.EmbeddingLength),
|
||||
HeadCount: uint32(arch.AttentionHeadCount),
|
||||
HeadCountKV: uint32(arch.AttentionHeadCountKV),
|
||||
}
|
||||
if meta.HeadCountKV == 0 {
|
||||
meta.HeadCountKV = meta.HeadCount
|
||||
}
|
||||
return meta
|
||||
}
|
||||
42
pkg/vram/types.go
Normal file
42
pkg/vram/types.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package vram
|
||||
|
||||
import "context"
|
||||
|
||||
// FileInput represents a single model file for estimation (URI and optional pre-known size).
|
||||
type FileInput struct {
|
||||
URI string
|
||||
Size int64
|
||||
}
|
||||
|
||||
// SizeResolver returns the content length in bytes for a given URI.
|
||||
type SizeResolver interface {
|
||||
ContentLength(ctx context.Context, uri string) (int64, error)
|
||||
}
|
||||
|
||||
// GGUFMeta holds parsed GGUF metadata used for VRAM estimation.
|
||||
type GGUFMeta struct {
|
||||
BlockCount uint32
|
||||
EmbeddingLength uint32
|
||||
HeadCount uint32
|
||||
HeadCountKV uint32
|
||||
}
|
||||
|
||||
// GGUFMetadataReader reads GGUF metadata from a URI (e.g. via HTTP Range).
|
||||
type GGUFMetadataReader interface {
|
||||
ReadMetadata(ctx context.Context, uri string) (*GGUFMeta, error)
|
||||
}
|
||||
|
||||
// EstimateOptions configures VRAM/size estimation.
|
||||
type EstimateOptions struct {
|
||||
ContextLength uint32
|
||||
GPULayers int
|
||||
KVQuantBits int
|
||||
}
|
||||
|
||||
// EstimateResult holds estimated download size and VRAM with display strings.
|
||||
type EstimateResult struct {
|
||||
SizeBytes uint64
|
||||
SizeDisplay string
|
||||
VRAMBytes uint64
|
||||
VRAMDisplay string
|
||||
}
|
||||
13
pkg/vram/vram_suite_test.go
Normal file
13
pkg/vram/vram_suite_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package vram_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestVram(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Vram test suite")
|
||||
}
|
||||
Reference in New Issue
Block a user