mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-20 06:39:01 -04:00
Compare commits
12 Commits
feat/depth
...
feat/black
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6715d75f22 | ||
|
|
2f7e76f0f3 | ||
|
|
bca250e2bd | ||
|
|
079ac0e15a | ||
|
|
2e734bf560 | ||
|
|
72d46c1115 | ||
|
|
606128e4e9 | ||
|
|
59c7ad5153 | ||
|
|
78d682224a | ||
|
|
29dbba7a25 | ||
|
|
4ad754eea3 | ||
|
|
67692cb984 |
@@ -70,6 +70,12 @@ if [ "${BUILD_TYPE:-}" = "vulkan" ] && [ "${SKIP_DRIVERS:-false}" = "false" ]; t
|
||||
git python-is-python3 bison libx11-xcb-dev liblz4-dev libzstd-dev \
|
||||
ocaml-core ninja-build pkg-config libxml2-dev wayland-protocols python3-jsonschema \
|
||||
clang-format qtbase5-dev qt6-base-dev libxcb-glx0-dev sudo xz-utils
|
||||
# Mesa Vulkan ICD drivers (ANV/RADV/lavapipe + Arm SoC) and their ICD
|
||||
# manifests. The LunarG SDK below only provides the loader and shader
|
||||
# tooling, not hardware drivers — without Mesa the packaged Vulkan backend
|
||||
# would ship a loader that finds no GPU. package-gpu-libs.sh bundles these
|
||||
# .so files plus their deps into the backend so it stays self-contained.
|
||||
apt-get install -y mesa-vulkan-drivers libdrm2
|
||||
if [ "amd64" = "${TARGETARCH:-}" ]; then
|
||||
wget "https://sdk.lunarg.com/sdk/download/1.4.335.0/linux/vulkansdk-linux-x86_64-1.4.335.0.tar.xz"
|
||||
tar -xf vulkansdk-linux-x86_64-1.4.335.0.tar.xz
|
||||
|
||||
@@ -65,7 +65,12 @@ RUN <<EOT bash
|
||||
libwayland-dev libxrandr-dev libxcb-randr0-dev libxcb-ewmh-dev \
|
||||
git python-is-python3 bison libx11-xcb-dev liblz4-dev libzstd-dev \
|
||||
ocaml-core ninja-build pkg-config libxml2-dev wayland-protocols python3-jsonschema \
|
||||
clang-format qtbase5-dev qt6-base-dev libxcb-glx0-dev sudo xz-utils
|
||||
clang-format qtbase5-dev qt6-base-dev libxcb-glx0-dev sudo xz-utils && \
|
||||
apt-get install -y mesa-vulkan-drivers libdrm2
|
||||
# Mesa Vulkan ICD drivers (ANV/RADV/lavapipe) + their manifests. The
|
||||
# LunarG SDK below only provides the loader and shader tooling, not
|
||||
# hardware drivers — without Mesa, package-gpu-libs.sh has no ICD to
|
||||
# bundle and the packaged backend finds no GPU at runtime.
|
||||
if [ "amd64" = "$TARGETARCH" ]; then
|
||||
wget "https://sdk.lunarg.com/sdk/download/1.4.335.0/linux/vulkansdk-linux-x86_64-1.4.335.0.tar.xz" && \
|
||||
tar -xf vulkansdk-linux-x86_64-1.4.335.0.tar.xz && \
|
||||
|
||||
@@ -66,7 +66,12 @@ RUN <<EOT bash
|
||||
libwayland-dev libxrandr-dev libxcb-randr0-dev libxcb-ewmh-dev \
|
||||
git python-is-python3 bison libx11-xcb-dev liblz4-dev libzstd-dev \
|
||||
ocaml-core ninja-build pkg-config libxml2-dev wayland-protocols python3-jsonschema \
|
||||
clang-format qtbase5-dev qt6-base-dev libxcb-glx0-dev sudo xz-utils
|
||||
clang-format qtbase5-dev qt6-base-dev libxcb-glx0-dev sudo xz-utils && \
|
||||
apt-get install -y mesa-vulkan-drivers libdrm2
|
||||
# Mesa Vulkan ICD drivers (ANV/RADV/lavapipe) + their manifests. The
|
||||
# LunarG SDK below only provides the loader and shader tooling, not
|
||||
# hardware drivers — without Mesa, package-gpu-libs.sh has no ICD to
|
||||
# bundle and the packaged backend finds no GPU at runtime.
|
||||
if [ "amd64" = "$TARGETARCH" ]; then
|
||||
wget "https://sdk.lunarg.com/sdk/download/1.4.335.0/linux/vulkansdk-linux-x86_64-1.4.335.0.tar.xz" && \
|
||||
tar -xf vulkansdk-linux-x86_64-1.4.335.0.tar.xz && \
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
IK_LLAMA_VERSION?=71af16a6b7f6fb7315b346b4a51aad530599c3f5
|
||||
IK_LLAMA_VERSION?=b3dfb7858cfcb9166e92f366e5af87f19ebc94be
|
||||
LLAMA_REPO?=https://github.com/ikawrakow/ik_llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -67,7 +67,7 @@ sources/CrispASR:
|
||||
# it, so ${CMAKE_SOURCE_DIR} is THIS backend dir and the talk-llama sources
|
||||
# aren't found. Rewrite to ${PROJECT_SOURCE_DIR} (the crispasr project root),
|
||||
# which is correct both standalone and as a subproject. Idempotent.
|
||||
sed -i 's#\$${CMAKE_SOURCE_DIR}/examples/talk-llama#\$${PROJECT_SOURCE_DIR}/examples/talk-llama#' sources/CrispASR/src/CMakeLists.txt
|
||||
sed -i.bak 's#\$${CMAKE_SOURCE_DIR}/examples/talk-llama#\$${PROJECT_SOURCE_DIR}/examples/talk-llama#' sources/CrispASR/src/CMakeLists.txt && rm -f sources/CrispASR/src/CMakeLists.txt.bak
|
||||
|
||||
# Detect OS
|
||||
UNAME_S := $(shell uname -s)
|
||||
|
||||
@@ -47,6 +47,74 @@ extern "C" void set_abort(int v) {
|
||||
g_abort.store(v, std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
// --- word-level timestamp accessors ---
|
||||
extern "C" {
|
||||
int crispasr_session_result_n_words(crispasr_session_result *r, int seg_i);
|
||||
const char *crispasr_session_result_word_text(crispasr_session_result *r,
|
||||
int seg_i, int word_i);
|
||||
int64_t crispasr_session_result_word_t0(crispasr_session_result *r, int seg_i,
|
||||
int word_i);
|
||||
int64_t crispasr_session_result_word_t1(crispasr_session_result *r, int seg_i,
|
||||
int word_i);
|
||||
|
||||
// Parakeet-specific word accessors
|
||||
int crispasr_parakeet_result_n_words(void *r);
|
||||
const char *crispasr_parakeet_result_word_text(void *r, int word_i);
|
||||
int64_t crispasr_parakeet_result_word_t0(void *r, int word_i);
|
||||
int64_t crispasr_parakeet_result_word_t1(void *r, int word_i);
|
||||
}
|
||||
|
||||
void *get_result(void) { return g_result; }
|
||||
|
||||
int get_word_count(int seg_i) {
|
||||
if (!g_result)
|
||||
return 0;
|
||||
return crispasr_session_result_n_words(g_result, seg_i);
|
||||
}
|
||||
|
||||
const char *get_word_text(int seg_i, int word_i) {
|
||||
if (!g_result)
|
||||
return "";
|
||||
return crispasr_session_result_word_text(g_result, seg_i, word_i);
|
||||
}
|
||||
|
||||
int64_t get_word_t0(int seg_i, int word_i) {
|
||||
if (!g_result)
|
||||
return 0;
|
||||
return crispasr_session_result_word_t0(g_result, seg_i, word_i);
|
||||
}
|
||||
|
||||
int64_t get_word_t1(int seg_i, int word_i) {
|
||||
if (!g_result)
|
||||
return 0;
|
||||
return crispasr_session_result_word_t1(g_result, seg_i, word_i);
|
||||
}
|
||||
|
||||
// Parakeet-specific word accessors
|
||||
int get_parakeet_word_count(void) {
|
||||
if (!g_result)
|
||||
return 0;
|
||||
return crispasr_parakeet_result_n_words(g_result);
|
||||
}
|
||||
|
||||
const char *get_parakeet_word_text(int word_i) {
|
||||
if (!g_result)
|
||||
return "";
|
||||
return crispasr_parakeet_result_word_text(g_result, word_i);
|
||||
}
|
||||
|
||||
int64_t get_parakeet_word_t0(int word_i) {
|
||||
if (!g_result)
|
||||
return 0;
|
||||
return crispasr_parakeet_result_word_t0(g_result, word_i);
|
||||
}
|
||||
|
||||
int64_t get_parakeet_word_t1(int word_i) {
|
||||
if (!g_result)
|
||||
return 0;
|
||||
return crispasr_parakeet_result_word_t1(g_result, word_i);
|
||||
}
|
||||
|
||||
static void ggml_log_cb(enum ggml_log_level level, const char *log,
|
||||
void *data) {
|
||||
const char *level_str;
|
||||
|
||||
@@ -20,4 +20,18 @@ float *tts_synthesize(const char *text, int *out_n_samples); // 24kHz mono float
|
||||
void tts_free(float *pcm);
|
||||
int tts_set_voice(const char *name); // best-effort speaker selection; 0 ok
|
||||
int tts_set_voice_file(const char *path, const char *ref_text); // load voice pack (.gguf) or zero-shot clone (.wav + ref_text)
|
||||
|
||||
// --- word-level timestamp accessors ---
|
||||
// Session-based (works for whisper-like backends)
|
||||
void *get_result(void);
|
||||
int get_word_count(int seg_i);
|
||||
const char *get_word_text(int seg_i, int word_i);
|
||||
int64_t get_word_t0(int seg_i, int word_i);
|
||||
int64_t get_word_t1(int seg_i, int word_i);
|
||||
|
||||
// Parakeet-specific (global word list, no segment index)
|
||||
int get_parakeet_word_count(void);
|
||||
const char *get_parakeet_word_text(int word_i);
|
||||
int64_t get_parakeet_word_t0(int word_i);
|
||||
int64_t get_parakeet_word_t1(int word_i);
|
||||
}
|
||||
|
||||
@@ -34,6 +34,18 @@ var (
|
||||
CppTTSFree func(ptr uintptr)
|
||||
CppTTSSetVoice func(name string) int
|
||||
CppTTSSetVoiceFile func(path string, refText string) int
|
||||
|
||||
// Word-level timestamp accessors (session-based, per-segment)
|
||||
CppGetWordCount func(segI int) int
|
||||
CppGetWordText func(segI int, wordI int) string
|
||||
CppGetWordT0 func(segI int, wordI int) int64
|
||||
CppGetWordT1 func(segI int, wordI int) int64
|
||||
|
||||
// Parakeet-specific word accessors (global, no segment index)
|
||||
CppGetParakeetWordCount func() int
|
||||
CppGetParakeetWordText func(wordI int) string
|
||||
CppGetParakeetWordT0 func(wordI int) int64
|
||||
CppGetParakeetWordT1 func(wordI int) int64
|
||||
)
|
||||
|
||||
type CrispASR struct {
|
||||
@@ -290,10 +302,36 @@ func (w *CrispASR) AudioTranscription(ctx context.Context, opts *pb.TranscriptRe
|
||||
// IDs, so Tokens is left empty.
|
||||
txt := strings.ToValidUTF8(strings.Clone(CppGetSegmentText(i)), "<22>")
|
||||
|
||||
// Populate word-level timestamps. Try session-based functions first
|
||||
// (per-segment); fall back to parakeet-specific functions (global word
|
||||
// list with no segment index — only populated on the first segment to
|
||||
// avoid duplication).
|
||||
words := []*pb.TranscriptWord{}
|
||||
wordCount := CppGetWordCount(i)
|
||||
if wordCount == 0 && i == 0 {
|
||||
wordCount = CppGetParakeetWordCount()
|
||||
for j := 0; j < wordCount; j++ {
|
||||
words = append(words, &pb.TranscriptWord{
|
||||
Start: CppGetParakeetWordT0(j) * (10000000),
|
||||
End: CppGetParakeetWordT1(j) * (10000000),
|
||||
Text: strings.ToValidUTF8(strings.Clone(CppGetParakeetWordText(j)), "<22>"),
|
||||
})
|
||||
}
|
||||
} else {
|
||||
for j := 0; j < wordCount; j++ {
|
||||
words = append(words, &pb.TranscriptWord{
|
||||
Start: CppGetWordT0(i, j) * (10000000),
|
||||
End: CppGetWordT1(i, j) * (10000000),
|
||||
Text: strings.ToValidUTF8(strings.Clone(CppGetWordText(i, j)), "<22>"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
segment := &pb.TranscriptSegment{
|
||||
Id: int32(i),
|
||||
Text: txt,
|
||||
Start: s, End: t,
|
||||
Words: words,
|
||||
}
|
||||
|
||||
segments = append(segments, segment)
|
||||
|
||||
@@ -44,6 +44,14 @@ func main() {
|
||||
{&CppTTSFree, "tts_free"},
|
||||
{&CppTTSSetVoice, "tts_set_voice"},
|
||||
{&CppTTSSetVoiceFile, "tts_set_voice_file"},
|
||||
{&CppGetWordCount, "get_word_count"},
|
||||
{&CppGetWordText, "get_word_text"},
|
||||
{&CppGetWordT0, "get_word_t0"},
|
||||
{&CppGetWordT1, "get_word_t1"},
|
||||
{&CppGetParakeetWordCount, "get_parakeet_word_count"},
|
||||
{&CppGetParakeetWordText, "get_parakeet_word_text"},
|
||||
{&CppGetParakeetWordT0, "get_parakeet_word_t0"},
|
||||
{&CppGetParakeetWordT1, "get_parakeet_word_t1"},
|
||||
}
|
||||
|
||||
for _, lf := range libFuncs {
|
||||
|
||||
@@ -1,23 +1,68 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# L0 packaging stub: copy the binary, run.sh and libparakeet.so* into
|
||||
# package/. The full ldd walk (libc, libstdc++, libgomp, GPU runtimes,
|
||||
# arch detection) lands in L3, mirroring backend/go/whisper/package.sh.
|
||||
# Bundle the parakeet-cpp-grpc binary, libparakeet.so, the core runtime
|
||||
# libs (libc/libstdc++/libgomp + ld.so) and the GPU runtime for the active
|
||||
# BUILD_TYPE so the package is self-contained. Mirrors
|
||||
# backend/go/whisper/package.sh; run.sh routes the (CGO_ENABLED=0) binary
|
||||
# through lib/ld.so so the packaged libc is used instead of the host's.
|
||||
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath "$0")")
|
||||
REPO_ROOT="${CURDIR}/../../.."
|
||||
|
||||
mkdir -p "$CURDIR/package/lib"
|
||||
|
||||
cp -avf "$CURDIR/parakeet-cpp-grpc" "$CURDIR/package/"
|
||||
cp -avf "$CURDIR/run.sh" "$CURDIR/package/"
|
||||
|
||||
# libparakeet.so + any soname symlinks (libparakeet.so.X, libparakeet.so.X.Y).
|
||||
# libparakeet.so + any soname symlinks (libparakeet.so.X[.Y]). purego.Dlopen
|
||||
# resolves it via LD_LIBRARY_PATH, which run.sh points at lib/.
|
||||
cp -avf "$CURDIR"/libparakeet.so* "$CURDIR/package/lib/" 2>/dev/null || {
|
||||
echo "ERROR: libparakeet.so not found in $CURDIR, run 'make' first" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
echo "L0 package layout (full ldd walk lands in L3):"
|
||||
# Detect architecture and copy the core runtime libs libparakeet.so links
|
||||
# against, plus the matching dynamic loader as lib/ld.so.
|
||||
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
|
||||
echo "Detected x86_64 architecture, copying x86_64 libraries..."
|
||||
cp -arfLv /lib64/ld-linux-x86-64.so.2 "$CURDIR/package/lib/ld.so"
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 "$CURDIR/package/lib/libc.so.6"
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 "$CURDIR/package/lib/libgcc_s.so.1"
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 "$CURDIR/package/lib/libstdc++.so.6"
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 "$CURDIR/package/lib/libm.so.6"
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 "$CURDIR/package/lib/libgomp.so.1"
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 "$CURDIR/package/lib/libdl.so.2"
|
||||
cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 "$CURDIR/package/lib/librt.so.1"
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 "$CURDIR/package/lib/libpthread.so.0"
|
||||
elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
||||
echo "Detected ARM64 architecture, copying ARM64 libraries..."
|
||||
cp -arfLv /lib/ld-linux-aarch64.so.1 "$CURDIR/package/lib/ld.so"
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 "$CURDIR/package/lib/libc.so.6"
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 "$CURDIR/package/lib/libgcc_s.so.1"
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 "$CURDIR/package/lib/libstdc++.so.6"
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 "$CURDIR/package/lib/libm.so.6"
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 "$CURDIR/package/lib/libgomp.so.1"
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 "$CURDIR/package/lib/libdl.so.2"
|
||||
cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 "$CURDIR/package/lib/librt.so.1"
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 "$CURDIR/package/lib/libpthread.so.0"
|
||||
elif [ "$(uname -s)" = "Darwin" ]; then
|
||||
echo "Detected Darwin"
|
||||
else
|
||||
echo "Error: Could not detect architecture"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Package GPU libraries (CUDA/ROCm/Intel/Vulkan loader + ICDs + drivers)
|
||||
# based on BUILD_TYPE so the backend can reach the GPU without the runtime
|
||||
# base image shipping those drivers.
|
||||
GPU_LIB_SCRIPT="${REPO_ROOT}/scripts/build/package-gpu-libs.sh"
|
||||
if [ -f "$GPU_LIB_SCRIPT" ]; then
|
||||
echo "Packaging GPU libraries for BUILD_TYPE=${BUILD_TYPE:-cpu}..."
|
||||
source "$GPU_LIB_SCRIPT" "$CURDIR/package/lib"
|
||||
package_gpu_libs
|
||||
fi
|
||||
|
||||
echo "Packaging completed successfully"
|
||||
ls -liah "$CURDIR/package/" "$CURDIR/package/lib/"
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/services/storage"
|
||||
coreStartup "github.com/mudler/LocalAI/core/startup"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
"github.com/mudler/LocalAI/pkg/signals"
|
||||
"github.com/mudler/LocalAI/pkg/vram"
|
||||
|
||||
@@ -71,6 +72,16 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create ModelPath: %q", err)
|
||||
}
|
||||
|
||||
// Reap *.partial downloads abandoned by a previous run (killed mid-transfer
|
||||
// by an OOM/restart, or stalled before cleanup could run). The 24h window
|
||||
// is well beyond any legitimate in-flight download, so this never trims an
|
||||
// active transfer; it just stops dead partials accumulating on the volume.
|
||||
if removed, cErr := downloader.CleanupStalePartialFiles(options.SystemState.Model.ModelsPath, 24*time.Hour); cErr != nil {
|
||||
xlog.Warn("Failed to reap stale partial downloads", "error", cErr)
|
||||
} else if removed > 0 {
|
||||
xlog.Info("Reaped stale partial downloads", "count", removed)
|
||||
}
|
||||
if options.GeneratedContentDir != "" {
|
||||
err := os.MkdirAll(options.GeneratedContentDir, 0o750)
|
||||
if err != nil {
|
||||
|
||||
190
core/config/hardware_defaults.go
Normal file
190
core/config/hardware_defaults.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/xsysinfo"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// Hardware-driven model-config defaults.
|
||||
//
|
||||
// This sits alongside the other config overriders (ApplyInferenceDefaults for
|
||||
// model families, guessDefaultsFromFile for GGUF/NGPULayers): they all
|
||||
// heuristically fill ModelConfig values the user left unset. Hardware tuning is
|
||||
// the same domain — "adjust the config from the device that will run it" — so
|
||||
// it lives here rather than scattered into the backend or a separate package.
|
||||
//
|
||||
// The heuristics are parameterized on a GPU descriptor (not on direct
|
||||
// detection) so they apply in both deployment shapes: SetDefaults passes the
|
||||
// LocalGPU on a single host, and the distributed router passes the *selected
|
||||
// node's* reported GPU before loading there (the frontend that loaded the
|
||||
// config may have no GPU at all).
|
||||
|
||||
// GPU describes the device that will run a model.
|
||||
type GPU struct {
|
||||
// Vendor is "nvidia", "amd", … (matches xsysinfo vendor constants).
|
||||
Vendor string
|
||||
// ComputeCapability is the NVIDIA compute capability as "major.minor"
|
||||
// (e.g. "12.1" for GB10 / DGX Spark). Empty for non-NVIDIA / unknown.
|
||||
ComputeCapability string
|
||||
// VRAM is total device memory in bytes (0 = unknown).
|
||||
VRAM uint64
|
||||
}
|
||||
|
||||
// Physical batch (n_batch / n_ubatch) defaults.
|
||||
const (
|
||||
// DefaultPhysicalBatch is the conservative default when no hardware-specific
|
||||
// tuning applies. Matches backend.DefaultBatchSize.
|
||||
DefaultPhysicalBatch = 512
|
||||
// BlackwellPhysicalBatch is the default on NVIDIA Blackwell consumer GPUs
|
||||
// (sm_12x: sm_120 RTX 50-series, sm_121 GB10 / DGX Spark). A larger physical
|
||||
// batch materially lifts MoE prefill there (per-expert GEMM tiles fill
|
||||
// better); measured on a GB10 with Qwen3-30B-A3B to saturate around 2048.
|
||||
BlackwellPhysicalBatch = 2048
|
||||
)
|
||||
|
||||
// IsNVIDIABlackwell reports whether the GPU is in the NVIDIA Blackwell consumer
|
||||
// family (sm_12x). Datacenter Blackwell (B100/B200/GB200, sm_100 / cc 10.0)
|
||||
// reports a different compute capability and is intentionally not matched.
|
||||
func (g GPU) IsNVIDIABlackwell() bool {
|
||||
maj, _ := parseComputeCapability(g.ComputeCapability)
|
||||
return maj >= 12
|
||||
}
|
||||
|
||||
// PhysicalBatch returns the canonical physical batch (n_batch/n_ubatch) for the
|
||||
// given hardware, used when the model config leaves batch unset.
|
||||
func PhysicalBatch(g GPU) int {
|
||||
if g.IsNVIDIABlackwell() {
|
||||
return BlackwellPhysicalBatch
|
||||
}
|
||||
return DefaultPhysicalBatch
|
||||
}
|
||||
|
||||
// IsManagedPhysicalBatch reports whether n is a value PhysicalBatch assigns.
|
||||
// Callers that re-tune a value chosen by an upstream host (the distributed
|
||||
// router correcting the frontend's guess) use this to avoid clobbering an
|
||||
// explicit user batch such as 1024.
|
||||
func IsManagedPhysicalBatch(n int) bool {
|
||||
return n == DefaultPhysicalBatch || n == BlackwellPhysicalBatch
|
||||
}
|
||||
|
||||
// Parallel-slot (n_parallel) VRAM tiers. llama.cpp serializes requests at
|
||||
// n_parallel=1 (the backend default) and only auto-enables continuous batching
|
||||
// when n_parallel > 1 — so a single-slot default makes concurrent requests
|
||||
// queue. We default a slot count by GPU size so multi-user serving works out of
|
||||
// the box. With the backend's unified KV cache the slots SHARE the context
|
||||
// budget, so more slots add concurrency without multiplying KV memory.
|
||||
const (
|
||||
parallelSlotsVRAMHigh = uint64(32) << 30 // >=32 GiB -> 8 slots
|
||||
parallelSlotsVRAMMid = uint64(8) << 30 // >=8 GiB -> 4 slots
|
||||
parallelSlotsVRAMLow = uint64(4) << 30 // >=4 GiB -> 2 slots
|
||||
)
|
||||
|
||||
// DefaultParallelSlots returns the n_parallel default for the given GPU. Returns
|
||||
// 1 (no concurrency) when VRAM is unknown or too small, so we never change
|
||||
// behavior on CPU-only / tiny devices.
|
||||
func DefaultParallelSlots(g GPU) int {
|
||||
switch {
|
||||
case g.VRAM >= parallelSlotsVRAMHigh:
|
||||
return 8
|
||||
case g.VRAM >= parallelSlotsVRAMMid:
|
||||
return 4
|
||||
case g.VRAM >= parallelSlotsVRAMLow:
|
||||
return 2
|
||||
default:
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
// EnsureParallelOption appends a VRAM-scaled "parallel:N" backend option when the
|
||||
// model doesn't already set one (and the GPU warrants concurrency). Returns the
|
||||
// possibly-extended options. Shared by the single-host config path
|
||||
// (ApplyHardwareDefaults) and the distributed router (per selected node).
|
||||
func EnsureParallelOption(opts []string, gpu GPU) []string {
|
||||
if slots := DefaultParallelSlots(gpu); slots > 1 && !hasParallelOption(opts) {
|
||||
return append(opts, fmt.Sprintf("parallel:%d", slots))
|
||||
}
|
||||
return opts
|
||||
}
|
||||
|
||||
// hasParallelOption reports whether the model already sets parallel/n_parallel
|
||||
// (backend options are "name:value" strings) so we never override an explicit value.
|
||||
func hasParallelOption(opts []string) bool {
|
||||
for _, o := range opts {
|
||||
name := o
|
||||
if i := strings.IndexByte(o, ':'); i >= 0 {
|
||||
name = o[:i]
|
||||
}
|
||||
switch strings.TrimSpace(strings.ToLower(name)) {
|
||||
case "parallel", "n_parallel":
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// localGPU builds a GPU descriptor from local detection, used by SetDefaults on
|
||||
// a single host (the distributed router builds it from the selected node's
|
||||
// reported info instead). It is a package var so tests can inject a
|
||||
// deterministic device — detection does a live nvidia-smi call.
|
||||
var localGPU = func() GPU {
|
||||
vendor, _ := xsysinfo.DetectGPUVendor()
|
||||
vram, _ := xsysinfo.TotalAvailableVRAM()
|
||||
return GPU{
|
||||
Vendor: vendor,
|
||||
ComputeCapability: xsysinfo.NVIDIAComputeCapability(),
|
||||
VRAM: vram,
|
||||
}
|
||||
}
|
||||
|
||||
// ApplyHardwareDefaults fills ModelConfig values that depend on the target GPU
|
||||
// and were left unset by the user. Currently: a larger physical batch on
|
||||
// Blackwell. Explicit config always wins (we only touch zero values).
|
||||
func ApplyHardwareDefaults(cfg *ModelConfig, gpu GPU) {
|
||||
if cfg == nil {
|
||||
return
|
||||
}
|
||||
if cfg.Batch == 0 && gpu.IsNVIDIABlackwell() {
|
||||
cfg.Batch = BlackwellPhysicalBatch
|
||||
xlog.Debug("[hardware_defaults] Blackwell GPU: defaulting physical batch",
|
||||
"batch", cfg.Batch, "compute_cap", gpu.ComputeCapability)
|
||||
}
|
||||
|
||||
// Enable concurrent serving by default on a capable GPU: without this the
|
||||
// llama.cpp backend runs n_parallel=1 and serializes multi-user requests
|
||||
// (continuous batching stays off). Unified KV means the slots share the
|
||||
// context budget, so this is concurrency without extra KV memory. Explicit
|
||||
// parallel/n_parallel in the model options always wins.
|
||||
if before := len(cfg.Options); true {
|
||||
cfg.Options = EnsureParallelOption(cfg.Options, gpu)
|
||||
if len(cfg.Options) > before {
|
||||
xlog.Debug("[hardware_defaults] defaulting parallel slots for concurrent serving",
|
||||
"option", cfg.Options[len(cfg.Options)-1], "vram_gib", gpu.VRAM>>30)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parseComputeCapability splits a "major.minor" string into integer parts.
|
||||
// Returns (-1, -1) when it can't be parsed.
|
||||
func parseComputeCapability(cc string) (int, int) {
|
||||
cc = strings.TrimSpace(cc)
|
||||
if cc == "" {
|
||||
return -1, -1
|
||||
}
|
||||
majStr, minStr := cc, "0"
|
||||
if dot := strings.IndexByte(cc, '.'); dot >= 0 {
|
||||
majStr, minStr = cc[:dot], cc[dot+1:]
|
||||
}
|
||||
maj, err := strconv.Atoi(strings.TrimSpace(majStr))
|
||||
if err != nil {
|
||||
return -1, -1
|
||||
}
|
||||
min, err := strconv.Atoi(strings.TrimSpace(minStr))
|
||||
if err != nil {
|
||||
min = 0
|
||||
}
|
||||
return maj, min
|
||||
}
|
||||
37
core/config/hardware_defaults_internal_test.go
Normal file
37
core/config/hardware_defaults_internal_test.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// Single-instance path: SetDefaults applies hardware defaults from the local
|
||||
// GPU. The detection seam (localGPU) is injected so the path is deterministic
|
||||
// without a real GPU.
|
||||
var _ = Describe("SetDefaults hardware defaults (single-instance)", func() {
|
||||
var orig func() GPU
|
||||
BeforeEach(func() { orig = localGPU })
|
||||
AfterEach(func() { localGPU = orig })
|
||||
|
||||
It("sets the physical batch on a local Blackwell GPU", func() {
|
||||
localGPU = func() GPU { return GPU{ComputeCapability: "12.1"} }
|
||||
cfg := &ModelConfig{}
|
||||
cfg.SetDefaults()
|
||||
Expect(cfg.Batch).To(Equal(BlackwellPhysicalBatch))
|
||||
})
|
||||
|
||||
It("leaves batch unset on a non-Blackwell local GPU", func() {
|
||||
localGPU = func() GPU { return GPU{ComputeCapability: "8.9"} }
|
||||
cfg := &ModelConfig{}
|
||||
cfg.SetDefaults()
|
||||
Expect(cfg.Batch).To(Equal(0))
|
||||
})
|
||||
|
||||
It("never overrides an explicit batch", func() {
|
||||
localGPU = func() GPU { return GPU{ComputeCapability: "12.1"} }
|
||||
cfg := &ModelConfig{}
|
||||
cfg.Batch = 1024
|
||||
cfg.SetDefaults()
|
||||
Expect(cfg.Batch).To(Equal(1024))
|
||||
})
|
||||
})
|
||||
97
core/config/hardware_defaults_test.go
Normal file
97
core/config/hardware_defaults_test.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package config_test
|
||||
|
||||
import (
|
||||
. "github.com/mudler/LocalAI/core/config"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Hardware-driven config defaults", func() {
|
||||
DescribeTable("GPU.IsNVIDIABlackwell (sm_12x consumer family)",
|
||||
func(cc string, want bool) {
|
||||
Expect(GPU{ComputeCapability: cc}.IsNVIDIABlackwell()).To(Equal(want))
|
||||
},
|
||||
Entry("GB10 12.1", "12.1", true),
|
||||
Entry("RTX 50 12.0", "12.0", true),
|
||||
Entry("future 13.0", "13.0", true),
|
||||
Entry("Hopper 9.0", "9.0", false),
|
||||
Entry("Ada 8.9", "8.9", false),
|
||||
Entry("datacenter Blackwell sm_100 10.0", "10.0", false),
|
||||
Entry("unknown", "", false),
|
||||
)
|
||||
|
||||
Describe("PhysicalBatch / IsManagedPhysicalBatch", func() {
|
||||
It("returns the Blackwell batch on Blackwell", func() {
|
||||
Expect(PhysicalBatch(GPU{ComputeCapability: "12.1"})).To(Equal(BlackwellPhysicalBatch))
|
||||
})
|
||||
It("returns the default batch otherwise", func() {
|
||||
Expect(PhysicalBatch(GPU{ComputeCapability: "9.0"})).To(Equal(DefaultPhysicalBatch))
|
||||
Expect(PhysicalBatch(GPU{})).To(Equal(DefaultPhysicalBatch))
|
||||
})
|
||||
It("recognizes managed defaults but not explicit values", func() {
|
||||
Expect(IsManagedPhysicalBatch(DefaultPhysicalBatch)).To(BeTrue())
|
||||
Expect(IsManagedPhysicalBatch(BlackwellPhysicalBatch)).To(BeTrue())
|
||||
Expect(IsManagedPhysicalBatch(1024)).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("ApplyHardwareDefaults", func() {
|
||||
It("raises an unset batch to 2048 on Blackwell", func() {
|
||||
cfg := &ModelConfig{}
|
||||
ApplyHardwareDefaults(cfg, GPU{ComputeCapability: "12.1"})
|
||||
Expect(cfg.Batch).To(Equal(BlackwellPhysicalBatch))
|
||||
})
|
||||
It("leaves batch unset on non-Blackwell", func() {
|
||||
cfg := &ModelConfig{}
|
||||
ApplyHardwareDefaults(cfg, GPU{ComputeCapability: "9.0"})
|
||||
Expect(cfg.Batch).To(Equal(0))
|
||||
})
|
||||
It("never overrides an explicit batch", func() {
|
||||
cfg := &ModelConfig{}
|
||||
cfg.Batch = 1024
|
||||
ApplyHardwareDefaults(cfg, GPU{ComputeCapability: "12.1"})
|
||||
Expect(cfg.Batch).To(Equal(1024))
|
||||
})
|
||||
It("no-ops on nil", func() {
|
||||
Expect(func() { ApplyHardwareDefaults(nil, GPU{ComputeCapability: "12.1"}) }).ToNot(Panic())
|
||||
})
|
||||
})
|
||||
|
||||
const gib = uint64(1) << 30
|
||||
|
||||
DescribeTable("DefaultParallelSlots (by VRAM)",
|
||||
func(vramGiB uint64, want int) {
|
||||
Expect(DefaultParallelSlots(GPU{VRAM: vramGiB * gib})).To(Equal(want))
|
||||
},
|
||||
Entry("GB10 119 GiB", uint64(119), 8),
|
||||
Entry("48 GiB", uint64(48), 8),
|
||||
Entry("24 GiB", uint64(24), 4),
|
||||
Entry("8 GiB", uint64(8), 4),
|
||||
Entry("6 GiB", uint64(6), 2),
|
||||
Entry("2 GiB", uint64(2), 1),
|
||||
Entry("unknown 0", uint64(0), 1),
|
||||
)
|
||||
|
||||
Describe("ApplyHardwareDefaults parallel slots", func() {
|
||||
It("adds a VRAM-scaled parallel option on a capable GPU", func() {
|
||||
cfg := &ModelConfig{}
|
||||
ApplyHardwareDefaults(cfg, GPU{ComputeCapability: "12.1", VRAM: 119 * gib})
|
||||
Expect(cfg.Options).To(ContainElement("parallel:8"))
|
||||
})
|
||||
It("scales the slot count down with VRAM", func() {
|
||||
cfg := &ModelConfig{}
|
||||
ApplyHardwareDefaults(cfg, GPU{VRAM: 24 * gib})
|
||||
Expect(cfg.Options).To(ContainElement("parallel:4"))
|
||||
})
|
||||
It("adds no parallel option on small/unknown VRAM", func() {
|
||||
cfg := &ModelConfig{}
|
||||
ApplyHardwareDefaults(cfg, GPU{VRAM: 2 * gib})
|
||||
Expect(cfg.Options).ToNot(ContainElement(ContainSubstring("parallel")))
|
||||
})
|
||||
It("never overrides an explicit parallel option", func() {
|
||||
cfg := &ModelConfig{Options: []string{"parallel:2"}}
|
||||
ApplyHardwareDefaults(cfg, GPU{VRAM: 119 * gib})
|
||||
Expect(cfg.Options).To(Equal([]string{"parallel:2"}))
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1111,6 +1111,11 @@ func (cfg *ModelConfig) SetDefaults(opts ...ConfigLoaderOption) {
|
||||
// This ensures gallery-installed and runtime-loaded models get optimal parameters.
|
||||
ApplyInferenceDefaults(cfg, cfg.Name, cfg.Model)
|
||||
|
||||
// Apply hardware-driven defaults (e.g. a larger physical batch on Blackwell).
|
||||
// Uses the local GPU here; in distributed mode the router re-applies the same
|
||||
// heuristics for the selected node's GPU before loading. Explicit config wins.
|
||||
ApplyHardwareDefaults(cfg, localGPU())
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/75cd4c77292034ecec587ecb401366f57338f7c0/common/sampling.h#L22
|
||||
defaultTopP := 0.95
|
||||
defaultTopK := 40
|
||||
|
||||
@@ -70,17 +70,20 @@ func GetNodeEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
||||
|
||||
// RegisterNodeRequest is the request body for registering a new worker node.
|
||||
type RegisterNodeRequest struct {
|
||||
Name string `json:"name"`
|
||||
NodeType string `json:"node_type,omitempty"` // "backend" (default) or "agent"
|
||||
Address string `json:"address"`
|
||||
HTTPAddress string `json:"http_address,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
TotalVRAM uint64 `json:"total_vram,omitempty"`
|
||||
AvailableVRAM uint64 `json:"available_vram,omitempty"`
|
||||
TotalRAM uint64 `json:"total_ram,omitempty"`
|
||||
AvailableRAM uint64 `json:"available_ram,omitempty"`
|
||||
GPUVendor string `json:"gpu_vendor,omitempty"`
|
||||
Labels map[string]string `json:"labels,omitempty"`
|
||||
Name string `json:"name"`
|
||||
NodeType string `json:"node_type,omitempty"` // "backend" (default) or "agent"
|
||||
Address string `json:"address"`
|
||||
HTTPAddress string `json:"http_address,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
TotalVRAM uint64 `json:"total_vram,omitempty"`
|
||||
AvailableVRAM uint64 `json:"available_vram,omitempty"`
|
||||
TotalRAM uint64 `json:"total_ram,omitempty"`
|
||||
AvailableRAM uint64 `json:"available_ram,omitempty"`
|
||||
GPUVendor string `json:"gpu_vendor,omitempty"`
|
||||
// GPUComputeCapability is the worker GPU's compute capability ("major.minor",
|
||||
// e.g. "12.1" for GB10). Used by the router for per-arch option tuning.
|
||||
GPUComputeCapability string `json:"gpu_compute_capability,omitempty"`
|
||||
Labels map[string]string `json:"labels,omitempty"`
|
||||
// MaxReplicasPerModel is the per-node cap on replicas of any single model.
|
||||
// Workers older than this field omit it; we coerce 0 → 1 below to preserve
|
||||
// historical single-replica behavior.
|
||||
@@ -152,17 +155,18 @@ func RegisterNodeEndpoint(registry *nodes.NodeRegistry, expectedToken string, au
|
||||
}
|
||||
|
||||
node := &nodes.BackendNode{
|
||||
Name: req.Name,
|
||||
NodeType: nodeType,
|
||||
Address: req.Address,
|
||||
HTTPAddress: req.HTTPAddress,
|
||||
TokenHash: tokenHash,
|
||||
TotalVRAM: req.TotalVRAM,
|
||||
AvailableVRAM: req.AvailableVRAM,
|
||||
TotalRAM: req.TotalRAM,
|
||||
AvailableRAM: req.AvailableRAM,
|
||||
GPUVendor: req.GPUVendor,
|
||||
MaxReplicasPerModel: maxReplicasPerModel,
|
||||
Name: req.Name,
|
||||
NodeType: nodeType,
|
||||
Address: req.Address,
|
||||
HTTPAddress: req.HTTPAddress,
|
||||
TokenHash: tokenHash,
|
||||
TotalVRAM: req.TotalVRAM,
|
||||
AvailableVRAM: req.AvailableVRAM,
|
||||
TotalRAM: req.TotalRAM,
|
||||
AvailableRAM: req.AvailableRAM,
|
||||
GPUVendor: req.GPUVendor,
|
||||
GPUComputeCapability: req.GPUComputeCapability,
|
||||
MaxReplicasPerModel: maxReplicasPerModel,
|
||||
}
|
||||
|
||||
ctx := c.Request().Context()
|
||||
|
||||
@@ -113,8 +113,13 @@ func (t *WebRTCTransport) sendLoop() {
|
||||
return
|
||||
}
|
||||
if err := t.dc.SendText(string(data)); err != nil {
|
||||
xlog.Error("data channel send failed", "error", err)
|
||||
return
|
||||
// Drop just this event and keep the loop alive: a single
|
||||
// failed send (e.g. an event over the negotiated SCTP
|
||||
// max-message-size) must not tear down the session and
|
||||
// silently drop every subsequent event. A genuinely dead
|
||||
// transport is handled by the <-t.closed case.
|
||||
xlog.Error("data channel send failed, dropping event", "error", err)
|
||||
continue
|
||||
}
|
||||
case <-t.closed:
|
||||
// Drain any remaining queued events before exiting
|
||||
@@ -122,7 +127,8 @@ func (t *WebRTCTransport) sendLoop() {
|
||||
select {
|
||||
case data := <-t.outEvents:
|
||||
if err := t.dc.SendText(string(data)); err != nil {
|
||||
return
|
||||
xlog.Error("data channel send failed while draining, dropping event", "error", err)
|
||||
continue
|
||||
}
|
||||
default:
|
||||
return
|
||||
|
||||
@@ -128,10 +128,13 @@ func RealtimeCalls(application *application.Application) echo.HandlerFunc {
|
||||
handleIncomingAudioTrack(track, transport)
|
||||
})
|
||||
|
||||
// Set the remote SDP (client's offer)
|
||||
// Set the remote SDP (client's offer). Raise the data-channel
|
||||
// max-message-size the browser advertised so pion permits the larger
|
||||
// realtime events some turns produce (e.g. tool calls), which would
|
||||
// otherwise be dropped on send. See realtime_webrtc_sctp.go.
|
||||
if err := pc.SetRemoteDescription(webrtc.SessionDescription{
|
||||
Type: webrtc.SDPTypeOffer,
|
||||
SDP: req.SDP,
|
||||
SDP: raiseDataChannelMaxMessageSize(req.SDP),
|
||||
}); err != nil {
|
||||
transport.Close()
|
||||
xlog.Error("failed to set remote description", "error", err)
|
||||
|
||||
29
core/http/endpoints/openai/realtime_webrtc_sctp.go
Normal file
29
core/http/endpoints/openai/realtime_webrtc_sctp.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
// realtimeDataChannelMaxMessageSize is the SCTP max-message-size LocalAI honors
|
||||
// for the "oai-events" data channel, in bytes.
|
||||
//
|
||||
// Browsers advertise a conservative max-message-size in their SDP offer (Chrome
|
||||
// uses 262144 = 256 KiB). pion enforces the remote's advertised value on send,
|
||||
// so a single realtime event larger than it cannot be sent: the SendText fails,
|
||||
// the event is dropped, and the turn silently yields no response. Some turns
|
||||
// legitimately produce a single JSON event above 256 KiB (notably tool calls
|
||||
// with sizeable schemas or results). Browsers advertise this value
|
||||
// conservatively but their SCTP stacks reassemble much larger messages, so we
|
||||
// raise the value honored for our own server-generated events.
|
||||
const realtimeDataChannelMaxMessageSize = 16 * 1024 * 1024 // 16 MiB
|
||||
|
||||
var maxMessageSizeAttrRe = regexp.MustCompile(`a=max-message-size:\d+`)
|
||||
|
||||
// raiseDataChannelMaxMessageSize rewrites the SCTP max-message-size attribute in
|
||||
// an SDP offer to realtimeDataChannelMaxMessageSize so pion permits larger
|
||||
// outbound realtime events. Offers that don't carry the attribute are returned
|
||||
// unchanged.
|
||||
func raiseDataChannelMaxMessageSize(sdp string) string {
|
||||
return maxMessageSizeAttrRe.ReplaceAllString(sdp, fmt.Sprintf("a=max-message-size:%d", realtimeDataChannelMaxMessageSize))
|
||||
}
|
||||
33
core/http/endpoints/openai/realtime_webrtc_sctp_test.go
Normal file
33
core/http/endpoints/openai/realtime_webrtc_sctp_test.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("raiseDataChannelMaxMessageSize", func() {
|
||||
It("raises a max-message-size the browser advertised", func() {
|
||||
offer := "v=0\r\nm=application 9 UDP/DTLS/SCTP webrtc-datachannel\r\na=max-message-size:262144\r\n"
|
||||
out := raiseDataChannelMaxMessageSize(offer)
|
||||
Expect(out).To(ContainSubstring(fmt.Sprintf("a=max-message-size:%d", realtimeDataChannelMaxMessageSize)))
|
||||
Expect(out).NotTo(ContainSubstring("a=max-message-size:262144"))
|
||||
})
|
||||
|
||||
It("leaves an offer without the attribute unchanged", func() {
|
||||
offer := "v=0\r\nm=application 9 UDP/DTLS/SCTP webrtc-datachannel\r\n"
|
||||
Expect(raiseDataChannelMaxMessageSize(offer)).To(Equal(offer))
|
||||
})
|
||||
|
||||
It("rewrites every occurrence", func() {
|
||||
offer := "a=max-message-size:1024\r\na=max-message-size:262144\r\n"
|
||||
out := raiseDataChannelMaxMessageSize(offer)
|
||||
Expect(strings.Count(out, fmt.Sprintf("a=max-message-size:%d", realtimeDataChannelMaxMessageSize))).To(Equal(2))
|
||||
})
|
||||
|
||||
It("raises above the 256 KiB browsers advertise", func() {
|
||||
Expect(realtimeDataChannelMaxMessageSize).To(BeNumerically(">", 262144))
|
||||
})
|
||||
})
|
||||
@@ -1,4 +1,9 @@
|
||||
{
|
||||
"unsaved": {
|
||||
"title": "Discard unsaved changes?",
|
||||
"message": "You have unsaved changes that will be lost if you leave this page.",
|
||||
"leave": "Leave"
|
||||
},
|
||||
"actions": {
|
||||
"save": "Speichern",
|
||||
"saving": "Speichern...",
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
{
|
||||
"unsaved": {
|
||||
"title": "Discard unsaved changes?",
|
||||
"message": "You have unsaved changes that will be lost if you leave this page.",
|
||||
"leave": "Leave"
|
||||
},
|
||||
"actions": {
|
||||
"save": "Save",
|
||||
"saving": "Saving...",
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
{
|
||||
"unsaved": {
|
||||
"title": "Discard unsaved changes?",
|
||||
"message": "You have unsaved changes that will be lost if you leave this page.",
|
||||
"leave": "Leave"
|
||||
},
|
||||
"actions": {
|
||||
"save": "Guardar",
|
||||
"saving": "Guardando...",
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
{
|
||||
"unsaved": {
|
||||
"title": "Discard unsaved changes?",
|
||||
"message": "You have unsaved changes that will be lost if you leave this page.",
|
||||
"leave": "Leave"
|
||||
},
|
||||
"actions": {
|
||||
"save": "Simpan",
|
||||
"saving": "Menyimpan...",
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
{
|
||||
"unsaved": {
|
||||
"title": "Scartare le modifiche non salvate?",
|
||||
"message": "Hai modifiche non salvate che andranno perse se esci da questa pagina.",
|
||||
"leave": "Esci"
|
||||
},
|
||||
"actions": {
|
||||
"save": "Salva",
|
||||
"saving": "Salvataggio...",
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
{
|
||||
"unsaved": {
|
||||
"title": "Discard unsaved changes?",
|
||||
"message": "You have unsaved changes that will be lost if you leave this page.",
|
||||
"leave": "Leave"
|
||||
},
|
||||
"actions": {
|
||||
"save": "저장",
|
||||
"saving": "저장 중...",
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
{
|
||||
"unsaved": {
|
||||
"title": "Discard unsaved changes?",
|
||||
"message": "You have unsaved changes that will be lost if you leave this page.",
|
||||
"leave": "Leave"
|
||||
},
|
||||
"actions": {
|
||||
"save": "保存",
|
||||
"saving": "保存中...",
|
||||
|
||||
@@ -2427,6 +2427,40 @@ select.input {
|
||||
border-radius: var(--radius-lg);
|
||||
}
|
||||
|
||||
/* ResponsiveTable: stack dense tables into label/value cards on narrow screens
|
||||
instead of a sideways scroll. Labels come from data-label (mirrored from the
|
||||
<thead> by the ResponsiveTable component). */
|
||||
@media (max-width: 640px) {
|
||||
/* Direct-child selectors only: a nested table inside a cell renders normally.
|
||||
min-width override defeats any inline min-width set for the desktop layout. */
|
||||
.table--responsive { border: none; min-width: 0 !important; }
|
||||
.table--responsive > thead { display: none; }
|
||||
.table--responsive > tbody > tr {
|
||||
display: block;
|
||||
border: 1px solid var(--color-border-subtle);
|
||||
border-radius: var(--radius-md);
|
||||
background: var(--color-surface-raised);
|
||||
margin: var(--spacing-sm);
|
||||
padding: var(--spacing-xs) var(--spacing-sm);
|
||||
}
|
||||
.table--responsive > tbody > tr > td {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
gap: var(--spacing-md);
|
||||
border: none;
|
||||
padding: var(--spacing-xs) 0;
|
||||
text-align: right;
|
||||
}
|
||||
.table--responsive > tbody > tr > td[data-label]::before {
|
||||
content: attr(data-label);
|
||||
font-weight: var(--font-weight-semibold);
|
||||
color: var(--color-text-muted);
|
||||
text-align: left;
|
||||
margin-right: auto;
|
||||
}
|
||||
}
|
||||
|
||||
.table {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
|
||||
40
core/http/react-ui/src/components/ResponsiveTable.jsx
Normal file
40
core/http/react-ui/src/components/ResponsiveTable.jsx
Normal file
@@ -0,0 +1,40 @@
|
||||
import { useRef, useEffect } from 'react'
|
||||
|
||||
// Wraps a standard .table and makes it reflow into stacked label/value cards on
|
||||
// narrow screens. Column labels are derived from the <thead> and mirrored onto
|
||||
// each body cell via data-label (read by CSS ::before in the mobile layout), so
|
||||
// any table becomes responsive without hand-labelling every <td>.
|
||||
export default function ResponsiveTable({ children, className = '', style, containerStyle }) {
|
||||
const ref = useRef(null)
|
||||
|
||||
useEffect(() => {
|
||||
const table = ref.current
|
||||
if (!table) return
|
||||
const apply = () => {
|
||||
// Direct children only, so a nested table inside a cell is left alone.
|
||||
const heads = [...table.querySelectorAll(':scope > thead > tr > th')].map(th => th.textContent.trim())
|
||||
table.querySelectorAll(':scope > tbody > tr').forEach(tr => {
|
||||
const cells = [...tr.children]
|
||||
// Skip detail/expansion rows (a single cell spanning the table).
|
||||
if (cells.length === 1 && cells[0].colSpan > 1) return
|
||||
cells.forEach((td, i) => {
|
||||
if (heads[i]) td.setAttribute('data-label', heads[i])
|
||||
})
|
||||
})
|
||||
}
|
||||
apply()
|
||||
// Re-apply when rows change (sort, paging, live data). setAttribute touches
|
||||
// attributes only, so a childList/subtree observer won't retrigger itself.
|
||||
const obs = new MutationObserver(apply)
|
||||
obs.observe(table, { childList: true, subtree: true })
|
||||
return () => obs.disconnect()
|
||||
}, [])
|
||||
|
||||
return (
|
||||
<div className="table-container" style={containerStyle}>
|
||||
<table ref={ref} className={`table table--responsive ${className}`.trim()} style={style}>
|
||||
{children}
|
||||
</table>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
36
core/http/react-ui/src/components/UnsavedChangesGuard.jsx
Normal file
36
core/http/react-ui/src/components/UnsavedChangesGuard.jsx
Normal file
@@ -0,0 +1,36 @@
|
||||
import { useEffect, useCallback } from 'react'
|
||||
import { useBlocker } from 'react-router-dom'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import ConfirmDialog from './ConfirmDialog'
|
||||
|
||||
// Guards against losing unsaved work: blocks in-app route changes (via the
|
||||
// router's useBlocker) and warns on tab close/reload (beforeunload) whenever
|
||||
// `when` is true. Drop into any page that has a dirty-state signal.
|
||||
export default function UnsavedChangesGuard({ when }) {
|
||||
const { t } = useTranslation('common')
|
||||
const blocker = useBlocker(
|
||||
useCallback(
|
||||
({ currentLocation, nextLocation }) => when && currentLocation.pathname !== nextLocation.pathname,
|
||||
[when]
|
||||
)
|
||||
)
|
||||
|
||||
useEffect(() => {
|
||||
if (!when) return
|
||||
const handler = (e) => { e.preventDefault(); e.returnValue = '' }
|
||||
window.addEventListener('beforeunload', handler)
|
||||
return () => window.removeEventListener('beforeunload', handler)
|
||||
}, [when])
|
||||
|
||||
return (
|
||||
<ConfirmDialog
|
||||
open={blocker.state === 'blocked'}
|
||||
title={t('unsaved.title')}
|
||||
message={t('unsaved.message')}
|
||||
confirmLabel={t('unsaved.leave')}
|
||||
danger
|
||||
onConfirm={() => blocker.proceed?.()}
|
||||
onCancel={() => blocker.reset?.()}
|
||||
/>
|
||||
)
|
||||
}
|
||||
@@ -1,8 +1,9 @@
|
||||
import { useState, useEffect, useMemo } from 'react'
|
||||
import { useState, useEffect, useMemo, useRef } from 'react'
|
||||
import { useParams, useNavigate, useLocation, useOutletContext, useSearchParams } from 'react-router-dom'
|
||||
import { agentsApi, skillsApi } from '../utils/api'
|
||||
import SearchableModelSelect from '../components/SearchableModelSelect'
|
||||
import PageHeader from '../components/PageHeader'
|
||||
import UnsavedChangesGuard from '../components/UnsavedChangesGuard'
|
||||
import { CAP_CHAT, CAP_TRANSCRIPT, CAP_TTS } from '../utils/capabilities'
|
||||
import Toggle from '../components/Toggle'
|
||||
import SettingRow from '../components/SettingRow'
|
||||
@@ -296,6 +297,8 @@ export default function AgentCreate() {
|
||||
const [activeSection, setActiveSection] = useState('BasicInfo')
|
||||
const [meta, setMeta] = useState(null)
|
||||
const [form, setForm] = useState({})
|
||||
// Snapshot of the form as first loaded, for the unsaved-changes guard.
|
||||
const initialFormRef = useRef(null)
|
||||
const [connectors, setConnectors] = useState([])
|
||||
const [actions, setActions] = useState([])
|
||||
const [filters, setFilters] = useState([])
|
||||
@@ -374,6 +377,7 @@ export default function AgentCreate() {
|
||||
if (Array.isArray(sourceConfig.selected_skills)) setSelectedSkills(sourceConfig.selected_skills)
|
||||
}
|
||||
|
||||
initialFormRef.current = initialForm
|
||||
setForm(initialForm)
|
||||
} catch (err) {
|
||||
addToast(`Failed to load configuration: ${err.message}`, 'error')
|
||||
@@ -819,8 +823,12 @@ export default function AgentCreate() {
|
||||
)
|
||||
}
|
||||
|
||||
const dirty = initialFormRef.current != null &&
|
||||
JSON.stringify(form) !== JSON.stringify(initialFormRef.current)
|
||||
|
||||
return (
|
||||
<div className="page page--narrow">
|
||||
<UnsavedChangesGuard when={dirty && !saving} />
|
||||
<style>{`
|
||||
.agent-form-container {
|
||||
display: flex;
|
||||
|
||||
@@ -2,6 +2,7 @@ import { useState, useEffect, useRef, useCallback } from 'react'
|
||||
import { fineTuneApi } from '../utils/api'
|
||||
import LoadingSpinner from '../components/LoadingSpinner'
|
||||
import PageHeader from '../components/PageHeader'
|
||||
import UnsavedChangesGuard from '../components/UnsavedChangesGuard'
|
||||
|
||||
const TRAINING_METHODS = ['sft', 'dpo', 'grpo', 'rloo', 'reward', 'kto', 'orpo']
|
||||
const TRAINING_TYPES = ['lora', 'loha', 'lokr', 'full']
|
||||
@@ -705,6 +706,8 @@ export default function FineTune() {
|
||||
const [error, setError] = useState('')
|
||||
const [backends, setBackends] = useState([])
|
||||
const [exportCheckpoint, setExportCheckpoint] = useState(null)
|
||||
// Baseline of the assembled config for the unsaved-changes guard.
|
||||
const initialConfigRef = useRef(null)
|
||||
|
||||
// Form state
|
||||
const [model, setModel] = useState('')
|
||||
@@ -845,6 +848,8 @@ export default function FineTune() {
|
||||
const resp = await fineTuneApi.startJob(req)
|
||||
setShowForm(false)
|
||||
setResumeFromCheckpoint('')
|
||||
// Job submitted: rebaseline so leaving the page no longer warns.
|
||||
initialConfigRef.current = JSON.stringify(getFormConfig())
|
||||
await loadJobs()
|
||||
|
||||
const newJob = { ...req, id: resp.id, status: 'queued', created_at: new Date().toISOString() }
|
||||
@@ -1057,8 +1062,13 @@ export default function FineTune() {
|
||||
setExportCheckpoint(checkpoint)
|
||||
}
|
||||
|
||||
// Lazy-init the baseline on first render; dirty when the open form diverges.
|
||||
if (initialConfigRef.current === null) initialConfigRef.current = JSON.stringify(getFormConfig())
|
||||
const dirty = JSON.stringify(getFormConfig()) !== initialConfigRef.current
|
||||
|
||||
return (
|
||||
<div className="page page--wide">
|
||||
<UnsavedChangesGuard when={dirty && showForm && !loading} />
|
||||
<PageHeader
|
||||
title={<>Fine-Tuning <span className="badge badge-warning" style={{ fontSize: '0.45em', verticalAlign: 'middle' }}>Experimental</span></>}
|
||||
supporting="Create and manage fine-tuning jobs"
|
||||
|
||||
@@ -12,6 +12,7 @@ import ManageSummary from '../components/ManageSummary'
|
||||
import MetaBadgeRow from '../components/MetaBadgeRow'
|
||||
import ActionMenu from '../components/ActionMenu'
|
||||
import ResourceRow, { ChevronCell, IconCell, StopPropagationCell } from '../components/ResourceRow'
|
||||
import ResponsiveTable from '../components/ResponsiveTable'
|
||||
import { useModels } from '../hooks/useModels'
|
||||
import { useGalleryEnrichment } from '../hooks/useGalleryEnrichment'
|
||||
import { useOperations } from '../hooks/useOperations'
|
||||
@@ -560,8 +561,7 @@ export default function Manage() {
|
||||
<button className="btn btn-ghost btn-sm" onClick={() => { setModelsSearch(''); setModelsFilter('all') }}>Clear filters</button>
|
||||
</div>
|
||||
) : (
|
||||
<div className="table-container">
|
||||
<table className="table">
|
||||
<ResponsiveTable>
|
||||
<thead>
|
||||
<tr>
|
||||
<th style={{ width: 30 }}></th>
|
||||
@@ -686,8 +686,7 @@ export default function Manage() {
|
||||
)
|
||||
})}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</ResponsiveTable>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
@@ -855,8 +854,7 @@ export default function Manage() {
|
||||
return (
|
||||
<>
|
||||
{filterBar}
|
||||
<div className="table-container">
|
||||
<table className="table">
|
||||
<ResponsiveTable>
|
||||
<thead>
|
||||
<tr>
|
||||
<th style={{ width: 30 }}></th>
|
||||
@@ -987,8 +985,7 @@ export default function Manage() {
|
||||
)
|
||||
})}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</ResponsiveTable>
|
||||
</>
|
||||
)
|
||||
})()}
|
||||
|
||||
@@ -12,6 +12,7 @@ import PageHeader from '../components/PageHeader'
|
||||
import ConfirmDialog from '../components/ConfirmDialog'
|
||||
import GalleryLoader from '../components/GalleryLoader'
|
||||
import Toggle from '../components/Toggle'
|
||||
import ResponsiveTable from '../components/ResponsiveTable'
|
||||
import React from 'react'
|
||||
|
||||
|
||||
@@ -389,9 +390,7 @@ export default function Models() {
|
||||
)}
|
||||
</div>
|
||||
) : (
|
||||
<div className="table-container" style={{ background: 'var(--color-bg-secondary)', borderRadius: 'var(--radius-lg)', overflow: 'hidden' }}>
|
||||
<div style={{ overflowX: 'auto' }}>
|
||||
<table className="table" style={{ minWidth: '800px' }}>
|
||||
<ResponsiveTable containerStyle={{ background: 'var(--color-bg-secondary)', borderRadius: 'var(--radius-lg)', overflow: 'hidden' }} style={{ minWidth: '800px' }}>
|
||||
<thead>
|
||||
<tr>
|
||||
<th style={{ width: '30px' }}></th>
|
||||
@@ -575,9 +574,7 @@ export default function Models() {
|
||||
)
|
||||
})}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
</ResponsiveTable>
|
||||
)}
|
||||
|
||||
{/* Pagination */}
|
||||
|
||||
@@ -9,6 +9,7 @@ import ActionMenu from '../components/ActionMenu'
|
||||
import SearchableModelSelect from '../components/SearchableModelSelect'
|
||||
import ImageSelector, { useImageSelector, dockerImage, dockerFlags } from '../components/ImageSelector'
|
||||
import StatCard from '../components/StatCard'
|
||||
import ResponsiveTable from '../components/ResponsiveTable'
|
||||
|
||||
function timeAgo(dateString) {
|
||||
if (!dateString) return 'never'
|
||||
@@ -1086,8 +1087,7 @@ export default function Nodes() {
|
||||
|
||||
{/* Node table */}
|
||||
{filteredNodes.length > 0 && (
|
||||
<div className="table-container">
|
||||
<table className="table">
|
||||
<ResponsiveTable>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Name</th>
|
||||
@@ -1533,8 +1533,7 @@ export default function Nodes() {
|
||||
)
|
||||
})}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</ResponsiveTable>
|
||||
)}
|
||||
</>}
|
||||
|
||||
@@ -1560,8 +1559,7 @@ export default function Nodes() {
|
||||
No scheduling rules configured. Add a rule to control how models are placed on nodes.
|
||||
</p>
|
||||
) : schedulingConfigs.length > 0 && (
|
||||
<div className="table-container">
|
||||
<table className="table">
|
||||
<ResponsiveTable>
|
||||
<thead><tr>
|
||||
<th>Model</th>
|
||||
<th>Mode</th>
|
||||
@@ -1667,8 +1665,7 @@ export default function Nodes() {
|
||||
)
|
||||
})}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</ResponsiveTable>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -5,6 +5,7 @@ import { settingsApi, resourcesApi, brandingApi } from '../utils/api'
|
||||
import { useBranding } from '../contexts/BrandingContext'
|
||||
import LoadingSpinner from '../components/LoadingSpinner'
|
||||
import PageHeader from '../components/PageHeader'
|
||||
import UnsavedChangesGuard from '../components/UnsavedChangesGuard'
|
||||
import SearchableModelSelect from '../components/SearchableModelSelect'
|
||||
import { CAP_CHAT } from '../utils/capabilities'
|
||||
import Toggle from '../components/Toggle'
|
||||
@@ -159,6 +160,7 @@ export default function Settings() {
|
||||
|
||||
return (
|
||||
<div className="page page--medium" style={{ padding: 0 }}>
|
||||
<UnsavedChangesGuard when={isDirty} />
|
||||
{/* Header */}
|
||||
<div style={{ padding: 'var(--spacing-lg) var(--spacing-lg) 0' }}>
|
||||
<PageHeader
|
||||
|
||||
@@ -5,6 +5,7 @@ import { tracesApi, settingsApi } from '../utils/api'
|
||||
import { formatTimestamp } from '../utils/format'
|
||||
import LoadingSpinner from '../components/LoadingSpinner'
|
||||
import PageHeader from '../components/PageHeader'
|
||||
import ResponsiveTable from '../components/ResponsiveTable'
|
||||
import Toggle from '../components/Toggle'
|
||||
import SettingRow from '../components/SettingRow'
|
||||
import WaveformPlayer from '../components/audio/WaveformPlayer'
|
||||
@@ -317,7 +318,35 @@ export default function Traces() {
|
||||
const [backendCount, setBackendCount] = useState(0)
|
||||
const [loading, setLoading] = useState(true)
|
||||
const [expandedRow, setExpandedRow] = useState(null)
|
||||
const [sort, setSort] = useState({ key: null, dir: 'asc' })
|
||||
const [tracingEnabled, setTracingEnabled] = useState(null)
|
||||
|
||||
const TRACE_SORT = {
|
||||
method: (a, b) => (a.request?.method || '').localeCompare(b.request?.method || ''),
|
||||
path: (a, b) => (a.request?.path || '').localeCompare(b.request?.path || ''),
|
||||
status: (a, b) => (a.response?.status || 0) - (b.response?.status || 0),
|
||||
type: (a, b) => (a.type || '').localeCompare(b.type || ''),
|
||||
time: (a, b) => new Date(a.timestamp || 0) - new Date(b.timestamp || 0),
|
||||
model: (a, b) => (a.model_name || '').localeCompare(b.model_name || ''),
|
||||
duration: (a, b) => (a.duration || 0) - (b.duration || 0),
|
||||
}
|
||||
const toggleSort = (key) => {
|
||||
setExpandedRow(null)
|
||||
setSort(s => s.key === key ? { key, dir: s.dir === 'asc' ? 'desc' : 'asc' } : { key, dir: 'asc' })
|
||||
}
|
||||
const sortableTh = (key, label, props = {}) => (
|
||||
<th
|
||||
{...props}
|
||||
role="button"
|
||||
tabIndex={0}
|
||||
aria-sort={sort.key === key ? (sort.dir === 'asc' ? 'ascending' : 'descending') : 'none'}
|
||||
onClick={() => toggleSort(key)}
|
||||
onKeyDown={(e) => { if (e.key === 'Enter' || e.key === ' ') { e.preventDefault(); toggleSort(key) } }}
|
||||
style={{ cursor: 'pointer', userSelect: 'none', ...(props.style || {}) }}
|
||||
>
|
||||
{label}{sort.key === key && <i className={`fas fa-caret-${sort.dir === 'asc' ? 'up' : 'down'}`} style={{ marginLeft: 4, opacity: 0.7 }} aria-hidden="true" />}
|
||||
</th>
|
||||
)
|
||||
const [backendLoggingEnabled, setBackendLoggingEnabled] = useState(null)
|
||||
const [settings, setSettings] = useState(null)
|
||||
const [settingsExpanded, setSettingsExpanded] = useState(false)
|
||||
@@ -406,6 +435,13 @@ export default function Traces() {
|
||||
URL.revokeObjectURL(url)
|
||||
}
|
||||
|
||||
// Reset sort + expansion when switching trace tabs (columns differ).
|
||||
useEffect(() => { setSort({ key: null, dir: 'asc' }); setExpandedRow(null) }, [activeTab])
|
||||
|
||||
const sortedTraces = sort.key && TRACE_SORT[sort.key]
|
||||
? [...traces].sort((a, b) => sort.dir === 'asc' ? TRACE_SORT[sort.key](a, b) : TRACE_SORT[sort.key](b, a))
|
||||
: traces
|
||||
|
||||
return (
|
||||
<div className="page page--wide">
|
||||
<PageHeader title={t('traces.title')} supporting={t('traces.subtitle')} />
|
||||
@@ -537,19 +573,18 @@ export default function Traces() {
|
||||
</p>
|
||||
</div>
|
||||
) : activeTab === 'api' ? (
|
||||
<div className="table-container">
|
||||
<table className="table">
|
||||
<ResponsiveTable>
|
||||
<thead>
|
||||
<tr>
|
||||
<th style={{ width: '30px' }}></th>
|
||||
<th>Method</th>
|
||||
<th>Path</th>
|
||||
<th>Status</th>
|
||||
{sortableTh('method', 'Method')}
|
||||
{sortableTh('path', 'Path')}
|
||||
{sortableTh('status', 'Status')}
|
||||
<th style={{ width: '40px' }}>Result</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{traces.map((trace, i) => (
|
||||
{sortedTraces.map((trace, i) => (
|
||||
<React.Fragment key={i}>
|
||||
<tr onClick={() => setExpandedRow(expandedRow === i ? null : i)} style={{ cursor: 'pointer' }}>
|
||||
<td><i className={`fas fa-chevron-${expandedRow === i ? 'down' : 'right'}`} style={{ fontSize: '0.7rem' }} /></td>
|
||||
@@ -572,24 +607,22 @@ export default function Traces() {
|
||||
</React.Fragment>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</ResponsiveTable>
|
||||
) : (
|
||||
<div className="table-container">
|
||||
<table className="table">
|
||||
<ResponsiveTable>
|
||||
<thead>
|
||||
<tr>
|
||||
<th style={{ width: '30px' }}></th>
|
||||
<th>Type</th>
|
||||
<th>Time</th>
|
||||
<th>Model</th>
|
||||
{sortableTh('type', 'Type')}
|
||||
{sortableTh('time', 'Time')}
|
||||
{sortableTh('model', 'Model')}
|
||||
<th>Summary</th>
|
||||
<th>Duration</th>
|
||||
{sortableTh('duration', 'Duration')}
|
||||
<th style={{ width: '40px' }}>Status</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{traces.map((trace, i) => (
|
||||
{sortedTraces.map((trace, i) => (
|
||||
<React.Fragment key={i}>
|
||||
<tr onClick={() => setExpandedRow(expandedRow === i ? null : i)} style={{ cursor: 'pointer' }}>
|
||||
<td><i className={`fas fa-chevron-${expandedRow === i ? 'down' : 'right'}`} style={{ fontSize: '0.7rem' }} /></td>
|
||||
@@ -616,8 +649,7 @@ export default function Traces() {
|
||||
</React.Fragment>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</ResponsiveTable>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
|
||||
@@ -413,7 +413,7 @@ function UsageTimeChart({ data, predictedData, period }) {
|
||||
<span style={{ fontSize: '0.875rem', fontWeight: 600, color: 'var(--color-text-primary)' }}>Tokens over time</span>
|
||||
<div style={{ display: 'flex', gap: 'var(--spacing-md)', fontSize: '0.6875rem', color: 'var(--color-text-muted)' }}>
|
||||
<span><span style={{ display: 'inline-block', width: 8, height: 8, borderRadius: "var(--radius-sm)", background: 'var(--color-primary)', marginRight: 4, verticalAlign: 'middle' }} />Prompt</span>
|
||||
<span><span style={{ display: 'inline-block', width: 8, height: 8, borderRadius: "var(--radius-sm)", background: 'var(--color-primary)', opacity: 0.35, marginRight: 4, verticalAlign: 'middle' }} />Completion</span>
|
||||
<span><span style={{ display: 'inline-block', width: 8, height: 8, borderRadius: "var(--radius-sm)", background: 'var(--color-data-3)', marginRight: 4, verticalAlign: 'middle' }} />Completion</span>
|
||||
{predictedData && predictedData.length > 0 && (
|
||||
<span>
|
||||
<span style={{
|
||||
@@ -472,7 +472,7 @@ function UsageTimeChart({ data, predictedData, period }) {
|
||||
{/* Prompt tokens (bottom) */}
|
||||
<rect x={x} y={chartH - promptH - compH} width={barWidth} height={promptH} fill="var(--color-primary)" rx={2} />
|
||||
{/* Completion tokens (top) */}
|
||||
<rect x={x} y={chartH - compH} width={barWidth} height={compH} fill="var(--color-primary)" opacity={0.35} rx={2} />
|
||||
<rect x={x} y={chartH - compH} width={barWidth} height={compH} fill="var(--color-data-3)" rx={2} />
|
||||
</g>
|
||||
)
|
||||
})}
|
||||
@@ -571,7 +571,7 @@ function UsageTimeChart({ data, predictedData, period }) {
|
||||
{formatBucket(tooltip.data.bucket, period)}
|
||||
</div>
|
||||
<div><span style={{ color: 'var(--color-primary)' }}>Prompt:</span> {tooltip.predicted ? '~' : ''}{tooltip.data.prompt_tokens.toLocaleString()}</div>
|
||||
<div><span style={{ color: 'var(--color-text-secondary)' }}>Completion:</span> {tooltip.predicted ? '~' : ''}{tooltip.data.completion_tokens.toLocaleString()}</div>
|
||||
<div><span style={{ color: 'var(--color-data-3)' }}>Completion:</span> {tooltip.predicted ? '~' : ''}{tooltip.data.completion_tokens.toLocaleString()}</div>
|
||||
<div style={{ color: 'var(--color-text-muted)', borderTop: '1px solid var(--color-border)', marginTop: 2, paddingTop: 2 }}>
|
||||
{tooltip.predicted ? '~' : ''}{tooltip.data.request_count} requests
|
||||
</div>
|
||||
@@ -596,7 +596,7 @@ function ModelDistChart({ rows }) {
|
||||
<span style={{ fontSize: '0.875rem', fontWeight: 600, color: 'var(--color-text-primary)' }}>Token distribution by model</span>
|
||||
<div style={{ display: 'flex', gap: 'var(--spacing-md)', fontSize: '0.6875rem', color: 'var(--color-text-muted)' }}>
|
||||
<span><span style={{ display: 'inline-block', width: 8, height: 8, borderRadius: "var(--radius-sm)", background: 'var(--color-primary)', marginRight: 4, verticalAlign: 'middle' }} />Prompt</span>
|
||||
<span><span style={{ display: 'inline-block', width: 8, height: 8, borderRadius: "var(--radius-sm)", background: 'var(--color-primary)', opacity: 0.35, marginRight: 4, verticalAlign: 'middle' }} />Completion</span>
|
||||
<span><span style={{ display: 'inline-block', width: 8, height: 8, borderRadius: "var(--radius-sm)", background: 'var(--color-data-3)', marginRight: 4, verticalAlign: 'middle' }} />Completion</span>
|
||||
</div>
|
||||
</div>
|
||||
<div style={{ display: 'flex', flexDirection: 'column', gap: gap }}>
|
||||
@@ -613,7 +613,7 @@ function ModelDistChart({ rows }) {
|
||||
</div>
|
||||
<div style={{ flex: 1, height: barH, background: 'var(--color-bg-primary)', borderRadius: "var(--radius-sm)", overflow: 'hidden', display: 'flex' }}>
|
||||
<div style={{ width: `${promptPct}%`, height: '100%', background: 'var(--color-primary)', transition: 'width 0.3s ease' }} />
|
||||
<div style={{ width: `${compPct}%`, height: '100%', background: 'var(--color-primary)', opacity: 0.35, transition: 'width 0.3s ease' }} />
|
||||
<div style={{ width: `${compPct}%`, height: '100%', background: 'var(--color-data-3)', transition: 'width 0.3s ease' }} />
|
||||
</div>
|
||||
<div style={{
|
||||
minWidth: 60, textAlign: 'right', fontSize: '0.75rem', fontFamily: 'var(--font-mono)',
|
||||
|
||||
@@ -101,7 +101,8 @@ export default function SourceTimeChart({ buckets = [], selectedKey, totals }) {
|
||||
viewBox={`0 0 ${width} ${height}`}
|
||||
preserveAspectRatio="none"
|
||||
style={{ width: '100%', height: 160, display: 'block' }}
|
||||
aria-hidden
|
||||
role="img"
|
||||
aria-label={t('usage.sources.topSources')}
|
||||
>
|
||||
{series.map((row, i) => {
|
||||
let y = height
|
||||
|
||||
@@ -5,6 +5,7 @@ import { useAuth } from '../context/AuthContext'
|
||||
import { adminUsersApi, adminInvitesApi } from '../utils/api'
|
||||
import LoadingSpinner from '../components/LoadingSpinner'
|
||||
import PageHeader from '../components/PageHeader'
|
||||
import ResponsiveTable from '../components/ResponsiveTable'
|
||||
import Modal from '../components/Modal'
|
||||
import ConfirmDialog from '../components/ConfirmDialog'
|
||||
import Toggle from '../components/Toggle'
|
||||
@@ -570,8 +571,7 @@ function InvitesTab({ addToast }) {
|
||||
<p className="empty-state-text">Generate an invite link to let someone register.</p>
|
||||
</div>
|
||||
) : (
|
||||
<div className="table-container">
|
||||
<table className="table">
|
||||
<ResponsiveTable>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Invite Link</th>
|
||||
@@ -638,8 +638,7 @@ function InvitesTab({ addToast }) {
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</ResponsiveTable>
|
||||
)}
|
||||
<ConfirmDialog
|
||||
open={!!confirmDialog}
|
||||
@@ -797,6 +796,33 @@ export default function Users() {
|
||||
return (u.name || '').toLowerCase().includes(q) || (u.email || '').toLowerCase().includes(q)
|
||||
})
|
||||
|
||||
const [sort, setSort] = useState({ key: null, dir: 'asc' })
|
||||
const USER_SORT = {
|
||||
name: (a, b) => (a.name || '').localeCompare(b.name || ''),
|
||||
email: (a, b) => (a.email || '').localeCompare(b.email || ''),
|
||||
provider: (a, b) => (a.provider || '').localeCompare(b.provider || ''),
|
||||
role: (a, b) => (a.role || '').localeCompare(b.role || ''),
|
||||
status: (a, b) => (a.status || '').localeCompare(b.status || ''),
|
||||
created: (a, b) => new Date(a.createdAt || 0) - new Date(b.createdAt || 0),
|
||||
}
|
||||
const sortedUsers = sort.key
|
||||
? [...filtered].sort((a, b) => sort.dir === 'asc' ? USER_SORT[sort.key](a, b) : USER_SORT[sort.key](b, a))
|
||||
: filtered
|
||||
const toggleSort = (key) => setSort(s => s.key === key ? { key, dir: s.dir === 'asc' ? 'desc' : 'asc' } : { key, dir: 'asc' })
|
||||
const sortableTh = (key, label, className) => (
|
||||
<th
|
||||
className={className}
|
||||
role="button"
|
||||
tabIndex={0}
|
||||
aria-sort={sort.key === key ? (sort.dir === 'asc' ? 'ascending' : 'descending') : 'none'}
|
||||
onClick={() => toggleSort(key)}
|
||||
onKeyDown={(e) => { if (e.key === 'Enter' || e.key === ' ') { e.preventDefault(); toggleSort(key) } }}
|
||||
style={{ cursor: 'pointer', userSelect: 'none' }}
|
||||
>
|
||||
{label}{sort.key === key && <i className={`fas fa-caret-${sort.dir === 'asc' ? 'up' : 'down'}`} style={{ marginLeft: 4, opacity: 0.7 }} aria-hidden="true" />}
|
||||
</th>
|
||||
)
|
||||
|
||||
const handlePermissionSave = (userId, newPerms, newModels, newQuotas) => {
|
||||
setUsers(prev => prev.map(u => u.id === userId ? { ...u, permissions: newPerms, allowed_models: newModels, quotas: newQuotas } : u))
|
||||
}
|
||||
@@ -854,22 +880,21 @@ export default function Users() {
|
||||
<p className="empty-state-text">{search ? 'Try a different search term.' : 'No registered users found.'}</p>
|
||||
</div>
|
||||
) : (
|
||||
<div className="table-container">
|
||||
<table className="table">
|
||||
<ResponsiveTable>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>User</th>
|
||||
<th>Email</th>
|
||||
<th>Provider</th>
|
||||
<th>Role</th>
|
||||
{sortableTh('name', 'User')}
|
||||
{sortableTh('email', 'Email')}
|
||||
{sortableTh('provider', 'Provider')}
|
||||
{sortableTh('role', 'Role')}
|
||||
<th>Permissions</th>
|
||||
<th>Status</th>
|
||||
<th>Created</th>
|
||||
{sortableTh('status', 'Status')}
|
||||
{sortableTh('created', 'Created')}
|
||||
<th className="cell-actions">Actions</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{filtered.map(u => (
|
||||
{sortedUsers.map(u => (
|
||||
<tr key={u.id}>
|
||||
<td>
|
||||
<div className="user-identity">
|
||||
@@ -943,8 +968,7 @@ export default function Users() {
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</ResponsiveTable>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/services/distributed"
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/mudler/xlog"
|
||||
@@ -402,6 +403,16 @@ func (g *GalleryService) applyCancel(id string) {
|
||||
}
|
||||
}
|
||||
|
||||
// newUserCancellableContext returns a child context whose CancelFunc cancels
|
||||
// with the downloader.ErrUserCancelled cause. This lets the download layer
|
||||
// distinguish a deliberate user cancel (discard the half-downloaded .partial)
|
||||
// from an incidental cancellation such as process shutdown (keep the .partial
|
||||
// so the next run resumes via Range instead of restarting from zero).
|
||||
func newUserCancellableContext(parent context.Context) (context.Context, context.CancelFunc) {
|
||||
ctx, cancelCause := context.WithCancelCause(parent)
|
||||
return ctx, func() { cancelCause(downloader.ErrUserCancelled) }
|
||||
}
|
||||
|
||||
// storeCancellation stores a cancellation function for an operation
|
||||
func (g *GalleryService) storeCancellation(id string, cancelFunc context.CancelFunc) {
|
||||
g.Lock()
|
||||
@@ -444,7 +455,7 @@ func (g *GalleryService) Start(c context.Context, cl *config.ModelConfigLoader,
|
||||
case op := <-g.BackendGalleryChannel:
|
||||
// Create context if not provided
|
||||
if op.Context == nil {
|
||||
op.Context, op.CancelFunc = context.WithCancel(c)
|
||||
op.Context, op.CancelFunc = newUserCancellableContext(c)
|
||||
g.storeCancellation(op.ID, op.CancelFunc)
|
||||
} else if op.CancelFunc != nil {
|
||||
g.storeCancellation(op.ID, op.CancelFunc)
|
||||
@@ -472,7 +483,7 @@ func (g *GalleryService) Start(c context.Context, cl *config.ModelConfigLoader,
|
||||
case op := <-g.ModelGalleryChannel:
|
||||
// Create context if not provided
|
||||
if op.Context == nil {
|
||||
op.Context, op.CancelFunc = context.WithCancel(c)
|
||||
op.Context, op.CancelFunc = newUserCancellableContext(c)
|
||||
g.storeCancellation(op.ID, op.CancelFunc)
|
||||
} else if op.CancelFunc != nil {
|
||||
g.storeCancellation(op.ID, op.CancelFunc)
|
||||
|
||||
@@ -36,6 +36,11 @@ type BackendNode struct {
|
||||
TotalRAM uint64 `gorm:"column:total_ram" json:"total_ram"` // Total system RAM in bytes (fallback when no GPU)
|
||||
AvailableRAM uint64 `gorm:"column:available_ram" json:"available_ram"` // Available system RAM in bytes
|
||||
GPUVendor string `gorm:"column:gpu_vendor;size:32" json:"gpu_vendor"` // nvidia, amd, intel, vulkan, unknown
|
||||
// GPUComputeCapability is the worker GPU's compute capability as
|
||||
// "major.minor" (e.g. "12.1" for GB10 / DGX Spark). Reported by the worker
|
||||
// on registration; used by the router to pick per-arch options (e.g. a
|
||||
// larger physical batch on Blackwell). Empty when unknown / non-NVIDIA.
|
||||
GPUComputeCapability string `gorm:"column:gpu_compute_capability;size:16" json:"gpu_compute_capability"`
|
||||
// MaxReplicasPerModel caps how many replicas of any one model can run on
|
||||
// this node concurrently. Default 1 preserves the historical "one
|
||||
// (node, model)" assumption; set higher (via worker --max-replicas-per-model)
|
||||
@@ -69,6 +74,7 @@ const (
|
||||
ColReservedVRAM = "reserved_vram"
|
||||
ColAvailableRAM = "available_ram"
|
||||
ColGPUVendor = "gpu_vendor"
|
||||
ColGPUComputeCap = "gpu_compute_capability"
|
||||
ColLastHeartbeat = "last_heartbeat"
|
||||
ColMaxReplicasPerModel = "max_replicas_per_model"
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/services/advisorylock"
|
||||
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
||||
"github.com/mudler/LocalAI/pkg/distributedhdr"
|
||||
@@ -138,6 +139,30 @@ type scheduleLoadResult struct {
|
||||
ReplicaIndex int
|
||||
}
|
||||
|
||||
// applyNodeHardwareDefaults tunes node-agnostic ModelOptions to the GPU of the
|
||||
// node that was actually selected to run the model, reusing the same hardware
|
||||
// heuristics as single-host config loading (core/config). On Blackwell it
|
||||
// raises the physical batch; on non-Blackwell it resets a hardware-default that
|
||||
// an upstream host (the GPU-less frontend in distributed mode) guessed higher.
|
||||
// Only values the heuristics themselves manage are touched, so an explicit user
|
||||
// batch (e.g. 1024) is never overridden.
|
||||
func applyNodeHardwareDefaults(opts *pb.ModelOptions, node *BackendNode) {
|
||||
if opts == nil || node == nil {
|
||||
return
|
||||
}
|
||||
gpu := config.GPU{
|
||||
Vendor: node.GPUVendor,
|
||||
ComputeCapability: node.GPUComputeCapability,
|
||||
VRAM: node.TotalVRAM,
|
||||
}
|
||||
if config.IsManagedPhysicalBatch(int(opts.NBatch)) {
|
||||
opts.NBatch = int32(config.PhysicalBatch(gpu))
|
||||
}
|
||||
// Default concurrent serving for the selected node (the frontend that built
|
||||
// the options may have no GPU). Only adds when no parallel option is set.
|
||||
opts.Options = config.EnsureParallelOption(opts.Options, gpu)
|
||||
}
|
||||
|
||||
// scheduleAndLoad is the shared core for loading a model on a new node.
|
||||
// Used by both Route() (for first-time loads) and ScheduleAndLoadModel() (for reconciler scale-ups).
|
||||
//
|
||||
@@ -153,6 +178,11 @@ func (r *SmartRouter) scheduleAndLoad(ctx context.Context, backendType, tracking
|
||||
return nil, fmt.Errorf("no available nodes: %w", err)
|
||||
}
|
||||
|
||||
// Tune node-agnostic options to the SELECTED node's GPU. Only now do we know
|
||||
// which node (and its compute capability) will run the model — the frontend
|
||||
// that built modelOpts may have no GPU at all in distributed mode.
|
||||
applyNodeHardwareDefaults(modelOpts, node)
|
||||
|
||||
// Pre-stage model files via FileStager before loading
|
||||
loadOpts := modelOpts
|
||||
if r.fileStager != nil && modelOpts != nil {
|
||||
|
||||
46
core/services/nodes/router_hardware_internal_test.go
Normal file
46
core/services/nodes/router_hardware_internal_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("applyNodeHardwareDefaults", func() {
|
||||
It("raises a managed default batch on a Blackwell node", func() {
|
||||
opts := &pb.ModelOptions{NBatch: config.DefaultPhysicalBatch}
|
||||
applyNodeHardwareDefaults(opts, &BackendNode{GPUComputeCapability: "12.1"})
|
||||
Expect(opts.NBatch).To(BeEquivalentTo(config.BlackwellPhysicalBatch))
|
||||
})
|
||||
|
||||
It("resets a Blackwell guess on a non-Blackwell node", func() {
|
||||
// frontend (Blackwell) guessed high, but the selected node is not Blackwell
|
||||
opts := &pb.ModelOptions{NBatch: config.BlackwellPhysicalBatch}
|
||||
applyNodeHardwareDefaults(opts, &BackendNode{GPUComputeCapability: "9.0"})
|
||||
Expect(opts.NBatch).To(BeEquivalentTo(config.DefaultPhysicalBatch))
|
||||
})
|
||||
|
||||
It("never overrides an explicit (non-managed) batch", func() {
|
||||
opts := &pb.ModelOptions{NBatch: 1024}
|
||||
applyNodeHardwareDefaults(opts, &BackendNode{GPUComputeCapability: "12.1"})
|
||||
Expect(opts.NBatch).To(BeEquivalentTo(int32(1024)))
|
||||
})
|
||||
|
||||
It("adds a VRAM-scaled parallel option for the selected node", func() {
|
||||
// frontend may have had no GPU (no parallel option); the node has a big GPU
|
||||
opts := &pb.ModelOptions{NBatch: config.DefaultPhysicalBatch}
|
||||
applyNodeHardwareDefaults(opts, &BackendNode{GPUComputeCapability: "12.1", TotalVRAM: 119 << 30})
|
||||
Expect(opts.Options).To(ContainElement("parallel:8"))
|
||||
})
|
||||
|
||||
It("never overrides an explicit parallel option on the node path", func() {
|
||||
opts := &pb.ModelOptions{NBatch: config.DefaultPhysicalBatch, Options: []string{"parallel:2"}}
|
||||
applyNodeHardwareDefaults(opts, &BackendNode{GPUComputeCapability: "12.1", TotalVRAM: 119 << 30})
|
||||
Expect(opts.Options).To(Equal([]string{"parallel:2"}))
|
||||
})
|
||||
|
||||
It("no-ops on nil inputs", func() {
|
||||
Expect(func() { applyNodeHardwareDefaults(nil, nil) }).ToNot(Panic())
|
||||
})
|
||||
})
|
||||
@@ -73,6 +73,10 @@ func (cfg *Config) registrationBody() map[string]any {
|
||||
// Detect GPU info for VRAM-aware scheduling
|
||||
totalVRAM, _ := xsysinfo.TotalAvailableVRAM()
|
||||
gpuVendor, _ := xsysinfo.DetectGPUVendor()
|
||||
// Compute capability (e.g. "12.1" for GB10) lets the router pick per-arch
|
||||
// options (e.g. larger physical batch on Blackwell). Detected on the worker
|
||||
// because only the worker sees the GPU in distributed mode.
|
||||
gpuComputeCap := xsysinfo.NVIDIAComputeCapability()
|
||||
|
||||
maxReplicas := cfg.MaxReplicasPerModel
|
||||
if maxReplicas < 1 {
|
||||
@@ -85,6 +89,7 @@ func (cfg *Config) registrationBody() map[string]any {
|
||||
"total_vram": totalVRAM,
|
||||
"available_vram": totalVRAM, // initially all VRAM is available
|
||||
"gpu_vendor": gpuVendor,
|
||||
"gpu_compute_capability": gpuComputeCap,
|
||||
"max_replicas_per_model": maxReplicas,
|
||||
}
|
||||
|
||||
|
||||
13
flake.lock
generated
13
flake.lock
generated
@@ -1,17 +1,5 @@
|
||||
{
|
||||
"nodes": {
|
||||
"inference-defaults": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"narHash": "sha256-ygWIkY2xiUEWqAZQM4/0vBz8vWd/RKX5VBj7EHovU14=",
|
||||
"type": "file",
|
||||
"url": "https://raw.githubusercontent.com/unslothai/unsloth/main/studio/backend/assets/configs/inference_defaults.json"
|
||||
},
|
||||
"original": {
|
||||
"type": "file",
|
||||
"url": "https://raw.githubusercontent.com/unslothai/unsloth/main/studio/backend/assets/configs/inference_defaults.json"
|
||||
}
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1777578337,
|
||||
@@ -30,7 +18,6 @@
|
||||
},
|
||||
"root": {
|
||||
"inputs": {
|
||||
"inference-defaults": "inference-defaults",
|
||||
"nixpkgs": "nixpkgs"
|
||||
}
|
||||
}
|
||||
|
||||
48
flake.nix
48
flake.nix
@@ -4,24 +4,36 @@
|
||||
|
||||
inputs = {
|
||||
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
|
||||
inference-defaults = {
|
||||
url = "https://raw.githubusercontent.com/unslothai/unsloth/main/studio/backend/assets/configs/inference_defaults.json";
|
||||
flake = false;
|
||||
};
|
||||
};
|
||||
|
||||
outputs = { self, nixpkgs, inference-defaults }:
|
||||
outputs = { self, nixpkgs }:
|
||||
let
|
||||
system = "x86_64-linux";
|
||||
pkgs = nixpkgs.legacyPackages.${system};
|
||||
in {
|
||||
packages.${system}.default = pkgs.buildGoModule {
|
||||
reactUi = pkgs.buildNpmPackage {
|
||||
pname = "localai-react-ui";
|
||||
version = "custom";
|
||||
src = ./core/http/react-ui;
|
||||
npmDeps = pkgs.importNpmLock {
|
||||
npmRoot = ./core/http/react-ui;
|
||||
};
|
||||
npmConfigHook = pkgs.importNpmLock.npmConfigHook;
|
||||
npmBuildScript = "build";
|
||||
|
||||
installPhase = ''
|
||||
runHook preInstall
|
||||
mkdir -p $out
|
||||
cp -r dist $out/
|
||||
runHook postInstall
|
||||
'';
|
||||
};
|
||||
localai-unwrapped = pkgs.buildGoModule {
|
||||
pname = "localai";
|
||||
version = "custom";
|
||||
|
||||
src = ./.;
|
||||
proxyVendor = true;
|
||||
vendorHash = "sha256-6f3adjGsoFXlUtXjBDHP4Mv9jKCOK3aeUXprm0EAVO8=";
|
||||
vendorHash = "sha256-z3lxQS8mXFuJzvYamejwapwVEmLpeAoiO3ksUKb4I3Q=";
|
||||
|
||||
nativeBuildInputs = with pkgs; [
|
||||
pkg-config cmake gcc protobuf go-protobuf protoc-gen-go protoc-gen-go-grpc
|
||||
@@ -44,8 +56,9 @@
|
||||
|
||||
go mod edit -replace github.com/mudler/LocalAI/pkg/grpc/proto=./pkg/grpc/proto
|
||||
|
||||
mkdir -p core/config/gen_inference_defaults
|
||||
cp ${inference-defaults} core/config/gen_inference_defaults/inference_defaults.json
|
||||
mkdir -p core/http/react-ui
|
||||
cp -r ${reactUi}/dist core/http/react-ui/dist
|
||||
|
||||
sed -i '/go:generate/d' core/config/inference_defaults.go || true
|
||||
|
||||
'';
|
||||
@@ -57,6 +70,21 @@
|
||||
[ -f $out/bin/local-ai ] && mv $out/bin/local-ai $out/bin/localai
|
||||
'';
|
||||
};
|
||||
in {
|
||||
packages.${system} = {
|
||||
localai-unwrapped = localai-unwrapped;
|
||||
|
||||
default = pkgs.buildFHSEnv {
|
||||
name = "localai";
|
||||
targetPkgs = pkgs: with pkgs; [
|
||||
localai-unwrapped
|
||||
bash
|
||||
coreutils
|
||||
gnugrep
|
||||
];
|
||||
runScript = "${localai-unwrapped}/bin/localai";
|
||||
};
|
||||
};
|
||||
|
||||
devShells.${system}.default = pkgs.mkShell {
|
||||
packages = with pkgs; [
|
||||
|
||||
@@ -34784,10 +34784,10 @@
|
||||
files:
|
||||
- filename: chatterbox-t3-q8_0.gguf
|
||||
uri: huggingface://cstr/chatterbox-GGUF/chatterbox-t3-q8_0.gguf
|
||||
sha256: 7b2da930c27df7e43d17a077bb58433b1bc33474ad66d781f715a7125f65d075
|
||||
sha256: d87e51da512ad54af66e587e5d5cf83762c407198cc284f825ea220062b9a67e
|
||||
- filename: chatterbox-s3gen-q8_0.gguf
|
||||
uri: huggingface://cstr/chatterbox-GGUF/chatterbox-s3gen-q8_0.gguf
|
||||
sha256: 6bbb93b892deeea73330cf773218e776e4bd0cf6ba71f60ef4dba72c922d0b3b
|
||||
sha256: 329cd9e3bdb273deae2a81754caeb67d3fe84e41bd308e98e35532645c9c4920
|
||||
- name: qwen3-tts-customvoice-crispasr
|
||||
url: github:mudler/LocalAI/gallery/virtual.yaml@master
|
||||
urls:
|
||||
|
||||
148
pkg/downloader/cancel_test.go
Normal file
148
pkg/downloader/cancel_test.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package downloader_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
. "github.com/mudler/LocalAI/pkg/downloader"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Download cancellation", func() {
|
||||
var filePath string
|
||||
|
||||
// streamingRangeServer serves data one small chunk at a time with a short
|
||||
// pause between chunks, so a context cancellation can land mid-transfer.
|
||||
// It honors a `bytes=N-` Range request so a second attempt can resume.
|
||||
streamingRangeServer := func(data []byte) *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "HEAD" {
|
||||
w.Header().Set("Accept-Ranges", "bytes")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
start := 0
|
||||
if rh := r.Header.Get("Range"); rh != "" {
|
||||
_, _ = fmt.Sscanf(strings.TrimPrefix(rh, "bytes="), "%d-", &start)
|
||||
}
|
||||
w.Header().Set("Content-Length", strconv.Itoa(len(data)-start))
|
||||
if start > 0 {
|
||||
w.WriteHeader(http.StatusPartialContent)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
f, _ := w.(http.Flusher)
|
||||
for i := start; i < len(data); i += 256 {
|
||||
end := i + 256
|
||||
if end > len(data) {
|
||||
end = len(data)
|
||||
}
|
||||
if _, err := w.Write(data[i:end]); err != nil {
|
||||
return
|
||||
}
|
||||
if f != nil {
|
||||
f.Flush()
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
dir, err := os.Getwd()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
filePath = dir + "/cancel_model"
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
_ = os.Remove(filePath)
|
||||
_ = os.Remove(filePath + ".partial")
|
||||
})
|
||||
|
||||
It("keeps the .partial file when the context is cancelled so the download can resume", func() {
|
||||
data := make([]byte, 8192)
|
||||
_, err := rand.Read(data)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
server := streamingRangeServer(data)
|
||||
defer server.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
err = URI(server.URL).DownloadFileWithContext(ctx, filePath, "", 1, 1, func(s1, s2, s3 string, f float64) {})
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(errors.Is(err, context.Canceled)).To(BeTrue())
|
||||
|
||||
info, statErr := os.Stat(filePath + ".partial")
|
||||
Expect(statErr).ToNot(HaveOccurred(),
|
||||
"a cancelled download must leave its .partial behind so the retry resumes instead of restarting from zero")
|
||||
Expect(info.Size()).To(BeNumerically(">", 0))
|
||||
Expect(info.Size()).To(BeNumerically("<", int64(len(data))))
|
||||
})
|
||||
|
||||
It("discards the .partial when the cancellation cause is ErrUserCancelled", func() {
|
||||
data := make([]byte, 8192)
|
||||
_, err := rand.Read(data)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
server := streamingRangeServer(data)
|
||||
defer server.Close()
|
||||
|
||||
// A deliberate user abort: cancel WITH the ErrUserCancelled cause. The
|
||||
// half-finished download should not linger on disk.
|
||||
ctx, cancel := context.WithCancelCause(context.Background())
|
||||
go func() {
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
cancel(ErrUserCancelled)
|
||||
}()
|
||||
|
||||
err = URI(server.URL).DownloadFileWithContext(ctx, filePath, "", 1, 1, func(s1, s2, s3 string, f float64) {})
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(errors.Is(err, context.Canceled)).To(BeTrue())
|
||||
|
||||
Expect(filePath + ".partial").ToNot(BeAnExistingFile(),
|
||||
"a deliberate user cancel must not leave a dangling .partial behind")
|
||||
})
|
||||
|
||||
It("resumes from the preserved .partial after a cancellation and completes", func() {
|
||||
data := make([]byte, 8192)
|
||||
_, err := rand.Read(data)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
sum := sha256.Sum256(data)
|
||||
sha := fmt.Sprintf("%x", sum)
|
||||
server := streamingRangeServer(data)
|
||||
defer server.Close()
|
||||
|
||||
// First attempt: cancel mid-stream.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
err = URI(server.URL).DownloadFileWithContext(ctx, filePath, sha, 1, 1, func(s1, s2, s3 string, f float64) {})
|
||||
Expect(err).To(HaveOccurred())
|
||||
partialInfo, statErr := os.Stat(filePath + ".partial")
|
||||
Expect(statErr).ToNot(HaveOccurred())
|
||||
resumedFrom := partialInfo.Size()
|
||||
Expect(resumedFrom).To(BeNumerically(">", 0))
|
||||
|
||||
// Second attempt: fresh context, must resume and finish with a valid SHA.
|
||||
err = URI(server.URL).DownloadFileWithContext(context.Background(), filePath, sha, 1, 1, func(s1, s2, s3 string, f float64) {})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
final, rerr := os.ReadFile(filePath)
|
||||
Expect(rerr).ToNot(HaveOccurred())
|
||||
Expect(final).To(Equal(data))
|
||||
})
|
||||
})
|
||||
69
pkg/downloader/partial.go
Normal file
69
pkg/downloader/partial.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// PartialFileSuffix marks an in-progress download. The success path renames the
|
||||
// partial to its final name, so any leftover with this suffix is an unfinished
|
||||
// transfer.
|
||||
const PartialFileSuffix = ".partial"
|
||||
|
||||
// CleanupStalePartialFiles removes *.partial files under root whose last
|
||||
// modification is older than olderThan, returning the number removed. These are
|
||||
// abandoned downloads left by a process killed mid-transfer (OOM, restart) or
|
||||
// by a stall whose cleanup never ran; without reaping they accumulate and can
|
||||
// fill the models volume. A still-in-progress download touches its .partial on
|
||||
// every write, so a generous olderThan never trims an active transfer.
|
||||
//
|
||||
// A missing root is not an error (nothing to clean). Unreadable entries are
|
||||
// skipped so one bad file does not abort the whole sweep.
|
||||
func CleanupStalePartialFiles(root string, olderThan time.Duration) (int, error) {
|
||||
if _, err := os.Stat(root); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return 0, nil
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
|
||||
cutoff := time.Now().Add(-olderThan)
|
||||
|
||||
// Collect candidates during the walk and delete them afterwards rather than
|
||||
// mutating the tree from inside the WalkDir callback (avoids the symlink
|
||||
// TOCTOU class flagged by gosec G122, and never removes an entry mid-walk).
|
||||
var stale []string
|
||||
err := filepath.WalkDir(root, func(path string, d fs.DirEntry, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
return nil // skip unreadable subtree, keep going
|
||||
}
|
||||
if d.IsDir() || !strings.HasSuffix(d.Name(), PartialFileSuffix) {
|
||||
return nil
|
||||
}
|
||||
info, err := d.Info()
|
||||
if err != nil || info.ModTime().After(cutoff) {
|
||||
return nil
|
||||
}
|
||||
stale = append(stale, path)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
removed := 0
|
||||
for _, path := range stale {
|
||||
if err := os.Remove(path); err != nil {
|
||||
xlog.Warn("failed to remove stale partial download", "file", path, "error", err)
|
||||
continue
|
||||
}
|
||||
removed++
|
||||
xlog.Info("removed stale partial download", "file", path)
|
||||
}
|
||||
return removed, nil
|
||||
}
|
||||
53
pkg/downloader/partial_test.go
Normal file
53
pkg/downloader/partial_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package downloader_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
. "github.com/mudler/LocalAI/pkg/downloader"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("CleanupStalePartialFiles", func() {
|
||||
var root string
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
root, err = os.MkdirTemp("", "partials")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
_ = os.RemoveAll(root)
|
||||
})
|
||||
|
||||
It("removes stale .partial files (recursively) while keeping fresh ones and completed files", func() {
|
||||
nested := filepath.Join(root, "llama-cpp", "models", "foo")
|
||||
Expect(os.MkdirAll(nested, 0755)).To(Succeed())
|
||||
|
||||
stale := filepath.Join(nested, "model.gguf.partial")
|
||||
fresh := filepath.Join(root, "fresh.gguf.partial")
|
||||
completed := filepath.Join(root, "done.gguf")
|
||||
for _, f := range []string{stale, fresh, completed} {
|
||||
Expect(os.WriteFile(f, []byte("data"), 0644)).To(Succeed())
|
||||
}
|
||||
old := time.Now().Add(-2 * time.Hour)
|
||||
Expect(os.Chtimes(stale, old, old)).To(Succeed())
|
||||
|
||||
removed, err := CleanupStalePartialFiles(root, time.Hour)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(removed).To(Equal(1))
|
||||
|
||||
Expect(stale).ToNot(BeAnExistingFile())
|
||||
Expect(fresh).To(BeAnExistingFile())
|
||||
Expect(completed).To(BeAnExistingFile())
|
||||
})
|
||||
|
||||
It("returns no error when the root directory does not exist", func() {
|
||||
removed, err := CleanupStalePartialFiles(filepath.Join(root, "does-not-exist"), time.Hour)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(removed).To(Equal(0))
|
||||
})
|
||||
})
|
||||
77
pkg/downloader/stall.go
Normal file
77
pkg/downloader/stall.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DownloadStallTimeout bounds how long an in-flight download may receive no
|
||||
// data before it is aborted. A silently-dropped TCP connection (no FIN/RST)
|
||||
// would otherwise block the body read forever, freezing an install at N bytes
|
||||
// until an external reaper kills it. Overridable (tests set it small); a value
|
||||
// <= 0 disables the guard.
|
||||
var DownloadStallTimeout = 60 * time.Second
|
||||
|
||||
// idleTimeoutReader wraps a streaming ReadCloser and aborts reads that make no
|
||||
// progress within timeout. A standard io.Copy blocks indefinitely on a Read
|
||||
// against a dead-but-unclosed socket; nothing in the copy loop can interrupt a
|
||||
// blocked syscall. The watchdog timer closes the underlying reader on expiry,
|
||||
// which unblocks the in-flight Read with an error. Each read that returns data
|
||||
// resets the idle clock, so a slow-but-steady transfer never trips the guard.
|
||||
type idleTimeoutReader struct {
|
||||
rc io.ReadCloser
|
||||
timeout time.Duration
|
||||
|
||||
mu sync.Mutex
|
||||
timer *time.Timer
|
||||
fired bool
|
||||
done bool
|
||||
}
|
||||
|
||||
func newIdleTimeoutReader(rc io.ReadCloser, timeout time.Duration) *idleTimeoutReader {
|
||||
r := &idleTimeoutReader{rc: rc, timeout: timeout}
|
||||
r.timer = time.AfterFunc(timeout, r.onStall)
|
||||
return r
|
||||
}
|
||||
|
||||
// onStall fires when no data has arrived within the timeout. Closing the
|
||||
// underlying reader is what unblocks a Read parked in the kernel.
|
||||
func (r *idleTimeoutReader) onStall() {
|
||||
r.mu.Lock()
|
||||
if r.done {
|
||||
r.mu.Unlock()
|
||||
return
|
||||
}
|
||||
r.fired = true
|
||||
r.mu.Unlock()
|
||||
_ = r.rc.Close()
|
||||
}
|
||||
|
||||
func (r *idleTimeoutReader) Read(p []byte) (int, error) {
|
||||
n, err := r.rc.Read(p)
|
||||
if n > 0 {
|
||||
r.timer.Reset(r.timeout)
|
||||
}
|
||||
if err != nil {
|
||||
r.mu.Lock()
|
||||
fired := r.fired
|
||||
r.mu.Unlock()
|
||||
if fired {
|
||||
// Translate the "use of closed connection" the watchdog induced
|
||||
// into an actionable stall error. This is not context.Canceled,
|
||||
// so the caller keeps the .partial file for a later resume.
|
||||
return n, fmt.Errorf("download stalled: no data received for %s", r.timeout)
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (r *idleTimeoutReader) Close() error {
|
||||
r.mu.Lock()
|
||||
r.done = true
|
||||
r.mu.Unlock()
|
||||
r.timer.Stop()
|
||||
return r.rc.Close()
|
||||
}
|
||||
131
pkg/downloader/stall_test.go
Normal file
131
pkg/downloader/stall_test.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package downloader_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
. "github.com/mudler/LocalAI/pkg/downloader"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Download stall timeout", func() {
|
||||
var filePath string
|
||||
var savedTimeout time.Duration
|
||||
|
||||
BeforeEach(func() {
|
||||
dir, err := os.Getwd()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
filePath = dir + "/stall_model"
|
||||
savedTimeout = DownloadStallTimeout
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
DownloadStallTimeout = savedTimeout
|
||||
_ = os.Remove(filePath)
|
||||
_ = os.Remove(filePath + ".partial")
|
||||
})
|
||||
|
||||
It("aborts a download that stalls mid-stream instead of hanging forever", func() {
|
||||
// Server sends a chunk, flushes, then blocks forever without closing
|
||||
// the connection — a silently-dropped TCP stream. Without a stall
|
||||
// guard the body Read blocks indefinitely and DownloadFile never
|
||||
// returns.
|
||||
release := make(chan struct{})
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "HEAD" {
|
||||
w.Header().Set("Accept-Ranges", "bytes")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(make([]byte, 4096))
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
<-release // hang: no more data, never close
|
||||
}))
|
||||
defer server.Close()
|
||||
defer close(release)
|
||||
|
||||
DownloadStallTimeout = 300 * time.Millisecond
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- URI(server.URL).DownloadFileWithContext(
|
||||
context.Background(), filePath, "", 1, 1,
|
||||
func(s1, s2, s3 string, f float64) {})
|
||||
}()
|
||||
|
||||
var err error
|
||||
Eventually(done, "5s").Should(Receive(&err))
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("stall"))
|
||||
})
|
||||
|
||||
It("preserves the .partial file when a download stalls so it can resume", func() {
|
||||
release := make(chan struct{})
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "HEAD" {
|
||||
w.Header().Set("Accept-Ranges", "bytes")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(make([]byte, 4096))
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
<-release
|
||||
}))
|
||||
defer server.Close()
|
||||
defer close(release)
|
||||
|
||||
DownloadStallTimeout = 300 * time.Millisecond
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- URI(server.URL).DownloadFileWithContext(
|
||||
context.Background(), filePath, "", 1, 1,
|
||||
func(s1, s2, s3 string, f float64) {})
|
||||
}()
|
||||
Eventually(done, "5s").Should(Receive(HaveOccurred()))
|
||||
|
||||
info, statErr := os.Stat(filePath + ".partial")
|
||||
Expect(statErr).ToNot(HaveOccurred(), "the .partial must survive a stall so the next attempt can resume")
|
||||
Expect(info.Size()).To(BeNumerically(">", 0))
|
||||
})
|
||||
|
||||
It("does not abort a slow-but-steady download", func() {
|
||||
// One byte every 100ms keeps the idle clock from ever expiring even
|
||||
// though the total transfer outlasts the stall timeout.
|
||||
payload := make([]byte, 12)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "HEAD" {
|
||||
w.Header().Set("Accept-Ranges", "bytes")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
f, _ := w.(http.Flusher)
|
||||
for i := range payload {
|
||||
_, _ = w.Write(payload[i : i+1])
|
||||
if f != nil {
|
||||
f.Flush()
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
DownloadStallTimeout = 300 * time.Millisecond
|
||||
|
||||
err := URI(server.URL).DownloadFileWithContext(
|
||||
context.Background(), filePath, "", 1, 1,
|
||||
func(s1, s2, s3 string, f float64) {})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
})
|
||||
@@ -330,6 +330,18 @@ func (s URI) ResolveURL() string {
|
||||
return string(s)
|
||||
}
|
||||
|
||||
// ErrUserCancelled distinguishes a deliberate user abort from an incidental
|
||||
// context cancellation (process shutdown, pod restart). Pass it as the cause
|
||||
// when cancelling the download context:
|
||||
//
|
||||
// ctx, cancel := context.WithCancelCause(parent)
|
||||
// cancel(downloader.ErrUserCancelled) // discards the .partial
|
||||
//
|
||||
// On a deliberate cancel the downloader removes the .partial (the user does not
|
||||
// want a half-download lingering). On a plain cancellation it keeps the .partial
|
||||
// so the next run resumes via Range instead of restarting from zero.
|
||||
var ErrUserCancelled = errors.New("download cancelled by user")
|
||||
|
||||
func removePartialFile(tmpFilePath string) error {
|
||||
xlog.Debug("Removing temporary file", "file", tmpFilePath)
|
||||
if err := os.Remove(tmpFilePath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
@@ -594,11 +606,17 @@ func (uri URI) DownloadFileWithContext(ctx context.Context, filePath, sha string
|
||||
// Start the request
|
||||
resp, err := downloadClient.Do(req)
|
||||
if err != nil {
|
||||
// Check if error is due to context cancellation
|
||||
if errors.Is(err, context.Canceled) {
|
||||
// Clean up partial file on cancellation
|
||||
removePartialFile(tmpFilePath)
|
||||
return err
|
||||
// Detect cancellation via the context, not the returned error: a
|
||||
// request cancelled *with a cause* surfaces the cause error (not
|
||||
// context.Canceled) from the HTTP client. Keep the .partial for
|
||||
// resume on an incidental cancel (shutdown, restart) — large GGUFs
|
||||
// take long enough that deleting progress means they never finish —
|
||||
// but discard it on a deliberate user abort (ErrUserCancelled).
|
||||
if ctx.Err() != nil {
|
||||
if errors.Is(context.Cause(ctx), ErrUserCancelled) {
|
||||
_ = removePartialFile(tmpFilePath)
|
||||
}
|
||||
return ctx.Err()
|
||||
}
|
||||
return fmt.Errorf("failed to download file %q: %v", filePath, err)
|
||||
}
|
||||
@@ -608,6 +626,13 @@ func (uri URI) DownloadFileWithContext(ctx context.Context, filePath, sha string
|
||||
return fmt.Errorf("failed to download url %q, invalid status code %d", url, resp.StatusCode)
|
||||
}
|
||||
source = resp.Body
|
||||
// Guard against a silently-stalled stream: a dropped TCP connection
|
||||
// that never sends FIN/RST would otherwise block the body Read (and
|
||||
// thus the whole install) forever. The watchdog aborts after a window
|
||||
// of zero progress; the .partial is kept for a later resume.
|
||||
if DownloadStallTimeout > 0 {
|
||||
source = newIdleTimeoutReader(resp.Body, DownloadStallTimeout)
|
||||
}
|
||||
contentLength = resp.ContentLength
|
||||
}
|
||||
defer source.Close()
|
||||
@@ -640,19 +665,27 @@ func (uri URI) DownloadFileWithContext(ctx context.Context, filePath, sha string
|
||||
|
||||
_, err = xio.Copy(ctx, io.MultiWriter(outFile, progress), source)
|
||||
if err != nil {
|
||||
// Check if error is due to context cancellation
|
||||
if errors.Is(err, context.Canceled) {
|
||||
// Clean up partial file on cancellation
|
||||
removePartialFile(tmpFilePath)
|
||||
return err
|
||||
// Detect cancellation via the context (a cause-cancelled read surfaces
|
||||
// the cause, not context.Canceled). Keep the .partial for resume,
|
||||
// except on a deliberate user abort (ErrUserCancelled), which discards
|
||||
// it. A stall-guard abort leaves ctx uncancelled, so it falls through
|
||||
// to the error path below and likewise preserves the partial.
|
||||
if ctx.Err() != nil {
|
||||
if errors.Is(context.Cause(ctx), ErrUserCancelled) {
|
||||
_ = removePartialFile(tmpFilePath)
|
||||
}
|
||||
return ctx.Err()
|
||||
}
|
||||
return fmt.Errorf("failed to write file %q: %v", filePath, err)
|
||||
}
|
||||
|
||||
// Check for cancellation before finalizing
|
||||
// Check for cancellation before finalizing. Keep the .partial for resume
|
||||
// unless the user deliberately aborted.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
removePartialFile(tmpFilePath)
|
||||
if errors.Is(context.Cause(ctx), ErrUserCancelled) {
|
||||
_ = removePartialFile(tmpFilePath)
|
||||
}
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
@@ -243,6 +243,14 @@ func (s *server) AudioTranscription(ctx context.Context, in *pb.TranscriptReques
|
||||
for _, t := range s.Tokens {
|
||||
tks = append(tks, int32(t))
|
||||
}
|
||||
words := make([]*pb.TranscriptWord, 0, len(s.Words))
|
||||
for _, w := range s.Words {
|
||||
words = append(words, &pb.TranscriptWord{
|
||||
Start: int64(w.Start),
|
||||
End: int64(w.End),
|
||||
Text: w.Text,
|
||||
})
|
||||
}
|
||||
tresult.Segments = append(tresult.Segments,
|
||||
&pb.TranscriptSegment{
|
||||
Text: s.Text,
|
||||
@@ -251,6 +259,7 @@ func (s *server) AudioTranscription(ctx context.Context, in *pb.TranscriptReques
|
||||
End: int64(s.End),
|
||||
Tokens: tks,
|
||||
Speaker: s.Speaker,
|
||||
Words: words,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -154,11 +154,20 @@ func (ml *ModelLoader) startProcess(grpcProcess, id string, serverAddress string
|
||||
return nil, err
|
||||
}
|
||||
|
||||
env := os.Environ()
|
||||
// Vulkan backends are self-contained: they bundle their own loader and
|
||||
// Mesa driver .so files in lib/ plus the matching ICD manifests in
|
||||
// vulkan/icd.d/. Point the loader at those manifests so it doesn't rely on
|
||||
// the runtime base image shipping a Vulkan driver (it carries the
|
||||
// SYCL/Level-Zero stack instead, so the default ICD search path is empty
|
||||
// and the GPU would silently fall back to CPU). No-op for other backends.
|
||||
env = append(env, vulkanICDEnv(workDir)...)
|
||||
|
||||
grpcControlProcess := process.New(
|
||||
process.WithTemporaryStateDir(),
|
||||
process.WithName(filepath.Base(grpcProcess)),
|
||||
process.WithArgs(append(args, []string{"--addr", serverAddress}...)...),
|
||||
process.WithEnvironment(os.Environ()...),
|
||||
process.WithEnvironment(env...),
|
||||
process.WithWorkDir(workDir),
|
||||
)
|
||||
|
||||
@@ -249,3 +258,38 @@ func (ml *ModelLoader) startProcess(grpcProcess, id string, serverAddress string
|
||||
|
||||
return grpcControlProcess, nil
|
||||
}
|
||||
|
||||
// vulkanICDEnv returns environment overrides that point the Vulkan loader at
|
||||
// the ICD manifests a backend bundles in <workDir>/vulkan/icd.d. Vulkan
|
||||
// backends ship a self-contained stack — their own loader and Mesa driver .so
|
||||
// files in lib/ (resolved via the LD_LIBRARY_PATH that run.sh sets) plus the
|
||||
// matching ICD manifests — so the loader must be told where those manifests
|
||||
// live; its default search path (/usr/share/vulkan/icd.d, /etc/vulkan/icd.d)
|
||||
// is empty on the runtime base image. Returns nil when the directory holds no
|
||||
// manifests (CPU/CUDA/SYCL builds), leaving the host's Vulkan setup untouched.
|
||||
func vulkanICDEnv(workDir string) []string {
|
||||
icdDir := filepath.Join(workDir, "vulkan", "icd.d")
|
||||
entries, err := os.ReadDir(icdDir)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
manifests := make([]string, 0, len(entries))
|
||||
for _, e := range entries {
|
||||
if e.IsDir() || !strings.HasSuffix(e.Name(), ".json") {
|
||||
continue
|
||||
}
|
||||
manifests = append(manifests, filepath.Join(icdDir, e.Name()))
|
||||
}
|
||||
if len(manifests) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
list := strings.Join(manifests, string(os.PathListSeparator))
|
||||
// VK_DRIVER_FILES is the current loader variable; VK_ICD_FILENAMES is its
|
||||
// deprecated alias, set too so older bundled loaders still pick it up.
|
||||
return []string{
|
||||
"VK_DRIVER_FILES=" + list,
|
||||
"VK_ICD_FILENAMES=" + list,
|
||||
}
|
||||
}
|
||||
|
||||
58
pkg/model/process_vulkan_test.go
Normal file
58
pkg/model/process_vulkan_test.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("vulkanICDEnv", func() {
|
||||
It("returns nil when the backend ships no vulkan/icd.d (CPU/CUDA/SYCL builds)", func() {
|
||||
Expect(vulkanICDEnv(GinkgoT().TempDir())).To(BeNil())
|
||||
})
|
||||
|
||||
It("returns nil when icd.d exists but holds no .json manifests", func() {
|
||||
work := GinkgoT().TempDir()
|
||||
icdDir := filepath.Join(work, "vulkan", "icd.d")
|
||||
Expect(os.MkdirAll(icdDir, 0o755)).To(Succeed())
|
||||
Expect(os.WriteFile(filepath.Join(icdDir, "README.txt"), []byte("not a manifest"), 0o644)).To(Succeed())
|
||||
// A directory whose name ends in .json must be ignored.
|
||||
Expect(os.MkdirAll(filepath.Join(icdDir, "nested.json"), 0o755)).To(Succeed())
|
||||
|
||||
Expect(vulkanICDEnv(work)).To(BeNil())
|
||||
})
|
||||
|
||||
It("points VK_DRIVER_FILES/VK_ICD_FILENAMES at the bundled manifests", func() {
|
||||
work := GinkgoT().TempDir()
|
||||
icdDir := filepath.Join(work, "vulkan", "icd.d")
|
||||
Expect(os.MkdirAll(icdDir, 0o755)).To(Succeed())
|
||||
for _, name := range []string{"intel_icd.json", "lvp_icd.json"} {
|
||||
Expect(os.WriteFile(filepath.Join(icdDir, name), []byte("{}"), 0o644)).To(Succeed())
|
||||
}
|
||||
|
||||
env := vulkanICDEnv(work)
|
||||
Expect(env).To(HaveLen(2))
|
||||
|
||||
got := map[string]string{}
|
||||
for _, kv := range env {
|
||||
k, v, ok := strings.Cut(kv, "=")
|
||||
Expect(ok).To(BeTrue(), "malformed env entry %q", kv)
|
||||
got[k] = v
|
||||
}
|
||||
|
||||
for _, key := range []string{"VK_DRIVER_FILES", "VK_ICD_FILENAMES"} {
|
||||
Expect(got).To(HaveKey(key))
|
||||
// Both manifests must be listed as absolute paths, joined by the
|
||||
// OS path-list separator the Vulkan loader expects.
|
||||
parts := strings.Split(got[key], string(os.PathListSeparator))
|
||||
Expect(parts).To(HaveLen(2))
|
||||
for _, p := range parts {
|
||||
Expect(filepath.IsAbs(p)).To(BeTrue(), "%s entry %q must be absolute", key, p)
|
||||
Expect(p).To(HaveSuffix(".json"))
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
23
pkg/xsysinfo/computecap_internal_test.go
Normal file
23
pkg/xsysinfo/computecap_internal_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package xsysinfo
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("parseComputeCap", func() {
|
||||
DescribeTable("splits major.minor",
|
||||
func(in string, maj, min int) {
|
||||
m, n := parseComputeCap(in)
|
||||
Expect(m).To(Equal(maj))
|
||||
Expect(n).To(Equal(min))
|
||||
},
|
||||
Entry("GB10 / DGX Spark", "12.1", 12, 1),
|
||||
Entry("RTX 50-series", "12.0", 12, 0),
|
||||
Entry("Hopper", "9.0", 9, 0),
|
||||
Entry("major only", "12", 12, 0),
|
||||
Entry("whitespace", " 12.1 ", 12, 1),
|
||||
Entry("empty", "", -1, -1),
|
||||
Entry("garbage", "abc", -1, -1),
|
||||
)
|
||||
})
|
||||
@@ -38,9 +38,9 @@ var UnifiedMemoryDevices = []string{
|
||||
|
||||
// GPUMemoryInfo contains real-time GPU memory usage information
|
||||
type GPUMemoryInfo struct {
|
||||
Index int `json:"index"`
|
||||
Name string `json:"name"`
|
||||
Vendor string `json:"vendor"`
|
||||
Index int `json:"index"`
|
||||
Name string `json:"name"`
|
||||
Vendor string `json:"vendor"`
|
||||
// BDF is the canonical PCI bus address (dddd:bb:dd.f) when known.
|
||||
// Populated by detection paths that can attribute the device to a
|
||||
// PCI location (clinfo, future amdgpu/nvidia paths); empty for
|
||||
@@ -307,6 +307,84 @@ func GetGPUAggregateInfo() GPUAggregateInfo {
|
||||
return aggregate
|
||||
}
|
||||
|
||||
var (
|
||||
computeCapOnce sync.Once
|
||||
computeCapResult string
|
||||
)
|
||||
|
||||
// NVIDIAComputeCapability returns the highest NVIDIA GPU compute capability on
|
||||
// this host as a "major.minor" string (e.g. "12.1" for GB10 / DGX Spark), or ""
|
||||
// when nvidia-smi is unavailable or reports none. Detected once and cached.
|
||||
//
|
||||
// This runs where the GPU actually is. In distributed mode it is reported by
|
||||
// each worker on registration so the router can make per-node decisions rather
|
||||
// than guessing from the (possibly GPU-less) frontend host.
|
||||
func NVIDIAComputeCapability() string {
|
||||
computeCapOnce.Do(func() {
|
||||
computeCapResult = detectNVIDIAComputeCapability()
|
||||
})
|
||||
return computeCapResult
|
||||
}
|
||||
|
||||
func detectNVIDIAComputeCapability() string {
|
||||
if _, err := exec.LookPath("nvidia-smi"); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
cmd := exec.Command("nvidia-smi", "--query-gpu=compute_cap", "--format=csv,noheader")
|
||||
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
xlog.Debug("nvidia-smi compute_cap query failed", "error", err, "stderr", stderr.String())
|
||||
return ""
|
||||
}
|
||||
|
||||
best := ""
|
||||
bestMajor, bestMinor := -1, -1
|
||||
for _, line := range strings.Split(strings.TrimSpace(stdout.String()), "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
maj, min := parseComputeCap(line)
|
||||
if maj < 0 {
|
||||
continue
|
||||
}
|
||||
if maj > bestMajor || (maj == bestMajor && min > bestMinor) {
|
||||
bestMajor, bestMinor, best = maj, min, line
|
||||
}
|
||||
}
|
||||
if best != "" {
|
||||
xlog.Debug("NVIDIA compute capability detected", "compute_cap", best)
|
||||
}
|
||||
return best
|
||||
}
|
||||
|
||||
// parseComputeCap splits a "major.minor" compute-capability string into its
|
||||
// integer parts. Returns (-1, -1) if it can't be parsed.
|
||||
func parseComputeCap(cc string) (int, int) {
|
||||
cc = strings.TrimSpace(cc)
|
||||
if cc == "" {
|
||||
return -1, -1
|
||||
}
|
||||
majStr, minStr := cc, "0"
|
||||
if dot := strings.IndexByte(cc, '.'); dot >= 0 {
|
||||
majStr, minStr = cc[:dot], cc[dot+1:]
|
||||
}
|
||||
maj, err := strconv.Atoi(strings.TrimSpace(majStr))
|
||||
if err != nil {
|
||||
return -1, -1
|
||||
}
|
||||
min, err := strconv.Atoi(strings.TrimSpace(minStr))
|
||||
if err != nil {
|
||||
min = 0
|
||||
}
|
||||
return maj, min
|
||||
}
|
||||
|
||||
// getNVIDIAGPUMemory queries NVIDIA GPUs using nvidia-smi
|
||||
func getNVIDIAGPUMemory() []GPUMemoryInfo {
|
||||
// Check if nvidia-smi is available
|
||||
@@ -866,12 +944,12 @@ func getVulkanGPUMemory() []GPUMemoryInfo {
|
||||
}
|
||||
|
||||
type vulkanGPUTextInfo struct {
|
||||
index int
|
||||
name string
|
||||
deviceType string
|
||||
totalVRAM uint64
|
||||
budgetVRAM uint64
|
||||
usageVRAM uint64
|
||||
index int
|
||||
name string
|
||||
deviceType string
|
||||
totalVRAM uint64
|
||||
budgetVRAM uint64
|
||||
usageVRAM uint64
|
||||
}
|
||||
|
||||
func parseVulkanGPUMemoryText(r io.Reader) []GPUMemoryInfo {
|
||||
@@ -909,7 +987,7 @@ func parseVulkanGPUMemoryText(r io.Reader) []GPUMemoryInfo {
|
||||
} else if current.usageVRAM != 0 && current.budgetVRAM == 0 {
|
||||
current.budgetVRAM = current.totalVRAM - current.usageVRAM
|
||||
} else if current.usageVRAM == 0 && current.budgetVRAM == 0 {
|
||||
current.usageVRAM = 0
|
||||
current.usageVRAM = 0
|
||||
current.budgetVRAM = current.totalVRAM
|
||||
}
|
||||
|
||||
|
||||
@@ -109,6 +109,38 @@ copy_libs_glob() {
|
||||
done
|
||||
}
|
||||
|
||||
# Returns success for the core runtime libs the base image and package.sh
|
||||
# already provide. We must NOT bundle our own copies of these — a second libc
|
||||
# or libstdc++ on LD_LIBRARY_PATH clashes with the loader and the rest of the
|
||||
# process — so they're skipped when pulling in a driver's transitive deps.
|
||||
is_core_lib() {
|
||||
case "$1" in
|
||||
ld-linux*|ld.so|libc.so.*|libm.so.*|libdl.so.*|libpthread.so.*|librt.so.*|\
|
||||
libgcc_s.so.*|libstdc++.so.*|libresolv.so.*|libutil.so.*|linux-vdso.so.*)
|
||||
return 0 ;;
|
||||
esac
|
||||
return 1
|
||||
}
|
||||
|
||||
# Copy the shared-library dependencies of an ELF file into TARGET_LIB_DIR.
|
||||
# Used to make a bundled GPU driver self-contained: e.g. the Mesa Vulkan ICDs
|
||||
# pull in libdrm, libexpat and (for RADV/lavapipe) libLLVM, none of which the
|
||||
# runtime base image is guaranteed to have. Core libc-family deps are skipped.
|
||||
copy_elf_deps() {
|
||||
local elf="$1"
|
||||
[ -e "$elf" ] || return 0
|
||||
command -v ldd >/dev/null 2>&1 || return 0
|
||||
|
||||
# ldd lines look like: "<TAB>libfoo.so.1 => /path/to/libfoo.so.1 (0x..)".
|
||||
# Take the resolved absolute path (field 3) and skip vdso/static entries.
|
||||
while read -r dep; do
|
||||
if is_core_lib "$(basename "$dep")"; then
|
||||
continue
|
||||
fi
|
||||
copy_lib "$dep"
|
||||
done < <(ldd "$elf" 2>/dev/null | awk '/=>/ && $3 ~ /^\// {print $3}')
|
||||
}
|
||||
|
||||
# Package NVIDIA CUDA libraries
|
||||
package_cuda_libs() {
|
||||
echo "Packaging CUDA libraries for BUILD_TYPE=${BUILD_TYPE}..."
|
||||
@@ -284,7 +316,7 @@ package_vulkan_libs() {
|
||||
"/usr/local/lib"
|
||||
)
|
||||
|
||||
# Core Vulkan runtime libraries
|
||||
# Core Vulkan runtime: the loader plus the shader tooling shipped by the SDK.
|
||||
local vulkan_libs=(
|
||||
"libvulkan.so*"
|
||||
"libshaderc_shared.so*"
|
||||
@@ -301,10 +333,63 @@ package_vulkan_libs() {
|
||||
fi
|
||||
done
|
||||
|
||||
# Copy Vulkan ICD files
|
||||
# Bundle the ICD drivers. Rather than hard-code Mesa's (platform- and
|
||||
# version-dependent) driver sonames, treat each installed ICD manifest as
|
||||
# the source of truth: every /usr/share/vulkan/icd.d/*.json names the exact
|
||||
# driver .so it needs in its "library_path". So we copy whatever drivers
|
||||
# the manifests reference (libvulkan_intel/radeon/lvp/... on amd64, the SoC
|
||||
# drivers on arm64, ...) plus each driver's transitive deps, and skip any
|
||||
# manifest whose driver isn't actually installed. The loader picks the
|
||||
# right driver for the GPU at runtime.
|
||||
if [ -d "/usr/share/vulkan/icd.d" ]; then
|
||||
mkdir -p "$TARGET_LIB_DIR/../vulkan/icd.d"
|
||||
cp -arfL /usr/share/vulkan/icd.d/* "$TARGET_LIB_DIR/../vulkan/icd.d/" 2>/dev/null || true
|
||||
local icd_dest="$TARGET_LIB_DIR/../vulkan/icd.d"
|
||||
mkdir -p "$icd_dest"
|
||||
|
||||
local manifest driver driver_base resolved lib_path
|
||||
for manifest in /usr/share/vulkan/icd.d/*.json; do
|
||||
[ -e "$manifest" ] || continue
|
||||
|
||||
# Pull the driver path out of "library_path": "<path-or-soname>".
|
||||
driver=$(sed -nE 's/.*"library_path"[[:space:]]*:[[:space:]]*"([^"]+)".*/\1/p' "$manifest" | head -n1)
|
||||
[ -n "$driver" ] || continue
|
||||
driver_base=$(basename "$driver")
|
||||
|
||||
# Resolve to an absolute path: honour an absolute library_path,
|
||||
# else look in the standard lib dirs, else fall back to ldconfig.
|
||||
resolved=""
|
||||
case "$driver" in
|
||||
/*) [ -e "$driver" ] && resolved="$driver" ;;
|
||||
esac
|
||||
if [ -z "$resolved" ]; then
|
||||
for lib_path in "${vulkan_lib_paths[@]}"; do
|
||||
if [ -e "${lib_path}/${driver_base}" ]; then
|
||||
resolved="${lib_path}/${driver_base}"
|
||||
break
|
||||
fi
|
||||
done
|
||||
fi
|
||||
if [ -z "$resolved" ] && command -v ldconfig >/dev/null 2>&1; then
|
||||
resolved=$(ldconfig -p | awk -v n="$driver_base" '$1 == n { print $NF; exit }')
|
||||
fi
|
||||
|
||||
if [ -z "$resolved" ] || [ ! -e "$resolved" ]; then
|
||||
echo "Vulkan ICD: driver '$driver_base' for $(basename "$manifest") not installed; skipping its manifest" >&2
|
||||
continue
|
||||
fi
|
||||
|
||||
# Bundle the driver + its transitive deps (libdrm, libexpat, and
|
||||
# libLLVM for RADV/lavapipe, ...) so the backend is self-contained
|
||||
# on a runtime base image without Mesa.
|
||||
copy_lib "$resolved"
|
||||
copy_elf_deps "$resolved"
|
||||
|
||||
# Copy the manifest and rewrite its library_path to a bare soname
|
||||
# so the loader resolves our bundled driver via LD_LIBRARY_PATH
|
||||
# (run.sh adds lib/ to it) instead of a host path that won't exist
|
||||
# on the runtime image.
|
||||
cp -arfL "$manifest" "$icd_dest/" 2>/dev/null || true
|
||||
sed -i -E 's#("library_path"[[:space:]]*:[[:space:]]*")[^"]*/#\1#' "$icd_dest/$(basename "$manifest")"
|
||||
done
|
||||
fi
|
||||
|
||||
echo "Vulkan libraries packaged successfully"
|
||||
@@ -345,6 +430,8 @@ package_gpu_libs() {
|
||||
export -f package_gpu_libs
|
||||
export -f copy_lib
|
||||
export -f copy_libs_glob
|
||||
export -f is_core_lib
|
||||
export -f copy_elf_deps
|
||||
export -f package_cuda_libs
|
||||
export -f package_rocm_libs
|
||||
export -f package_intel_libs
|
||||
|
||||
Reference in New Issue
Block a user