mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-22 23:58:25 -04:00
Compare commits
39 Commits
fix/turboq
...
fix/distri
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
969cb850b5 | ||
|
|
6240e16ad8 | ||
|
|
780e720593 | ||
|
|
a72649b486 | ||
|
|
f96df5eb85 | ||
|
|
e14d9ae8e3 | ||
|
|
a97dc3bf57 | ||
|
|
a560329430 | ||
|
|
07b2e4e703 | ||
|
|
f03aacf7e7 | ||
|
|
e8e75aadb6 | ||
|
|
352ebe241d | ||
|
|
c851768163 | ||
|
|
58f29496e5 | ||
|
|
871f5e0b7a | ||
|
|
148fb1c0d3 | ||
|
|
fbb20ed2b3 | ||
|
|
17ea9f93f9 | ||
|
|
4b66c3ad45 | ||
|
|
169ff75633 | ||
|
|
4f89882057 | ||
|
|
f9b47c6eab | ||
|
|
4306b730ed | ||
|
|
71d940f1e0 | ||
|
|
0e2b84d8e3 | ||
|
|
61bf34ea2f | ||
|
|
0b2ae3c6ca | ||
|
|
4735345105 | ||
|
|
7384fd800b | ||
|
|
6942713d85 | ||
|
|
0cf52c44d4 | ||
|
|
0d34cf7cbd | ||
|
|
f0cb02afb8 | ||
|
|
a39e025d64 | ||
|
|
05e8e1e9f4 | ||
|
|
a7f6cc8956 | ||
|
|
f15b9178ec | ||
|
|
959de86761 | ||
|
|
4c234abc2c |
1
.github/workflows/image_build.yml
vendored
1
.github/workflows/image_build.yml
vendored
@@ -106,6 +106,7 @@ jobs:
|
||||
type=ref,event=branch
|
||||
type=semver,pattern={{raw}}
|
||||
type=sha
|
||||
type=raw,value={{branch}}-{{date 'X'}}-{{sha}},enable={{is_default_branch}}
|
||||
flavor: |
|
||||
latest=${{ inputs.tag-latest }}
|
||||
suffix=${{ inputs.tag-suffix }},onlatest=true
|
||||
|
||||
1
.github/workflows/image_merge.yml
vendored
1
.github/workflows/image_merge.yml
vendored
@@ -80,6 +80,7 @@ jobs:
|
||||
type=ref,event=branch
|
||||
type=semver,pattern={{raw}}
|
||||
type=sha
|
||||
type=raw,value={{branch}}-{{date 'X'}}-{{sha}},enable={{is_default_branch}}
|
||||
flavor: |
|
||||
latest=${{ inputs.tag-latest }}
|
||||
suffix=${{ inputs.tag-suffix }},onlatest=true
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -77,3 +77,6 @@ local-backends/
|
||||
tests/e2e-ui/ui-test-server
|
||||
core/http/react-ui/playwright-report/
|
||||
core/http/react-ui/test-results/
|
||||
|
||||
# Local worktrees
|
||||
.worktrees/
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
# ds4 backend Makefile.
|
||||
#
|
||||
# Upstream pin lives below as DS4_VERSION?=2606543be7a8c125a32cee37f5d1d85dc78f2fcf
|
||||
# Upstream pin lives below as DS4_VERSION?=8d576642c39b9a2d782a80159ba84ef5a81c0b81
|
||||
# (.github/bump_deps.sh) can find and update it - matches the
|
||||
# llama-cpp / ik-llama-cpp / turboquant convention.
|
||||
|
||||
DS4_VERSION?=2606543be7a8c125a32cee37f5d1d85dc78f2fcf
|
||||
DS4_VERSION?=8d576642c39b9a2d782a80159ba84ef5a81c0b81
|
||||
DS4_REPO?=https://github.com/antirez/ds4
|
||||
|
||||
CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
IK_LLAMA_VERSION?=11a1fea9e291f12ce2c803a9d7812c30ca806bcf
|
||||
IK_LLAMA_VERSION?=48a55f74e4c6e2aeda363dd386c1ac9170a0af71
|
||||
LLAMA_REPO?=https://github.com/ikawrakow/ik_llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
LLAMA_VERSION?=ad277572619fcfb6ddd38f4c6437283a4b2b8636
|
||||
LLAMA_VERSION?=bb28c1fe246b72276ee1d00ce89306be7b865766
|
||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -517,10 +517,27 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
params.warmup = true;
|
||||
// no_op_offload: disable host tensor op offload (default: false)
|
||||
params.no_op_offload = false;
|
||||
// kv_unified: enable unified KV cache (default: false)
|
||||
params.kv_unified = false;
|
||||
// n_ctx_checkpoints: max context checkpoints per slot (default: 8)
|
||||
params.n_ctx_checkpoints = 8;
|
||||
// kv_unified: enable unified KV cache. Upstream's server auto-enables this
|
||||
// when the slot count is auto (-np <0), bumping n_parallel to 4 alongside.
|
||||
// LocalAI keeps n_parallel=1 by default, which would skip that auto path
|
||||
// and leave kv_unified=false. We flip the default to true here so the
|
||||
// server-side prompt cache (cache_idle_slots) is actually usable on the
|
||||
// single-slot path that LocalAI ships with: without it, idle slots are
|
||||
// never persisted across requests and the prompt cache is dead weight.
|
||||
// Users can opt out with `options: [ "kv_unified:false" ]`.
|
||||
params.kv_unified = true;
|
||||
// n_ctx_checkpoints: max context checkpoints per slot. Match upstream's
|
||||
// default (32); the previous LocalAI-specific 8 was unnecessarily tight
|
||||
// and limits partial-prefix recovery without a clear memory rationale.
|
||||
params.n_ctx_checkpoints = 32;
|
||||
// cache_idle_slots: save and clear idle slot KV to the prompt cache on
|
||||
// task switch. Upstream default is true; the server auto-disables it if
|
||||
// kv_unified=false or cache_ram_mib=0, so flipping kv_unified above is
|
||||
// what actually unlocks it.
|
||||
params.cache_idle_slots = true;
|
||||
// checkpoint_every_nt: create a context checkpoint every N tokens during
|
||||
// prefill (-1 disables). Match upstream's default (8192).
|
||||
params.checkpoint_every_nt = 8192;
|
||||
|
||||
// decode options. Options are in form optname:optvale, or if booleans only optname.
|
||||
for (int i = 0; i < request->options_size(); i++) {
|
||||
@@ -679,7 +696,29 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
try {
|
||||
params.n_ctx_checkpoints = std::stoi(optval_str);
|
||||
} catch (const std::exception& e) {
|
||||
// If conversion fails, keep default value (8)
|
||||
// If conversion fails, keep default value (32)
|
||||
}
|
||||
}
|
||||
|
||||
// --- server-side idle-slot prompt cache toggle (upstream --cache-idle-slots) ---
|
||||
// Saves the slot's KV state into the host-side prompt cache on task
|
||||
// switch so a later request with the same prefix can warm-load it.
|
||||
// Auto-disabled by the server if kv_unified=false or cache_ram=0.
|
||||
} else if (!strcmp(optname, "cache_idle_slots") || !strcmp(optname, "idle_slots_cache")) {
|
||||
if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") {
|
||||
params.cache_idle_slots = true;
|
||||
} else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") {
|
||||
params.cache_idle_slots = false;
|
||||
}
|
||||
|
||||
// --- prefill checkpoint cadence (upstream -cpent / --checkpoint-every-n-tokens) ---
|
||||
// -1 disables checkpointing during prefill.
|
||||
} else if (!strcmp(optname, "checkpoint_every_nt") || !strcmp(optname, "checkpoint_every_n_tokens")) {
|
||||
if (optval != NULL) {
|
||||
try {
|
||||
params.checkpoint_every_nt = std::stoi(optval_str);
|
||||
} catch (const std::exception& e) {
|
||||
// If conversion fails, keep default value (8192)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
|
||||
# Pinned to the HEAD of feature/turboquant-kv-cache on https://github.com/TheTom/llama-cpp-turboquant.
|
||||
# Auto-bumped nightly by .github/workflows/bump_deps.yaml.
|
||||
TURBOQUANT_VERSION?=4c1c3ac09d2dba0aa9a55b94f6c50c41a92f9c8c
|
||||
TURBOQUANT_VERSION?=5aeb2fdbe26cd4c534c6fa15de73cb5749bd0403
|
||||
LLAMA_REPO?=https://github.com/TheTom/llama-cpp-turboquant
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -1,23 +1,30 @@
|
||||
#!/bin/bash
|
||||
# Patch the shared backend/cpp/llama-cpp/grpc-server.cpp *copy* used by the
|
||||
# turboquant build:
|
||||
# turboquant build to account for the gaps between upstream and the fork:
|
||||
#
|
||||
# 1. Augment the kv_cache_types[] allow-list so `LoadModel` accepts the
|
||||
# fork-specific `turbo2` / `turbo3` / `turbo4` cache types.
|
||||
#
|
||||
# Historical context: this script used to also paper over API gaps between the
|
||||
# fork and upstream (flat vs nested `common_params_speculative`, missing
|
||||
# `get_media_marker()`, `ctx_server.impl->model` vs `model_tgt`, and a
|
||||
# LOCALAI_LEGACY_LLAMA_CPP_SPEC compile gate). As of TURBOQUANT_VERSION
|
||||
# 4c1c3ac0 the fork has rebased past ggml-org/llama.cpp#21962, #22397 and
|
||||
# #22838, so the shared grpc-server.cpp compiles unmodified against the fork.
|
||||
# Only the fork-specific KV-cache enum entries remain.
|
||||
# 2. Replace `get_media_marker()` (added upstream in ggml-org/llama.cpp#21962,
|
||||
# server-side random per-instance marker) with the legacy "<__media__>"
|
||||
# literal. The fork branched before that PR, so server-common.cpp has no
|
||||
# get_media_marker symbol. The fork's mtmd_default_marker() still returns
|
||||
# "<__media__>", and Go-side tooling falls back to that sentinel when the
|
||||
# backend does not expose media_marker, so substituting the literal keeps
|
||||
# behavior identical on the turboquant path.
|
||||
# 3. Revert the `common_params_speculative` field references to the
|
||||
# pre-refactor flat layout. Upstream ggml-org/llama.cpp#22397 split the
|
||||
# struct into nested `draft` / `ngram_simple` / `ngram_mod` / etc. members;
|
||||
# the turboquant fork branched before that PR and still exposes the flat
|
||||
# `n_max`, `mparams_dft`, `ngram_size_n`, ... fields. The substitutions
|
||||
# below map the new nested paths back to the legacy flat names so the
|
||||
# shared grpc-server.cpp keeps compiling against the fork's common.h.
|
||||
# Drop this block once the fork rebases past #22397.
|
||||
#
|
||||
# We patch the *copy* sitting in turboquant-<flavor>-build/, never the original
|
||||
# under backend/cpp/llama-cpp/, so the stock llama-cpp build stays compiling
|
||||
# under backend/cpp/llama-cpp/, so the stock llama-cpp build keeps compiling
|
||||
# against vanilla upstream.
|
||||
#
|
||||
# Idempotent: skips the insertion if its marker is already present (so re-runs
|
||||
# Idempotent: skips each insertion if its marker is already present (so re-runs
|
||||
# of the same build dir don't double-insert).
|
||||
|
||||
set -euo pipefail
|
||||
@@ -45,7 +52,7 @@ else
|
||||
awk '
|
||||
/^ GGML_TYPE_Q5_1,$/ && !done {
|
||||
print
|
||||
print " // turboquant fork extras - added by patch-grpc-server.sh"
|
||||
print " // turboquant fork extras — added by patch-grpc-server.sh"
|
||||
print " GGML_TYPE_TURBO2_0,"
|
||||
print " GGML_TYPE_TURBO3_0,"
|
||||
print " GGML_TYPE_TURBO4_0,"
|
||||
@@ -65,4 +72,83 @@ else
|
||||
echo "==> KV allow-list patch OK"
|
||||
fi
|
||||
|
||||
if grep -q 'get_media_marker()' "$SRC"; then
|
||||
echo "==> patching $SRC to replace get_media_marker() with legacy \"<__media__>\" literal"
|
||||
# Only one call site today (ModelMetadata), but replace all occurrences to
|
||||
# stay robust if upstream adds more. Use a temp file to avoid relying on
|
||||
# sed -i portability (the builder image uses GNU sed, but keeping this
|
||||
# consistent with the awk block above).
|
||||
sed 's/get_media_marker()/"<__media__>"/g' "$SRC" > "$SRC.tmp"
|
||||
mv "$SRC.tmp" "$SRC"
|
||||
echo "==> get_media_marker() substitution OK"
|
||||
else
|
||||
echo "==> $SRC has no get_media_marker() call, skipping media-marker patch"
|
||||
fi
|
||||
|
||||
if grep -q 'params\.speculative\.draft\.\|params\.speculative\.ngram_simple\.' "$SRC"; then
|
||||
echo "==> patching $SRC to revert common_params_speculative refs to pre-#22397 flat layout"
|
||||
# Each substitution is the exact post-refactor path → legacy flat field.
|
||||
# Order doesn't matter because the source paths are disjoint, but we keep
|
||||
# the most-specific (mparams.path) first for readability.
|
||||
sed -E \
|
||||
-e 's/params\.speculative\.draft\.mparams\.path/params.speculative.mparams_dft.path/g' \
|
||||
-e 's/params\.speculative\.draft\.n_max/params.speculative.n_max/g' \
|
||||
-e 's/params\.speculative\.draft\.n_min/params.speculative.n_min/g' \
|
||||
-e 's/params\.speculative\.draft\.p_min/params.speculative.p_min/g' \
|
||||
-e 's/params\.speculative\.draft\.p_split/params.speculative.p_split/g' \
|
||||
-e 's/params\.speculative\.draft\.n_gpu_layers/params.speculative.n_gpu_layers/g' \
|
||||
-e 's/params\.speculative\.draft\.n_ctx/params.speculative.n_ctx/g' \
|
||||
-e 's/params\.speculative\.ngram_simple\.size_n/params.speculative.ngram_size_n/g' \
|
||||
-e 's/params\.speculative\.ngram_simple\.size_m/params.speculative.ngram_size_m/g' \
|
||||
-e 's/params\.speculative\.ngram_simple\.min_hits/params.speculative.ngram_min_hits/g' \
|
||||
"$SRC" > "$SRC.tmp"
|
||||
mv "$SRC.tmp" "$SRC"
|
||||
echo "==> speculative field rename OK"
|
||||
else
|
||||
echo "==> $SRC has no post-#22397 speculative field refs, skipping spec rename patch"
|
||||
fi
|
||||
|
||||
# 4. Revert the `ctx_server.impl->model_tgt` rename introduced by upstream
|
||||
# ggml-org/llama.cpp#22838 (parallel drafting). The turboquant fork still
|
||||
# exposes the field as `model` on `server_context_impl`. The two call sites
|
||||
# are in the Rerank and ModelMetadata RPC handlers.
|
||||
if grep -q 'ctx_server\.impl->model_tgt' "$SRC"; then
|
||||
echo "==> patching $SRC to revert ctx_server.impl->model_tgt -> ctx_server.impl->model"
|
||||
sed -E 's/ctx_server\.impl->model_tgt/ctx_server.impl->model/g' "$SRC" > "$SRC.tmp"
|
||||
mv "$SRC.tmp" "$SRC"
|
||||
echo "==> model_tgt rename OK"
|
||||
else
|
||||
echo "==> $SRC has no ctx_server.impl->model_tgt refs, skipping model_tgt rename patch"
|
||||
fi
|
||||
|
||||
# 5. Define LOCALAI_LEGACY_LLAMA_CPP_SPEC at the top of the file so the
|
||||
# grpc-server option parser skips the new option-handler blocks (ngram_mod,
|
||||
# ngram_map_k, ngram_map_k4v, ngram_cache, draft.cache_type_*, draft.cpuparams*,
|
||||
# draft.tensor_buft_overrides) introduced for the post-#22838 layout. Those
|
||||
# blocks reference struct fields that simply do not exist in the fork.
|
||||
if grep -q '^#define LOCALAI_LEGACY_LLAMA_CPP_SPEC' "$SRC"; then
|
||||
echo "==> $SRC already defines LOCALAI_LEGACY_LLAMA_CPP_SPEC, skipping"
|
||||
else
|
||||
echo "==> patching $SRC to define LOCALAI_LEGACY_LLAMA_CPP_SPEC at the top"
|
||||
# Insert the define before the very first `#include` so it precedes all the
|
||||
# speculative-decoding code paths.
|
||||
awk '
|
||||
!done && /^#include/ {
|
||||
print "#define LOCALAI_LEGACY_LLAMA_CPP_SPEC 1"
|
||||
print "// ^ injected by backend/cpp/turboquant/patch-grpc-server.sh"
|
||||
print ""
|
||||
done = 1
|
||||
}
|
||||
{ print }
|
||||
END {
|
||||
if (!done) {
|
||||
print "patch-grpc-server.sh: no #include anchor found to insert LOCALAI_LEGACY_LLAMA_CPP_SPEC" > "/dev/stderr"
|
||||
exit 1
|
||||
}
|
||||
}
|
||||
' "$SRC" > "$SRC.tmp"
|
||||
mv "$SRC.tmp" "$SRC"
|
||||
echo "==> LOCALAI_LEGACY_LLAMA_CPP_SPEC define OK"
|
||||
fi
|
||||
|
||||
echo "==> all patches applied"
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# stablediffusion.cpp (ggml)
|
||||
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
||||
STABLEDIFFUSION_GGML_VERSION?=5b0267e941cade15bd80089d89838795d9f4baa6
|
||||
STABLEDIFFUSION_GGML_VERSION?=3a8788cb7d74f185d6b18688e9563015524ecaf5
|
||||
|
||||
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# whisper.cpp version
|
||||
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
||||
WHISPER_CPP_VERSION?=afa2ea544fb4b0448916b4a31ecd33c8685bd482
|
||||
WHISPER_CPP_VERSION?=8443cf05e3fa8ce1b32348e1bcbcf8fc31f7f3ae
|
||||
SO_TARGET?=libgowhisper.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -233,7 +233,12 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
|
||||
xlog.Info("File stager initialized (HTTP direct transfer)")
|
||||
}
|
||||
// Create RemoteUnloaderAdapter — needed by SmartRouter and startup.go
|
||||
remoteUnloader := nodes.NewRemoteUnloaderAdapter(registry, natsClient)
|
||||
remoteUnloader := nodes.NewRemoteUnloaderAdapter(
|
||||
registry,
|
||||
natsClient,
|
||||
cfg.Distributed.BackendInstallTimeoutOrDefault(),
|
||||
cfg.Distributed.BackendUpgradeTimeoutOrDefault(),
|
||||
)
|
||||
|
||||
// All dependencies ready — build SmartRouter with all options at once
|
||||
var conflictResolver nodes.ConcurrencyConflictResolver
|
||||
|
||||
@@ -17,9 +17,9 @@ import (
|
||||
"github.com/mudler/LocalAI/core/services/jobs"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/storage"
|
||||
"github.com/mudler/LocalAI/pkg/vram"
|
||||
coreStartup "github.com/mudler/LocalAI/core/startup"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
"github.com/mudler/LocalAI/pkg/vram"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/sanitize"
|
||||
@@ -200,7 +200,7 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
nodes.NewDistributedModelManager(options, application.modelLoader, distSvc.Unloader),
|
||||
)
|
||||
application.galleryService.SetBackendManager(
|
||||
nodes.NewDistributedBackendManager(options, application.modelLoader, distSvc.Unloader, distSvc.Registry),
|
||||
nodes.NewDistributedBackendManager(options, application.modelLoader, distSvc.Unloader, distSvc.Registry, application.galleryService),
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -552,6 +552,13 @@ func loadRuntimeSettingsFromFile(options *config.ApplicationConfig) {
|
||||
options.TracingMaxItems = *settings.TracingMaxItems
|
||||
}
|
||||
}
|
||||
if settings.TracingMaxBodyBytes != nil {
|
||||
// Allow the on-disk setting to override the CLI/env default. The
|
||||
// startup default is non-zero (see NewApplicationConfig), so a plain
|
||||
// `== 0` guard like the others would never trigger; we instead respect
|
||||
// any value the file specifies. 0 in the file means "uncapped".
|
||||
options.TracingMaxBodyBytes = *settings.TracingMaxBodyBytes
|
||||
}
|
||||
|
||||
// Branding / whitelabeling. There are no env vars for these — the file is
|
||||
// the only source — so apply unconditionally. Without this block a server
|
||||
|
||||
@@ -39,19 +39,19 @@ type RunCMD struct {
|
||||
LocalaiConfigDir string `env:"LOCALAI_CONFIG_DIR" type:"path" default:"${basepath}/configuration" help:"Directory for dynamic loading of certain configuration files (currently api_keys.json and external_backends.json)" group:"storage"`
|
||||
LocalaiConfigDirPollInterval time.Duration `env:"LOCALAI_CONFIG_DIR_POLL_INTERVAL" help:"Typically the config path picks up changes automatically, but if your system has broken fsnotify events, set this to an interval to poll the LocalAI Config Dir (example: 1m)" group:"storage"`
|
||||
// The alias on this option is there to preserve functionality with the old `--config-file` parameter
|
||||
ModelsConfigFile string `env:"LOCALAI_MODELS_CONFIG_FILE,CONFIG_FILE" aliases:"config-file" help:"YAML file containing a list of model backend configs" group:"storage"`
|
||||
BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"`
|
||||
Galleries string `env:"LOCALAI_GALLERIES,GALLERIES" help:"JSON list of galleries" group:"models" default:"${galleries}"`
|
||||
AutoloadGalleries bool `env:"LOCALAI_AUTOLOAD_GALLERIES,AUTOLOAD_GALLERIES" group:"models" default:"true"`
|
||||
AutoloadBackendGalleries bool `env:"LOCALAI_AUTOLOAD_BACKEND_GALLERIES,AUTOLOAD_BACKEND_GALLERIES" group:"backends" default:"true"`
|
||||
BackendImagesReleaseTag string `env:"LOCALAI_BACKEND_IMAGES_RELEASE_TAG,BACKEND_IMAGES_RELEASE_TAG" help:"Fallback release tag for backend images" group:"backends" default:"latest"`
|
||||
BackendImagesBranchTag string `env:"LOCALAI_BACKEND_IMAGES_BRANCH_TAG,BACKEND_IMAGES_BRANCH_TAG" help:"Fallback branch tag for backend images" group:"backends" default:"master"`
|
||||
BackendDevSuffix string `env:"LOCALAI_BACKEND_DEV_SUFFIX,BACKEND_DEV_SUFFIX" help:"Development suffix for backend images" group:"backends" default:"development"`
|
||||
ModelsConfigFile string `env:"LOCALAI_MODELS_CONFIG_FILE,CONFIG_FILE" aliases:"config-file" help:"YAML file containing a list of model backend configs" group:"storage"`
|
||||
BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"`
|
||||
Galleries string `env:"LOCALAI_GALLERIES,GALLERIES" help:"JSON list of galleries" group:"models" default:"${galleries}"`
|
||||
AutoloadGalleries bool `env:"LOCALAI_AUTOLOAD_GALLERIES,AUTOLOAD_GALLERIES" group:"models" default:"true"`
|
||||
AutoloadBackendGalleries bool `env:"LOCALAI_AUTOLOAD_BACKEND_GALLERIES,AUTOLOAD_BACKEND_GALLERIES" group:"backends" default:"true"`
|
||||
BackendImagesReleaseTag string `env:"LOCALAI_BACKEND_IMAGES_RELEASE_TAG,BACKEND_IMAGES_RELEASE_TAG" help:"Fallback release tag for backend images" group:"backends" default:"latest"`
|
||||
BackendImagesBranchTag string `env:"LOCALAI_BACKEND_IMAGES_BRANCH_TAG,BACKEND_IMAGES_BRANCH_TAG" help:"Fallback branch tag for backend images" group:"backends" default:"master"`
|
||||
BackendDevSuffix string `env:"LOCALAI_BACKEND_DEV_SUFFIX,BACKEND_DEV_SUFFIX" help:"Development suffix for backend images" group:"backends" default:"development"`
|
||||
AutoUpgradeBackends bool `env:"LOCALAI_AUTO_UPGRADE_BACKENDS,AUTO_UPGRADE_BACKENDS" help:"Automatically upgrade backends when new versions are detected" group:"backends" default:"false"`
|
||||
PreferDevelopmentBackends bool `env:"LOCALAI_PREFER_DEV_BACKENDS,PREFER_DEV_BACKENDS" help:"Prefer development backend versions (shows development backends by default in UI)" group:"backends" default:"false"`
|
||||
PreloadModels string `env:"LOCALAI_PRELOAD_MODELS,PRELOAD_MODELS" help:"A List of models to apply in JSON at start" group:"models"`
|
||||
Models []string `env:"LOCALAI_MODELS,MODELS" help:"A List of model configuration URLs to load" group:"models"`
|
||||
PreloadModelsConfig string `env:"LOCALAI_PRELOAD_MODELS_CONFIG,PRELOAD_MODELS_CONFIG" help:"A List of models to apply at startup. Path to a YAML config file" group:"models"`
|
||||
Models []string `env:"LOCALAI_MODELS,MODELS" help:"A List of model configuration URLs to load" group:"models"`
|
||||
PreloadModelsConfig string `env:"LOCALAI_PRELOAD_MODELS_CONFIG,PRELOAD_MODELS_CONFIG" help:"A List of models to apply at startup. Path to a YAML config file" group:"models"`
|
||||
|
||||
F16 bool `name:"f16" env:"LOCALAI_F16,F16" help:"Enable GPU acceleration" group:"performance"`
|
||||
Threads int `env:"LOCALAI_THREADS,THREADS" short:"t" help:"Number of threads used for parallel computation. Usage of the number of physical cores in the system is suggested" group:"performance"`
|
||||
@@ -100,6 +100,7 @@ type RunCMD struct {
|
||||
LoadToMemory []string `env:"LOCALAI_LOAD_TO_MEMORY,LOAD_TO_MEMORY" help:"A list of models to load into memory at startup" group:"models"`
|
||||
EnableTracing bool `env:"LOCALAI_ENABLE_TRACING,ENABLE_TRACING" help:"Enable API tracing" group:"api"`
|
||||
TracingMaxItems int `env:"LOCALAI_TRACING_MAX_ITEMS" default:"1024" help:"Maximum number of traces to keep" group:"api"`
|
||||
TracingMaxBodyBytes int `env:"LOCALAI_TRACING_MAX_BODY_BYTES" default:"65536" help:"Maximum bytes captured per request/response body in the trace buffer (0 = uncapped). Caps memory growth from chatty endpoints like /embeddings." group:"api"`
|
||||
AgentJobRetentionDays int `env:"LOCALAI_AGENT_JOB_RETENTION_DAYS,AGENT_JOB_RETENTION_DAYS" default:"30" help:"Number of days to keep agent job history (default: 30)" group:"api"`
|
||||
OpenResponsesStoreTTL string `env:"LOCALAI_OPEN_RESPONSES_STORE_TTL,OPEN_RESPONSES_STORE_TTL" default:"0" help:"TTL for Open Responses store (e.g., 1h, 30m, 0 = no expiration)" group:"api"`
|
||||
|
||||
@@ -144,16 +145,18 @@ type RunCMD struct {
|
||||
DefaultAPIKeyExpiry string `env:"LOCALAI_DEFAULT_API_KEY_EXPIRY" help:"Default expiry for API keys (e.g. 90d, 1y; empty = no expiry)" group:"auth"`
|
||||
|
||||
// Distributed / Horizontal Scaling
|
||||
Distributed bool `env:"LOCALAI_DISTRIBUTED" default:"false" help:"Enable distributed mode (requires PostgreSQL + NATS)" group:"distributed"`
|
||||
InstanceID string `env:"LOCALAI_INSTANCE_ID" help:"Unique instance ID for distributed mode (auto-generated UUID if empty)" group:"distributed"`
|
||||
NatsURL string `env:"LOCALAI_NATS_URL" help:"NATS server URL (e.g., nats://localhost:4222)" group:"distributed"`
|
||||
StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3-compatible storage endpoint URL (e.g., http://minio:9000)" group:"distributed"`
|
||||
StorageBucket string `env:"LOCALAI_STORAGE_BUCKET" default:"localai" help:"S3 bucket name for object storage" group:"distributed"`
|
||||
StorageRegion string `env:"LOCALAI_STORAGE_REGION" default:"us-east-1" help:"S3 region" group:"distributed"`
|
||||
StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key ID" group:"distributed"`
|
||||
StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret access key" group:"distributed"`
|
||||
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token that backend nodes must provide to register (empty = no auth required)" group:"distributed"`
|
||||
AutoApproveNodes bool `env:"LOCALAI_AUTO_APPROVE_NODES" default:"false" help:"Auto-approve new worker nodes (skip admin approval)" group:"distributed"`
|
||||
Distributed bool `env:"LOCALAI_DISTRIBUTED" default:"false" help:"Enable distributed mode (requires PostgreSQL + NATS)" group:"distributed"`
|
||||
InstanceID string `env:"LOCALAI_INSTANCE_ID" help:"Unique instance ID for distributed mode (auto-generated UUID if empty)" group:"distributed"`
|
||||
NatsURL string `env:"LOCALAI_NATS_URL" help:"NATS server URL (e.g., nats://localhost:4222)" group:"distributed"`
|
||||
StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3-compatible storage endpoint URL (e.g., http://minio:9000)" group:"distributed"`
|
||||
StorageBucket string `env:"LOCALAI_STORAGE_BUCKET" default:"localai" help:"S3 bucket name for object storage" group:"distributed"`
|
||||
StorageRegion string `env:"LOCALAI_STORAGE_REGION" default:"us-east-1" help:"S3 region" group:"distributed"`
|
||||
StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key ID" group:"distributed"`
|
||||
StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret access key" group:"distributed"`
|
||||
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token that backend nodes must provide to register (empty = no auth required)" group:"distributed"`
|
||||
AutoApproveNodes bool `env:"LOCALAI_AUTO_APPROVE_NODES" default:"false" help:"Auto-approve new worker nodes (skip admin approval)" group:"distributed"`
|
||||
BackendInstallTimeout string `env:"LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT" help:"NATS round-trip timeout for backend.install requests sent to worker nodes (default 15m). Increase for slow links pulling multi-GB images." group:"distributed"`
|
||||
BackendUpgradeTimeout string `env:"LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT" help:"NATS round-trip timeout for backend.upgrade requests (default 15m)." group:"distributed"`
|
||||
|
||||
Version bool
|
||||
}
|
||||
@@ -254,6 +257,20 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
if r.StorageSecretKey != "" {
|
||||
opts = append(opts, config.WithStorageSecretKey(r.StorageSecretKey))
|
||||
}
|
||||
if r.BackendInstallTimeout != "" {
|
||||
d, err := time.ParseDuration(r.BackendInstallTimeout)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT %q: %w", r.BackendInstallTimeout, err)
|
||||
}
|
||||
opts = append(opts, config.WithBackendInstallTimeout(d))
|
||||
}
|
||||
if r.BackendUpgradeTimeout != "" {
|
||||
d, err := time.ParseDuration(r.BackendUpgradeTimeout)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT %q: %w", r.BackendUpgradeTimeout, err)
|
||||
}
|
||||
opts = append(opts, config.WithBackendUpgradeTimeout(d))
|
||||
}
|
||||
if r.RegistrationToken != "" {
|
||||
opts = append(opts, config.WithRegistrationToken(r.RegistrationToken))
|
||||
}
|
||||
@@ -273,6 +290,7 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
opts = append(opts, config.EnableTracing)
|
||||
}
|
||||
opts = append(opts, config.WithTracingMaxItems(r.TracingMaxItems))
|
||||
opts = append(opts, config.WithTracingMaxBodyBytes(r.TracingMaxBodyBytes))
|
||||
|
||||
token := ""
|
||||
if r.Peer2Peer || r.Peer2PeerToken != "" {
|
||||
|
||||
@@ -21,6 +21,7 @@ type ApplicationConfig struct {
|
||||
Debug bool
|
||||
EnableTracing bool
|
||||
TracingMaxItems int
|
||||
TracingMaxBodyBytes int // Per-body cap for captured request/response bodies; 0 disables the cap
|
||||
EnableBackendLogging bool
|
||||
GeneratedContentDir string
|
||||
|
||||
@@ -187,6 +188,7 @@ func NewApplicationConfig(o ...AppOption) *ApplicationConfig {
|
||||
LRUEvictionRetryInterval: 1 * time.Second, // Default: 1 second
|
||||
WatchDogInterval: 500 * time.Millisecond, // Default: 500ms
|
||||
TracingMaxItems: 1024,
|
||||
TracingMaxBodyBytes: 64 * 1024, // 64 KiB - caps each request/response body in the trace buffer
|
||||
AgentPool: AgentPoolConfig{
|
||||
Enabled: true,
|
||||
Timeout: "5m",
|
||||
@@ -578,6 +580,12 @@ func WithTracingMaxItems(items int) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
func WithTracingMaxBodyBytes(bytes int) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.TracingMaxBodyBytes = bytes
|
||||
}
|
||||
}
|
||||
|
||||
func WithGeneratedContentDir(generatedContentDir string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.GeneratedContentDir = generatedContentDir
|
||||
@@ -920,6 +928,7 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
|
||||
f16 := o.F16
|
||||
debug := o.Debug
|
||||
tracingMaxItems := o.TracingMaxItems
|
||||
tracingMaxBodyBytes := o.TracingMaxBodyBytes
|
||||
enableTracing := o.EnableTracing
|
||||
enableBackendLogging := o.EnableBackendLogging
|
||||
cors := o.CORS
|
||||
@@ -1008,6 +1017,7 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
|
||||
F16: &f16,
|
||||
Debug: &debug,
|
||||
TracingMaxItems: &tracingMaxItems,
|
||||
TracingMaxBodyBytes: &tracingMaxBodyBytes,
|
||||
EnableTracing: &enableTracing,
|
||||
EnableBackendLogging: &enableBackendLogging,
|
||||
CORS: &cors,
|
||||
@@ -1146,6 +1156,9 @@ func (o *ApplicationConfig) ApplyRuntimeSettings(settings *RuntimeSettings) (req
|
||||
if settings.TracingMaxItems != nil {
|
||||
o.TracingMaxItems = *settings.TracingMaxItems
|
||||
}
|
||||
if settings.TracingMaxBodyBytes != nil {
|
||||
o.TracingMaxBodyBytes = *settings.TracingMaxBodyBytes
|
||||
}
|
||||
if settings.EnableBackendLogging != nil {
|
||||
o.EnableBackendLogging = *settings.EnableBackendLogging
|
||||
}
|
||||
|
||||
@@ -40,7 +40,10 @@ type DistributedConfig struct {
|
||||
// model-row cleanup on MarkUnhealthy / MarkDraining).
|
||||
DisablePerModelHealthCheck bool
|
||||
|
||||
MCPCIJobTimeout time.Duration // MCP CI job execution timeout (default 10m)
|
||||
MCPCIJobTimeout time.Duration // MCP CI job execution timeout (default 10m)
|
||||
|
||||
BackendInstallTimeout time.Duration // NATS round-trip timeout for backend.install (default 15m)
|
||||
BackendUpgradeTimeout time.Duration // NATS round-trip timeout for backend.upgrade (default 15m)
|
||||
|
||||
MaxUploadSize int64 // Maximum upload body size in bytes (default 50 GB)
|
||||
|
||||
@@ -68,13 +71,15 @@ func (c DistributedConfig) Validate() error {
|
||||
}
|
||||
// Check for negative durations
|
||||
for name, d := range map[string]time.Duration{
|
||||
"mcp-tool-timeout": c.MCPToolTimeout,
|
||||
"mcp-discovery-timeout": c.MCPDiscoveryTimeout,
|
||||
"worker-wait-timeout": c.WorkerWaitTimeout,
|
||||
"drain-timeout": c.DrainTimeout,
|
||||
"health-check-interval": c.HealthCheckInterval,
|
||||
"stale-node-threshold": c.StaleNodeThreshold,
|
||||
"mcp-ci-job-timeout": c.MCPCIJobTimeout,
|
||||
"mcp-tool-timeout": c.MCPToolTimeout,
|
||||
"mcp-discovery-timeout": c.MCPDiscoveryTimeout,
|
||||
"worker-wait-timeout": c.WorkerWaitTimeout,
|
||||
"drain-timeout": c.DrainTimeout,
|
||||
"health-check-interval": c.HealthCheckInterval,
|
||||
"stale-node-threshold": c.StaleNodeThreshold,
|
||||
"mcp-ci-job-timeout": c.MCPCIJobTimeout,
|
||||
"backend-install-timeout": c.BackendInstallTimeout,
|
||||
"backend-upgrade-timeout": c.BackendUpgradeTimeout,
|
||||
} {
|
||||
if d < 0 {
|
||||
return fmt.Errorf("%s must not be negative", name)
|
||||
@@ -137,24 +142,48 @@ func WithStorageSecretKey(key string) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
func WithBackendInstallTimeout(d time.Duration) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.BackendInstallTimeout = d
|
||||
}
|
||||
}
|
||||
|
||||
func WithBackendUpgradeTimeout(d time.Duration) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.BackendUpgradeTimeout = d
|
||||
}
|
||||
}
|
||||
|
||||
var EnableAutoApproveNodes = func(o *ApplicationConfig) {
|
||||
o.Distributed.AutoApproveNodes = true
|
||||
}
|
||||
|
||||
// Defaults for distributed timeouts.
|
||||
const (
|
||||
DefaultMCPToolTimeout = 360 * time.Second
|
||||
DefaultMCPDiscoveryTimeout = 60 * time.Second
|
||||
DefaultWorkerWaitTimeout = 5 * time.Minute
|
||||
DefaultDrainTimeout = 30 * time.Second
|
||||
DefaultHealthCheckInterval = 15 * time.Second
|
||||
DefaultStaleNodeThreshold = 60 * time.Second
|
||||
DefaultMCPCIJobTimeout = 10 * time.Minute
|
||||
DefaultMCPToolTimeout = 360 * time.Second
|
||||
DefaultMCPDiscoveryTimeout = 60 * time.Second
|
||||
DefaultWorkerWaitTimeout = 5 * time.Minute
|
||||
DefaultDrainTimeout = 30 * time.Second
|
||||
DefaultHealthCheckInterval = 15 * time.Second
|
||||
DefaultStaleNodeThreshold = 60 * time.Second
|
||||
DefaultMCPCIJobTimeout = 10 * time.Minute
|
||||
DefaultBackendInstallTimeout = 15 * time.Minute
|
||||
DefaultBackendUpgradeTimeout = 15 * time.Minute
|
||||
)
|
||||
|
||||
// DefaultMaxUploadSize is the default maximum upload body size (50 GB).
|
||||
const DefaultMaxUploadSize int64 = 50 << 30
|
||||
|
||||
// BackendInstallTimeoutOrDefault returns the configured timeout or the default.
|
||||
func (c DistributedConfig) BackendInstallTimeoutOrDefault() time.Duration {
|
||||
return cmp.Or(c.BackendInstallTimeout, DefaultBackendInstallTimeout)
|
||||
}
|
||||
|
||||
// BackendUpgradeTimeoutOrDefault returns the configured timeout or the default.
|
||||
func (c DistributedConfig) BackendUpgradeTimeoutOrDefault() time.Duration {
|
||||
return cmp.Or(c.BackendUpgradeTimeout, DefaultBackendUpgradeTimeout)
|
||||
}
|
||||
|
||||
// MCPToolTimeoutOrDefault returns the configured timeout or the default.
|
||||
func (c DistributedConfig) MCPToolTimeoutOrDefault() time.Duration {
|
||||
return cmp.Or(c.MCPToolTimeout, DefaultMCPToolTimeout)
|
||||
|
||||
36
core/config/distributed_config_test.go
Normal file
36
core/config/distributed_config_test.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package config_test
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
)
|
||||
|
||||
var _ = Describe("DistributedConfig backend NATS timeouts", func() {
|
||||
Context("BackendInstallTimeoutOrDefault", func() {
|
||||
It("returns 15 minutes when unset", func() {
|
||||
c := config.DistributedConfig{}
|
||||
Expect(c.BackendInstallTimeoutOrDefault()).To(Equal(15 * time.Minute))
|
||||
})
|
||||
|
||||
It("returns the configured value when set", func() {
|
||||
c := config.DistributedConfig{BackendInstallTimeout: 42 * time.Minute}
|
||||
Expect(c.BackendInstallTimeoutOrDefault()).To(Equal(42 * time.Minute))
|
||||
})
|
||||
})
|
||||
|
||||
Context("BackendUpgradeTimeoutOrDefault", func() {
|
||||
It("returns 15 minutes when unset", func() {
|
||||
c := config.DistributedConfig{}
|
||||
Expect(c.BackendUpgradeTimeoutOrDefault()).To(Equal(15 * time.Minute))
|
||||
})
|
||||
|
||||
It("returns the configured value when set", func() {
|
||||
c := config.DistributedConfig{BackendUpgradeTimeout: 30 * time.Minute}
|
||||
Expect(c.BackendUpgradeTimeoutOrDefault()).To(Equal(30 * time.Minute))
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -38,6 +38,7 @@ type RuntimeSettings struct {
|
||||
Debug *bool `json:"debug,omitempty"`
|
||||
EnableTracing *bool `json:"enable_tracing,omitempty"`
|
||||
TracingMaxItems *int `json:"tracing_max_items,omitempty"`
|
||||
TracingMaxBodyBytes *int `json:"tracing_max_body_bytes,omitempty"` // Per-body cap in bytes; 0 disables the cap
|
||||
EnableBackendLogging *bool `json:"enable_backend_logging,omitempty"`
|
||||
|
||||
// Security/CORS settings
|
||||
|
||||
@@ -28,6 +28,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/services/monitoring"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/quantization"
|
||||
"github.com/mudler/LocalAI/pkg/signals"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
@@ -267,9 +268,12 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||
e.Static("/generated-videos", videoPath)
|
||||
}
|
||||
|
||||
// Initialize usage recording when auth DB is available
|
||||
// Initialize usage recording when auth DB is available, and ensure the
|
||||
// batcher drains its in-memory queue on graceful shutdown so the last
|
||||
// few seconds of usage don't disappear when the process exits.
|
||||
if application.AuthDB() != nil {
|
||||
httpMiddleware.InitUsageRecorder(application.AuthDB())
|
||||
signals.RegisterGracefulTerminationHandler(httpMiddleware.ShutdownUsageRecorder)
|
||||
}
|
||||
|
||||
// Auth is applied to _all_ endpoints. Filtering out endpoints to bypass is
|
||||
@@ -403,7 +407,7 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||
}
|
||||
}
|
||||
routes.RegisterNodeSelfServiceRoutes(e, registry, distCfg.RegistrationToken, distCfg.AutoApproveNodes, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret)
|
||||
routes.RegisterNodeAdminRoutes(e, registry, remoteUnloader, adminMiddleware, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret, application.ApplicationConfig().Distributed.RegistrationToken)
|
||||
routes.RegisterNodeAdminRoutes(e, registry, remoteUnloader, application.GalleryService(), opcache, application.ApplicationConfig(), adminMiddleware, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret, application.ApplicationConfig().Distributed.RegistrationToken)
|
||||
|
||||
// Distributed SSE routes (job progress + agent events via NATS)
|
||||
if d := application.Distributed(); d != nil {
|
||||
|
||||
@@ -38,9 +38,15 @@ func InitDB(databaseURL string) (*gorm.DB, error) {
|
||||
}
|
||||
|
||||
// Backfill: users created before the provider column existed have an empty
|
||||
// provider — treat them as local accounts so the UI can identify them.
|
||||
// provider - treat them as local accounts so the UI can identify them.
|
||||
db.Exec("UPDATE users SET provider = ? WHERE provider = '' OR provider IS NULL", ProviderLocal)
|
||||
|
||||
// Backfill: pre-feature usage_records have no source column. Classify them so the
|
||||
// new per-source aggregators include them.
|
||||
if err := BackfillUsageSource(db); err != nil {
|
||||
return nil, fmt.Errorf("failed to backfill usage source: %w", err)
|
||||
}
|
||||
|
||||
// Create composite index on users(provider, subject) for fast OAuth lookups
|
||||
if err := db.Exec("CREATE INDEX IF NOT EXISTS idx_users_provider_subject ON users(provider, subject)").Error; err != nil {
|
||||
// Ignore error on postgres if index already exists
|
||||
|
||||
@@ -16,8 +16,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
contextKeyUser = "auth_user"
|
||||
contextKeyRole = "auth_role"
|
||||
contextKeyUser = "auth_user"
|
||||
contextKeyRole = "auth_role"
|
||||
contextKeyAPIKey = "auth_apikey"
|
||||
contextKeySource = "auth_source"
|
||||
)
|
||||
|
||||
// Middleware returns an Echo middleware that handles authentication.
|
||||
@@ -75,6 +77,7 @@ func Middleware(db *gorm.DB, appConfig *config.ApplicationConfig) echo.Middlewar
|
||||
}
|
||||
c.Set(contextKeyUser, syntheticUser)
|
||||
c.Set(contextKeyRole, RoleAdmin)
|
||||
c.Set(contextKeySource, UsageSourceLegacy)
|
||||
authenticated = true
|
||||
}
|
||||
}
|
||||
@@ -213,6 +216,20 @@ func GetUserRole(c echo.Context) string {
|
||||
return role
|
||||
}
|
||||
|
||||
// GetAPIKey returns the resolved API key from the echo context, or nil.
|
||||
// Nil for session-cookie and legacy-env-key authentication.
|
||||
func GetAPIKey(c echo.Context) *UserAPIKey {
|
||||
k, _ := c.Get(contextKeyAPIKey).(*UserAPIKey)
|
||||
return k
|
||||
}
|
||||
|
||||
// GetSource returns the request's authentication source: UsageSourceAPIKey,
|
||||
// UsageSourceWeb, UsageSourceLegacy, or empty if no authentication was performed.
|
||||
func GetSource(c echo.Context) string {
|
||||
s, _ := c.Get(contextKeySource).(string)
|
||||
return s
|
||||
}
|
||||
|
||||
// RequireRouteFeature returns a global middleware that checks the user has access
|
||||
// to the feature required by the matched route. It uses the RouteFeatureRegistry
|
||||
// to look up the required feature for each route pattern + HTTP method.
|
||||
@@ -421,47 +438,67 @@ func RequireQuota(db *gorm.DB) echo.MiddlewareFunc {
|
||||
}
|
||||
|
||||
// tryAuthenticate attempts to authenticate the request using the database.
|
||||
//
|
||||
// On success it returns the user and, as a side effect, sets the following
|
||||
// values on the Echo context:
|
||||
// - contextKeySource ("auth_source"): always set, one of UsageSourceWeb /
|
||||
// UsageSourceAPIKey. UsageSourceLegacy is set elsewhere by the parent
|
||||
// Middleware when a legacy env key matches.
|
||||
// - contextKeyAPIKey ("auth_apikey"): set to the resolved *UserAPIKey for
|
||||
// named-key branches (Bearer, x-api-key, xi-api-key, token cookie).
|
||||
// - "_auth_session": session record, used by Middleware to drive cookie
|
||||
// rotation. Only set on the session-cookie branch.
|
||||
//
|
||||
// contextKeyUser and contextKeyRole are populated by the parent Middleware
|
||||
// after this function returns.
|
||||
func tryAuthenticate(c echo.Context, db *gorm.DB, appConfig *config.ApplicationConfig) *User {
|
||||
hmacSecret := appConfig.Auth.APIKeyHMACSecret
|
||||
|
||||
// a. Session cookie
|
||||
// a. Session cookie -> web UI
|
||||
if cookie, err := c.Cookie(sessionCookie); err == nil && cookie.Value != "" {
|
||||
if user, session := ValidateSession(db, cookie.Value, hmacSecret); user != nil {
|
||||
// Store session for rotation check in middleware
|
||||
c.Set("_auth_session", session)
|
||||
c.Set(contextKeySource, UsageSourceWeb)
|
||||
return user
|
||||
}
|
||||
}
|
||||
|
||||
// b. Authorization: Bearer token
|
||||
// b. Authorization: Bearer
|
||||
authHeader := c.Request().Header.Get("Authorization")
|
||||
if strings.HasPrefix(authHeader, "Bearer ") {
|
||||
token := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
|
||||
// Try as session ID first
|
||||
// b1. Session token via Bearer -> still web UI
|
||||
if user, _ := ValidateSession(db, token, hmacSecret); user != nil {
|
||||
c.Set(contextKeySource, UsageSourceWeb)
|
||||
return user
|
||||
}
|
||||
|
||||
// Try as user API key
|
||||
// b2. Named API key
|
||||
if key, err := ValidateAPIKey(db, token, hmacSecret); err == nil {
|
||||
c.Set(contextKeySource, UsageSourceAPIKey)
|
||||
c.Set(contextKeyAPIKey, key)
|
||||
return &key.User
|
||||
}
|
||||
}
|
||||
|
||||
// c. x-api-key / xi-api-key headers
|
||||
// c. x-api-key / xi-api-key -> named API key
|
||||
for _, header := range []string{"x-api-key", "xi-api-key"} {
|
||||
if key := c.Request().Header.Get(header); key != "" {
|
||||
if apiKey, err := ValidateAPIKey(db, key, hmacSecret); err == nil {
|
||||
if k := c.Request().Header.Get(header); k != "" {
|
||||
if apiKey, err := ValidateAPIKey(db, k, hmacSecret); err == nil {
|
||||
c.Set(contextKeySource, UsageSourceAPIKey)
|
||||
c.Set(contextKeyAPIKey, apiKey)
|
||||
return &apiKey.User
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// d. token cookie (legacy)
|
||||
// d. token cookie -> named API key
|
||||
if cookie, err := c.Cookie("token"); err == nil && cookie.Value != "" {
|
||||
// Try as user API key
|
||||
if key, err := ValidateAPIKey(db, cookie.Value, hmacSecret); err == nil {
|
||||
c.Set(contextKeySource, UsageSourceAPIKey)
|
||||
c.Set(contextKeyAPIKey, key)
|
||||
return &key.User
|
||||
}
|
||||
}
|
||||
|
||||
@@ -303,4 +303,122 @@ var _ = Describe("Auth Middleware", func() {
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
Describe("auth context plumbing for usage source", func() {
|
||||
// probeApp builds a minimal echo app with the auth middleware and a single
|
||||
// "/probe" route that captures the user, source, and apikey from context.
|
||||
type probe struct {
|
||||
user *auth.User
|
||||
source string
|
||||
key *auth.UserAPIKey
|
||||
}
|
||||
probeApp := func(db *gorm.DB, appConfig *config.ApplicationConfig, p *probe) *echo.Echo {
|
||||
e := echo.New()
|
||||
e.Use(auth.Middleware(db, appConfig))
|
||||
e.GET("/probe", func(c echo.Context) error {
|
||||
p.user = auth.GetUser(c)
|
||||
p.source = auth.GetSource(c)
|
||||
p.key = auth.GetAPIKey(c)
|
||||
return c.NoContent(http.StatusOK)
|
||||
})
|
||||
return e
|
||||
}
|
||||
|
||||
It("session cookie sets source=web, apikey=nil", func() {
|
||||
db := testDB()
|
||||
appConfig := config.NewApplicationConfig()
|
||||
user := createTestUser(db, "alice@example.com", auth.RoleUser, auth.ProviderLocal)
|
||||
token := createTestSession(db, user.ID)
|
||||
|
||||
var p probe
|
||||
app := probeApp(db, appConfig, &p)
|
||||
rec := doRequest(app, http.MethodGet, "/probe", withSessionCookie(token))
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(p.user).ToNot(BeNil())
|
||||
Expect(p.user.ID).To(Equal(user.ID))
|
||||
Expect(p.source).To(Equal(auth.UsageSourceWeb))
|
||||
Expect(p.key).To(BeNil())
|
||||
})
|
||||
|
||||
It("Bearer session token sets source=web, apikey=nil", func() {
|
||||
db := testDB()
|
||||
appConfig := config.NewApplicationConfig()
|
||||
user := createTestUser(db, "alice@example.com", auth.RoleUser, auth.ProviderLocal)
|
||||
token := createTestSession(db, user.ID)
|
||||
|
||||
var p probe
|
||||
app := probeApp(db, appConfig, &p)
|
||||
rec := doRequest(app, http.MethodGet, "/probe", withBearerToken(token))
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(p.user).ToNot(BeNil())
|
||||
Expect(p.user.ID).To(Equal(user.ID))
|
||||
Expect(p.source).To(Equal(auth.UsageSourceWeb))
|
||||
Expect(p.key).To(BeNil())
|
||||
})
|
||||
|
||||
It("Bearer API key sets source=apikey and exposes the resolved *UserAPIKey", func() {
|
||||
db := testDB()
|
||||
appConfig := config.NewApplicationConfig()
|
||||
user := createTestUser(db, "alice@example.com", auth.RoleUser, auth.ProviderLocal)
|
||||
plaintext, key, err := auth.CreateAPIKey(db, user.ID, "ci", auth.RoleUser, appConfig.Auth.APIKeyHMACSecret, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var p probe
|
||||
app := probeApp(db, appConfig, &p)
|
||||
rec := doRequest(app, http.MethodGet, "/probe", withBearerToken(plaintext))
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(p.source).To(Equal(auth.UsageSourceAPIKey))
|
||||
Expect(p.key).ToNot(BeNil())
|
||||
Expect(p.key.ID).To(Equal(key.ID))
|
||||
})
|
||||
|
||||
It("x-api-key header sets source=apikey", func() {
|
||||
db := testDB()
|
||||
appConfig := config.NewApplicationConfig()
|
||||
user := createTestUser(db, "alice@example.com", auth.RoleUser, auth.ProviderLocal)
|
||||
plaintext, _, err := auth.CreateAPIKey(db, user.ID, "ci", auth.RoleUser, appConfig.Auth.APIKeyHMACSecret, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var p probe
|
||||
app := probeApp(db, appConfig, &p)
|
||||
rec := doRequest(app, http.MethodGet, "/probe", withXApiKey(plaintext))
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(p.source).To(Equal(auth.UsageSourceAPIKey))
|
||||
Expect(p.key).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("token cookie sets source=apikey", func() {
|
||||
db := testDB()
|
||||
appConfig := config.NewApplicationConfig()
|
||||
user := createTestUser(db, "alice@example.com", auth.RoleUser, auth.ProviderLocal)
|
||||
plaintext, _, err := auth.CreateAPIKey(db, user.ID, "ci", auth.RoleUser, appConfig.Auth.APIKeyHMACSecret, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var p probe
|
||||
app := probeApp(db, appConfig, &p)
|
||||
rec := doRequest(app, http.MethodGet, "/probe", withTokenCookie(plaintext))
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(p.source).To(Equal(auth.UsageSourceAPIKey))
|
||||
Expect(p.key).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("legacy env key sets source=legacy, apikey=nil", func() {
|
||||
db := testDB()
|
||||
appConfig := config.NewApplicationConfig()
|
||||
appConfig.ApiKeys = []string{"legacy-secret"}
|
||||
|
||||
var p probe
|
||||
app := probeApp(db, appConfig, &p)
|
||||
rec := doRequest(app, http.MethodGet, "/probe", withBearerToken("legacy-secret"))
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(p.source).To(Equal(auth.UsageSourceLegacy))
|
||||
Expect(p.key).To(BeNil())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -5,14 +5,31 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Source classification for a UsageRecord.
|
||||
const (
|
||||
UsageSourceAPIKey = "apikey" // request authenticated with a named UserAPIKey
|
||||
UsageSourceWeb = "web" // request authenticated with a session cookie (web UI)
|
||||
UsageSourceLegacy = "legacy" // request authenticated with an env-configured legacy key
|
||||
)
|
||||
|
||||
// UsageRecord represents a single API request's token usage.
|
||||
type UsageRecord struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement"`
|
||||
UserID string `gorm:"size:36;index:idx_usage_user_time"`
|
||||
UserName string `gorm:"size:255"`
|
||||
ID uint `gorm:"primaryKey;autoIncrement"`
|
||||
UserID string `gorm:"size:36;index:idx_usage_user_time"`
|
||||
UserName string `gorm:"size:255"`
|
||||
|
||||
// Source classifies how the request authenticated. One of UsageSource* constants.
|
||||
// Empty for pre-feature rows until the InitDB backfill runs.
|
||||
Source string `gorm:"size:16;index:idx_usage_source"`
|
||||
// APIKeyID is the UserAPIKey.ID when Source == UsageSourceAPIKey. Nil otherwise.
|
||||
APIKeyID *string `gorm:"size:36;index:idx_usage_apikey"`
|
||||
// APIKeyName is a snapshot of UserAPIKey.Name at write time. Survives key deletion.
|
||||
APIKeyName string `gorm:"size:255"`
|
||||
|
||||
Model string `gorm:"size:255;index"`
|
||||
Endpoint string `gorm:"size:255"`
|
||||
PromptTokens int64
|
||||
@@ -30,9 +47,12 @@ func RecordUsage(db *gorm.DB, record *UsageRecord) error {
|
||||
// UsageBucket is an aggregated time bucket for the dashboard.
|
||||
type UsageBucket struct {
|
||||
Bucket string `json:"bucket"`
|
||||
Model string `json:"model"`
|
||||
Model string `json:"model,omitempty"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
UserName string `json:"user_name,omitempty"`
|
||||
Source string `json:"source,omitempty"`
|
||||
APIKeyID string `json:"api_key_id,omitempty"`
|
||||
APIKeyName string `json:"api_key_name,omitempty"`
|
||||
PromptTokens int64 `json:"prompt_tokens"`
|
||||
CompletionTokens int64 `json:"completion_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
@@ -119,6 +139,28 @@ func GetUserUsage(db *gorm.DB, userID, period string) ([]UsageBucket, error) {
|
||||
return buckets, nil
|
||||
}
|
||||
|
||||
// BackfillUsageSource sets the Source column on pre-feature usage rows.
|
||||
// Idempotent: only touches rows where source is NULL or empty.
|
||||
// - rows whose user_id == "legacy-api-key" -> UsageSourceLegacy
|
||||
// - everything else -> UsageSourceWeb
|
||||
func BackfillUsageSource(db *gorm.DB) error {
|
||||
// Legacy first (more specific predicate)
|
||||
if err := db.Exec(
|
||||
`UPDATE usage_records SET source = ? WHERE (source IS NULL OR source = '') AND user_id = ?`,
|
||||
UsageSourceLegacy, "legacy-api-key",
|
||||
).Error; err != nil {
|
||||
return fmt.Errorf("backfill legacy usage source: %w", err)
|
||||
}
|
||||
// Everything else -> web
|
||||
if err := db.Exec(
|
||||
`UPDATE usage_records SET source = ? WHERE (source IS NULL OR source = '')`,
|
||||
UsageSourceWeb,
|
||||
).Error; err != nil {
|
||||
return fmt.Errorf("backfill web usage source: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllUsage returns aggregated usage for all users (admin). Optional userID filter.
|
||||
func GetAllUsage(db *gorm.DB, period, userID string) ([]UsageBucket, error) {
|
||||
sqlite := isSQLiteDB(db)
|
||||
@@ -149,3 +191,257 @@ func GetAllUsage(db *gorm.DB, period, userID string) ([]UsageBucket, error) {
|
||||
}
|
||||
return buckets, nil
|
||||
}
|
||||
|
||||
// TotalsEntry is a token+request roll-up.
|
||||
type TotalsEntry struct {
|
||||
Tokens int64 `json:"tokens"`
|
||||
Requests int64 `json:"requests"`
|
||||
}
|
||||
|
||||
// KeyTotal is the per-key roll-up returned by sources endpoints. UserID and
|
||||
// UserName are snapshotted from the UsageRecord so revoked-and-deleted keys
|
||||
// still carry their owner attribution in admin views.
|
||||
type KeyTotal struct {
|
||||
APIKeyID string `json:"api_key_id"`
|
||||
APIKeyName string `json:"api_key_name"`
|
||||
UserID string `json:"user_id"`
|
||||
UserName string `json:"user_name"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
Requests int64 `json:"requests"`
|
||||
LastUsed time.Time `json:"last_used"`
|
||||
}
|
||||
|
||||
// UserSourceTotal is a per-(user, source) roll-up for sources that don't carry
|
||||
// a named API key identity (web, legacy). It exists so admin views can show
|
||||
// which user generated each block of Web UI / legacy traffic; the per-apikey
|
||||
// breakdown for source=apikey already lives in KeyTotal.
|
||||
type UserSourceTotal struct {
|
||||
Source string `json:"source"`
|
||||
UserID string `json:"user_id"`
|
||||
UserName string `json:"user_name"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
Requests int64 `json:"requests"`
|
||||
}
|
||||
|
||||
// SourceTotals summarises a per-source breakdown.
|
||||
type SourceTotals struct {
|
||||
BySource map[string]TotalsEntry `json:"by_source"`
|
||||
ByKey []KeyTotal `json:"by_key"` // server-sorted desc by tokens, capped
|
||||
ByUserSource []UserSourceTotal `json:"by_user_source,omitempty"` // populated only when includeLegacy=true
|
||||
GrandTotal TotalsEntry `json:"grand_total"`
|
||||
}
|
||||
|
||||
const maxKeyTotals = 200
|
||||
|
||||
// GetUserUsageBySource returns per-source aggregated usage for one user. Legacy
|
||||
// is excluded by design (visible to admins only via the admin variant).
|
||||
func GetUserUsageBySource(db *gorm.DB, userID, period string) ([]UsageBucket, SourceTotals, error) {
|
||||
sqlite := isSQLiteDB(db)
|
||||
since, dateFmt := periodToWindow(period, sqlite)
|
||||
bucketExpr := fmt.Sprintf("%s as bucket", dateFmt)
|
||||
|
||||
query := db.Model(&UsageRecord{}).
|
||||
Select(bucketExpr+", source, COALESCE(api_key_id, '') as api_key_id, api_key_name, "+
|
||||
"SUM(prompt_tokens) as prompt_tokens, "+
|
||||
"SUM(completion_tokens) as completion_tokens, "+
|
||||
"SUM(total_tokens) as total_tokens, "+
|
||||
"COUNT(*) as request_count").
|
||||
Where("user_id = ?", userID).
|
||||
Where("source <> ?", UsageSourceLegacy).
|
||||
Group("bucket, source, api_key_id, api_key_name").
|
||||
Order("bucket ASC")
|
||||
|
||||
if !since.IsZero() {
|
||||
query = query.Where("created_at >= ?", since)
|
||||
}
|
||||
|
||||
var buckets []UsageBucket
|
||||
if err := query.Find(&buckets).Error; err != nil {
|
||||
return nil, SourceTotals{}, err
|
||||
}
|
||||
|
||||
totals := computeSourceTotals(db, userID, "", since, false)
|
||||
return buckets, totals, nil
|
||||
}
|
||||
|
||||
// computeSourceTotals rolls up by_source / by_key / grand_total.
|
||||
// userID/apiKeyID are optional filters. includeLegacy controls whether the
|
||||
// legacy bucket is exposed (admin-only).
|
||||
func computeSourceTotals(db *gorm.DB, userID, apiKeyID string, since time.Time, includeLegacy bool) SourceTotals {
|
||||
totals := SourceTotals{BySource: map[string]TotalsEntry{}}
|
||||
|
||||
bySourceQ := db.Model(&UsageRecord{}).
|
||||
Select("source, SUM(total_tokens) as tokens, COUNT(*) as requests").
|
||||
Group("source")
|
||||
bySourceQ = applyFilters(bySourceQ, userID, apiKeyID, since, includeLegacy)
|
||||
|
||||
var bySourceRows []struct {
|
||||
Source string
|
||||
Tokens int64
|
||||
Requests int64
|
||||
}
|
||||
if err := bySourceQ.Scan(&bySourceRows).Error; err != nil {
|
||||
xlog.Warn("computeSourceTotals: by-source Scan failed", "error", err)
|
||||
return totals
|
||||
}
|
||||
for _, r := range bySourceRows {
|
||||
totals.BySource[r.Source] = TotalsEntry{Tokens: r.Tokens, Requests: r.Requests}
|
||||
totals.GrandTotal.Tokens += r.Tokens
|
||||
totals.GrandTotal.Requests += r.Requests
|
||||
}
|
||||
|
||||
byKeyQ := db.Model(&UsageRecord{}).
|
||||
Select("COALESCE(api_key_id, '') as api_key_id, api_key_name, "+
|
||||
"user_id, user_name, "+
|
||||
"SUM(total_tokens) as tokens, COUNT(*) as requests, MAX(created_at) as last_used").
|
||||
Where("api_key_id IS NOT NULL AND api_key_id <> ''").
|
||||
Group("api_key_id, api_key_name, user_id, user_name").
|
||||
Order("tokens DESC").
|
||||
Limit(maxKeyTotals)
|
||||
byKeyQ = applyFilters(byKeyQ, userID, apiKeyID, since, includeLegacy)
|
||||
|
||||
// Iterate Rows() manually because MAX(created_at) is returned as a string by
|
||||
// the SQLite driver, and Go's database/sql refuses to scan that into
|
||||
// *time.Time. Postgres returns a proper timestamp. We accept both shapes
|
||||
// via a Rows.Scan into a string column, then parse uniformly.
|
||||
rows, err := byKeyQ.Rows()
|
||||
if err != nil {
|
||||
xlog.Warn("computeSourceTotals: by-key Rows() failed", "error", err)
|
||||
} else {
|
||||
defer func() { _ = rows.Close() }()
|
||||
out := make([]KeyTotal, 0)
|
||||
for rows.Next() {
|
||||
var (
|
||||
apiKeyID, apiKeyName, userIDCol, userName, lastUsedRaw string
|
||||
tokens, requests int64
|
||||
)
|
||||
if scanErr := rows.Scan(&apiKeyID, &apiKeyName, &userIDCol, &userName, &tokens, &requests, &lastUsedRaw); scanErr != nil {
|
||||
continue
|
||||
}
|
||||
out = append(out, KeyTotal{
|
||||
APIKeyID: apiKeyID,
|
||||
APIKeyName: apiKeyName,
|
||||
UserID: userIDCol,
|
||||
UserName: userName,
|
||||
Tokens: tokens,
|
||||
Requests: requests,
|
||||
LastUsed: parseLastUsedString(lastUsedRaw),
|
||||
})
|
||||
}
|
||||
if rerr := rows.Err(); rerr != nil {
|
||||
xlog.Warn("computeSourceTotals: by-key rows iteration failed", "error", rerr)
|
||||
}
|
||||
totals.ByKey = out
|
||||
}
|
||||
|
||||
// by_user_source: only populated for admin callers (includeLegacy=true) so
|
||||
// they can attribute Web UI / legacy traffic to specific users. Per-apikey
|
||||
// rows already carry user info via KeyTotal above, so this query only
|
||||
// covers source != apikey.
|
||||
if includeLegacy {
|
||||
byUserSourceQ := db.Model(&UsageRecord{}).
|
||||
Select("source, user_id, user_name, "+
|
||||
"SUM(total_tokens) as tokens, COUNT(*) as requests").
|
||||
Where("source <> ?", UsageSourceAPIKey).
|
||||
Group("source, user_id, user_name").
|
||||
Order("tokens DESC")
|
||||
byUserSourceQ = applyFilters(byUserSourceQ, userID, apiKeyID, since, includeLegacy)
|
||||
|
||||
var byUserSourceRows []UserSourceTotal
|
||||
if scanErr := byUserSourceQ.Scan(&byUserSourceRows).Error; scanErr != nil {
|
||||
xlog.Warn("computeSourceTotals: by-user-source Scan failed", "error", scanErr)
|
||||
} else {
|
||||
totals.ByUserSource = byUserSourceRows
|
||||
}
|
||||
}
|
||||
|
||||
return totals
|
||||
}
|
||||
|
||||
// parseLastUsedString converts the textual MAX(created_at) value returned by
|
||||
// SQLite (or any driver that surfaces the timestamp as a string) into a
|
||||
// time.Time. Returns the zero time on parse failure.
|
||||
func parseLastUsedString(s string) time.Time {
|
||||
if s == "" {
|
||||
return time.Time{}
|
||||
}
|
||||
// GORM's SQLite driver emits Go's default time formatting. Try the formats
|
||||
// it commonly produces, falling back to RFC3339Nano.
|
||||
layouts := []string{
|
||||
"2006-01-02 15:04:05.999999999 -0700 MST",
|
||||
"2006-01-02 15:04:05.999999999-07:00",
|
||||
"2006-01-02 15:04:05.999999999",
|
||||
"2006-01-02 15:04:05",
|
||||
time.RFC3339Nano,
|
||||
time.RFC3339,
|
||||
}
|
||||
for _, layout := range layouts {
|
||||
if t, err := time.Parse(layout, s); err == nil {
|
||||
return t
|
||||
}
|
||||
}
|
||||
xlog.Warn("parseLastUsedString: unrecognised format", "value", s)
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
// GetAllUsageBySource is the admin variant of GetUserUsageBySource.
|
||||
// Optional filters: userID and apiKeyID. Legacy is included.
|
||||
// truncated == true iff the per-key roll-up was capped at maxKeyTotals.
|
||||
func GetAllUsageBySource(db *gorm.DB, period, userID, apiKeyID string) ([]UsageBucket, SourceTotals, bool, error) {
|
||||
sqlite := isSQLiteDB(db)
|
||||
since, dateFmt := periodToWindow(period, sqlite)
|
||||
bucketExpr := fmt.Sprintf("%s as bucket", dateFmt)
|
||||
|
||||
query := db.Model(&UsageRecord{}).
|
||||
Select(bucketExpr+", source, COALESCE(api_key_id, '') as api_key_id, api_key_name, "+
|
||||
"user_id, user_name, "+
|
||||
"SUM(prompt_tokens) as prompt_tokens, "+
|
||||
"SUM(completion_tokens) as completion_tokens, "+
|
||||
"SUM(total_tokens) as total_tokens, "+
|
||||
"COUNT(*) as request_count").
|
||||
Group("bucket, source, api_key_id, api_key_name, user_id, user_name").
|
||||
Order("bucket ASC")
|
||||
|
||||
query = applyFilters(query, userID, apiKeyID, since, true)
|
||||
|
||||
var buckets []UsageBucket
|
||||
if err := query.Find(&buckets).Error; err != nil {
|
||||
return nil, SourceTotals{}, false, err
|
||||
}
|
||||
|
||||
totals := computeSourceTotals(db, userID, apiKeyID, since, true)
|
||||
|
||||
// Count distinct api_key_ids matching the filters. If > maxKeyTotals,
|
||||
// the by_key slice was capped and we signal truncation to the caller.
|
||||
truncated := false
|
||||
var distinct int64
|
||||
countQ := applyFilters(
|
||||
db.Model(&UsageRecord{}).
|
||||
Distinct("api_key_id").
|
||||
Where("api_key_id IS NOT NULL AND api_key_id <> ''"),
|
||||
userID, apiKeyID, since, true,
|
||||
)
|
||||
if err := countQ.Count(&distinct).Error; err != nil {
|
||||
xlog.Warn("GetAllUsageBySource: distinct api_key_id count failed", "error", err)
|
||||
} else {
|
||||
truncated = distinct > maxKeyTotals
|
||||
}
|
||||
|
||||
return buckets, totals, truncated, nil
|
||||
}
|
||||
|
||||
func applyFilters(q *gorm.DB, userID, apiKeyID string, since time.Time, includeLegacy bool) *gorm.DB {
|
||||
if userID != "" {
|
||||
q = q.Where("user_id = ?", userID)
|
||||
}
|
||||
if apiKeyID != "" {
|
||||
q = q.Where("api_key_id = ?", apiKeyID)
|
||||
}
|
||||
if !since.IsZero() {
|
||||
q = q.Where("created_at >= ?", since)
|
||||
}
|
||||
if !includeLegacy {
|
||||
q = q.Where("source <> ?", UsageSourceLegacy)
|
||||
}
|
||||
return q
|
||||
}
|
||||
|
||||
@@ -3,11 +3,13 @@
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var _ = Describe("Usage", func() {
|
||||
@@ -158,4 +160,275 @@ var _ = Describe("Usage", func() {
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Usage source backfill", func() {
|
||||
It("backfills 'web' for pre-feature rows", func() {
|
||||
db := testDB()
|
||||
|
||||
rawDB, err := db.DB()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = rawDB.Exec(
|
||||
`INSERT INTO usage_records (user_id, source, model, created_at, total_tokens, prompt_tokens, completion_tokens, duration) VALUES (?, '', ?, ?, 0, 0, 0, 0)`,
|
||||
"user-x", "gpt-4", time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(auth.BackfillUsageSource(db)).To(Succeed())
|
||||
|
||||
var loaded auth.UsageRecord
|
||||
Expect(db.Where("user_id = ?", "user-x").First(&loaded).Error).To(Succeed())
|
||||
Expect(loaded.Source).To(Equal(auth.UsageSourceWeb))
|
||||
})
|
||||
|
||||
It("backfills 'legacy' for pre-feature rows with legacy-api-key user_id", func() {
|
||||
db := testDB()
|
||||
|
||||
rawDB, err := db.DB()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = rawDB.Exec(
|
||||
`INSERT INTO usage_records (user_id, source, model, created_at, total_tokens, prompt_tokens, completion_tokens, duration) VALUES (?, '', ?, ?, 0, 0, 0, 0)`,
|
||||
"legacy-api-key", "gpt-4", time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(auth.BackfillUsageSource(db)).To(Succeed())
|
||||
|
||||
var loaded auth.UsageRecord
|
||||
Expect(db.Where("user_id = ?", "legacy-api-key").First(&loaded).Error).To(Succeed())
|
||||
Expect(loaded.Source).To(Equal(auth.UsageSourceLegacy))
|
||||
})
|
||||
|
||||
It("is idempotent on re-run", func() {
|
||||
db := testDB()
|
||||
Expect(auth.BackfillUsageSource(db)).To(Succeed())
|
||||
Expect(auth.BackfillUsageSource(db)).To(Succeed())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("UsageRecord with source fields", func() {
|
||||
It("persists Source, APIKeyID, APIKeyName", func() {
|
||||
db := testDB()
|
||||
keyID := "key-uuid-1"
|
||||
record := &auth.UsageRecord{
|
||||
UserID: "user-1",
|
||||
UserName: "Test User",
|
||||
Source: auth.UsageSourceAPIKey,
|
||||
APIKeyID: &keyID,
|
||||
APIKeyName: "ci-runner",
|
||||
Model: "gpt-4",
|
||||
Endpoint: "/v1/chat/completions",
|
||||
TotalTokens: 150,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
Expect(auth.RecordUsage(db, record)).To(Succeed())
|
||||
|
||||
var loaded auth.UsageRecord
|
||||
Expect(db.First(&loaded, record.ID).Error).To(Succeed())
|
||||
Expect(loaded.Source).To(Equal(auth.UsageSourceAPIKey))
|
||||
Expect(loaded.APIKeyID).ToNot(BeNil())
|
||||
Expect(*loaded.APIKeyID).To(Equal("key-uuid-1"))
|
||||
Expect(loaded.APIKeyName).To(Equal("ci-runner"))
|
||||
})
|
||||
|
||||
It("allows nil APIKeyID for web/legacy sources", func() {
|
||||
db := testDB()
|
||||
record := &auth.UsageRecord{
|
||||
UserID: "user-1",
|
||||
Source: auth.UsageSourceWeb,
|
||||
Model: "gpt-4",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
Expect(auth.RecordUsage(db, record)).To(Succeed())
|
||||
|
||||
var loaded auth.UsageRecord
|
||||
Expect(db.First(&loaded, record.ID).Error).To(Succeed())
|
||||
Expect(loaded.Source).To(Equal(auth.UsageSourceWeb))
|
||||
Expect(loaded.APIKeyID).To(BeNil())
|
||||
Expect(loaded.APIKeyName).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("GetUserUsageBySource", func() {
|
||||
insert := func(db *gorm.DB, userID, source, keyID, keyName string, tokens int64, when time.Time) {
|
||||
rec := &auth.UsageRecord{
|
||||
UserID: userID,
|
||||
Source: source,
|
||||
Model: "gpt-4",
|
||||
TotalTokens: tokens,
|
||||
CreatedAt: when,
|
||||
}
|
||||
if keyID != "" {
|
||||
rec.APIKeyID = &keyID
|
||||
rec.APIKeyName = keyName
|
||||
}
|
||||
Expect(auth.RecordUsage(db, rec)).To(Succeed())
|
||||
}
|
||||
|
||||
It("returns only the caller's rows, never legacy", func() {
|
||||
db := testDB()
|
||||
now := time.Now()
|
||||
insert(db, "alice", auth.UsageSourceAPIKey, "k1", "ci", 100, now)
|
||||
insert(db, "alice", auth.UsageSourceWeb, "", "", 50, now)
|
||||
insert(db, "alice", auth.UsageSourceLegacy, "", "", 30, now)
|
||||
insert(db, "bob", auth.UsageSourceAPIKey, "k2", "bobk", 90, now)
|
||||
|
||||
buckets, totals, err := auth.GetUserUsageBySource(db, "alice", "month")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
for _, b := range buckets {
|
||||
Expect(b.UserID).To(Or(BeEmpty(), Equal("alice")))
|
||||
Expect(b.Source).ToNot(Equal(auth.UsageSourceLegacy))
|
||||
}
|
||||
|
||||
Expect(totals.GrandTotal.Tokens).To(Equal(int64(150)))
|
||||
Expect(totals.BySource[auth.UsageSourceAPIKey].Tokens).To(Equal(int64(100)))
|
||||
Expect(totals.BySource[auth.UsageSourceWeb].Tokens).To(Equal(int64(50)))
|
||||
_, hasLegacy := totals.BySource[auth.UsageSourceLegacy]
|
||||
Expect(hasLegacy).To(BeFalse())
|
||||
})
|
||||
|
||||
It("snapshots survive key deletion", func() {
|
||||
db := testDB()
|
||||
now := time.Now()
|
||||
insert(db, "alice", auth.UsageSourceAPIKey, "deleted-key", "old-name", 42, now)
|
||||
_, totals, err := auth.GetUserUsageBySource(db, "alice", "month")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(totals.ByKey).To(HaveLen(1))
|
||||
Expect(totals.ByKey[0].APIKeyName).To(Equal("old-name"))
|
||||
Expect(totals.ByKey[0].APIKeyID).To(Equal("deleted-key"))
|
||||
Expect(totals.ByKey[0].LastUsed).ToNot(BeZero())
|
||||
Expect(totals.ByKey[0].LastUsed).To(BeTemporally("~", now, 2*time.Second))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("GetAllUsageBySource", func() {
|
||||
insert := func(db *gorm.DB, userID, source, keyID string, tokens int64) {
|
||||
rec := &auth.UsageRecord{
|
||||
UserID: userID,
|
||||
Source: source,
|
||||
Model: "gpt-4",
|
||||
TotalTokens: tokens,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if keyID != "" {
|
||||
rec.APIKeyID = &keyID
|
||||
rec.APIKeyName = "name-" + keyID
|
||||
}
|
||||
Expect(auth.RecordUsage(db, rec)).To(Succeed())
|
||||
}
|
||||
|
||||
It("includes legacy for admins", func() {
|
||||
db := testDB()
|
||||
insert(db, "alice", auth.UsageSourceAPIKey, "k1", 10)
|
||||
insert(db, "legacy-api-key", auth.UsageSourceLegacy, "", 5)
|
||||
|
||||
_, totals, _, err := auth.GetAllUsageBySource(db, "month", "", "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(totals.BySource).To(HaveKey(auth.UsageSourceLegacy))
|
||||
Expect(totals.BySource[auth.UsageSourceLegacy].Tokens).To(Equal(int64(5)))
|
||||
})
|
||||
|
||||
It("filters by user_id AND api_key_id", func() {
|
||||
db := testDB()
|
||||
insert(db, "alice", auth.UsageSourceAPIKey, "k1", 10)
|
||||
insert(db, "alice", auth.UsageSourceAPIKey, "k2", 20)
|
||||
insert(db, "bob", auth.UsageSourceAPIKey, "k3", 30)
|
||||
|
||||
_, totals, _, err := auth.GetAllUsageBySource(db, "month", "alice", "k2")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(totals.GrandTotal.Tokens).To(Equal(int64(20)))
|
||||
})
|
||||
|
||||
It("sets truncated=true when by_key exceeds the cap", func() {
|
||||
db := testDB()
|
||||
for i := 0; i < 210; i++ {
|
||||
insert(db, "alice", auth.UsageSourceAPIKey, fmt.Sprintf("key-%03d", i), int64(210-i))
|
||||
}
|
||||
|
||||
_, totals, truncated, err := auth.GetAllUsageBySource(db, "month", "", "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(truncated).To(BeTrue())
|
||||
Expect(totals.ByKey).To(HaveLen(200))
|
||||
Expect(totals.ByKey[0].Tokens > totals.ByKey[199].Tokens).To(BeTrue())
|
||||
})
|
||||
|
||||
// insertNamed records a row with explicit user_id, user_name, source,
|
||||
// and optional api key snapshot. Used by the user-attribution tests
|
||||
// below which the older insert helper can't express.
|
||||
insertNamed := func(db *gorm.DB, userID, userName, source, keyID, keyName string, tokens int64) {
|
||||
rec := &auth.UsageRecord{
|
||||
UserID: userID,
|
||||
UserName: userName,
|
||||
Source: source,
|
||||
Model: "gpt-4",
|
||||
TotalTokens: tokens,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if keyID != "" {
|
||||
rec.APIKeyID = &keyID
|
||||
rec.APIKeyName = keyName
|
||||
}
|
||||
Expect(auth.RecordUsage(db, rec)).To(Succeed())
|
||||
}
|
||||
|
||||
It("attributes each KeyTotal to its owner user", func() {
|
||||
db := testDB()
|
||||
insertNamed(db, "alice", "Alice", auth.UsageSourceAPIKey, "k1", "ci-runner", 100)
|
||||
insertNamed(db, "bob", "Bob", auth.UsageSourceAPIKey, "k2", "lap", 50)
|
||||
|
||||
_, totals, _, err := auth.GetAllUsageBySource(db, "month", "", "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(totals.ByKey).To(HaveLen(2))
|
||||
|
||||
byID := map[string]auth.KeyTotal{}
|
||||
for _, k := range totals.ByKey {
|
||||
byID[k.APIKeyID] = k
|
||||
}
|
||||
Expect(byID["k1"].UserID).To(Equal("alice"))
|
||||
Expect(byID["k1"].UserName).To(Equal("Alice"))
|
||||
Expect(byID["k2"].UserID).To(Equal("bob"))
|
||||
Expect(byID["k2"].UserName).To(Equal("Bob"))
|
||||
})
|
||||
|
||||
It("breaks Web UI and legacy traffic out per user in by_user_source for admin", func() {
|
||||
db := testDB()
|
||||
// Alice and Bob both have Web UI traffic; a synthetic legacy user
|
||||
// also contributes. ByUserSource should expose one row per
|
||||
// (source, user) pair, never for source=apikey.
|
||||
insertNamed(db, "alice", "Alice", auth.UsageSourceWeb, "", "", 30)
|
||||
insertNamed(db, "bob", "Bob", auth.UsageSourceWeb, "", "", 70)
|
||||
insertNamed(db, "legacy-api-key", "API Key User", auth.UsageSourceLegacy, "", "", 10)
|
||||
insertNamed(db, "alice", "Alice", auth.UsageSourceAPIKey, "k1", "ci-runner", 5)
|
||||
|
||||
_, totals, _, err := auth.GetAllUsageBySource(db, "month", "", "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(totals.ByUserSource).ToNot(BeEmpty())
|
||||
|
||||
for _, r := range totals.ByUserSource {
|
||||
Expect(r.Source).ToNot(Equal(auth.UsageSourceAPIKey))
|
||||
}
|
||||
|
||||
webByUser := map[string]int64{}
|
||||
legacyByUser := map[string]int64{}
|
||||
for _, r := range totals.ByUserSource {
|
||||
switch r.Source {
|
||||
case auth.UsageSourceWeb:
|
||||
webByUser[r.UserID] = r.Tokens
|
||||
case auth.UsageSourceLegacy:
|
||||
legacyByUser[r.UserID] = r.Tokens
|
||||
}
|
||||
}
|
||||
Expect(webByUser["alice"]).To(Equal(int64(30)))
|
||||
Expect(webByUser["bob"]).To(Equal(int64(70)))
|
||||
Expect(legacyByUser["legacy-api-key"]).To(Equal(int64(10)))
|
||||
})
|
||||
|
||||
It("does NOT populate by_user_source in the non-admin path", func() {
|
||||
db := testDB()
|
||||
insertNamed(db, "alice", "Alice", auth.UsageSourceWeb, "", "", 30)
|
||||
|
||||
_, totals, err := auth.GetUserUsageBySource(db, "alice", "month")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// Non-admin path uses includeLegacy=false, so by_user_source stays nil.
|
||||
Expect(totals.ByUserSource).To(BeNil())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -16,8 +16,11 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/xlog"
|
||||
"gorm.io/gorm"
|
||||
@@ -381,14 +384,24 @@ func ResumeNodeEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// InstallBackendOnNodeEndpoint triggers backend installation on a worker node via NATS.
|
||||
// InstallBackendOnNodeEndpoint triggers backend installation on a worker node.
|
||||
// Async: enqueues a ManagementOp on the gallery service channel and returns a
|
||||
// jobID immediately. The gallery service worker goroutine drives the actual
|
||||
// install via DistributedBackendManager.InstallBackend, which honors the op's
|
||||
// TargetNodeID to scope the fan-out to one node. The UI polls /api/backends/job/:uid
|
||||
// for progress, mirroring /api/backends/install/:id.
|
||||
//
|
||||
// Backend can be either a gallery ID (resolved against BackendGalleries) or a
|
||||
// direct URI install (URI + Name + optional Alias) — same shape as the
|
||||
// direct URI install (URI + Name + optional Alias) - same shape as the
|
||||
// standalone /api/backends/install-external path, just scoped to one node.
|
||||
func InstallBackendOnNodeEndpoint(unloader nodes.NodeCommandSender) echo.HandlerFunc {
|
||||
//
|
||||
// The legacy unloader argument is retained for signature symmetry with
|
||||
// DeleteBackendOnNodeEndpoint / ListBackendsOnNodeEndpoint but is no longer
|
||||
// used here - the async path goes through galleryService.
|
||||
func InstallBackendOnNodeEndpoint(_ nodes.NodeCommandSender, galleryService *galleryop.GalleryService, opcache *galleryop.OpCache, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if unloader == nil {
|
||||
return c.JSON(http.StatusServiceUnavailable, nodeError(http.StatusServiceUnavailable, "NATS not configured"))
|
||||
if galleryService == nil {
|
||||
return c.JSON(http.StatusServiceUnavailable, nodeError(http.StatusServiceUnavailable, "gallery service not configured"))
|
||||
}
|
||||
nodeID := c.Param("id")
|
||||
var req struct {
|
||||
@@ -401,25 +414,65 @@ func InstallBackendOnNodeEndpoint(unloader nodes.NodeCommandSender) echo.Handler
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid request body"))
|
||||
}
|
||||
// Either a gallery backend name or a direct URI must be supplied.
|
||||
if req.Backend == "" && req.URI == "" {
|
||||
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "backend name or uri required"))
|
||||
}
|
||||
// Admin-driven backend install: not tied to a specific replica slot
|
||||
// (no model is being loaded). Pass replica 0 to match the worker's
|
||||
// admin process-key convention (`backend#0`). The worker's fast path
|
||||
// takes over if the backend is already running — upgrades go through
|
||||
// the dedicated /api/backends/upgrade path on backend.upgrade.
|
||||
reply, err := unloader.InstallBackend(nodeID, req.Backend, "", req.BackendGalleries, req.URI, req.Name, req.Alias, 0)
|
||||
|
||||
jobUUID, err := uuid.NewUUID()
|
||||
if err != nil {
|
||||
xlog.Error("Failed to install backend on node", "node", nodeID, "backend", req.Backend, "uri", req.URI, "error", err)
|
||||
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to install backend on node"))
|
||||
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to generate job id"))
|
||||
}
|
||||
if !reply.Success {
|
||||
xlog.Error("Backend install failed on node", "node", nodeID, "backend", req.Backend, "uri", req.URI, "error", reply.Error)
|
||||
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "backend installation failed"))
|
||||
jobID := jobUUID.String()
|
||||
|
||||
// Cache key: for gallery installs, use the backend slug; for URI
|
||||
// installs prefer the provided Name (falling back to URI). All keys
|
||||
// are node-scoped so concurrent installs of the same backend on
|
||||
// different nodes do not stomp each other in opcache.
|
||||
backendKey := req.Backend
|
||||
if backendKey == "" {
|
||||
backendKey = req.Name
|
||||
if backendKey == "" {
|
||||
backendKey = req.URI
|
||||
}
|
||||
}
|
||||
return c.JSON(http.StatusOK, map[string]string{"message": "backend installed"})
|
||||
cacheKey := galleryop.NodeScopedKey(nodeID, backendKey)
|
||||
opcache.SetBackend(cacheKey, jobID)
|
||||
|
||||
// Optional caller-supplied galleries override. Mirrors the standalone
|
||||
// install path so an admin can point at a private gallery.
|
||||
galleries := appConfig.BackendGalleries
|
||||
if req.BackendGalleries != "" {
|
||||
var custom []config.Gallery
|
||||
if err := json.Unmarshal([]byte(req.BackendGalleries), &custom); err != nil {
|
||||
xlog.Warn("Ignoring malformed backend_galleries override; falling back to configured galleries", "error", err, "nodeID", nodeID)
|
||||
} else if len(custom) > 0 {
|
||||
galleries = custom
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
op := galleryop.ManagementOp[gallery.GalleryBackend, any]{
|
||||
ID: jobID,
|
||||
GalleryElementName: req.Backend,
|
||||
Galleries: galleries,
|
||||
TargetNodeID: nodeID,
|
||||
ExternalURI: req.URI,
|
||||
ExternalName: req.Name,
|
||||
ExternalAlias: req.Alias,
|
||||
Context: ctx,
|
||||
CancelFunc: cancelFunc,
|
||||
}
|
||||
galleryService.StoreCancellation(jobID, cancelFunc)
|
||||
go func() {
|
||||
galleryService.BackendGalleryChannel <- op
|
||||
}()
|
||||
|
||||
xlog.Info("Node-scoped backend install dispatched", "node", nodeID, "backend", req.Backend, "uri", req.URI, "jobID", jobID)
|
||||
return c.JSON(http.StatusAccepted, map[string]string{
|
||||
"jobID": jobID,
|
||||
"statusUrl": "/api/backends/job/" + jobID,
|
||||
"message": "backend installation started",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
123
core/http/endpoints/localai/nodes_install_async_test.go
Normal file
123
core/http/endpoints/localai/nodes_install_async_test.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package localai_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
)
|
||||
|
||||
// InstallBackendOnNodeEndpoint became async to stop blocking the browser on
|
||||
// the 3-minute NATS reply timeout. These specs lock in the new contract:
|
||||
// HTTP 202 with a jobID, a ManagementOp enqueued on the gallery channel, and
|
||||
// an opcache entry keyed by NodeScopedKey so concurrent installs of the same
|
||||
// backend on different nodes do not stomp each other.
|
||||
var _ = Describe("InstallBackendOnNodeEndpoint async behavior", func() {
|
||||
var (
|
||||
e *echo.Echo
|
||||
galleryService *galleryop.GalleryService
|
||||
opcache *galleryop.OpCache
|
||||
appCfg *config.ApplicationConfig
|
||||
dispatched chan galleryop.ManagementOp[gallery.GalleryBackend, any]
|
||||
done chan struct{}
|
||||
drainExited chan struct{}
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
e = echo.New()
|
||||
appCfg = &config.ApplicationConfig{
|
||||
BackendGalleries: []config.Gallery{{Name: "test-gallery", URL: "http://example.com"}},
|
||||
}
|
||||
galleryService = galleryop.NewGalleryService(appCfg, nil)
|
||||
opcache = galleryop.NewOpCache(galleryService)
|
||||
// Drain the gallery channel into a buffered side channel so the
|
||||
// handler's `go func() { ch <- op }()` send does not block waiting
|
||||
// for the real worker (which is not running in this unit test).
|
||||
dispatched = make(chan galleryop.ManagementOp[gallery.GalleryBackend, any], 4)
|
||||
done = make(chan struct{})
|
||||
drainExited = make(chan struct{})
|
||||
go func() {
|
||||
defer close(drainExited)
|
||||
for {
|
||||
select {
|
||||
case op := <-galleryService.BackendGalleryChannel:
|
||||
dispatched <- op
|
||||
case <-done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
// Signal the drain goroutine to exit. We do NOT close
|
||||
// BackendGalleryChannel: the handler's dispatch goroutine may still
|
||||
// be pending (specs that don't Eventually-Receive), and a send on a
|
||||
// closed channel panics. Signalling via `done` lets the drain
|
||||
// goroutine return without touching the gallery channel.
|
||||
close(done)
|
||||
Eventually(drainExited, "2s").Should(BeClosed())
|
||||
})
|
||||
|
||||
It("returns 202 with a jobID and dispatches a TargetNodeID-scoped op", func() {
|
||||
body := `{"backend": "llama-cpp"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/nodes/node-xyz/backends/install", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
c.SetParamNames("id")
|
||||
c.SetParamValues("node-xyz")
|
||||
|
||||
handler := localai.InstallBackendOnNodeEndpoint(nil, galleryService, opcache, appCfg)
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusAccepted))
|
||||
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
Expect(resp["jobID"]).To(BeAssignableToTypeOf(""))
|
||||
Expect(resp["jobID"].(string)).ToNot(BeEmpty())
|
||||
Expect(resp["message"]).To(Equal("backend installation started"))
|
||||
|
||||
Eventually(dispatched, "2s").Should(Receive())
|
||||
Expect(opcache.Exists(galleryop.NodeScopedKey("node-xyz", "llama-cpp"))).To(BeTrue())
|
||||
Expect(opcache.IsBackendOp(galleryop.NodeScopedKey("node-xyz", "llama-cpp"))).To(BeTrue())
|
||||
})
|
||||
|
||||
It("returns 400 when neither backend nor uri is supplied", func() {
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/nodes/node-xyz/backends/install", bytes.NewBufferString(`{}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
c.SetParamNames("id")
|
||||
c.SetParamValues("node-xyz")
|
||||
|
||||
handler := localai.InstallBackendOnNodeEndpoint(nil, galleryService, opcache, appCfg)
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
})
|
||||
|
||||
It("accepts a direct URI install and uses the name as the cache key", func() {
|
||||
body := `{"uri": "oci://example.com/custom-backend:v1", "name": "custom"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/nodes/node-xyz/backends/install", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
c.SetParamNames("id")
|
||||
c.SetParamValues("node-xyz")
|
||||
|
||||
handler := localai.InstallBackendOnNodeEndpoint(nil, galleryService, opcache, appCfg)
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusAccepted))
|
||||
|
||||
Expect(opcache.Exists(galleryop.NodeScopedKey("node-xyz", "custom"))).To(BeTrue())
|
||||
})
|
||||
})
|
||||
@@ -73,363 +73,6 @@ func mergeToolCallDeltas(existing []schema.ToolCall, deltas []schema.ToolCall) [
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/chat/completions [post]
|
||||
func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig, natsClient mcpTools.MCPNATSClient, assistantHolder *mcpTools.LocalAIAssistantHolder) echo.HandlerFunc {
|
||||
process := func(s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool, id string, created int) error {
|
||||
initialMessage := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0, FinishReason: nil}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
responses <- initialMessage
|
||||
|
||||
// Detect if thinking token is already in prompt or template
|
||||
// When UseTokenizerTemplate is enabled, predInput is empty, so we check the template
|
||||
var template string
|
||||
if config.TemplateConfig.UseTokenizerTemplate {
|
||||
template = config.GetModelTemplate()
|
||||
} else {
|
||||
template = s
|
||||
}
|
||||
thinkingStartToken := reason.DetectThinkingStartToken(template, &config.ReasoningConfig)
|
||||
extractor := reason.NewReasoningExtractor(thinkingStartToken, config.ReasoningConfig)
|
||||
|
||||
_, _, _, err := ComputeChoices(req, s, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool {
|
||||
var reasoningDelta, contentDelta string
|
||||
|
||||
// Always keep the Go-side extractor in sync with raw tokens so it
|
||||
// can serve as fallback for backends without an autoparser (e.g. vLLM).
|
||||
goReasoning, goContent := extractor.ProcessToken(s)
|
||||
|
||||
// When C++ autoparser chat deltas are available, prefer them — they
|
||||
// handle model-specific formats (Gemma 4, etc.) without Go-side tags.
|
||||
// Otherwise fall back to Go-side extraction.
|
||||
if tokenUsage.HasChatDeltaContent() {
|
||||
rawReasoning, cd := tokenUsage.ChatDeltaReasoningAndContent()
|
||||
contentDelta = cd
|
||||
reasoningDelta = extractor.ProcessChatDeltaReasoning(rawReasoning)
|
||||
} else {
|
||||
reasoningDelta = goReasoning
|
||||
contentDelta = goContent
|
||||
}
|
||||
|
||||
usage := schema.OpenAIUsage{
|
||||
PromptTokens: tokenUsage.Prompt,
|
||||
CompletionTokens: tokenUsage.Completion,
|
||||
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
|
||||
}
|
||||
if extraUsage {
|
||||
usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration
|
||||
usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing
|
||||
}
|
||||
|
||||
delta := &schema.Message{}
|
||||
if contentDelta != "" {
|
||||
delta.Content = &contentDelta
|
||||
}
|
||||
if reasoningDelta != "" {
|
||||
delta.Reasoning = &reasoningDelta
|
||||
}
|
||||
|
||||
// Usage rides as a struct field for the consumer to track the
|
||||
// running cumulative — it is stripped before JSON marshal so the
|
||||
// wire chunk stays spec-compliant (no `usage` on intermediate
|
||||
// chunks). The dedicated trailer chunk (when include_usage=true)
|
||||
// carries the final totals.
|
||||
usageForChunk := usage
|
||||
resp := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{Delta: delta, Index: 0, FinishReason: nil}},
|
||||
Object: "chat.completion.chunk",
|
||||
Usage: &usageForChunk,
|
||||
}
|
||||
|
||||
responses <- resp
|
||||
return true
|
||||
})
|
||||
close(responses)
|
||||
return err
|
||||
}
|
||||
processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool, id string, created int, textContentToReturn *string) error {
|
||||
// Detect if thinking token is already in prompt or template
|
||||
var template string
|
||||
if config.TemplateConfig.UseTokenizerTemplate {
|
||||
template = config.GetModelTemplate()
|
||||
} else {
|
||||
template = prompt
|
||||
}
|
||||
thinkingStartToken := reason.DetectThinkingStartToken(template, &config.ReasoningConfig)
|
||||
extractor := reason.NewReasoningExtractor(thinkingStartToken, config.ReasoningConfig)
|
||||
|
||||
result := ""
|
||||
lastEmittedCount := 0
|
||||
sentInitialRole := false
|
||||
sentReasoning := false
|
||||
hasChatDeltaToolCalls := false
|
||||
hasChatDeltaContent := false
|
||||
|
||||
_, _, chatDeltas, err := ComputeChoices(req, prompt, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
||||
result += s
|
||||
|
||||
// Track whether ChatDeltas from the C++ autoparser contain
|
||||
// tool calls or content, so the retry decision can account for them.
|
||||
for _, d := range usage.ChatDeltas {
|
||||
if len(d.ToolCalls) > 0 {
|
||||
hasChatDeltaToolCalls = true
|
||||
}
|
||||
if d.Content != "" {
|
||||
hasChatDeltaContent = true
|
||||
}
|
||||
}
|
||||
|
||||
var reasoningDelta, contentDelta string
|
||||
|
||||
goReasoning, goContent := extractor.ProcessToken(s)
|
||||
|
||||
if usage.HasChatDeltaContent() {
|
||||
rawReasoning, cd := usage.ChatDeltaReasoningAndContent()
|
||||
contentDelta = cd
|
||||
reasoningDelta = extractor.ProcessChatDeltaReasoning(rawReasoning)
|
||||
} else {
|
||||
reasoningDelta = goReasoning
|
||||
contentDelta = goContent
|
||||
}
|
||||
|
||||
// Emit reasoning deltas in their own SSE chunks before any tool-call chunks
|
||||
// (OpenAI spec: reasoning and tool_calls never share a delta)
|
||||
if reasoningDelta != "" {
|
||||
responses <- schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model,
|
||||
Choices: []schema.Choice{{
|
||||
Delta: &schema.Message{Reasoning: &reasoningDelta},
|
||||
Index: 0,
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
sentReasoning = true
|
||||
}
|
||||
|
||||
// Stream content deltas (cleaned of reasoning tags) while no tool calls
|
||||
// have been detected. Once the incremental parser finds tool calls,
|
||||
// content stops — per OpenAI spec, content and tool_calls don't mix.
|
||||
if lastEmittedCount == 0 && contentDelta != "" {
|
||||
if !sentInitialRole {
|
||||
responses <- schema.OpenAIResponse{
|
||||
ID: id, Created: created, Model: req.Model,
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
sentInitialRole = true
|
||||
}
|
||||
responses <- schema.OpenAIResponse{
|
||||
ID: id, Created: created, Model: req.Model,
|
||||
Choices: []schema.Choice{{
|
||||
Delta: &schema.Message{Content: &contentDelta},
|
||||
Index: 0,
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
}
|
||||
|
||||
// Try incremental XML parsing for streaming support using iterative parser
|
||||
// This allows emitting partial tool calls as they're being generated
|
||||
cleanedResult := functions.CleanupLLMResult(result, config.FunctionsConfig)
|
||||
|
||||
// Determine XML format from config
|
||||
var xmlFormat *functions.XMLToolCallFormat
|
||||
if config.FunctionsConfig.XMLFormat != nil {
|
||||
xmlFormat = config.FunctionsConfig.XMLFormat
|
||||
} else if config.FunctionsConfig.XMLFormatPreset != "" {
|
||||
xmlFormat = functions.GetXMLFormatPreset(config.FunctionsConfig.XMLFormatPreset)
|
||||
}
|
||||
|
||||
// Use iterative parser for streaming (partial parsing enabled)
|
||||
// Try XML parsing first
|
||||
partialResults, parseErr := functions.ParseXMLIterative(cleanedResult, xmlFormat, true)
|
||||
if parseErr == nil && len(partialResults) > 0 {
|
||||
// Emit new XML tool calls that weren't emitted before
|
||||
if len(partialResults) > lastEmittedCount {
|
||||
for i := lastEmittedCount; i < len(partialResults); i++ {
|
||||
toolCall := partialResults[i]
|
||||
initialMessage := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model,
|
||||
Choices: []schema.Choice{{
|
||||
Delta: &schema.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []schema.ToolCall{
|
||||
{
|
||||
Index: i,
|
||||
ID: id,
|
||||
Type: "function",
|
||||
FunctionCall: schema.FunctionCall{
|
||||
Name: toolCall.Name,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Index: 0,
|
||||
FinishReason: nil,
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
select {
|
||||
case responses <- initialMessage:
|
||||
default:
|
||||
}
|
||||
}
|
||||
lastEmittedCount = len(partialResults)
|
||||
}
|
||||
} else {
|
||||
// Try JSON tool call parsing for streaming.
|
||||
// Only emit NEW tool calls (same guard as XML parser above).
|
||||
jsonResults, jsonErr := functions.ParseJSONIterative(cleanedResult, true)
|
||||
if jsonErr == nil && len(jsonResults) > lastEmittedCount {
|
||||
for i := lastEmittedCount; i < len(jsonResults); i++ {
|
||||
jsonObj := jsonResults[i]
|
||||
name, ok := jsonObj["name"].(string)
|
||||
if !ok || name == "" {
|
||||
continue
|
||||
}
|
||||
args := "{}"
|
||||
if argsVal, ok := jsonObj["arguments"]; ok {
|
||||
if argsStr, ok := argsVal.(string); ok {
|
||||
args = argsStr
|
||||
} else {
|
||||
argsBytes, _ := json.Marshal(argsVal)
|
||||
args = string(argsBytes)
|
||||
}
|
||||
}
|
||||
initialMessage := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model,
|
||||
Choices: []schema.Choice{{
|
||||
Delta: &schema.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []schema.ToolCall{
|
||||
{
|
||||
Index: i,
|
||||
ID: id,
|
||||
Type: "function",
|
||||
FunctionCall: schema.FunctionCall{
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Index: 0,
|
||||
FinishReason: nil,
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
responses <- initialMessage
|
||||
}
|
||||
lastEmittedCount = len(jsonResults)
|
||||
}
|
||||
}
|
||||
return true
|
||||
},
|
||||
func(attempt int) bool {
|
||||
// After streaming completes: check if we got actionable content
|
||||
cleaned := extractor.CleanedContent()
|
||||
// Check for tool calls from chat deltas (will be re-checked after ComputeChoices,
|
||||
// but we need to know here whether to retry).
|
||||
// Also check ChatDelta flags — when the C++ autoparser is active,
|
||||
// tool calls and content are delivered via ChatDeltas while the
|
||||
// raw message is cleared. Without this check, we'd retry
|
||||
// unnecessarily, losing valid results and concatenating output.
|
||||
hasToolCalls := lastEmittedCount > 0 || hasChatDeltaToolCalls
|
||||
hasContent := cleaned != "" || hasChatDeltaContent
|
||||
if !hasContent && !hasToolCalls {
|
||||
xlog.Warn("Streaming: backend produced only reasoning, retrying",
|
||||
"reasoning_len", len(extractor.Reasoning()), "attempt", attempt+1)
|
||||
extractor.ResetAndSuppressReasoning()
|
||||
result = ""
|
||||
lastEmittedCount = 0
|
||||
sentInitialRole = false
|
||||
hasChatDeltaToolCalls = false
|
||||
hasChatDeltaContent = false
|
||||
return true
|
||||
}
|
||||
return false
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Try using pre-parsed tool calls from C++ autoparser (chat deltas)
|
||||
var functionResults []functions.FuncCallResults
|
||||
var reasoning string
|
||||
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 {
|
||||
xlog.Debug("[ChatDeltas] Using pre-parsed tool calls from C++ autoparser", "count", len(deltaToolCalls))
|
||||
functionResults = deltaToolCalls
|
||||
// Use content/reasoning from deltas too
|
||||
*textContentToReturn = functions.ContentFromChatDeltas(chatDeltas)
|
||||
reasoning = functions.ReasoningFromChatDeltas(chatDeltas)
|
||||
} else {
|
||||
// Fallback: parse tool calls from raw text (no chat deltas from backend)
|
||||
xlog.Debug("[ChatDeltas] no pre-parsed tool calls, falling back to Go-side text parsing")
|
||||
reasoning = extractor.Reasoning()
|
||||
cleanedResult := extractor.CleanedContent()
|
||||
*textContentToReturn = functions.ParseTextContent(cleanedResult, config.FunctionsConfig)
|
||||
cleanedResult = functions.CleanupLLMResult(cleanedResult, config.FunctionsConfig)
|
||||
functionResults = functions.ParseFunctionCall(cleanedResult, config.FunctionsConfig)
|
||||
}
|
||||
xlog.Debug("[ChatDeltas] final tool call decision", "tool_calls", len(functionResults), "text_content", *textContentToReturn)
|
||||
// noAction is a sentinel "just answer" pseudo-function — not a real
|
||||
// tool call. Scan the whole slice rather than only index 0 so we
|
||||
// don't drop a real tool call that happens to follow a noAction
|
||||
// entry, and so the default branch isn't entered with only noAction
|
||||
// entries to emit as tool_calls.
|
||||
noActionToRun := !hasRealCall(functionResults, noAction)
|
||||
|
||||
switch {
|
||||
case noActionToRun:
|
||||
// Token-cumulative usage is communicated to the streaming
|
||||
// consumer via the per-token callback's chunk struct (stripped
|
||||
// before wire marshal). The final usage trailer — when the
|
||||
// caller opted in with stream_options.include_usage — is built
|
||||
// by the outer streaming loop, not here.
|
||||
var result string
|
||||
if !sentInitialRole {
|
||||
var hqErr error
|
||||
result, hqErr = handleQuestion(config, functionResults, extractor.CleanedContent(), prompt)
|
||||
if hqErr != nil {
|
||||
xlog.Error("error handling question", "error", hqErr)
|
||||
return hqErr
|
||||
}
|
||||
}
|
||||
for _, chunk := range buildNoActionFinalChunks(
|
||||
id, req.Model, created,
|
||||
sentInitialRole, sentReasoning,
|
||||
result, reasoning,
|
||||
) {
|
||||
responses <- chunk
|
||||
}
|
||||
|
||||
default:
|
||||
for _, chunk := range buildDeferredToolCallChunks(
|
||||
id, req.Model, created,
|
||||
functionResults, lastEmittedCount,
|
||||
sentInitialRole, *textContentToReturn,
|
||||
sentReasoning, reasoning,
|
||||
) {
|
||||
responses <- chunk
|
||||
}
|
||||
}
|
||||
|
||||
close(responses)
|
||||
return err
|
||||
}
|
||||
|
||||
return func(c echo.Context) error {
|
||||
var textContentToReturn string
|
||||
id := uuid.New().String()
|
||||
@@ -697,17 +340,19 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
}
|
||||
|
||||
responses := make(chan schema.OpenAIResponse)
|
||||
ended := make(chan error, 1)
|
||||
ended := make(chan streamWorkerResult, 1)
|
||||
|
||||
go func() {
|
||||
if !shouldUseFn {
|
||||
ended <- process(predInput, input, config, ml, responses, extraUsage, id, created)
|
||||
u, err := processStream(predInput, input, config, cl, startupOptions, ml, responses, id, created)
|
||||
ended <- streamWorkerResult{usage: u, err: err}
|
||||
} else {
|
||||
ended <- processTools(noActionName, predInput, input, config, ml, responses, extraUsage, id, created, &textContentToReturn)
|
||||
u, err := processStreamWithTools(noActionName, predInput, input, config, cl, startupOptions, ml, responses, id, created, &textContentToReturn)
|
||||
ended <- streamWorkerResult{usage: u, err: err}
|
||||
}
|
||||
}()
|
||||
|
||||
usage := &schema.OpenAIUsage{}
|
||||
var finalUsage backend.TokenUsage
|
||||
toolsCalled := false
|
||||
var collectedToolCalls []schema.ToolCall
|
||||
var collectedContent string
|
||||
@@ -725,13 +370,6 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
xlog.Debug("No choices in the response, skipping")
|
||||
continue
|
||||
}
|
||||
// Capture the running cumulative usage from this chunk
|
||||
// (when present) so the include_usage trailer can carry
|
||||
// the final totals. Usage is stripped before marshal
|
||||
// below so the wire chunk stays spec-compliant.
|
||||
if ev.Usage != nil {
|
||||
usage = ev.Usage
|
||||
}
|
||||
if len(ev.Choices[0].Delta.ToolCalls) > 0 {
|
||||
toolsCalled = true
|
||||
// Collect and merge tool call deltas for MCP execution
|
||||
@@ -747,11 +385,6 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
collectedContent += *sp
|
||||
}
|
||||
}
|
||||
// OpenAI streaming spec: intermediate chunks must NOT
|
||||
// carry a `usage` field. Strip the tracking copy
|
||||
// before marshalling — usage is delivered via the
|
||||
// dedicated trailer chunk when include_usage=true.
|
||||
ev.Usage = nil
|
||||
respData, err := json.Marshal(ev)
|
||||
if err != nil {
|
||||
xlog.Debug("Failed to marshal response", "error", err)
|
||||
@@ -766,15 +399,16 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
return err
|
||||
}
|
||||
c.Response().Flush()
|
||||
case err := <-ended:
|
||||
if err == nil {
|
||||
case res := <-ended:
|
||||
if res.err == nil {
|
||||
finalUsage = res.usage
|
||||
break LOOP
|
||||
}
|
||||
xlog.Error("Stream ended with error", "error", err)
|
||||
xlog.Error("Stream ended with error", "error", res.err)
|
||||
|
||||
errorResp := schema.ErrorResponse{
|
||||
Error: &schema.APIError{
|
||||
Message: err.Error(),
|
||||
Message: res.err.Error(),
|
||||
Type: "server_error",
|
||||
Code: "server_error",
|
||||
},
|
||||
@@ -797,7 +431,10 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
// still trying to send (e.g., after client disconnect). The goroutine
|
||||
// calls close(responses) when done, which terminates the drain.
|
||||
if input.Context.Err() != nil {
|
||||
go func() { for range responses {} }()
|
||||
go func() {
|
||||
for range responses {
|
||||
}
|
||||
}()
|
||||
<-ended
|
||||
}
|
||||
|
||||
@@ -921,8 +558,16 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
// Trailing usage chunk per OpenAI spec: emit only when the
|
||||
// caller opted in via stream_options.include_usage. Shape:
|
||||
// {"choices":[],"usage":{...},"object":"chat.completion.chunk",...}
|
||||
if input.StreamOptions != nil && input.StreamOptions.IncludeUsage && usage != nil {
|
||||
trailer := streamUsageTrailerJSON(id, input.Model, created, *usage)
|
||||
//
|
||||
// finalUsage is the authoritative TokenUsage returned by the
|
||||
// worker function (process / processTools) via the `ended`
|
||||
// channel. The worker reads it from ComputeChoices' return
|
||||
// value, which is the cumulative count produced by the backend
|
||||
// over the whole prediction. Issue #9927 was caused by the
|
||||
// tools-path worker not surfacing this value at all.
|
||||
if input.StreamOptions != nil && input.StreamOptions.IncludeUsage {
|
||||
trailerUsage := streamUsageFromTokenUsage(finalUsage, extraUsage)
|
||||
trailer := streamUsageTrailerJSON(id, input.Model, created, trailerUsage)
|
||||
_, _ = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", trailer)
|
||||
}
|
||||
|
||||
|
||||
@@ -4,10 +4,39 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
)
|
||||
|
||||
// streamWorkerResult is what the streaming workers (process / processTools)
|
||||
// hand back to the outer ChatEndpoint loop through the `ended` channel.
|
||||
// Threading the final TokenUsage here, instead of piggy-backing it on the
|
||||
// `responses` SSE channel, keeps the SSE channel single-purpose (wire chunks)
|
||||
// and gives the trailer emitter a plain Go value to read after LOOP exits.
|
||||
// Fix for issue #9927: the previous tools-path worker never surfaced the
|
||||
// cumulative token counts at all, so the include_usage trailer reported zeros.
|
||||
type streamWorkerResult struct {
|
||||
usage backend.TokenUsage
|
||||
err error
|
||||
}
|
||||
|
||||
// streamUsageFromTokenUsage converts the backend's cumulative TokenUsage into
|
||||
// the OpenAI-spec OpenAIUsage shape used on the wire. `extraUsage` controls
|
||||
// whether the non-standard timing fields are forwarded.
|
||||
func streamUsageFromTokenUsage(usage backend.TokenUsage, extraUsage bool) schema.OpenAIUsage {
|
||||
out := schema.OpenAIUsage{
|
||||
PromptTokens: usage.Prompt,
|
||||
CompletionTokens: usage.Completion,
|
||||
TotalTokens: usage.Prompt + usage.Completion,
|
||||
}
|
||||
if extraUsage {
|
||||
out.TimingTokenGeneration = usage.TimingTokenGeneration
|
||||
out.TimingPromptProcessing = usage.TimingPromptProcessing
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// streamUsageTrailerJSON returns the bytes of the OpenAI-spec trailing usage
|
||||
// chunk emitted in streaming completions when the request opts in via
|
||||
// `stream_options.include_usage: true`. The shape is:
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
@@ -152,6 +156,28 @@ var _ = Describe("streaming usage spec compliance", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Describe("streamUsageFromTokenUsage", func() {
|
||||
It("converts backend TokenUsage to schema OpenAIUsage", func() {
|
||||
tu := backend.TokenUsage{Prompt: 18, Completion: 213}
|
||||
u := streamUsageFromTokenUsage(tu, false)
|
||||
Expect(u.PromptTokens).To(Equal(18))
|
||||
Expect(u.CompletionTokens).To(Equal(213))
|
||||
Expect(u.TotalTokens).To(Equal(231))
|
||||
Expect(u.TimingTokenGeneration).To(BeZero())
|
||||
Expect(u.TimingPromptProcessing).To(BeZero())
|
||||
})
|
||||
It("includes timings when extraUsage is true", func() {
|
||||
tu := backend.TokenUsage{
|
||||
Prompt: 10, Completion: 20,
|
||||
TimingPromptProcessing: 0.5,
|
||||
TimingTokenGeneration: 1.5,
|
||||
}
|
||||
u := streamUsageFromTokenUsage(tu, true)
|
||||
Expect(u.TimingPromptProcessing).To(Equal(0.5))
|
||||
Expect(u.TimingTokenGeneration).To(Equal(1.5))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("OpenAIRequest.StreamOptions", func() {
|
||||
It("parses stream_options.include_usage=true", func() {
|
||||
body := []byte(`{
|
||||
@@ -177,3 +203,160 @@ var _ = Describe("streaming usage spec compliance", func() {
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// Functional regression coverage for issue #9927: the streaming workers
|
||||
// must surface the cumulative TokenUsage returned by ComputeChoices to
|
||||
// their caller. The earlier broken implementations discarded that value
|
||||
// (`_, _, chatDeltas, err := ComputeChoices(...)`) and threw away the
|
||||
// counts on the floor, so the include_usage trailer always reported
|
||||
// zeros when tools were enabled.
|
||||
//
|
||||
// These tests stub backend.ModelInferenceFunc so the worker exercises the
|
||||
// real ComputeChoices → predFunc → LLMResponse pipeline. If a future change
|
||||
// drops the TokenUsage somewhere along that path, the assertions on the
|
||||
// returned value fail with a concrete count mismatch (e.g. 0 vs 213),
|
||||
// not with a "function undefined" compile error.
|
||||
var _ = Describe("streaming workers surface final TokenUsage (issue #9927)", func() {
|
||||
var (
|
||||
origInference modelInferenceFunc
|
||||
appCfg *config.ApplicationConfig
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
origInference = backend.ModelInferenceFunc
|
||||
appCfg = config.NewApplicationConfig()
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
backend.ModelInferenceFunc = origInference
|
||||
})
|
||||
|
||||
// mockBackendUsage installs a stub backend that yields one LLMResponse
|
||||
// carrying the supplied TokenUsage. ComputeChoices' single-attempt path
|
||||
// copies these counts into the value it returns to the worker.
|
||||
mockBackendUsage := func(usage backend.TokenUsage, response string) {
|
||||
backend.ModelInferenceFunc = func(
|
||||
ctx context.Context, s string, messages schema.Messages,
|
||||
images, videos, audios []string,
|
||||
loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader,
|
||||
o *config.ApplicationConfig,
|
||||
tokenCallback func(string, backend.TokenUsage) bool,
|
||||
tools, toolChoice string,
|
||||
logprobs, topLogprobs *int,
|
||||
logitBias map[string]float64,
|
||||
metadata map[string]string,
|
||||
) (func() (backend.LLMResponse, error), error) {
|
||||
return func() (backend.LLMResponse, error) {
|
||||
return backend.LLMResponse{
|
||||
Response: response,
|
||||
Usage: usage,
|
||||
}, nil
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
makeReq := func() *schema.OpenAIRequest {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
req := &schema.OpenAIRequest{
|
||||
Context: ctx,
|
||||
Cancel: cancel,
|
||||
}
|
||||
req.Model = "test-model" // promoted from BasicModelRequest
|
||||
return req
|
||||
}
|
||||
|
||||
// drainResponses consumes everything the worker pushes onto the channel
|
||||
// so the worker is never blocked on its send. The channel is unbuffered
|
||||
// (matching production), so the drain goroutine must be running before
|
||||
// the worker is called.
|
||||
drainResponses := func(ch <-chan schema.OpenAIResponse) <-chan struct{} {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
for range ch {
|
||||
}
|
||||
close(done)
|
||||
}()
|
||||
return done
|
||||
}
|
||||
|
||||
Describe("processStream (no-tools path)", func() {
|
||||
It("returns the cumulative TokenUsage produced by the backend", func() {
|
||||
mockBackendUsage(backend.TokenUsage{Prompt: 18, Completion: 213}, "Hello there")
|
||||
|
||||
req := makeReq()
|
||||
cfg := &config.ModelConfig{}
|
||||
responses := make(chan schema.OpenAIResponse)
|
||||
done := drainResponses(responses)
|
||||
|
||||
actual, err := processStream("prompt", req, cfg, nil, appCfg, nil, responses, "req-1", 0)
|
||||
<-done
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(actual.Prompt).To(Equal(18),
|
||||
"prompt tokens must round-trip from backend through processStream")
|
||||
Expect(actual.Completion).To(Equal(213),
|
||||
"completion tokens must round-trip from backend through processStream")
|
||||
})
|
||||
|
||||
It("returns zero TokenUsage when the backend reports zero (negative control)", func() {
|
||||
mockBackendUsage(backend.TokenUsage{}, "x")
|
||||
|
||||
req := makeReq()
|
||||
cfg := &config.ModelConfig{}
|
||||
responses := make(chan schema.OpenAIResponse)
|
||||
done := drainResponses(responses)
|
||||
|
||||
actual, err := processStream("prompt", req, cfg, nil, appCfg, nil, responses, "req-1", 0)
|
||||
<-done
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(actual.Prompt).To(BeZero())
|
||||
Expect(actual.Completion).To(BeZero())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("processStreamWithTools (tools path)", func() {
|
||||
It("returns the cumulative TokenUsage produced by the backend", func() {
|
||||
// This is the direct regression check for issue #9927: with tools
|
||||
// enabled, the trailer was reporting {0,0,0} because the worker
|
||||
// discarded ComputeChoices' second return value.
|
||||
mockBackendUsage(backend.TokenUsage{Prompt: 18, Completion: 213}, "answer")
|
||||
|
||||
req := makeReq()
|
||||
cfg := &config.ModelConfig{}
|
||||
responses := make(chan schema.OpenAIResponse)
|
||||
done := drainResponses(responses)
|
||||
var textContent string
|
||||
|
||||
actual, err := processStreamWithTools("none", "prompt", req, cfg, nil, appCfg, nil, responses, "req-1", 0, &textContent)
|
||||
<-done
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(actual.Prompt).To(Equal(18),
|
||||
"prompt tokens must round-trip from backend through processStreamWithTools (issue #9927)")
|
||||
Expect(actual.Completion).To(Equal(213),
|
||||
"completion tokens must round-trip from backend through processStreamWithTools (issue #9927)")
|
||||
})
|
||||
|
||||
It("forwards timing fields when the backend supplies them", func() {
|
||||
mockBackendUsage(backend.TokenUsage{
|
||||
Prompt: 10, Completion: 20,
|
||||
TimingPromptProcessing: 0.5,
|
||||
TimingTokenGeneration: 1.5,
|
||||
}, "answer")
|
||||
|
||||
req := makeReq()
|
||||
cfg := &config.ModelConfig{}
|
||||
responses := make(chan schema.OpenAIResponse)
|
||||
done := drainResponses(responses)
|
||||
var textContent string
|
||||
|
||||
actual, err := processStreamWithTools("none", "prompt", req, cfg, nil, appCfg, nil, responses, "req-1", 0, &textContent)
|
||||
<-done
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(actual.TimingPromptProcessing).To(Equal(0.5))
|
||||
Expect(actual.TimingTokenGeneration).To(Equal(1.5))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
390
core/http/endpoints/openai/chat_stream_workers.go
Normal file
390
core/http/endpoints/openai/chat_stream_workers.go
Normal file
@@ -0,0 +1,390 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
reason "github.com/mudler/LocalAI/pkg/reasoning"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// processStream is the streaming worker for chat completions with no
|
||||
// tool/function calling involved. It pushes SSE-shaped chunks onto
|
||||
// `responses` and returns the authoritative cumulative TokenUsage from
|
||||
// the prediction so the caller can populate the include_usage trailer
|
||||
// without having to peek inside the chunks.
|
||||
//
|
||||
// The caller owns the `responses` channel and is expected to read from
|
||||
// it while this function runs; processStream closes the channel before
|
||||
// returning.
|
||||
func processStream(
|
||||
s string,
|
||||
req *schema.OpenAIRequest,
|
||||
cfg *config.ModelConfig,
|
||||
cl *config.ModelConfigLoader,
|
||||
startupOptions *config.ApplicationConfig,
|
||||
loader *model.ModelLoader,
|
||||
responses chan schema.OpenAIResponse,
|
||||
id string,
|
||||
created int,
|
||||
) (backend.TokenUsage, error) {
|
||||
responses <- schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0, FinishReason: nil}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
|
||||
// Detect if thinking token is already in prompt or template
|
||||
// When UseTokenizerTemplate is enabled, predInput is empty, so we check the template
|
||||
var template string
|
||||
if cfg.TemplateConfig.UseTokenizerTemplate {
|
||||
template = cfg.GetModelTemplate()
|
||||
} else {
|
||||
template = s
|
||||
}
|
||||
thinkingStartToken := reason.DetectThinkingStartToken(template, &cfg.ReasoningConfig)
|
||||
extractor := reason.NewReasoningExtractor(thinkingStartToken, cfg.ReasoningConfig)
|
||||
|
||||
_, finalUsage, _, err := ComputeChoices(req, s, cfg, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool {
|
||||
var reasoningDelta, contentDelta string
|
||||
|
||||
// Always keep the Go-side extractor in sync with raw tokens so it
|
||||
// can serve as fallback for backends without an autoparser (e.g. vLLM).
|
||||
goReasoning, goContent := extractor.ProcessToken(s)
|
||||
|
||||
// When C++ autoparser chat deltas are available, prefer them: they
|
||||
// handle model-specific formats (Gemma 4, etc.) without Go-side tags.
|
||||
// Otherwise fall back to Go-side extraction.
|
||||
if tokenUsage.HasChatDeltaContent() {
|
||||
rawReasoning, cd := tokenUsage.ChatDeltaReasoningAndContent()
|
||||
contentDelta = cd
|
||||
reasoningDelta = extractor.ProcessChatDeltaReasoning(rawReasoning)
|
||||
} else {
|
||||
reasoningDelta = goReasoning
|
||||
contentDelta = goContent
|
||||
}
|
||||
|
||||
delta := &schema.Message{}
|
||||
if contentDelta != "" {
|
||||
delta.Content = &contentDelta
|
||||
}
|
||||
if reasoningDelta != "" {
|
||||
delta.Reasoning = &reasoningDelta
|
||||
}
|
||||
|
||||
responses <- schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{Delta: delta, Index: 0, FinishReason: nil}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
return true
|
||||
})
|
||||
close(responses)
|
||||
return finalUsage, err
|
||||
}
|
||||
|
||||
// processStreamWithTools is the streaming worker for chat completions
|
||||
// with tools / function calling. Same contract as processStream: pushes
|
||||
// chunks onto `responses`, closes the channel, returns the cumulative
|
||||
// TokenUsage.
|
||||
//
|
||||
// Returning the TokenUsage as a normal Go value (rather than smuggling
|
||||
// it on a sentinel chunk) is the fix for issue #9927 — the previous
|
||||
// implementation discarded the value from ComputeChoices, so the
|
||||
// include_usage trailer reported zeros whenever `tools` was in play.
|
||||
func processStreamWithTools(
|
||||
noAction string,
|
||||
prompt string,
|
||||
req *schema.OpenAIRequest,
|
||||
cfg *config.ModelConfig,
|
||||
cl *config.ModelConfigLoader,
|
||||
startupOptions *config.ApplicationConfig,
|
||||
loader *model.ModelLoader,
|
||||
responses chan schema.OpenAIResponse,
|
||||
id string,
|
||||
created int,
|
||||
textContentToReturn *string,
|
||||
) (backend.TokenUsage, error) {
|
||||
// Detect if thinking token is already in prompt or template
|
||||
var template string
|
||||
if cfg.TemplateConfig.UseTokenizerTemplate {
|
||||
template = cfg.GetModelTemplate()
|
||||
} else {
|
||||
template = prompt
|
||||
}
|
||||
thinkingStartToken := reason.DetectThinkingStartToken(template, &cfg.ReasoningConfig)
|
||||
extractor := reason.NewReasoningExtractor(thinkingStartToken, cfg.ReasoningConfig)
|
||||
|
||||
result := ""
|
||||
lastEmittedCount := 0
|
||||
sentInitialRole := false
|
||||
sentReasoning := false
|
||||
hasChatDeltaToolCalls := false
|
||||
hasChatDeltaContent := false
|
||||
|
||||
_, finalUsage, chatDeltas, err := ComputeChoices(req, prompt, cfg, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
||||
result += s
|
||||
|
||||
// Track whether ChatDeltas from the C++ autoparser contain
|
||||
// tool calls or content, so the retry decision can account for them.
|
||||
for _, d := range usage.ChatDeltas {
|
||||
if len(d.ToolCalls) > 0 {
|
||||
hasChatDeltaToolCalls = true
|
||||
}
|
||||
if d.Content != "" {
|
||||
hasChatDeltaContent = true
|
||||
}
|
||||
}
|
||||
|
||||
var reasoningDelta, contentDelta string
|
||||
|
||||
goReasoning, goContent := extractor.ProcessToken(s)
|
||||
|
||||
if usage.HasChatDeltaContent() {
|
||||
rawReasoning, cd := usage.ChatDeltaReasoningAndContent()
|
||||
contentDelta = cd
|
||||
reasoningDelta = extractor.ProcessChatDeltaReasoning(rawReasoning)
|
||||
} else {
|
||||
reasoningDelta = goReasoning
|
||||
contentDelta = goContent
|
||||
}
|
||||
|
||||
// Emit reasoning deltas in their own SSE chunks before any tool-call chunks
|
||||
// (OpenAI spec: reasoning and tool_calls never share a delta)
|
||||
if reasoningDelta != "" {
|
||||
responses <- schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model,
|
||||
Choices: []schema.Choice{{
|
||||
Delta: &schema.Message{Reasoning: &reasoningDelta},
|
||||
Index: 0,
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
sentReasoning = true
|
||||
}
|
||||
|
||||
// Stream content deltas (cleaned of reasoning tags) while no tool calls
|
||||
// have been detected. Once the incremental parser finds tool calls,
|
||||
// content stops: per OpenAI spec, content and tool_calls don't mix.
|
||||
if lastEmittedCount == 0 && contentDelta != "" {
|
||||
if !sentInitialRole {
|
||||
responses <- schema.OpenAIResponse{
|
||||
ID: id, Created: created, Model: req.Model,
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
sentInitialRole = true
|
||||
}
|
||||
responses <- schema.OpenAIResponse{
|
||||
ID: id, Created: created, Model: req.Model,
|
||||
Choices: []schema.Choice{{
|
||||
Delta: &schema.Message{Content: &contentDelta},
|
||||
Index: 0,
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
}
|
||||
|
||||
// Try incremental XML parsing for streaming support using iterative parser
|
||||
// This allows emitting partial tool calls as they're being generated
|
||||
cleanedResult := functions.CleanupLLMResult(result, cfg.FunctionsConfig)
|
||||
|
||||
// Determine XML format from config
|
||||
var xmlFormat *functions.XMLToolCallFormat
|
||||
if cfg.FunctionsConfig.XMLFormat != nil {
|
||||
xmlFormat = cfg.FunctionsConfig.XMLFormat
|
||||
} else if cfg.FunctionsConfig.XMLFormatPreset != "" {
|
||||
xmlFormat = functions.GetXMLFormatPreset(cfg.FunctionsConfig.XMLFormatPreset)
|
||||
}
|
||||
|
||||
// Use iterative parser for streaming (partial parsing enabled)
|
||||
// Try XML parsing first
|
||||
partialResults, parseErr := functions.ParseXMLIterative(cleanedResult, xmlFormat, true)
|
||||
if parseErr == nil && len(partialResults) > 0 {
|
||||
// Emit new XML tool calls that weren't emitted before
|
||||
if len(partialResults) > lastEmittedCount {
|
||||
for i := lastEmittedCount; i < len(partialResults); i++ {
|
||||
toolCall := partialResults[i]
|
||||
initialMessage := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model,
|
||||
Choices: []schema.Choice{{
|
||||
Delta: &schema.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []schema.ToolCall{
|
||||
{
|
||||
Index: i,
|
||||
ID: id,
|
||||
Type: "function",
|
||||
FunctionCall: schema.FunctionCall{
|
||||
Name: toolCall.Name,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Index: 0,
|
||||
FinishReason: nil,
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
select {
|
||||
case responses <- initialMessage:
|
||||
default:
|
||||
}
|
||||
}
|
||||
lastEmittedCount = len(partialResults)
|
||||
}
|
||||
} else {
|
||||
// Try JSON tool call parsing for streaming.
|
||||
// Only emit NEW tool calls (same guard as XML parser above).
|
||||
jsonResults, jsonErr := functions.ParseJSONIterative(cleanedResult, true)
|
||||
if jsonErr == nil && len(jsonResults) > lastEmittedCount {
|
||||
for i := lastEmittedCount; i < len(jsonResults); i++ {
|
||||
jsonObj := jsonResults[i]
|
||||
name, ok := jsonObj["name"].(string)
|
||||
if !ok || name == "" {
|
||||
continue
|
||||
}
|
||||
args := "{}"
|
||||
if argsVal, ok := jsonObj["arguments"]; ok {
|
||||
if argsStr, ok := argsVal.(string); ok {
|
||||
args = argsStr
|
||||
} else {
|
||||
argsBytes, _ := json.Marshal(argsVal)
|
||||
args = string(argsBytes)
|
||||
}
|
||||
}
|
||||
initialMessage := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model,
|
||||
Choices: []schema.Choice{{
|
||||
Delta: &schema.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []schema.ToolCall{
|
||||
{
|
||||
Index: i,
|
||||
ID: id,
|
||||
Type: "function",
|
||||
FunctionCall: schema.FunctionCall{
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Index: 0,
|
||||
FinishReason: nil,
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
responses <- initialMessage
|
||||
}
|
||||
lastEmittedCount = len(jsonResults)
|
||||
}
|
||||
}
|
||||
return true
|
||||
},
|
||||
func(attempt int) bool {
|
||||
// After streaming completes: check if we got actionable content
|
||||
cleaned := extractor.CleanedContent()
|
||||
// Check for tool calls from chat deltas (will be re-checked after ComputeChoices,
|
||||
// but we need to know here whether to retry).
|
||||
// Also check ChatDelta flags: when the C++ autoparser is active,
|
||||
// tool calls and content are delivered via ChatDeltas while the
|
||||
// raw message is cleared. Without this check, we'd retry
|
||||
// unnecessarily, losing valid results and concatenating output.
|
||||
hasToolCalls := lastEmittedCount > 0 || hasChatDeltaToolCalls
|
||||
hasContent := cleaned != "" || hasChatDeltaContent
|
||||
if !hasContent && !hasToolCalls {
|
||||
xlog.Warn("Streaming: backend produced only reasoning, retrying",
|
||||
"reasoning_len", len(extractor.Reasoning()), "attempt", attempt+1)
|
||||
extractor.ResetAndSuppressReasoning()
|
||||
result = ""
|
||||
lastEmittedCount = 0
|
||||
sentInitialRole = false
|
||||
hasChatDeltaToolCalls = false
|
||||
hasChatDeltaContent = false
|
||||
return true
|
||||
}
|
||||
return false
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return finalUsage, err
|
||||
}
|
||||
// Try using pre-parsed tool calls from C++ autoparser (chat deltas)
|
||||
var functionResults []functions.FuncCallResults
|
||||
var reasoning string
|
||||
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 {
|
||||
xlog.Debug("[ChatDeltas] Using pre-parsed tool calls from C++ autoparser", "count", len(deltaToolCalls))
|
||||
functionResults = deltaToolCalls
|
||||
// Use content/reasoning from deltas too
|
||||
*textContentToReturn = functions.ContentFromChatDeltas(chatDeltas)
|
||||
reasoning = functions.ReasoningFromChatDeltas(chatDeltas)
|
||||
} else {
|
||||
// Fallback: parse tool calls from raw text (no chat deltas from backend)
|
||||
xlog.Debug("[ChatDeltas] no pre-parsed tool calls, falling back to Go-side text parsing")
|
||||
reasoning = extractor.Reasoning()
|
||||
cleanedResult := extractor.CleanedContent()
|
||||
*textContentToReturn = functions.ParseTextContent(cleanedResult, cfg.FunctionsConfig)
|
||||
cleanedResult = functions.CleanupLLMResult(cleanedResult, cfg.FunctionsConfig)
|
||||
functionResults = functions.ParseFunctionCall(cleanedResult, cfg.FunctionsConfig)
|
||||
}
|
||||
xlog.Debug("[ChatDeltas] final tool call decision", "tool_calls", len(functionResults), "text_content", *textContentToReturn)
|
||||
// noAction is a sentinel "just answer" pseudo-function: not a real
|
||||
// tool call. Scan the whole slice rather than only index 0 so we
|
||||
// don't drop a real tool call that happens to follow a noAction
|
||||
// entry, and so the default branch isn't entered with only noAction
|
||||
// entries to emit as tool_calls.
|
||||
noActionToRun := !hasRealCall(functionResults, noAction)
|
||||
|
||||
switch {
|
||||
case noActionToRun:
|
||||
// The final usage trailer (when the caller opted in with
|
||||
// stream_options.include_usage) is built by the outer streaming
|
||||
// loop from the TokenUsage this function returns, not from any
|
||||
// chunk on the responses channel.
|
||||
var result string
|
||||
if !sentInitialRole {
|
||||
var hqErr error
|
||||
result, hqErr = handleQuestion(cfg, functionResults, extractor.CleanedContent(), prompt)
|
||||
if hqErr != nil {
|
||||
xlog.Error("error handling question", "error", hqErr)
|
||||
return finalUsage, hqErr
|
||||
}
|
||||
}
|
||||
for _, chunk := range buildNoActionFinalChunks(
|
||||
id, req.Model, created,
|
||||
sentInitialRole, sentReasoning,
|
||||
result, reasoning,
|
||||
) {
|
||||
responses <- chunk
|
||||
}
|
||||
|
||||
default:
|
||||
for _, chunk := range buildDeferredToolCallChunks(
|
||||
id, req.Model, created,
|
||||
functionResults, lastEmittedCount,
|
||||
sentInitialRole, *textContentToReturn,
|
||||
sentReasoning, reasoning,
|
||||
) {
|
||||
responses <- chunk
|
||||
}
|
||||
}
|
||||
|
||||
close(responses)
|
||||
return finalUsage, err
|
||||
}
|
||||
@@ -17,16 +17,20 @@ import (
|
||||
)
|
||||
|
||||
type APIExchangeRequest struct {
|
||||
Method string `json:"method"`
|
||||
Path string `json:"path"`
|
||||
Headers *http.Header `json:"headers"`
|
||||
Body *[]byte `json:"body"`
|
||||
Method string `json:"method"`
|
||||
Path string `json:"path"`
|
||||
Headers *http.Header `json:"headers"`
|
||||
Body *[]byte `json:"body"`
|
||||
BodyTruncated bool `json:"body_truncated,omitempty"`
|
||||
BodyBytes int `json:"body_bytes,omitempty"` // original size before truncation
|
||||
}
|
||||
|
||||
type APIExchangeResponse struct {
|
||||
Status int `json:"status"`
|
||||
Headers *http.Header `json:"headers"`
|
||||
Body *[]byte `json:"body"`
|
||||
Status int `json:"status"`
|
||||
Headers *http.Header `json:"headers"`
|
||||
Body *[]byte `json:"body"`
|
||||
BodyTruncated bool `json:"body_truncated,omitempty"`
|
||||
BodyBytes int `json:"body_bytes,omitempty"` // original size before truncation
|
||||
}
|
||||
|
||||
type APIExchange struct {
|
||||
@@ -66,11 +70,29 @@ var doInitializeTracing = sync.OnceFunc(func() {
|
||||
|
||||
type bodyWriter struct {
|
||||
http.ResponseWriter
|
||||
body *bytes.Buffer
|
||||
body *bytes.Buffer
|
||||
maxBytes int // 0 = unlimited capture
|
||||
truncated bool
|
||||
totalBytes int // bytes the upstream handler wrote, even past the cap
|
||||
}
|
||||
|
||||
func (w *bodyWriter) Write(b []byte) (int, error) {
|
||||
w.body.Write(b)
|
||||
// Capture into the trace buffer up to maxBytes, then drop the overflow
|
||||
// so a chatty endpoint can't grow the buffer without bound. The full
|
||||
// payload still flows through to the real client below.
|
||||
w.totalBytes += len(b)
|
||||
if w.maxBytes <= 0 {
|
||||
w.body.Write(b)
|
||||
} else if remain := w.maxBytes - w.body.Len(); remain > 0 {
|
||||
if remain >= len(b) {
|
||||
w.body.Write(b)
|
||||
} else {
|
||||
w.body.Write(b[:remain])
|
||||
w.truncated = true
|
||||
}
|
||||
} else {
|
||||
w.truncated = true
|
||||
}
|
||||
return w.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
@@ -80,6 +102,20 @@ func (w *bodyWriter) Flush() {
|
||||
}
|
||||
}
|
||||
|
||||
// truncateForTrace returns a defensive copy of body capped at maxBytes,
|
||||
// and a flag indicating whether the cap forced truncation. maxBytes <= 0
|
||||
// disables the cap.
|
||||
func truncateForTrace(body []byte, maxBytes int) ([]byte, bool) {
|
||||
if maxBytes <= 0 || len(body) <= maxBytes {
|
||||
out := make([]byte, len(body))
|
||||
copy(out, body)
|
||||
return out, false
|
||||
}
|
||||
out := make([]byte, maxBytes)
|
||||
copy(out, body[:maxBytes])
|
||||
return out, true
|
||||
}
|
||||
|
||||
func initializeTracing(maxItems int) {
|
||||
tracingMaxItems = maxItems
|
||||
doInitializeTracing()
|
||||
@@ -134,11 +170,18 @@ func TraceMiddleware(app *application.Application) echo.MiddlewareFunc {
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
// Cap captured payload size. Without this, /embeddings and
|
||||
// streaming /chat/completions blow the in-memory buffer into the
|
||||
// tens of MB, which then locks the admin Traces UI fetching the
|
||||
// JSON dump faster than the 5s auto-refresh.
|
||||
maxBodyBytes := app.ApplicationConfig().TracingMaxBodyBytes
|
||||
|
||||
// Wrap response writer to capture body
|
||||
resBody := new(bytes.Buffer)
|
||||
mw := &bodyWriter{
|
||||
ResponseWriter: c.Response().Writer,
|
||||
body: resBody,
|
||||
maxBytes: maxBodyBytes,
|
||||
}
|
||||
c.Response().Writer = mw
|
||||
|
||||
@@ -159,8 +202,7 @@ func TraceMiddleware(app *application.Application) echo.MiddlewareFunc {
|
||||
// via any heap-dump-style introspection, and tokens shouldn't
|
||||
// outlive the request that carried them.
|
||||
requestHeaders := redactSensitiveHeaders(c.Request().Header)
|
||||
requestBody := make([]byte, len(body))
|
||||
copy(requestBody, body)
|
||||
requestBody, requestTruncated := truncateForTrace(body, maxBodyBytes)
|
||||
responseHeaders := redactSensitiveHeaders(c.Response().Header())
|
||||
responseBody := make([]byte, resBody.Len())
|
||||
copy(responseBody, resBody.Bytes())
|
||||
@@ -168,15 +210,19 @@ func TraceMiddleware(app *application.Application) echo.MiddlewareFunc {
|
||||
Timestamp: startTime,
|
||||
Duration: time.Since(startTime),
|
||||
Request: APIExchangeRequest{
|
||||
Method: c.Request().Method,
|
||||
Path: c.Path(),
|
||||
Headers: &requestHeaders,
|
||||
Body: &requestBody,
|
||||
Method: c.Request().Method,
|
||||
Path: c.Path(),
|
||||
Headers: &requestHeaders,
|
||||
Body: &requestBody,
|
||||
BodyTruncated: requestTruncated,
|
||||
BodyBytes: len(body),
|
||||
},
|
||||
Response: APIExchangeResponse{
|
||||
Status: status,
|
||||
Headers: &responseHeaders,
|
||||
Body: &responseBody,
|
||||
Status: status,
|
||||
Headers: &responseHeaders,
|
||||
Body: &responseBody,
|
||||
BodyTruncated: mw.truncated,
|
||||
BodyBytes: mw.totalBytes,
|
||||
},
|
||||
}
|
||||
if handlerErr != nil {
|
||||
|
||||
116
core/http/middleware/trace_body_cap_test.go
Normal file
116
core/http/middleware/trace_body_cap_test.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// The trace middleware copies request and response bodies into an in-memory
|
||||
// buffer that backs the admin /api/traces endpoint. With no upper bound a
|
||||
// chatty workload (embeddings, large completions) trivially produces a
|
||||
// multi-MB response that locks the Traces UI in a loading state — fetching
|
||||
// and parsing the payload outruns the 5-second auto-refresh. These specs
|
||||
// pin the capping contract so future refactors keep both the cap and the
|
||||
// passthrough to the real client intact.
|
||||
|
||||
var _ = Describe("bodyWriter capping", func() {
|
||||
It("captures the full body when maxBytes is 0 (unlimited)", func() {
|
||||
downstream := httptest.NewRecorder()
|
||||
buf := &bytes.Buffer{}
|
||||
bw := &bodyWriter{ResponseWriter: downstream, body: buf, maxBytes: 0}
|
||||
|
||||
payload := []byte(strings.Repeat("x", 4096))
|
||||
n, err := bw.Write(payload)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(len(payload)))
|
||||
Expect(buf.Len()).To(Equal(len(payload)))
|
||||
Expect(downstream.Body.Len()).To(Equal(len(payload)))
|
||||
Expect(bw.truncated).To(BeFalse())
|
||||
})
|
||||
|
||||
It("stops appending to the trace buffer once maxBytes is reached but still forwards to the client", func() {
|
||||
downstream := httptest.NewRecorder()
|
||||
buf := &bytes.Buffer{}
|
||||
bw := &bodyWriter{ResponseWriter: downstream, body: buf, maxBytes: 100}
|
||||
|
||||
payload := []byte(strings.Repeat("a", 250))
|
||||
n, err := bw.Write(payload)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(len(payload)), "Write must return the full byte count so callers see no short write")
|
||||
Expect(buf.Len()).To(Equal(100), "trace buffer should hold exactly maxBytes")
|
||||
Expect(downstream.Body.Len()).To(Equal(len(payload)), "client must still receive every byte")
|
||||
Expect(bw.truncated).To(BeTrue())
|
||||
})
|
||||
|
||||
It("handles a write that straddles the cap by keeping only the leading slice", func() {
|
||||
downstream := httptest.NewRecorder()
|
||||
buf := &bytes.Buffer{}
|
||||
bw := &bodyWriter{ResponseWriter: downstream, body: buf, maxBytes: 10}
|
||||
|
||||
_, err := bw.Write([]byte("12345"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(bw.truncated).To(BeFalse())
|
||||
|
||||
_, err = bw.Write([]byte("67890ABCDE"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(buf.String()).To(Equal("1234567890"))
|
||||
Expect(downstream.Body.String()).To(Equal("1234567890ABCDE"))
|
||||
Expect(bw.truncated).To(BeTrue())
|
||||
})
|
||||
|
||||
It("ignores further writes after the cap was already hit", func() {
|
||||
downstream := httptest.NewRecorder()
|
||||
buf := &bytes.Buffer{}
|
||||
bw := &bodyWriter{ResponseWriter: downstream, body: buf, maxBytes: 4}
|
||||
|
||||
_, _ = bw.Write([]byte("AAAA"))
|
||||
_, _ = bw.Write([]byte("BBBB"))
|
||||
_, _ = bw.Write([]byte("CCCC"))
|
||||
|
||||
Expect(buf.String()).To(Equal("AAAA"))
|
||||
Expect(downstream.Body.String()).To(Equal("AAAABBBBCCCC"))
|
||||
Expect(bw.truncated).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("truncateForTrace", func() {
|
||||
It("returns the input unchanged when below the cap", func() {
|
||||
in := []byte("hello")
|
||||
out, truncated := truncateForTrace(in, 1024)
|
||||
Expect(truncated).To(BeFalse())
|
||||
Expect(out).To(Equal(in))
|
||||
})
|
||||
|
||||
It("truncates when the input exceeds the cap and signals truncation", func() {
|
||||
in := []byte(strings.Repeat("z", 200))
|
||||
out, truncated := truncateForTrace(in, 64)
|
||||
Expect(truncated).To(BeTrue())
|
||||
Expect(out).To(HaveLen(64))
|
||||
Expect(string(out)).To(Equal(strings.Repeat("z", 64)))
|
||||
})
|
||||
|
||||
It("treats maxBytes <= 0 as unlimited (back-compat with current default)", func() {
|
||||
in := []byte(strings.Repeat("q", 10_000))
|
||||
out, truncated := truncateForTrace(in, 0)
|
||||
Expect(truncated).To(BeFalse())
|
||||
Expect(out).To(HaveLen(len(in)))
|
||||
})
|
||||
|
||||
It("does not retain the caller's backing array (defensive copy)", func() {
|
||||
in := []byte("abcdefghij")
|
||||
out, truncated := truncateForTrace(in, 4)
|
||||
Expect(truncated).To(BeTrue())
|
||||
Expect(string(out)).To(Equal("abcd"))
|
||||
|
||||
// Mutating the source must not corrupt the trace copy.
|
||||
in[0] = 'Z'
|
||||
Expect(string(out)).To(Equal("abcd"))
|
||||
})
|
||||
})
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
@@ -14,18 +15,37 @@ import (
|
||||
|
||||
const (
|
||||
usageFlushInterval = 5 * time.Second
|
||||
usageMaxPending = 5000
|
||||
// usageMaxPending bounds the in-memory queue. Sized for bursty inference
|
||||
// traffic on a self-hosted instance with a slow or unavailable DB.
|
||||
usageMaxPending = 50000
|
||||
)
|
||||
|
||||
// usageBatcher accumulates usage records and flushes them to the DB periodically.
|
||||
type usageBatcher struct {
|
||||
mu sync.Mutex
|
||||
pending []*auth.UsageRecord
|
||||
db *gorm.DB
|
||||
mu sync.Mutex
|
||||
pending []*auth.UsageRecord
|
||||
db *gorm.DB
|
||||
stop chan struct{}
|
||||
done chan struct{}
|
||||
stopOnce sync.Once
|
||||
}
|
||||
|
||||
// droppedRecords counts records discarded because the in-memory queue was full.
|
||||
// Used to rate-limit the warn log so a sustained outage doesn't flood it.
|
||||
var droppedRecords atomic.Uint64
|
||||
|
||||
func (b *usageBatcher) add(r *auth.UsageRecord) {
|
||||
b.mu.Lock()
|
||||
if len(b.pending) >= usageMaxPending {
|
||||
b.mu.Unlock()
|
||||
// Rate-limit: one warn per 1024 drops keeps the log readable.
|
||||
n := droppedRecords.Add(1)
|
||||
if n&1023 == 1 {
|
||||
xlog.Warn("usage batcher full, dropping record",
|
||||
"cap", usageMaxPending, "total_dropped", n)
|
||||
}
|
||||
return
|
||||
}
|
||||
b.pending = append(b.pending, r)
|
||||
b.mu.Unlock()
|
||||
}
|
||||
@@ -42,31 +62,102 @@ func (b *usageBatcher) flush() {
|
||||
|
||||
if err := b.db.Create(&batch).Error; err != nil {
|
||||
xlog.Error("Failed to flush usage batch", "count", len(batch), "error", err)
|
||||
// Re-queue failed records with a cap to avoid unbounded growth
|
||||
// Cap-aware re-queue: prepend as much of the failed batch as fits
|
||||
// alongside any records added concurrently with the failed write.
|
||||
b.mu.Lock()
|
||||
if len(b.pending) < usageMaxPending {
|
||||
b.pending = append(batch, b.pending...)
|
||||
room := usageMaxPending - len(b.pending)
|
||||
if room > 0 {
|
||||
if room > len(batch) {
|
||||
room = len(batch)
|
||||
}
|
||||
b.pending = append(batch[:room], b.pending...)
|
||||
}
|
||||
b.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
var batcher *usageBatcher
|
||||
func (b *usageBatcher) run() {
|
||||
defer close(b.done)
|
||||
ticker := time.NewTicker(usageFlushInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
b.flush()
|
||||
case <-b.stop:
|
||||
b.flush() // final drain
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *usageBatcher) shutdown() {
|
||||
b.stopOnce.Do(func() {
|
||||
close(b.stop)
|
||||
<-b.done
|
||||
})
|
||||
}
|
||||
|
||||
// The package-level batcher is guarded by batcherMu so Init / Shutdown cycles
|
||||
// (the test pattern) don't race against UsageMiddleware reads.
|
||||
var (
|
||||
batcherMu sync.RWMutex
|
||||
batcher *usageBatcher
|
||||
)
|
||||
|
||||
func currentBatcher() *usageBatcher {
|
||||
batcherMu.RLock()
|
||||
defer batcherMu.RUnlock()
|
||||
return batcher
|
||||
}
|
||||
|
||||
// InitUsageRecorder starts a background goroutine that periodically flushes
|
||||
// accumulated usage records to the database.
|
||||
// accumulated usage records to the database. Calling it more than once
|
||||
// shuts down the previous batcher first so its goroutine doesn't leak.
|
||||
func InitUsageRecorder(db *gorm.DB) {
|
||||
if db == nil {
|
||||
return
|
||||
}
|
||||
batcher = &usageBatcher{db: db}
|
||||
go func() {
|
||||
ticker := time.NewTicker(usageFlushInterval)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
batcher.flush()
|
||||
}
|
||||
}()
|
||||
|
||||
batcherMu.Lock()
|
||||
old := batcher
|
||||
batcher = nil
|
||||
batcherMu.Unlock()
|
||||
if old != nil {
|
||||
old.shutdown()
|
||||
}
|
||||
|
||||
b := &usageBatcher{
|
||||
db: db,
|
||||
stop: make(chan struct{}),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
batcherMu.Lock()
|
||||
batcher = b
|
||||
batcherMu.Unlock()
|
||||
|
||||
go b.run()
|
||||
}
|
||||
|
||||
// ShutdownUsageRecorder stops the background flusher and synchronously drains
|
||||
// pending records once. Safe to call multiple times. Not yet wired into the
|
||||
// application lifecycle; intended for graceful process exit and tests.
|
||||
func ShutdownUsageRecorder() {
|
||||
batcherMu.Lock()
|
||||
b := batcher
|
||||
batcher = nil
|
||||
batcherMu.Unlock()
|
||||
if b != nil {
|
||||
b.shutdown()
|
||||
}
|
||||
}
|
||||
|
||||
// FlushNow synchronously flushes any pending usage records. Intended for tests
|
||||
// that need deterministic behaviour without waiting for the ticker.
|
||||
func FlushNow() {
|
||||
if b := currentBatcher(); b != nil {
|
||||
b.flush()
|
||||
}
|
||||
}
|
||||
|
||||
// usageResponseBody is the minimal structure we need from the response JSON.
|
||||
@@ -84,7 +175,8 @@ type usageResponseBody struct {
|
||||
func UsageMiddleware(db *gorm.DB) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if db == nil || batcher == nil {
|
||||
b := currentBatcher()
|
||||
if db == nil || b == nil {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
@@ -149,9 +241,17 @@ func UsageMiddleware(db *gorm.DB) echo.MiddlewareFunc {
|
||||
return handlerErr
|
||||
}
|
||||
|
||||
source := auth.GetSource(c)
|
||||
if source == "" {
|
||||
// Auth disabled or unrecognised path: classify as web so the row is still
|
||||
// bucketable rather than silently dropped from per-source aggregates.
|
||||
source = auth.UsageSourceWeb
|
||||
}
|
||||
|
||||
record := &auth.UsageRecord{
|
||||
UserID: user.ID,
|
||||
UserName: user.Name,
|
||||
Source: source,
|
||||
Model: resp.Model,
|
||||
Endpoint: c.Request().URL.Path,
|
||||
PromptTokens: resp.Usage.PromptTokens,
|
||||
@@ -161,7 +261,13 @@ func UsageMiddleware(db *gorm.DB) echo.MiddlewareFunc {
|
||||
CreatedAt: startTime,
|
||||
}
|
||||
|
||||
batcher.add(record)
|
||||
if key := auth.GetAPIKey(c); key != nil {
|
||||
id := key.ID
|
||||
record.APIKeyID = &id
|
||||
record.APIKeyName = key.Name
|
||||
}
|
||||
|
||||
b.add(record)
|
||||
|
||||
return handlerErr
|
||||
}
|
||||
|
||||
140
core/http/middleware/usage_test.go
Normal file
140
core/http/middleware/usage_test.go
Normal file
@@ -0,0 +1,140 @@
|
||||
//go:build auth
|
||||
|
||||
package middleware_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// testAuthDB returns a fresh in-memory SQLite auth DB.
|
||||
func testAuthDB() *gorm.DB {
|
||||
db, err := auth.InitDB(":memory:")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
var _ = Describe("UsageMiddleware", func() {
|
||||
var (
|
||||
e *echo.Echo
|
||||
db *gorm.DB
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
db = testAuthDB()
|
||||
e = echo.New()
|
||||
middleware.InitUsageRecorder(db)
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
middleware.ShutdownUsageRecorder()
|
||||
})
|
||||
|
||||
okHandler := func(c echo.Context) error {
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"model": "gpt-4",
|
||||
"usage": map[string]int{
|
||||
"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15,
|
||||
},
|
||||
})
|
||||
c.Response().Header().Set("Content-Type", "application/json")
|
||||
c.Response().WriteHeader(http.StatusOK)
|
||||
_, _ = c.Response().Write(body)
|
||||
return nil
|
||||
}
|
||||
|
||||
// FlushNow drains pending records synchronously, replacing the 6s sleep
|
||||
// that was previously needed to wait for the batcher's ticker.
|
||||
flush := middleware.FlushNow
|
||||
|
||||
It("records source=web when auth_source is web", func() {
|
||||
e.POST("/v1/chat/completions", okHandler, func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
c.Set("auth_user", &auth.User{ID: "alice", Name: "Alice"})
|
||||
c.Set("auth_source", auth.UsageSourceWeb)
|
||||
return next(c)
|
||||
}
|
||||
}, middleware.UsageMiddleware(db))
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader([]byte(`{}`)))
|
||||
e.ServeHTTP(httptest.NewRecorder(), req)
|
||||
flush()
|
||||
|
||||
var rec auth.UsageRecord
|
||||
Expect(db.Where("user_id = ?", "alice").First(&rec).Error).To(Succeed())
|
||||
Expect(rec.Source).To(Equal(auth.UsageSourceWeb))
|
||||
Expect(rec.APIKeyID).To(BeNil())
|
||||
Expect(rec.APIKeyName).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("records source=apikey with snapshotted name when auth_apikey is set", func() {
|
||||
e.POST("/v1/chat/completions", okHandler, func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
c.Set("auth_user", &auth.User{ID: "alice", Name: "Alice"})
|
||||
c.Set("auth_source", auth.UsageSourceAPIKey)
|
||||
c.Set("auth_apikey", &auth.UserAPIKey{ID: "key-1", Name: "ci-runner"})
|
||||
return next(c)
|
||||
}
|
||||
}, middleware.UsageMiddleware(db))
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader([]byte(`{}`)))
|
||||
e.ServeHTTP(httptest.NewRecorder(), req)
|
||||
flush()
|
||||
|
||||
var rec auth.UsageRecord
|
||||
Expect(db.Where("user_id = ?", "alice").First(&rec).Error).To(Succeed())
|
||||
Expect(rec.Source).To(Equal(auth.UsageSourceAPIKey))
|
||||
Expect(rec.APIKeyID).ToNot(BeNil())
|
||||
Expect(*rec.APIKeyID).To(Equal("key-1"))
|
||||
Expect(rec.APIKeyName).To(Equal("ci-runner"))
|
||||
})
|
||||
|
||||
It("FlushNow drains pending records synchronously", func() {
|
||||
e.POST("/v1/chat/completions", okHandler, func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
c.Set("auth_user", &auth.User{ID: "carol", Name: "Carol"})
|
||||
c.Set("auth_source", auth.UsageSourceWeb)
|
||||
return next(c)
|
||||
}
|
||||
}, middleware.UsageMiddleware(db))
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader([]byte(`{}`)))
|
||||
e.ServeHTTP(httptest.NewRecorder(), req)
|
||||
|
||||
// No sleep: FlushNow should drain immediately.
|
||||
middleware.FlushNow()
|
||||
|
||||
var rec auth.UsageRecord
|
||||
Expect(db.Where("user_id = ?", "carol").First(&rec).Error).To(Succeed())
|
||||
Expect(rec.Source).To(Equal(auth.UsageSourceWeb))
|
||||
})
|
||||
|
||||
It("falls back to source=web when auth_source is empty", func() {
|
||||
e.POST("/v1/chat/completions", okHandler, func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
c.Set("auth_user", &auth.User{ID: "alice", Name: "Alice"})
|
||||
// no auth_source set
|
||||
return next(c)
|
||||
}
|
||||
}, middleware.UsageMiddleware(db))
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader([]byte(`{}`)))
|
||||
e.ServeHTTP(httptest.NewRecorder(), req)
|
||||
flush()
|
||||
|
||||
var rec auth.UsageRecord
|
||||
Expect(db.Where("user_id = ?", "alice").First(&rec).Error).To(Succeed())
|
||||
Expect(rec.Source).To(Equal(auth.UsageSourceWeb))
|
||||
})
|
||||
})
|
||||
@@ -53,7 +53,30 @@
|
||||
},
|
||||
"usage": {
|
||||
"title": "Usage",
|
||||
"subtitle": "API token usage statistics"
|
||||
"subtitle": "API token usage statistics",
|
||||
"sources": {
|
||||
"tab": "Sources",
|
||||
"mixTitle": "Source mix",
|
||||
"ribbonAria": "{{apikey}}% API keys, {{web}}% Web UI, {{legacy}}% Legacy",
|
||||
"topSources": "Top sources over time",
|
||||
"searchPlaceholder": "Search by name or prefix",
|
||||
"sortBy": "Sort",
|
||||
"sortTokens": "Tokens",
|
||||
"sortRequests": "Requests",
|
||||
"sortLastUsed": "Last used",
|
||||
"sortName": "Name",
|
||||
"sortUser": "User",
|
||||
"webUI": "Web UI",
|
||||
"legacy": "Legacy",
|
||||
"revoked": "revoked",
|
||||
"filteredTo": "Filtered to: {{name}}",
|
||||
"clearFilter": "Clear filter",
|
||||
"other": "Other ({{count}})",
|
||||
"noTrafficShort": "No requests in this period.",
|
||||
"noKeysYet": "Once requests come in, you'll see them broken down here.",
|
||||
"createKey": "Create your first API key",
|
||||
"truncatedWarning": "Showing top 200 keys. Apply a filter to narrow further."
|
||||
}
|
||||
},
|
||||
"explorer": {
|
||||
"title": "Explorer",
|
||||
|
||||
@@ -649,6 +649,7 @@
|
||||
align-items: center;
|
||||
gap: var(--spacing-md);
|
||||
padding: var(--spacing-xs) 0;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.operation-info {
|
||||
@@ -739,6 +740,110 @@
|
||||
color: var(--color-error);
|
||||
}
|
||||
|
||||
/* Operations bar: per-node breakdown (multi-worker installs) */
|
||||
.operation-expand {
|
||||
background: none;
|
||||
border: none;
|
||||
color: var(--color-text-muted);
|
||||
cursor: pointer;
|
||||
padding: 0 var(--spacing-xs);
|
||||
font-size: var(--text-xs);
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 0.25rem;
|
||||
}
|
||||
.operation-expand:hover {
|
||||
color: var(--color-text-primary);
|
||||
}
|
||||
.operation-expand-label {
|
||||
font-size: var(--text-xs);
|
||||
}
|
||||
|
||||
.operation-nodes-list {
|
||||
list-style: none;
|
||||
margin: var(--spacing-xs) 0 0;
|
||||
padding: var(--spacing-xs) 0 0;
|
||||
border-top: 1px solid var(--color-border-subtle);
|
||||
flex-basis: 100%;
|
||||
width: 100%;
|
||||
}
|
||||
.operation-node {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: var(--spacing-sm);
|
||||
padding: var(--spacing-xs) 0;
|
||||
font-size: var(--text-xs);
|
||||
color: var(--color-text-muted);
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
.operation-node-status {
|
||||
padding: 2px 6px;
|
||||
border-radius: var(--radius-md);
|
||||
font-size: 0.65rem;
|
||||
font-weight: 600;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.025em;
|
||||
white-space: nowrap;
|
||||
}
|
||||
.operation-node-status-success {
|
||||
background: var(--color-success-light);
|
||||
color: var(--color-success);
|
||||
}
|
||||
.operation-node-status-error {
|
||||
background: var(--color-error-light);
|
||||
color: var(--color-error);
|
||||
}
|
||||
.operation-node-status-queued {
|
||||
background: var(--color-bg-tertiary);
|
||||
color: var(--color-text-muted);
|
||||
}
|
||||
.operation-node-status-running_on_worker {
|
||||
background: var(--color-warning-light);
|
||||
color: var(--color-warning);
|
||||
}
|
||||
.operation-node-status-downloading {
|
||||
background: var(--color-primary-light);
|
||||
color: var(--color-primary);
|
||||
}
|
||||
.operation-node-name {
|
||||
font-weight: 500;
|
||||
color: var(--color-text-secondary);
|
||||
}
|
||||
.operation-node-file {
|
||||
font-family: var(--font-mono);
|
||||
color: var(--color-text-tertiary);
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
max-width: 30ch;
|
||||
white-space: nowrap;
|
||||
}
|
||||
.operation-node-bytes {
|
||||
font-variant-numeric: tabular-nums;
|
||||
color: var(--color-text-tertiary);
|
||||
}
|
||||
.operation-node-pct {
|
||||
font-variant-numeric: tabular-nums;
|
||||
color: var(--color-primary);
|
||||
font-weight: 500;
|
||||
}
|
||||
.operation-node-error {
|
||||
color: var(--color-error);
|
||||
}
|
||||
.operation-node-bar-container {
|
||||
flex-basis: 100%;
|
||||
height: 3px;
|
||||
background: var(--color-surface-sunken);
|
||||
border-radius: var(--radius-full);
|
||||
overflow: hidden;
|
||||
margin-top: 0.25rem;
|
||||
}
|
||||
.operation-node-bar {
|
||||
height: 100%;
|
||||
background: var(--color-primary);
|
||||
border-radius: var(--radius-full);
|
||||
transition: width var(--duration-slow, 0.3s) var(--ease-spring, ease);
|
||||
}
|
||||
|
||||
/* Toast */
|
||||
.toast-container {
|
||||
position: fixed;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { useState, useMemo, useEffect, useRef } from 'react'
|
||||
import Modal from './Modal'
|
||||
import SearchableSelect from './SearchableSelect'
|
||||
import { nodesApi } from '../utils/api'
|
||||
import { nodesApi, backendsApi } from '../utils/api'
|
||||
|
||||
// NodeInstallPicker is the single multi-node install surface used both from
|
||||
// the Backends gallery split-button and from the "Install on more nodes" `+`
|
||||
@@ -240,6 +240,37 @@ export default function NodeInstallPicker({
|
||||
}
|
||||
const clearSelection = () => setSelected(new Set())
|
||||
|
||||
// pollJob resolves with { done: true, error?: string } once a single job
|
||||
// completes, fails, or is cancelled. Bounded by a hard wall-clock cap so a
|
||||
// stuck worker eventually surfaces in the UI as "Failed" instead of
|
||||
// spinning forever.
|
||||
const pollJob = (jobID) => new Promise((resolve) => {
|
||||
const POLL_INTERVAL_MS = 1500
|
||||
const HARD_CAP_MS = 6 * 60 * 1000 // 6 min - generous for a fresh worker download
|
||||
const startedAt = Date.now()
|
||||
|
||||
const tick = async () => {
|
||||
try {
|
||||
const status = await backendsApi.getJob(jobID)
|
||||
if (status?.completed) { resolve({ done: true }); return }
|
||||
if (status?.error) { resolve({ done: true, error: status.error }); return }
|
||||
if (status?.processed && !status?.completed) {
|
||||
resolve({ done: true, error: status.error || 'install did not complete' })
|
||||
return
|
||||
}
|
||||
} catch (err) {
|
||||
resolve({ done: true, error: err?.message || 'polling failed' })
|
||||
return
|
||||
}
|
||||
if (Date.now() - startedAt > HARD_CAP_MS) {
|
||||
resolve({ done: true, error: 'timed out waiting for install to finish' })
|
||||
return
|
||||
}
|
||||
setTimeout(tick, POLL_INTERVAL_MS)
|
||||
}
|
||||
tick()
|
||||
})
|
||||
|
||||
const submit = async () => {
|
||||
if (selected.size === 0 || submitting) return
|
||||
if (counts.overrides > 0 && !showMismatchConfirm) {
|
||||
@@ -255,38 +286,68 @@ export default function NodeInstallPicker({
|
||||
return next
|
||||
})
|
||||
|
||||
const results = await Promise.allSettled(ids.map(id =>
|
||||
// Phase 1: dispatch all installs in parallel. Each POST returns immediately
|
||||
// with { jobID } now that the handler is async.
|
||||
const dispatchResults = await Promise.allSettled(ids.map(id =>
|
||||
nodesApi.installBackend(id, effectiveBackendName)
|
||||
.then(r => ({ id, ok: true, message: r?.message }))
|
||||
.catch(err => ({ id, ok: false, error: err?.message || 'install failed' }))
|
||||
.then(r => ({ id, ok: true, jobID: r?.jobID }))
|
||||
.catch(err => ({ id, ok: false, error: err?.message || 'dispatch failed' }))
|
||||
))
|
||||
|
||||
let successCount = 0, failCount = 0
|
||||
setPerNode(prev => {
|
||||
const next = { ...prev }
|
||||
for (const r of results) {
|
||||
if (r.status !== 'fulfilled') continue
|
||||
const v = r.value
|
||||
if (v.ok) {
|
||||
next[v.id] = { status: 'done' }
|
||||
successCount++
|
||||
} else {
|
||||
next[v.id] = { status: 'error', error: v.error }
|
||||
failCount++
|
||||
}
|
||||
// Classify dispatch results synchronously OUTSIDE the setter. React may
|
||||
// invoke a functional state updater more than once (StrictMode dev double
|
||||
// invoke, concurrent rendering replay): building the jobs array inside
|
||||
// the closure would duplicate entries and re-poll the same job.
|
||||
const jobs = []
|
||||
const dispatchPatch = {}
|
||||
for (const r of dispatchResults) {
|
||||
if (r.status !== 'fulfilled') continue
|
||||
const v = r.value
|
||||
if (v.ok && v.jobID) {
|
||||
dispatchPatch[v.id] = { status: 'installing', jobID: v.jobID }
|
||||
jobs.push({ nodeID: v.id, jobID: v.jobID })
|
||||
} else {
|
||||
dispatchPatch[v.id] = { status: 'error', error: v.error || 'dispatch failed' }
|
||||
}
|
||||
return next
|
||||
}
|
||||
setPerNode(prev => ({ ...prev, ...dispatchPatch }))
|
||||
|
||||
// Phase 2: poll each job. Promise.all resolves when the last job settles;
|
||||
// intermediate updates flip per-row state via the setPerNode inside pollJob.
|
||||
await Promise.all(jobs.map(async ({ nodeID, jobID }) => {
|
||||
const result = await pollJob(jobID)
|
||||
setPerNode(prev => {
|
||||
const next = { ...prev }
|
||||
if (result.error) {
|
||||
next[nodeID] = { status: 'error', error: result.error, jobID }
|
||||
} else {
|
||||
next[nodeID] = { status: 'done', jobID }
|
||||
}
|
||||
return next
|
||||
})
|
||||
}))
|
||||
|
||||
// Phase 3: summary toast + onComplete. Read latest state via functional setter.
|
||||
let successCount = 0
|
||||
let failCount = 0
|
||||
setPerNode(prev => {
|
||||
for (const v of Object.values(prev)) {
|
||||
if (v.status === 'done') successCount++
|
||||
else if (v.status === 'error') failCount++
|
||||
}
|
||||
return prev
|
||||
})
|
||||
|
||||
setSubmitting(false)
|
||||
|
||||
if (successCount > 0 && onComplete) onComplete()
|
||||
|
||||
if (failCount === 0) {
|
||||
if (failCount === 0 && successCount > 0) {
|
||||
addToast?.(`Installed on ${successCount} node${successCount === 1 ? '' : 's'}`, 'success')
|
||||
setTimeout(() => onClose?.(), 800)
|
||||
} else if (successCount === 0) {
|
||||
} else if (successCount === 0 && failCount > 0) {
|
||||
addToast?.(`Install failed on all ${failCount} node${failCount === 1 ? '' : 's'}`, 'error')
|
||||
} else {
|
||||
} else if (successCount > 0 && failCount > 0) {
|
||||
addToast?.(`Installed on ${successCount}, failed on ${failCount}`, 'warning')
|
||||
}
|
||||
}
|
||||
@@ -297,32 +358,58 @@ export default function NodeInstallPicker({
|
||||
.map(([id]) => id)
|
||||
if (failedIds.length === 0) return
|
||||
setSelected(new Set(failedIds))
|
||||
// Replace state for failed rows so they show "installing" again, not stale errors.
|
||||
setPerNode(prev => {
|
||||
const next = { ...prev }
|
||||
failedIds.forEach(id => { next[id] = { status: 'installing' } })
|
||||
return next
|
||||
})
|
||||
setSubmitting(true)
|
||||
const results = await Promise.allSettled(failedIds.map(id =>
|
||||
|
||||
const dispatchResults = await Promise.allSettled(failedIds.map(id =>
|
||||
nodesApi.installBackend(id, effectiveBackendName)
|
||||
.then(r => ({ id, ok: true, message: r?.message }))
|
||||
.catch(err => ({ id, ok: false, error: err?.message || 'install failed' }))
|
||||
.then(r => ({ id, ok: true, jobID: r?.jobID }))
|
||||
.catch(err => ({ id, ok: false, error: err?.message || 'dispatch failed' }))
|
||||
))
|
||||
|
||||
// Same precaution as in submit(): classify outside the functional setter
|
||||
// so a replayed updater can't push duplicate jobs into the polling list.
|
||||
const jobs = []
|
||||
const dispatchPatch = {}
|
||||
for (const r of dispatchResults) {
|
||||
if (r.status !== 'fulfilled') continue
|
||||
const v = r.value
|
||||
if (v.ok && v.jobID) {
|
||||
dispatchPatch[v.id] = { status: 'installing', jobID: v.jobID }
|
||||
jobs.push({ nodeID: v.id, jobID: v.jobID })
|
||||
} else {
|
||||
dispatchPatch[v.id] = { status: 'error', error: v.error || 'dispatch failed' }
|
||||
}
|
||||
}
|
||||
setPerNode(prev => ({ ...prev, ...dispatchPatch }))
|
||||
|
||||
await Promise.all(jobs.map(async ({ nodeID, jobID }) => {
|
||||
const result = await pollJob(jobID)
|
||||
setPerNode(prev => {
|
||||
const next = { ...prev }
|
||||
if (result.error) next[nodeID] = { status: 'error', error: result.error, jobID }
|
||||
else next[nodeID] = { status: 'done', jobID }
|
||||
return next
|
||||
})
|
||||
}))
|
||||
|
||||
setSubmitting(false)
|
||||
|
||||
let successCount = 0, failCount = 0
|
||||
setPerNode(prev => {
|
||||
const next = { ...prev }
|
||||
for (const r of results) {
|
||||
if (r.status !== 'fulfilled') continue
|
||||
const v = r.value
|
||||
if (v.ok) { next[v.id] = { status: 'done' }; successCount++ }
|
||||
else { next[v.id] = { status: 'error', error: v.error }; failCount++ }
|
||||
for (const id of failedIds) {
|
||||
const v = prev[id]
|
||||
if (v?.status === 'done') successCount++
|
||||
else if (v?.status === 'error') failCount++
|
||||
}
|
||||
return next
|
||||
return prev
|
||||
})
|
||||
setSubmitting(false)
|
||||
if (successCount > 0 && onComplete) onComplete()
|
||||
if (failCount === 0) {
|
||||
if (failCount === 0 && successCount > 0) {
|
||||
addToast?.(`Installed on ${successCount} node${successCount === 1 ? '' : 's'}`, 'success')
|
||||
setTimeout(() => onClose?.(), 800)
|
||||
}
|
||||
|
||||
@@ -1,14 +1,33 @@
|
||||
import { useState } from 'react'
|
||||
import { useOperations } from '../hooks/useOperations'
|
||||
|
||||
const nodeStatusLabels = {
|
||||
success: 'Done',
|
||||
error: 'Failed',
|
||||
queued: 'Queued',
|
||||
running_on_worker: 'Worker busy',
|
||||
downloading: 'Downloading',
|
||||
}
|
||||
|
||||
const runningOnWorkerTooltip = 'NATS round-trip timed out, but the worker is still installing in the background. The reconciler will confirm completion.'
|
||||
|
||||
export default function OperationsBar() {
|
||||
const { operations, cancelOperation, dismissFailedOp } = useOperations()
|
||||
const [expanded, setExpanded] = useState({})
|
||||
|
||||
if (operations.length === 0) return null
|
||||
|
||||
const toggle = (key) => setExpanded((m) => ({ ...m, [key]: !m[key] }))
|
||||
|
||||
return (
|
||||
<div className="operations-bar">
|
||||
{operations.map(op => (
|
||||
<div key={op.jobID || op.id} className="operation-item">
|
||||
{operations.map(op => {
|
||||
const key = op.jobID || op.id
|
||||
const nodes = Array.isArray(op.nodes) ? op.nodes : []
|
||||
const canExpand = nodes.length > 1
|
||||
const isOpen = !!expanded[key]
|
||||
return (
|
||||
<div key={key} className="operation-item">
|
||||
<div className="operation-info">
|
||||
{op.error ? (
|
||||
<i className="fas fa-circle-exclamation" style={{ color: 'var(--color-error)', marginRight: 'var(--spacing-xs)' }} />
|
||||
@@ -80,8 +99,55 @@ export default function OperationsBar() {
|
||||
<i className="fas fa-xmark" />
|
||||
</button>
|
||||
) : null}
|
||||
{canExpand && (
|
||||
<button
|
||||
type="button"
|
||||
className="operation-expand"
|
||||
onClick={() => toggle(key)}
|
||||
aria-expanded={isOpen}
|
||||
title={isOpen ? 'Hide per-node detail' : `Show ${nodes.length} nodes`}
|
||||
>
|
||||
<i className={`fas fa-chevron-${isOpen ? 'up' : 'down'}`} />
|
||||
<span className="operation-expand-label">{nodes.length} nodes</span>
|
||||
</button>
|
||||
)}
|
||||
{canExpand && isOpen && (
|
||||
<ul className="operation-nodes-list">
|
||||
{nodes.map((n) => (
|
||||
<li key={n.node_id} className={`operation-node operation-node-${n.status}`}>
|
||||
<span
|
||||
className={`operation-node-status operation-node-status-${n.status}`}
|
||||
title={n.status === 'running_on_worker' ? runningOnWorkerTooltip : undefined}
|
||||
>
|
||||
{nodeStatusLabels[n.status] || n.status}
|
||||
</span>
|
||||
<span className="operation-node-name">{n.node_name || n.node_id}</span>
|
||||
{n.file_name && <span className="operation-node-file">{n.file_name}</span>}
|
||||
{(n.current || n.total) && (
|
||||
<span className="operation-node-bytes">
|
||||
{n.current || '?'} / {n.total || '?'}
|
||||
</span>
|
||||
)}
|
||||
{n.percentage > 0 && (
|
||||
<span className="operation-node-pct">{Math.round(n.percentage)}%</span>
|
||||
)}
|
||||
{n.error && (
|
||||
<span className="operation-node-error" title={n.error}>
|
||||
{n.error.length > 80 ? n.error.slice(0, 80) + '...' : n.error}
|
||||
</span>
|
||||
)}
|
||||
{n.percentage > 0 && n.percentage < 100 && (
|
||||
<div className="operation-node-bar-container">
|
||||
<div className="operation-node-bar" style={{ width: `${n.percentage}%` }} />
|
||||
</div>
|
||||
)}
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
)}
|
||||
</div>
|
||||
))}
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -179,16 +179,19 @@ export default function Backends() {
|
||||
|
||||
// Install a single gallery backend on a specific node, used in target-node
|
||||
// mode (the URL has ?target=<node-id> set from the Nodes page entry point).
|
||||
// The handler is async - we dispatch and let the global Operations panel
|
||||
// surface progress; no need to await completion here.
|
||||
const handleInstallOnTarget = async (id) => {
|
||||
if (!targetNode) return
|
||||
try {
|
||||
await nodesApi.installBackend(targetNode.id, id)
|
||||
addToast(`Installing ${id} on ${targetNode.name}…`, 'info')
|
||||
// Per-node install is request-reply, not part of the global jobs feed —
|
||||
// refetch to reflect the new Nodes column state.
|
||||
setTimeout(() => { fetchBackends(); refetchNodes() }, 600)
|
||||
addToast(`Installing ${id} on ${targetNode.name}...`, 'info')
|
||||
// The install runs async via the gallery job queue. Refetch shortly so
|
||||
// the Nodes column reflects "installing" state; the Operations panel
|
||||
// tracks the actual progress until completion.
|
||||
setTimeout(() => { fetchBackends(); refetchNodes() }, 1200)
|
||||
} catch (err) {
|
||||
addToast(`Install failed on ${targetNode.name}: ${err.message}`, 'error')
|
||||
addToast(`Install dispatch failed on ${targetNode.name}: ${err.message}`, 'error')
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -406,7 +406,15 @@ export default function Traces() {
|
||||
<button className="btn btn-secondary btn-sm" onClick={fetchTraces}><i className="fas fa-rotate" /> Refresh</button>
|
||||
<button className="btn btn-secondary btn-sm" onClick={handleExport} disabled={traces.length === 0}><i className="fas fa-download" /> Export</button>
|
||||
<div style={{ flex: 1 }} />
|
||||
<button className="btn btn-danger btn-sm" onClick={handleClear} disabled={traces.length === 0}><i className="fas fa-trash" /> Clear</button>
|
||||
<button
|
||||
className="btn btn-danger btn-sm"
|
||||
onClick={handleClear}
|
||||
/* Stay enabled while loading: a massive in-memory trace buffer is
|
||||
precisely the case where the user can't see the table yet and
|
||||
needs Clear to recover. Clearing an already-empty server-side
|
||||
buffer is a harmless no-op. */
|
||||
disabled={!loading && traces.length === 0}
|
||||
><i className="fas fa-trash" /> Clear</button>
|
||||
</div>
|
||||
|
||||
{settings && (() => {
|
||||
|
||||
@@ -4,6 +4,7 @@ import { useTranslation } from 'react-i18next'
|
||||
import { useAuth } from '../context/AuthContext'
|
||||
import { apiUrl } from '../utils/basePath'
|
||||
import LoadingSpinner from '../components/LoadingSpinner'
|
||||
import SourcesTab from './Usage/SourcesTab'
|
||||
|
||||
const PERIODS = [
|
||||
{ key: 'day', label: 'Day' },
|
||||
@@ -724,23 +725,27 @@ export default function Usage() {
|
||||
{p.label}
|
||||
</button>
|
||||
))}
|
||||
<div style={{ width: 1, height: 20, background: 'var(--color-border-subtle)', margin: '0 var(--spacing-xs)' }} />
|
||||
<button
|
||||
className={`btn btn-sm ${activeTab === 'models' ? 'btn-primary' : 'btn-secondary'}`}
|
||||
onClick={() => setActiveTab('models')}
|
||||
>
|
||||
<i className="fas fa-cube" style={{ fontSize: '0.7rem' }} /> Models
|
||||
</button>
|
||||
{isAdmin && (
|
||||
<>
|
||||
<div style={{ width: 1, height: 20, background: 'var(--color-border-subtle)', margin: '0 var(--spacing-xs)' }} />
|
||||
<button
|
||||
className={`btn btn-sm ${activeTab === 'models' ? 'btn-primary' : 'btn-secondary'}`}
|
||||
onClick={() => setActiveTab('models')}
|
||||
>
|
||||
<i className="fas fa-cube" style={{ fontSize: '0.7rem' }} /> Models
|
||||
</button>
|
||||
<button
|
||||
className={`btn btn-sm ${activeTab === 'users' ? 'btn-primary' : 'btn-secondary'}`}
|
||||
onClick={() => setActiveTab('users')}
|
||||
>
|
||||
<i className="fas fa-users" style={{ fontSize: '0.7rem' }} /> Users
|
||||
</button>
|
||||
</>
|
||||
<button
|
||||
className={`btn btn-sm ${activeTab === 'users' ? 'btn-primary' : 'btn-secondary'}`}
|
||||
onClick={() => setActiveTab('users')}
|
||||
>
|
||||
<i className="fas fa-users" style={{ fontSize: '0.7rem' }} /> Users
|
||||
</button>
|
||||
)}
|
||||
<button
|
||||
className={`btn btn-sm ${activeTab === 'sources' ? 'btn-primary' : 'btn-secondary'}`}
|
||||
onClick={() => setActiveTab('sources')}
|
||||
>
|
||||
<i className="fas fa-key" style={{ fontSize: '0.7rem' }} /> {t('usage.sources.tab')}
|
||||
</button>
|
||||
<div style={{ flex: 1 }} />
|
||||
<button className="btn btn-secondary btn-sm" onClick={fetchUsage} disabled={loading} style={{ gap: 4 }}>
|
||||
<i className={`fas fa-rotate${loading ? ' fa-spin' : ''}`} /> Refresh
|
||||
@@ -884,6 +889,10 @@ export default function Usage() {
|
||||
</div>
|
||||
)
|
||||
)}
|
||||
|
||||
{activeTab === 'sources' && (
|
||||
<SourcesTab period={period} adminUserId={selectedUserId} />
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
83
core/http/react-ui/src/pages/Usage/SourceMixRibbon.jsx
Normal file
83
core/http/react-ui/src/pages/Usage/SourceMixRibbon.jsx
Normal file
@@ -0,0 +1,83 @@
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
const SEGMENT_COLORS = {
|
||||
apikey: 'var(--color-primary)',
|
||||
web: 'var(--color-info, #3b82f6)',
|
||||
legacy: 'var(--color-warning, #f59e0b)',
|
||||
}
|
||||
|
||||
// SourceMixRibbon renders one segmented horizontal bar showing the share of
|
||||
// tokens by source class (apikey / web / legacy). Clicking a segment invokes
|
||||
// onSelectSourceClass with the segment key so the parent can filter the view.
|
||||
//
|
||||
// Props:
|
||||
// bySource: { apikey?: {tokens, requests}, web?: {...}, legacy?: {...} }
|
||||
// keyCount: number of distinct API keys in the dataset (for the legend)
|
||||
// onSelectSourceClass: (cls: 'apikey'|'web'|'legacy') => void (optional)
|
||||
export default function SourceMixRibbon({ bySource = {}, keyCount = 0, onSelectSourceClass }) {
|
||||
const { t } = useTranslation('admin')
|
||||
|
||||
const apikey = (bySource.apikey?.tokens) || 0
|
||||
const web = (bySource.web?.tokens) || 0
|
||||
const legacy = (bySource.legacy?.tokens) || 0
|
||||
const total = apikey + web + legacy || 1
|
||||
|
||||
const pct = (n) => Math.round((n / total) * 100)
|
||||
const apiPct = pct(apikey)
|
||||
const webPct = pct(web)
|
||||
const legacyPct = pct(legacy)
|
||||
|
||||
const segments = [
|
||||
{ key: 'apikey', label: `${apiPct}% API keys (${keyCount})`, pct: apiPct, color: SEGMENT_COLORS.apikey },
|
||||
{ key: 'web', label: `${webPct}% ${t('usage.sources.webUI')}`, pct: webPct, color: SEGMENT_COLORS.web },
|
||||
{ key: 'legacy', label: `${legacyPct}% ${t('usage.sources.legacy')}`, pct: legacyPct, color: SEGMENT_COLORS.legacy },
|
||||
].filter((s) => s.pct > 0)
|
||||
|
||||
return (
|
||||
<div
|
||||
role="group"
|
||||
aria-label={t('usage.sources.ribbonAria', { apikey: apiPct, web: webPct, legacy: legacyPct })}
|
||||
style={{ display: 'flex', flexDirection: 'column', gap: 'var(--spacing-xs)' }}
|
||||
>
|
||||
<div style={{ fontSize: '0.875rem', fontWeight: 600, color: 'var(--color-text-primary)' }}>
|
||||
{t('usage.sources.mixTitle')}
|
||||
</div>
|
||||
<div
|
||||
style={{
|
||||
display: 'flex',
|
||||
height: 12,
|
||||
borderRadius: 'var(--radius-sm)',
|
||||
overflow: 'hidden',
|
||||
border: '1px solid var(--color-border-subtle)',
|
||||
}}
|
||||
>
|
||||
{segments.map((s) => (
|
||||
<button
|
||||
key={s.key}
|
||||
type="button"
|
||||
onClick={() => onSelectSourceClass?.(s.key)}
|
||||
aria-label={s.label}
|
||||
style={{
|
||||
width: `${s.pct}%`,
|
||||
background: s.color,
|
||||
border: 'none',
|
||||
padding: 0,
|
||||
cursor: onSelectSourceClass ? 'pointer' : 'default',
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
<div style={{ display: 'flex', flexWrap: 'wrap', gap: 'var(--spacing-sm)', fontSize: '0.75rem' }}>
|
||||
{segments.map((s) => (
|
||||
<span key={s.key} style={{ display: 'inline-flex', alignItems: 'center', gap: 6 }}>
|
||||
<span
|
||||
style={{ width: 10, height: 10, borderRadius: 2, background: s.color, display: 'inline-block' }}
|
||||
aria-hidden
|
||||
/>
|
||||
{s.label}
|
||||
</span>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
147
core/http/react-ui/src/pages/Usage/SourceTimeChart.jsx
Normal file
147
core/http/react-ui/src/pages/Usage/SourceTimeChart.jsx
Normal file
@@ -0,0 +1,147 @@
|
||||
import { useMemo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
const TOP_N = 7
|
||||
// Distinct, accessible-ish series colors that read on both light and dark themes.
|
||||
const SERIES_COLORS = [
|
||||
'var(--color-primary)',
|
||||
'var(--color-success, #10b981)',
|
||||
'var(--color-warning, #f59e0b)',
|
||||
'var(--color-info, #3b82f6)',
|
||||
'var(--color-danger, #ef4444)',
|
||||
'#a855f7',
|
||||
'#ec4899',
|
||||
]
|
||||
const OTHER_COLOR = 'var(--color-text-muted, #94a3b8)'
|
||||
|
||||
function identityFor(bucket) {
|
||||
return bucket.api_key_id || bucket.source || 'unknown'
|
||||
}
|
||||
|
||||
// buckets: UsageBucket[] from /api/auth/usage/sources (server-sorted ASC by bucket)
|
||||
// selectedKey: 'web' | 'legacy' | api_key_id | null
|
||||
// totals: SourceTotals (for the "Other (count)" legend label)
|
||||
export default function SourceTimeChart({ buckets = [], selectedKey, totals }) {
|
||||
const { t } = useTranslation('admin')
|
||||
|
||||
// Find the top-N identities by total tokens across the period.
|
||||
const topIds = useMemo(() => {
|
||||
const sums = new Map()
|
||||
for (const b of buckets) {
|
||||
const id = identityFor(b)
|
||||
sums.set(id, (sums.get(id) || 0) + (b.total_tokens || 0))
|
||||
}
|
||||
return [...sums.entries()]
|
||||
.sort((a, b) => b[1] - a[1])
|
||||
.slice(0, TOP_N)
|
||||
.map(([id]) => id)
|
||||
}, [buckets])
|
||||
|
||||
const topSet = useMemo(() => new Set(topIds), [topIds])
|
||||
|
||||
// Resolve a display label for an identity (api_key_id -> snapshotted name, or source name).
|
||||
const labelByIdentity = useMemo(() => {
|
||||
const m = new Map()
|
||||
for (const b of buckets) {
|
||||
const id = identityFor(b)
|
||||
if (m.has(id)) continue
|
||||
if (b.source === 'web') { m.set(id, t('usage.sources.webUI')); continue }
|
||||
if (b.source === 'legacy') { m.set(id, t('usage.sources.legacy')); continue }
|
||||
m.set(id, b.api_key_name || b.api_key_id || id)
|
||||
}
|
||||
return m
|
||||
}, [buckets, t])
|
||||
|
||||
// Build a dense per-bucket row, splitting top-N vs Other.
|
||||
const series = useMemo(() => {
|
||||
const byBucket = new Map()
|
||||
for (const b of buckets) {
|
||||
const id = identityFor(b)
|
||||
const seriesId = topSet.has(id) ? id : '__other__'
|
||||
const row = byBucket.get(b.bucket) || { bucket: b.bucket, total: 0 }
|
||||
row[seriesId] = (row[seriesId] || 0) + (b.total_tokens || 0)
|
||||
row.total += b.total_tokens || 0
|
||||
byBucket.set(b.bucket, row)
|
||||
}
|
||||
return [...byBucket.values()]
|
||||
}, [buckets, topSet])
|
||||
|
||||
const max = useMemo(
|
||||
() => series.reduce((m, r) => Math.max(m, r.total), 0) || 1,
|
||||
[series]
|
||||
)
|
||||
|
||||
const seriesIds = [...topIds, '__other__']
|
||||
const colorOf = (id) =>
|
||||
id === '__other__'
|
||||
? OTHER_COLOR
|
||||
: SERIES_COLORS[topIds.indexOf(id) % SERIES_COLORS.length]
|
||||
|
||||
const labelOfId = (id) => {
|
||||
if (id === '__other__') return null // computed inline (need count)
|
||||
return labelByIdentity.get(id) || id
|
||||
}
|
||||
|
||||
const otherCount = Math.max(0, (totals?.by_key?.length || 0) - TOP_N)
|
||||
|
||||
// SVG geometry: 24px wide per bar (2px gap), 100px tall, viewBox stretches with bar count.
|
||||
const barWidth = 20
|
||||
const barGap = 4
|
||||
const slotWidth = barWidth + barGap
|
||||
const height = 100
|
||||
const width = Math.max(series.length * slotWidth, 200)
|
||||
|
||||
return (
|
||||
<div style={{ display: 'flex', flexDirection: 'column', gap: 'var(--spacing-xs)' }}>
|
||||
<div style={{ fontSize: '0.875rem', fontWeight: 600, color: 'var(--color-text-primary)' }}>
|
||||
{t('usage.sources.topSources')}
|
||||
</div>
|
||||
|
||||
<svg
|
||||
viewBox={`0 0 ${width} ${height}`}
|
||||
preserveAspectRatio="none"
|
||||
style={{ width: '100%', height: 160, display: 'block' }}
|
||||
aria-hidden
|
||||
>
|
||||
{series.map((row, i) => {
|
||||
let y = height
|
||||
return (
|
||||
<g key={row.bucket} transform={`translate(${i * slotWidth}, 0)`}>
|
||||
{seriesIds.map(id => {
|
||||
const v = row[id] || 0
|
||||
if (!v) return null
|
||||
const h = (v / max) * height
|
||||
y -= h
|
||||
const dim = selectedKey && selectedKey !== id ? 0.25 : 1
|
||||
const title = id === '__other__'
|
||||
? t('usage.sources.other', { count: otherCount })
|
||||
: labelOfId(id)
|
||||
return (
|
||||
<rect
|
||||
key={id}
|
||||
x={barGap / 2} y={y}
|
||||
width={barWidth} height={h}
|
||||
fill={colorOf(id)} opacity={dim}
|
||||
>
|
||||
<title>{`${row.bucket} - ${title}: ${v.toLocaleString()}`}</title>
|
||||
</rect>
|
||||
)
|
||||
})}
|
||||
</g>
|
||||
)
|
||||
})}
|
||||
</svg>
|
||||
|
||||
<div style={{ display: 'flex', flexWrap: 'wrap', gap: 'var(--spacing-sm)', fontSize: '0.75rem' }}>
|
||||
{seriesIds.map(id => (
|
||||
<span key={id} style={{ display: 'inline-flex', alignItems: 'center', gap: 6 }}>
|
||||
<span style={{ width: 10, height: 10, borderRadius: 2, background: colorOf(id), display: 'inline-block' }} aria-hidden />
|
||||
{id === '__other__'
|
||||
? t('usage.sources.other', { count: otherCount })
|
||||
: labelOfId(id)}
|
||||
</span>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
176
core/http/react-ui/src/pages/Usage/SourcesTab.jsx
Normal file
176
core/http/react-ui/src/pages/Usage/SourcesTab.jsx
Normal file
@@ -0,0 +1,176 @@
|
||||
import { useEffect, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { usageApi, apiKeysApi } from '../../utils/api'
|
||||
import { useAuth } from '../../context/AuthContext'
|
||||
import LoadingSpinner from '../../components/LoadingSpinner'
|
||||
import SourceMixRibbon from './SourceMixRibbon'
|
||||
import SourcesTable from './SourcesTable'
|
||||
import SourceTimeChart from './SourceTimeChart'
|
||||
|
||||
const EMPTY_DATA = {
|
||||
buckets: [],
|
||||
totals: { by_source: {}, by_key: [], grand_total: { tokens: 0, requests: 0 } },
|
||||
truncated: false,
|
||||
}
|
||||
|
||||
// Resolve a human label for the currently selected key (web/legacy class or api_key_id).
|
||||
function labelForSelected(totals, selectedKey, t) {
|
||||
if (!selectedKey) return ''
|
||||
if (selectedKey === 'web') return t('usage.sources.webUI')
|
||||
if (selectedKey === 'legacy') return t('usage.sources.legacy')
|
||||
const row = (totals?.by_key || []).find(k => k.api_key_id === selectedKey)
|
||||
return row ? (row.api_key_name || selectedKey) : selectedKey
|
||||
}
|
||||
|
||||
// SourcesTab fetches and renders per-source / per-API-key usage breakdown.
|
||||
// Task 10 replaces the raw JSON / list placeholders with SourceMixRibbon and
|
||||
// SourcesTable. Task 11 will add the time chart and drill-in chip.
|
||||
export default function SourcesTab({ period, adminUserId }) {
|
||||
const { t } = useTranslation('admin')
|
||||
const { isAdmin } = useAuth()
|
||||
|
||||
const [data, setData] = useState(EMPTY_DATA)
|
||||
const [loading, setLoading] = useState(false)
|
||||
const [error, setError] = useState(null)
|
||||
|
||||
const [selectedKey, setSelectedKey] = useState(null)
|
||||
const [search, setSearch] = useState('')
|
||||
const [sortKey, setSortKey] = useState('tokens')
|
||||
|
||||
// Pull the current set of API key ids so the table can mark unknown keys as
|
||||
// revoked. null = "don't know yet" so the table won't dim live keys during
|
||||
// the fetch or after a failure.
|
||||
const [existingKeyIds, setExistingKeyIds] = useState(null)
|
||||
useEffect(() => {
|
||||
apiKeysApi
|
||||
.list()
|
||||
.then((resp) => {
|
||||
const list = Array.isArray(resp) ? resp : (resp?.keys || [])
|
||||
setExistingKeyIds(new Set(list.map((k) => k.id)))
|
||||
})
|
||||
.catch(() => { /* leave existingKeyIds null so revoked detection is skipped */ })
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
let cancelled = false
|
||||
setLoading(true)
|
||||
setError(null)
|
||||
const p = isAdmin
|
||||
? usageApi.getAdminSources(period, adminUserId)
|
||||
: usageApi.getMySources(period)
|
||||
p
|
||||
.then((d) => { if (!cancelled) setData(d || EMPTY_DATA) })
|
||||
.catch((e) => { if (!cancelled) setError(e) })
|
||||
.finally(() => { if (!cancelled) setLoading(false) })
|
||||
return () => { cancelled = true }
|
||||
}, [isAdmin, period, adminUserId])
|
||||
|
||||
const totals = data.totals || EMPTY_DATA.totals
|
||||
const buckets = data.buckets || EMPTY_DATA.buckets
|
||||
const grandT = totals.grand_total || { tokens: 0, requests: 0 }
|
||||
const truncated = data.truncated || false
|
||||
|
||||
const isEmpty = !loading && (grandT.tokens || 0) === 0 && (grandT.requests || 0) === 0
|
||||
|
||||
if (loading) {
|
||||
return (
|
||||
<div style={{ display: 'flex', justifyContent: 'center', padding: 'var(--spacing-xl)' }}>
|
||||
<LoadingSpinner size="lg" />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<div className="empty-state">
|
||||
<div className="empty-state-icon"><i className="fas fa-triangle-exclamation" /></div>
|
||||
<h2 className="empty-state-title">Failed to load</h2>
|
||||
<p className="empty-state-text">{String(error.message || error)}</p>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
if (isEmpty) {
|
||||
return (
|
||||
<div className="empty-state">
|
||||
<div className="empty-state-icon"><i className="fas fa-key" /></div>
|
||||
<h2 className="empty-state-title">{t('usage.sources.noTrafficShort')}</h2>
|
||||
<p className="empty-state-text">{t('usage.sources.noKeysYet')}</p>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<div style={{ display: 'flex', flexDirection: 'column', gap: 'var(--spacing-md)' }}>
|
||||
<div className="card" style={{ padding: 'var(--spacing-md)' }}>
|
||||
<SourceMixRibbon
|
||||
bySource={totals.by_source}
|
||||
keyCount={(totals.by_key || []).length}
|
||||
onSelectSourceClass={(cls) => setSelectedKey(cls)}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{selectedKey && (
|
||||
<div style={{ display: 'flex', alignItems: 'center', gap: 'var(--spacing-xs)' }}>
|
||||
<span
|
||||
style={{
|
||||
display: 'inline-flex',
|
||||
alignItems: 'center',
|
||||
gap: 'var(--spacing-xs)',
|
||||
padding: 'calc(var(--spacing-xs) / 2) var(--spacing-sm)',
|
||||
background: 'var(--color-bg-secondary)',
|
||||
color: 'var(--color-text-primary)',
|
||||
fontSize: '0.75rem',
|
||||
borderRadius: 'var(--radius-sm)',
|
||||
border: '1px solid var(--color-border-subtle)',
|
||||
}}
|
||||
>
|
||||
<i className="fas fa-filter" style={{ fontSize: '0.6875rem', color: 'var(--color-text-muted)' }} aria-hidden />
|
||||
{t('usage.sources.filteredTo', { name: labelForSelected(totals, selectedKey, t) })}
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setSelectedKey(null)}
|
||||
aria-label={t('usage.sources.clearFilter')}
|
||||
style={{
|
||||
appearance: 'none',
|
||||
background: 'transparent',
|
||||
border: 'none',
|
||||
color: 'var(--color-text-muted)',
|
||||
cursor: 'pointer',
|
||||
padding: 0,
|
||||
fontSize: '0.875rem',
|
||||
lineHeight: 1,
|
||||
}}
|
||||
>
|
||||
<i className="fas fa-xmark" />
|
||||
</button>
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="card" style={{ padding: 'var(--spacing-md)' }}>
|
||||
<SourceTimeChart buckets={buckets} selectedKey={selectedKey} totals={totals} />
|
||||
</div>
|
||||
|
||||
<div className="card" style={{ padding: 'var(--spacing-md)' }}>
|
||||
<SourcesTable
|
||||
totals={totals}
|
||||
selectedKey={selectedKey}
|
||||
onSelectKey={setSelectedKey}
|
||||
search={search}
|
||||
setSearch={setSearch}
|
||||
sortKey={sortKey}
|
||||
setSortKey={setSortKey}
|
||||
existingKeyIds={existingKeyIds}
|
||||
showUserColumn={isAdmin}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{truncated && (
|
||||
<div style={{ fontSize: '0.75rem', color: 'var(--color-warning)' }}>
|
||||
{t('usage.sources.truncatedWarning')}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
245
core/http/react-ui/src/pages/Usage/SourcesTable.jsx
Normal file
245
core/http/react-ui/src/pages/Usage/SourcesTable.jsx
Normal file
@@ -0,0 +1,245 @@
|
||||
import { useMemo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
const SORT_FNS = {
|
||||
tokens: (a, b) => (b.tokens || 0) - (a.tokens || 0),
|
||||
requests: (a, b) => (b.requests || 0) - (a.requests || 0),
|
||||
last_used: (a, b) => new Date(b.last_used || 0).getTime() - new Date(a.last_used || 0).getTime(),
|
||||
name: (a, b) => (a.name || '').localeCompare(b.name || ''),
|
||||
user: (a, b) => (a.userName || '').localeCompare(b.userName || ''),
|
||||
}
|
||||
|
||||
function formatTokens(n) {
|
||||
if (!n) return '0'
|
||||
if (n >= 1_000_000) return (n / 1_000_000).toFixed(1) + 'M'
|
||||
if (n >= 1_000) return (n / 1_000).toFixed(1) + 'k'
|
||||
return String(n)
|
||||
}
|
||||
|
||||
function formatRelative(iso) {
|
||||
if (!iso) return '-'
|
||||
const t = new Date(iso).getTime()
|
||||
if (Number.isNaN(t) || t <= 0) return '-'
|
||||
const diff = Date.now() - t
|
||||
if (diff < 60_000) return 'just now'
|
||||
if (diff < 3_600_000) return Math.round(diff / 60_000) + 'm ago'
|
||||
if (diff < 86_400_000) return Math.round(diff / 3_600_000) + 'h ago'
|
||||
return Math.round(diff / 86_400_000) + 'd ago'
|
||||
}
|
||||
|
||||
// SourcesTable is the searchable, sortable list of key totals plus pseudo-rows
|
||||
// for the web UI and legacy (unkeyed) source classes. Clicking a row selects
|
||||
// it; the parent decides what to do with the selection (the drill-in panel
|
||||
// will be wired in Task 11).
|
||||
//
|
||||
// Props:
|
||||
// totals: SourceTotals payload (from /api/auth/usage/sources)
|
||||
// selectedKey: currently-selected row id (api_key_id | 'web' | 'legacy' | null)
|
||||
// onSelectKey: (id|null) => void
|
||||
// search / setSearch: free-text filter state lifted to the parent
|
||||
// sortKey / setSortKey: sort column state lifted to the parent
|
||||
// existingKeyIds: Set<string> of current (non-revoked) api key ids, or null
|
||||
// when the parent hasn't yet learned which keys exist. Null suppresses the
|
||||
// revoked badge entirely so live keys aren't dimmed during the fetch or
|
||||
// after a failure.
|
||||
// showUserColumn: render the User column. Admin views set this true so the
|
||||
// reader can attribute each key (and each Web UI row) to its owner.
|
||||
export default function SourcesTable({
|
||||
totals,
|
||||
selectedKey,
|
||||
onSelectKey,
|
||||
search,
|
||||
setSearch,
|
||||
sortKey,
|
||||
setSortKey,
|
||||
existingKeyIds = null,
|
||||
showUserColumn = false,
|
||||
}) {
|
||||
const { t } = useTranslation('admin')
|
||||
|
||||
const rows = useMemo(() => {
|
||||
const named = (totals?.by_key || []).map((k) => ({
|
||||
kind: 'apikey',
|
||||
id: k.api_key_id,
|
||||
name: k.api_key_name || k.api_key_id,
|
||||
userID: k.user_id || '',
|
||||
userName: k.user_name || '',
|
||||
prefix: '',
|
||||
tokens: k.tokens,
|
||||
requests: k.requests,
|
||||
last_used: k.last_used,
|
||||
revoked: existingKeyIds != null && !existingKeyIds.has(k.api_key_id),
|
||||
}))
|
||||
|
||||
// Pseudo-rows for sources that don't have a named key identity.
|
||||
// In admin view (showUserColumn=true), prefer the per-user breakdown
|
||||
// from totals.by_user_source so each user's Web UI / legacy traffic
|
||||
// gets its own row. Otherwise fall back to the global by_source aggregate.
|
||||
let unkeyed = []
|
||||
if (showUserColumn && Array.isArray(totals?.by_user_source) && totals.by_user_source.length > 0) {
|
||||
unkeyed = totals.by_user_source.map((r) => ({
|
||||
kind: r.source,
|
||||
id: r.source + ':' + (r.user_id || ''),
|
||||
name: r.source === 'legacy' ? t('usage.sources.legacy') : t('usage.sources.webUI'),
|
||||
userID: r.user_id || '',
|
||||
userName: r.user_name || '',
|
||||
prefix: '-',
|
||||
tokens: r.tokens,
|
||||
requests: r.requests,
|
||||
}))
|
||||
} else {
|
||||
if (totals?.by_source?.web) {
|
||||
unkeyed.push({
|
||||
kind: 'web',
|
||||
id: 'web',
|
||||
name: t('usage.sources.webUI'),
|
||||
userID: '',
|
||||
userName: '',
|
||||
prefix: '-',
|
||||
tokens: totals.by_source.web.tokens,
|
||||
requests: totals.by_source.web.requests,
|
||||
})
|
||||
}
|
||||
if (totals?.by_source?.legacy) {
|
||||
unkeyed.push({
|
||||
kind: 'legacy',
|
||||
id: 'legacy',
|
||||
name: t('usage.sources.legacy'),
|
||||
userID: '',
|
||||
userName: '',
|
||||
prefix: '-',
|
||||
tokens: totals.by_source.legacy.tokens,
|
||||
requests: totals.by_source.legacy.requests,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return [...named, ...unkeyed]
|
||||
}, [totals, existingKeyIds, showUserColumn, t])
|
||||
|
||||
const filtered = useMemo(() => {
|
||||
const q = (search || '').trim().toLowerCase()
|
||||
const list = q
|
||||
? rows.filter((r) =>
|
||||
(r.name || '').toLowerCase().includes(q) ||
|
||||
(r.prefix || '').toLowerCase().includes(q) ||
|
||||
(r.userName || '').toLowerCase().includes(q) ||
|
||||
(r.userID || '').toLowerCase().includes(q)
|
||||
)
|
||||
: rows
|
||||
return [...list].sort(SORT_FNS[sortKey] || SORT_FNS.tokens)
|
||||
}, [rows, search, sortKey])
|
||||
|
||||
const iconFor = (kind) =>
|
||||
kind === 'apikey' ? 'fas fa-key' : kind === 'web' ? 'fas fa-globe' : 'fas fa-gear'
|
||||
|
||||
return (
|
||||
<div style={{ display: 'flex', flexDirection: 'column', gap: 'var(--spacing-sm)' }}>
|
||||
<div style={{ display: 'flex', alignItems: 'center', gap: 'var(--spacing-sm)', flexWrap: 'wrap' }}>
|
||||
<input
|
||||
type="search"
|
||||
value={search}
|
||||
onChange={(e) => setSearch(e.target.value)}
|
||||
placeholder={t('usage.sources.searchPlaceholder')}
|
||||
aria-label={t('usage.sources.searchPlaceholder')}
|
||||
style={{
|
||||
flex: '1 1 12rem',
|
||||
minWidth: 160,
|
||||
padding: 'var(--spacing-xs) var(--spacing-sm)',
|
||||
border: '1px solid var(--color-border-subtle)',
|
||||
borderRadius: 'var(--radius-sm)',
|
||||
background: 'var(--color-bg-primary)',
|
||||
color: 'var(--color-text-primary)',
|
||||
}}
|
||||
/>
|
||||
<label style={{ display: 'inline-flex', alignItems: 'center', gap: 6, fontSize: '0.75rem' }}>
|
||||
{t('usage.sources.sortBy')}:
|
||||
<select
|
||||
value={sortKey}
|
||||
onChange={(e) => setSortKey(e.target.value)}
|
||||
style={{
|
||||
padding: 'calc(var(--spacing-xs) / 2) var(--spacing-xs)',
|
||||
border: '1px solid var(--color-border-subtle)',
|
||||
borderRadius: 'var(--radius-sm)',
|
||||
background: 'var(--color-bg-primary)',
|
||||
color: 'var(--color-text-primary)',
|
||||
}}
|
||||
>
|
||||
<option value="tokens">{t('usage.sources.sortTokens')}</option>
|
||||
<option value="requests">{t('usage.sources.sortRequests')}</option>
|
||||
<option value="last_used">{t('usage.sources.sortLastUsed')}</option>
|
||||
<option value="name">{t('usage.sources.sortName')}</option>
|
||||
{showUserColumn && <option value="user">{t('usage.sources.sortUser')}</option>}
|
||||
</select>
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<div className="table-container">
|
||||
<table className="table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>{t('usage.sources.sortName')}</th>
|
||||
{showUserColumn && <th style={{ width: 180 }}>{t('usage.sources.sortUser')}</th>}
|
||||
<th style={{ width: 110 }}>Prefix</th>
|
||||
<th style={{ width: 100, textAlign: 'right' }}>{t('usage.sources.sortRequests')}</th>
|
||||
<th style={{ width: 100, textAlign: 'right' }}>{t('usage.sources.sortTokens')}</th>
|
||||
<th style={{ width: 120, textAlign: 'right' }}>{t('usage.sources.sortLastUsed')}</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{filtered.map((r) => {
|
||||
const isSel = selectedKey === r.id
|
||||
return (
|
||||
<tr
|
||||
key={r.id}
|
||||
onClick={() => onSelectKey?.(isSel ? null : r.id)}
|
||||
style={{
|
||||
cursor: 'pointer',
|
||||
background: isSel ? 'var(--color-bg-secondary)' : undefined,
|
||||
opacity: r.revoked ? 0.5 : 1,
|
||||
}}
|
||||
>
|
||||
<td>
|
||||
<span style={{ display: 'inline-flex', alignItems: 'center', gap: 8 }}>
|
||||
<i
|
||||
className={iconFor(r.kind)}
|
||||
style={{ color: 'var(--color-text-muted)', fontSize: '0.8125rem' }}
|
||||
/>
|
||||
<span>{r.name}</span>
|
||||
{r.revoked && (
|
||||
<span
|
||||
style={{
|
||||
fontSize: '0.6875rem',
|
||||
textTransform: 'uppercase',
|
||||
color: 'var(--color-text-muted)',
|
||||
}}
|
||||
>
|
||||
({t('usage.sources.revoked')})
|
||||
</span>
|
||||
)}
|
||||
</span>
|
||||
</td>
|
||||
{showUserColumn && (
|
||||
<td style={{ color: 'var(--color-text-secondary)', fontSize: '0.8125rem' }}>
|
||||
{r.userName || r.userID || '-'}
|
||||
</td>
|
||||
)}
|
||||
<td style={{ color: 'var(--color-text-muted)', fontSize: '0.75rem' }}>{r.prefix || '-'}</td>
|
||||
<td style={{ textAlign: 'right', fontFamily: 'var(--font-mono)' }}>
|
||||
{Number(r.requests || 0).toLocaleString()}
|
||||
</td>
|
||||
<td style={{ textAlign: 'right', fontFamily: 'var(--font-mono)' }}>
|
||||
{formatTokens(r.tokens || 0)}
|
||||
</td>
|
||||
<td style={{ textAlign: 'right', fontSize: '0.75rem', color: 'var(--color-text-muted)' }}>
|
||||
{formatRelative(r.last_used)}
|
||||
</td>
|
||||
</tr>
|
||||
)
|
||||
})}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
8
core/http/react-ui/src/utils/api.js
vendored
8
core/http/react-ui/src/utils/api.js
vendored
@@ -422,6 +422,14 @@ export const usageApi = {
|
||||
if (userId) url += `&user_id=${encodeURIComponent(userId)}`
|
||||
return fetchJSON(url)
|
||||
},
|
||||
getMySources: (period) =>
|
||||
fetchJSON(`/api/auth/usage/sources?period=${period || 'month'}`),
|
||||
getAdminSources: (period, userId, apiKeyId) => {
|
||||
let url = `/api/auth/admin/usage/sources?period=${period || 'month'}`
|
||||
if (userId) url += `&user_id=${encodeURIComponent(userId)}`
|
||||
if (apiKeyId) url += `&api_key_id=${encodeURIComponent(apiKeyId)}`
|
||||
return fetchJSON(url)
|
||||
},
|
||||
getMyQuotas: () => fetchJSON('/api/auth/quota'),
|
||||
}
|
||||
|
||||
|
||||
@@ -789,6 +789,30 @@ func RegisterAuthRoutes(e *echo.Echo, app *application.Application) {
|
||||
})
|
||||
})
|
||||
|
||||
// GET /api/auth/usage/sources - caller's per-source breakdown (no legacy)
|
||||
e.GET("/api/auth/usage/sources", func(c echo.Context) error {
|
||||
user := auth.GetUser(c)
|
||||
if user == nil {
|
||||
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "not authenticated"})
|
||||
}
|
||||
|
||||
period := c.QueryParam("period")
|
||||
if period == "" {
|
||||
period = "month"
|
||||
}
|
||||
|
||||
buckets, totals, err := auth.GetUserUsageBySource(db, user.ID, period)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to get usage"})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]any{
|
||||
"buckets": buckets,
|
||||
"totals": totals,
|
||||
"truncated": false,
|
||||
})
|
||||
})
|
||||
|
||||
// Admin endpoints
|
||||
adminMw := auth.RequireAdmin()
|
||||
|
||||
@@ -1104,6 +1128,27 @@ func RegisterAuthRoutes(e *echo.Echo, app *application.Application) {
|
||||
})
|
||||
}, adminMw)
|
||||
|
||||
// GET /api/auth/admin/usage/sources - all users' per-source breakdown (admin only)
|
||||
e.GET("/api/auth/admin/usage/sources", func(c echo.Context) error {
|
||||
period := c.QueryParam("period")
|
||||
if period == "" {
|
||||
period = "month"
|
||||
}
|
||||
userID := c.QueryParam("user_id")
|
||||
apiKeyID := c.QueryParam("api_key_id")
|
||||
|
||||
buckets, totals, truncated, err := auth.GetAllUsageBySource(db, period, userID, apiKeyID)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to get usage"})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]any{
|
||||
"buckets": buckets,
|
||||
"totals": totals,
|
||||
"truncated": truncated,
|
||||
})
|
||||
}, adminMw)
|
||||
|
||||
// --- Invite management endpoints ---
|
||||
|
||||
// POST /api/auth/admin/invites - create invite (admin only)
|
||||
|
||||
@@ -286,6 +286,45 @@ func newTestAuthApp(db *gorm.DB, appConfig *config.ApplicationConfig) *echo.Echo
|
||||
return c.JSON(http.StatusOK, map[string]string{"message": "user deleted"})
|
||||
}, adminMw)
|
||||
|
||||
// Mirror of production handler in routes/auth.go GET /api/auth/usage/sources.
|
||||
// Keep this body in sync with the real handler; this test app cannot call
|
||||
// RegisterAuthRoutes because it needs a *application.Application.
|
||||
e.GET("/api/auth/usage/sources", func(c echo.Context) error {
|
||||
user := auth.GetUser(c)
|
||||
if user == nil {
|
||||
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "not authenticated"})
|
||||
}
|
||||
period := c.QueryParam("period")
|
||||
if period == "" {
|
||||
period = "month"
|
||||
}
|
||||
buckets, totals, err := auth.GetUserUsageBySource(db, user.ID, period)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to get usage"})
|
||||
}
|
||||
return c.JSON(http.StatusOK, map[string]any{
|
||||
"buckets": buckets, "totals": totals, "truncated": false,
|
||||
})
|
||||
})
|
||||
|
||||
// Mirror of production handler in routes/auth.go GET /api/auth/admin/usage/sources.
|
||||
// Keep this body in sync with the real handler.
|
||||
e.GET("/api/auth/admin/usage/sources", func(c echo.Context) error {
|
||||
period := c.QueryParam("period")
|
||||
if period == "" {
|
||||
period = "month"
|
||||
}
|
||||
userID := c.QueryParam("user_id")
|
||||
apiKeyID := c.QueryParam("api_key_id")
|
||||
buckets, totals, truncated, err := auth.GetAllUsageBySource(db, period, userID, apiKeyID)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to get usage"})
|
||||
}
|
||||
return c.JSON(http.StatusOK, map[string]any{
|
||||
"buckets": buckets, "totals": totals, "truncated": truncated,
|
||||
})
|
||||
}, adminMw)
|
||||
|
||||
// Regular API endpoint for testing
|
||||
e.POST("/v1/chat/completions", func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
@@ -931,4 +970,110 @@ var _ = Describe("Auth Routes", Label("auth"), func() {
|
||||
Expect(providers).To(ContainElement(auth.ProviderGitHub))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("GET /api/auth/usage/sources", func() {
|
||||
It("returns only the caller's data, never legacy", func() {
|
||||
app := newTestAuthApp(db, appConfig)
|
||||
|
||||
alice := createRouteTestUser(db, "alice@example.com", auth.RoleUser)
|
||||
aliceToken, err := auth.CreateSession(db, alice.ID, "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
keyID := "k-alice"
|
||||
now := time.Now()
|
||||
Expect(auth.RecordUsage(db, &auth.UsageRecord{
|
||||
UserID: alice.ID, Source: auth.UsageSourceAPIKey,
|
||||
APIKeyID: &keyID, APIKeyName: "alice-key",
|
||||
Model: "gpt-4", TotalTokens: 100, CreatedAt: now,
|
||||
})).To(Succeed())
|
||||
Expect(auth.RecordUsage(db, &auth.UsageRecord{
|
||||
UserID: alice.ID, Source: auth.UsageSourceWeb,
|
||||
Model: "gpt-4", TotalTokens: 50, CreatedAt: now,
|
||||
})).To(Succeed())
|
||||
Expect(auth.RecordUsage(db, &auth.UsageRecord{
|
||||
UserID: "legacy-api-key", Source: auth.UsageSourceLegacy,
|
||||
Model: "gpt-4", TotalTokens: 30, CreatedAt: now,
|
||||
})).To(Succeed())
|
||||
|
||||
rec := doAuthRequest(app, http.MethodGet, "/api/auth/usage/sources?period=month", nil, withSession(aliceToken))
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var resp struct {
|
||||
Buckets []auth.UsageBucket `json:"buckets"`
|
||||
Totals auth.SourceTotals `json:"totals"`
|
||||
Truncated bool `json:"truncated"`
|
||||
}
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
_, hasLegacy := resp.Totals.BySource[auth.UsageSourceLegacy]
|
||||
Expect(hasLegacy).To(BeFalse())
|
||||
Expect(resp.Totals.GrandTotal.Tokens).To(Equal(int64(150)))
|
||||
Expect(resp.Truncated).To(BeFalse())
|
||||
})
|
||||
|
||||
It("returns 401 when unauthenticated", func() {
|
||||
app := newTestAuthApp(db, appConfig)
|
||||
|
||||
// Without a session cookie or bearer token, the global auth middleware
|
||||
// should refuse the request before our handler runs.
|
||||
rec := doAuthRequest(app, http.MethodGet, "/api/auth/usage/sources?period=month", nil)
|
||||
Expect(rec.Code).To(Equal(http.StatusUnauthorized))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("GET /api/auth/admin/usage/sources", func() {
|
||||
It("returns 403 for non-admin", func() {
|
||||
app := newTestAuthApp(db, appConfig)
|
||||
|
||||
alice := createRouteTestUser(db, "alice@example.com", auth.RoleUser)
|
||||
aliceToken, _ := auth.CreateSession(db, alice.ID, "")
|
||||
|
||||
rec := doAuthRequest(app, http.MethodGet, "/api/auth/admin/usage/sources?period=month", nil, withSession(aliceToken))
|
||||
Expect(rec.Code).To(Equal(http.StatusForbidden))
|
||||
})
|
||||
|
||||
It("returns legacy bucket for admin and applies api_key_id filter", func() {
|
||||
app := newTestAuthApp(db, appConfig)
|
||||
|
||||
admin := createRouteTestUser(db, "admin@example.com", auth.RoleAdmin)
|
||||
adminToken, _ := auth.CreateSession(db, admin.ID, "")
|
||||
|
||||
k1 := "k1"
|
||||
k2 := "k2"
|
||||
now := time.Now()
|
||||
Expect(auth.RecordUsage(db, &auth.UsageRecord{UserID: "alice", Source: auth.UsageSourceAPIKey, APIKeyID: &k1, APIKeyName: "ci", Model: "gpt-4", TotalTokens: 10, CreatedAt: now})).To(Succeed())
|
||||
Expect(auth.RecordUsage(db, &auth.UsageRecord{UserID: "alice", Source: auth.UsageSourceAPIKey, APIKeyID: &k2, APIKeyName: "lap", Model: "gpt-4", TotalTokens: 20, CreatedAt: now})).To(Succeed())
|
||||
Expect(auth.RecordUsage(db, &auth.UsageRecord{UserID: "legacy-api-key", Source: auth.UsageSourceLegacy, Model: "gpt-4", TotalTokens: 5, CreatedAt: now})).To(Succeed())
|
||||
|
||||
rec := doAuthRequest(app, http.MethodGet,
|
||||
"/api/auth/admin/usage/sources?period=month&api_key_id=k2", nil, withSession(adminToken))
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var resp struct {
|
||||
Totals auth.SourceTotals `json:"totals"`
|
||||
Truncated bool `json:"truncated"`
|
||||
}
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
Expect(resp.Totals.GrandTotal.Tokens).To(Equal(int64(20)))
|
||||
})
|
||||
|
||||
It("includes legacy in by_source for admin with no filter", func() {
|
||||
app := newTestAuthApp(db, appConfig)
|
||||
|
||||
admin := createRouteTestUser(db, "admin@example.com", auth.RoleAdmin)
|
||||
adminToken, _ := auth.CreateSession(db, admin.ID, "")
|
||||
|
||||
now := time.Now()
|
||||
Expect(auth.RecordUsage(db, &auth.UsageRecord{UserID: "legacy-api-key", Source: auth.UsageSourceLegacy, Model: "gpt-4", TotalTokens: 7, CreatedAt: now})).To(Succeed())
|
||||
|
||||
rec := doAuthRequest(app, http.MethodGet, "/api/auth/admin/usage/sources?period=month", nil, withSession(adminToken))
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var resp struct {
|
||||
Totals auth.SourceTotals `json:"totals"`
|
||||
}
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
Expect(resp.Totals.BySource).To(HaveKey(auth.UsageSourceLegacy))
|
||||
Expect(resp.Totals.BySource[auth.UsageSourceLegacy].Tokens).To(Equal(int64(7)))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -6,7 +6,9 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -53,7 +55,12 @@ func RegisterNodeSelfServiceRoutes(e *echo.Echo, registry *nodes.NodeRegistry, r
|
||||
|
||||
// RegisterNodeAdminRoutes registers /api/nodes/ endpoints used by admins
|
||||
// (list, get, get models, drain, delete, approve, backend management). Protected by admin middleware.
|
||||
func RegisterNodeAdminRoutes(e *echo.Echo, registry *nodes.NodeRegistry, unloader nodes.NodeCommandSender, adminMw echo.MiddlewareFunc, authDB *gorm.DB, hmacSecret string, registrationToken string) {
|
||||
//
|
||||
// galleryService/opcache/appConfig are threaded in for the async node-scoped
|
||||
// backend install path (POST /:id/backends/install). That handler enqueues a
|
||||
// ManagementOp on the gallery channel rather than blocking on a NATS reply, so
|
||||
// the browser gets HTTP 202 + jobID immediately instead of waiting up to 3 minutes.
|
||||
func RegisterNodeAdminRoutes(e *echo.Echo, registry *nodes.NodeRegistry, unloader nodes.NodeCommandSender, galleryService *galleryop.GalleryService, opcache *galleryop.OpCache, appConfig *config.ApplicationConfig, adminMw echo.MiddlewareFunc, authDB *gorm.DB, hmacSecret string, registrationToken string) {
|
||||
if registry == nil {
|
||||
return
|
||||
}
|
||||
@@ -78,7 +85,7 @@ func RegisterNodeAdminRoutes(e *echo.Echo, registry *nodes.NodeRegistry, unloade
|
||||
|
||||
// Backend management on workers
|
||||
admin.GET("/:id/backends", localai.ListBackendsOnNodeEndpoint(unloader))
|
||||
admin.POST("/:id/backends/install", localai.InstallBackendOnNodeEndpoint(unloader))
|
||||
admin.POST("/:id/backends/install", localai.InstallBackendOnNodeEndpoint(unloader, galleryService, opcache, appConfig))
|
||||
admin.POST("/:id/backends/delete", localai.DeleteBackendOnNodeEndpoint(unloader))
|
||||
|
||||
// Model management on workers
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -57,7 +58,6 @@ var usecaseFilters = map[string]config.ModelConfigUsecase{
|
||||
config.UsecaseRealtimeAudio: config.FLAG_REALTIME_AUDIO,
|
||||
}
|
||||
|
||||
|
||||
// extractHFRepo tries to find a HuggingFace repo ID from model overrides or URLs.
|
||||
func extractHFRepo(overrides map[string]any, urls []string) string {
|
||||
if overrides != nil {
|
||||
@@ -214,6 +214,17 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model
|
||||
}
|
||||
}
|
||||
|
||||
// Node-scoped backend ops (from /api/nodes/:id/backends/install)
|
||||
// carry the nodeID inside the opcache key as "node:<nodeID>:<backend>".
|
||||
// Pull it back out so the operations panel can label which node the
|
||||
// install is targeting, and so the display name is just the backend
|
||||
// slug instead of the full prefixed key.
|
||||
scopedNodeID := ""
|
||||
if nodeID, backend, ok := galleryop.ParseNodeScopedKey(galleryID); ok {
|
||||
scopedNodeID = nodeID
|
||||
galleryID = backend
|
||||
}
|
||||
|
||||
// Extract display name (remove repo prefix if exists)
|
||||
displayName := galleryID
|
||||
if strings.Contains(galleryID, "@") {
|
||||
@@ -237,9 +248,53 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model
|
||||
"cancellable": isCancellable,
|
||||
"message": message,
|
||||
}
|
||||
// Only attach nodeID when this op was node-scoped: an empty string
|
||||
// would mislead the UI into rendering a node attribution that never
|
||||
// existed in the first place.
|
||||
if scopedNodeID != "" {
|
||||
opData["nodeID"] = scopedNodeID
|
||||
}
|
||||
if status != nil && status.Error != nil {
|
||||
opData["error"] = status.Error.Error()
|
||||
}
|
||||
// Expose the per-node breakdown when the Phase 4 progress sink
|
||||
// has populated OpStatus.Nodes (distributed backend installs).
|
||||
// We sort by node_name for stable UI rendering across polls;
|
||||
// the underlying slice is order-dependent on UpdateNodeProgress
|
||||
// arrival order, which the UI must not depend on. Single-node
|
||||
// ops and model installs leave Nodes empty so this block emits
|
||||
// no key, preserving the legacy payload shape.
|
||||
if status != nil && len(status.Nodes) > 0 {
|
||||
nodes := make([]map[string]any, 0, len(status.Nodes))
|
||||
for _, n := range status.Nodes {
|
||||
entry := map[string]any{
|
||||
"node_id": n.NodeID,
|
||||
"node_name": n.NodeName,
|
||||
"status": n.Status,
|
||||
"percentage": n.Percentage,
|
||||
}
|
||||
if n.FileName != "" {
|
||||
entry["file_name"] = n.FileName
|
||||
}
|
||||
if n.Current != "" {
|
||||
entry["current"] = n.Current
|
||||
}
|
||||
if n.Total != "" {
|
||||
entry["total"] = n.Total
|
||||
}
|
||||
if n.Phase != "" {
|
||||
entry["phase"] = n.Phase
|
||||
}
|
||||
if n.Error != "" {
|
||||
entry["error"] = n.Error
|
||||
}
|
||||
nodes = append(nodes, entry)
|
||||
}
|
||||
sort.SliceStable(nodes, func(i, j int) bool {
|
||||
return fmt.Sprintf("%v", nodes[i]["node_name"]) < fmt.Sprintf("%v", nodes[j]["node_name"])
|
||||
})
|
||||
opData["nodes"] = nodes
|
||||
}
|
||||
operations = append(operations, opData)
|
||||
}
|
||||
|
||||
@@ -540,11 +595,11 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model
|
||||
NodeStatus string `json:"node_status"`
|
||||
}
|
||||
type modelCapability struct {
|
||||
ID string `json:"id"`
|
||||
Capabilities []string `json:"capabilities"`
|
||||
Backend string `json:"backend"`
|
||||
Disabled bool `json:"disabled"`
|
||||
Pinned bool `json:"pinned"`
|
||||
ID string `json:"id"`
|
||||
Capabilities []string `json:"capabilities"`
|
||||
Backend string `json:"backend"`
|
||||
Disabled bool `json:"disabled"`
|
||||
Pinned bool `json:"pinned"`
|
||||
// LoadedOn is populated only when the node registry is active
|
||||
// (distributed mode). Lets the UI show "loaded on worker-1" without
|
||||
// the operator having to expand every node manually. An empty slice
|
||||
@@ -1142,17 +1197,17 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model
|
||||
}
|
||||
|
||||
return c.JSON(200, map[string]any{
|
||||
"backends": backendsJSON,
|
||||
"repositories": appConfig.BackendGalleries,
|
||||
"allTags": tags,
|
||||
"processingBackends": processingBackendsData,
|
||||
"taskTypes": taskTypes,
|
||||
"availableBackends": totalBackends,
|
||||
"installedBackends": installedBackendsCount,
|
||||
"currentPage": pageNum,
|
||||
"totalPages": totalPages,
|
||||
"prevPage": prevPage,
|
||||
"nextPage": nextPage,
|
||||
"backends": backendsJSON,
|
||||
"repositories": appConfig.BackendGalleries,
|
||||
"allTags": tags,
|
||||
"processingBackends": processingBackendsData,
|
||||
"taskTypes": taskTypes,
|
||||
"availableBackends": totalBackends,
|
||||
"installedBackends": installedBackendsCount,
|
||||
"currentPage": pageNum,
|
||||
"totalPages": totalPages,
|
||||
"prevPage": prevPage,
|
||||
"nextPage": nextPage,
|
||||
"systemCapability": detectedCapability,
|
||||
"preferDevelopmentBackends": appConfig.PreferDevelopmentBackends,
|
||||
})
|
||||
@@ -1582,4 +1637,3 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model
|
||||
app.DELETE("/api/branding/asset/:kind", localai.DeleteBrandingAssetEndpoint(appConfig), adminMiddleware)
|
||||
|
||||
}
|
||||
|
||||
|
||||
155
core/http/routes/ui_api_operations_test.go
Normal file
155
core/http/routes/ui_api_operations_test.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package routes_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/routes"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
)
|
||||
|
||||
// These specs guard the contract between the opcache (which stores
|
||||
// node-scoped backend installs under a "node:<nodeID>:<backend>" key) and the
|
||||
// /api/operations response surface the React UI polls. Without nodeID
|
||||
// extraction the panel would show the raw prefixed key and have no way to
|
||||
// label which worker an install is targeting.
|
||||
var _ = Describe("/api/operations with node-scoped backend ops", func() {
|
||||
// We pass a zero-value *application.Application because the handler's
|
||||
// distributed-services branch guards on a nil check on the returned
|
||||
// *DistributedServices, which is nil for a fresh Application{}.
|
||||
noopMw := func(next echo.HandlerFunc) echo.HandlerFunc { return next }
|
||||
|
||||
It("emits nodeID and the un-prefixed backend name for keys built by NodeScopedKey", func() {
|
||||
appCfg := &config.ApplicationConfig{}
|
||||
galleryService := galleryop.NewGalleryService(appCfg, nil)
|
||||
opcache := galleryop.NewOpCache(galleryService)
|
||||
|
||||
key := galleryop.NodeScopedKey("worker-7", "llama-cpp")
|
||||
opcache.SetBackend(key, "job-uuid-123")
|
||||
|
||||
e := echo.New()
|
||||
routes.RegisterUIAPIRoutes(e, nil, nil, appCfg, galleryService, opcache, &application.Application{}, noopMw)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/operations", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
// The handler wraps operations in {"operations": [...]}.
|
||||
var envelope struct {
|
||||
Operations []map[string]any `json:"operations"`
|
||||
}
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &envelope)).To(Succeed())
|
||||
|
||||
var found map[string]any
|
||||
for _, op := range envelope.Operations {
|
||||
if op["jobID"] == "job-uuid-123" {
|
||||
found = op
|
||||
break
|
||||
}
|
||||
}
|
||||
Expect(found).ToNot(BeNil(), "node-scoped op should appear in /api/operations")
|
||||
Expect(found["nodeID"]).To(Equal("worker-7"))
|
||||
Expect(found["name"]).To(Equal("llama-cpp"))
|
||||
Expect(found["isBackend"]).To(Equal(true))
|
||||
})
|
||||
|
||||
It("surfaces per-node OpStatus entries on /api/operations", func() {
|
||||
appCfg := &config.ApplicationConfig{}
|
||||
galleryService := galleryop.NewGalleryService(appCfg, nil)
|
||||
opcache := galleryop.NewOpCache(galleryService)
|
||||
|
||||
jobID := "test-op-nodes-1"
|
||||
// Register a backend op so the handler treats this as a backend
|
||||
// install (no need to consult the gallery during the test).
|
||||
opcache.SetBackend("vllm", jobID)
|
||||
|
||||
// Populate per-node entries via the P4.2 helper. The helper also
|
||||
// allocates an OpStatus under jobID, which the handler will read.
|
||||
galleryService.UpdateNodeProgress(jobID, "node-b", galleryop.NodeProgress{
|
||||
NodeID: "node-b", NodeName: "worker-b", Status: "running_on_worker",
|
||||
})
|
||||
galleryService.UpdateNodeProgress(jobID, "node-a", galleryop.NodeProgress{
|
||||
NodeID: "node-a", NodeName: "worker-a", Status: "downloading", Percentage: 30, FileName: "vllm.tar",
|
||||
})
|
||||
|
||||
e := echo.New()
|
||||
routes.RegisterUIAPIRoutes(e, nil, nil, appCfg, galleryService, opcache, &application.Application{}, noopMw)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/operations", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var envelope struct {
|
||||
Operations []map[string]any `json:"operations"`
|
||||
}
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &envelope)).To(Succeed())
|
||||
|
||||
var found map[string]any
|
||||
for _, op := range envelope.Operations {
|
||||
if op["jobID"] == jobID {
|
||||
found = op
|
||||
break
|
||||
}
|
||||
}
|
||||
Expect(found).ToNot(BeNil(), "operation should appear in /api/operations")
|
||||
nodes, ok := found["nodes"].([]any)
|
||||
Expect(ok).To(BeTrue(), "operation should have a nodes array")
|
||||
Expect(nodes).To(HaveLen(2))
|
||||
|
||||
// Stable sort by node_name: "worker-a" comes before "worker-b"
|
||||
// even though UpdateNodeProgress was called in reverse order.
|
||||
first := nodes[0].(map[string]any)
|
||||
Expect(first["node_name"]).To(Equal("worker-a"))
|
||||
Expect(first["status"]).To(Equal("downloading"))
|
||||
Expect(first["file_name"]).To(Equal("vllm.tar"))
|
||||
Expect(first["percentage"]).To(Equal(30.0))
|
||||
|
||||
second := nodes[1].(map[string]any)
|
||||
Expect(second["node_name"]).To(Equal("worker-b"))
|
||||
Expect(second["status"]).To(Equal("running_on_worker"))
|
||||
})
|
||||
|
||||
It("does not emit nodeID for non-node-scoped backend ops", func() {
|
||||
appCfg := &config.ApplicationConfig{}
|
||||
galleryService := galleryop.NewGalleryService(appCfg, nil)
|
||||
opcache := galleryop.NewOpCache(galleryService)
|
||||
|
||||
// Legacy/global install path: bare backend name as the opcache key.
|
||||
opcache.SetBackend("llama-cpp", "job-uuid-456")
|
||||
|
||||
e := echo.New()
|
||||
routes.RegisterUIAPIRoutes(e, nil, nil, appCfg, galleryService, opcache, &application.Application{}, noopMw)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/operations", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
var envelope struct {
|
||||
Operations []map[string]any `json:"operations"`
|
||||
}
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &envelope)).To(Succeed())
|
||||
|
||||
var found map[string]any
|
||||
for _, op := range envelope.Operations {
|
||||
if op["jobID"] == "job-uuid-456" {
|
||||
found = op
|
||||
break
|
||||
}
|
||||
}
|
||||
Expect(found).ToNot(BeNil())
|
||||
// Critical: bare ops must NOT gain a misleading empty nodeID field.
|
||||
Expect(found).ToNot(HaveKey("nodeID"), "non-node-scoped ops must NOT carry a nodeID field")
|
||||
Expect(found["name"]).To(Equal("llama-cpp"))
|
||||
})
|
||||
})
|
||||
@@ -91,6 +91,21 @@ func (g *GalleryService) backendHandler(op *ManagementOp[gallery.GalleryBackend,
|
||||
})
|
||||
return err
|
||||
}
|
||||
if errors.Is(err, ErrWorkerStillInstalling) {
|
||||
// Soft failure: at least one worker timed out replying but is
|
||||
// still running the install in the background. Mark the op as
|
||||
// processed with a non-error message so the admin UI shows a
|
||||
// yellow in-progress state rather than red. The reconciler's
|
||||
// next pass will reconcile the actual outcome via backend.list.
|
||||
xlog.Info("worker still installing in background", "backend", op.GalleryElementName, "error", err)
|
||||
g.UpdateStatus(op.ID, &OpStatus{
|
||||
Processed: true,
|
||||
GalleryElementName: op.GalleryElementName,
|
||||
Message: fmt.Sprintf("backend %s: worker still installing in background; reconciler will confirm completion (%v)", op.GalleryElementName, err),
|
||||
Cancellable: false,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
xlog.Error("error installing backend", "error", err, "backend", op.GalleryElementName)
|
||||
if !op.Delete {
|
||||
// If we didn't install the backend, we need to make sure we don't have a leftover directory
|
||||
|
||||
@@ -196,4 +196,60 @@ var _ = Describe("ManagementOp with External Backend", func() {
|
||||
Expect(op.ExternalName).To(Equal("test-backend"))
|
||||
Expect(op.ExternalAlias).To(Equal("test-alias"))
|
||||
})
|
||||
|
||||
Context("TargetNodeID field", func() {
|
||||
It("defaults to empty string", func() {
|
||||
op := galleryop.ManagementOp[string, string]{
|
||||
ExternalURI: "oci://example.com/backend:latest",
|
||||
}
|
||||
Expect(op.TargetNodeID).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("preserves TargetNodeID across a channel send", func() {
|
||||
ch := make(chan galleryop.ManagementOp[string, string], 1)
|
||||
ch <- galleryop.ManagementOp[string, string]{
|
||||
GalleryElementName: "llama-cpp",
|
||||
TargetNodeID: "node-abc-123",
|
||||
}
|
||||
received := <-ch
|
||||
Expect(received.TargetNodeID).To(Equal("node-abc-123"))
|
||||
Expect(received.GalleryElementName).To(Equal("llama-cpp"))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("NodeScopedKey", func() {
|
||||
It("builds a unique key per (nodeID, backend) pair", func() {
|
||||
Expect(galleryop.NodeScopedKey("node-a", "llama-cpp")).To(Equal("node:node-a:llama-cpp"))
|
||||
Expect(galleryop.NodeScopedKey("node-b", "llama-cpp")).To(Equal("node:node-b:llama-cpp"))
|
||||
Expect(galleryop.NodeScopedKey("node-a", "vllm")).To(Equal("node:node-a:vllm"))
|
||||
})
|
||||
|
||||
It("handles backend names containing colons", func() {
|
||||
// Gallery IDs sometimes look like "official@llama-cpp"; nodeIDs are UUIDs
|
||||
// without colons, but the backend slug may contain anything. Splitting on
|
||||
// the first colon after the prefix MUST yield the full backend back.
|
||||
key := galleryop.NodeScopedKey("node-1", "official@llama-cpp:v2")
|
||||
node, backend, ok := galleryop.ParseNodeScopedKey(key)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(node).To(Equal("node-1"))
|
||||
Expect(backend).To(Equal("official@llama-cpp:v2"))
|
||||
})
|
||||
|
||||
It("rejects keys without the node prefix", func() {
|
||||
_, _, ok := galleryop.ParseNodeScopedKey("llama-cpp")
|
||||
Expect(ok).To(BeFalse())
|
||||
_, _, ok = galleryop.ParseNodeScopedKey("official@llama-cpp")
|
||||
Expect(ok).To(BeFalse())
|
||||
})
|
||||
|
||||
It("rejects malformed node-prefixed keys", func() {
|
||||
_, _, ok := galleryop.ParseNodeScopedKey("node:only-one-segment")
|
||||
Expect(ok).To(BeFalse())
|
||||
})
|
||||
|
||||
It("rejects keys with an empty nodeID segment", func() {
|
||||
_, _, ok := galleryop.ParseNodeScopedKey("node::llama-cpp")
|
||||
Expect(ok).To(BeFalse())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
13
core/services/galleryop/errors.go
Normal file
13
core/services/galleryop/errors.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package galleryop
|
||||
|
||||
import "errors"
|
||||
|
||||
// ErrWorkerStillInstalling indicates a distributed backend install
|
||||
// timed out at the NATS round-trip layer but the worker is most likely
|
||||
// still pulling the OCI image in the background. Producers
|
||||
// (DistributedBackendManager) wrap this when the round-trip times out;
|
||||
// consumers (backendHandler) use errors.Is(err, ErrWorkerStillInstalling)
|
||||
// to surface a yellow "in progress" OpStatus instead of a red error,
|
||||
// leaving the pending_backend_ops row in place for the reconciler to
|
||||
// confirm via backend.list.
|
||||
var ErrWorkerStillInstalling = errors.New("worker did not reply in time; install may still be running in the background")
|
||||
93
core/services/galleryop/node_progress_test.go
Normal file
93
core/services/galleryop/node_progress_test.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package galleryop_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
)
|
||||
|
||||
var _ = Describe("OpStatus.Nodes", func() {
|
||||
It("defaults to empty on a fresh OpStatus", func() {
|
||||
os := &galleryop.OpStatus{}
|
||||
Expect(os.Nodes).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("JSON round-trips with all NodeProgress fields", func() {
|
||||
os := &galleryop.OpStatus{
|
||||
Nodes: []galleryop.NodeProgress{
|
||||
{
|
||||
NodeID: "node-1",
|
||||
NodeName: "worker-a",
|
||||
Status: "running_on_worker",
|
||||
FileName: "vllm.tar.zst",
|
||||
Current: "412 MB",
|
||||
Total: "2.1 GB",
|
||||
Percentage: 19.6,
|
||||
Phase: "downloading",
|
||||
Error: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
raw, err := json.Marshal(os)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
got := &galleryop.OpStatus{}
|
||||
Expect(json.Unmarshal(raw, got)).To(Succeed())
|
||||
Expect(got.Nodes).To(HaveLen(1))
|
||||
Expect(got.Nodes[0]).To(Equal(os.Nodes[0]))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("GalleryService.UpdateNodeProgress", func() {
|
||||
var svc *galleryop.GalleryService
|
||||
|
||||
BeforeEach(func() {
|
||||
// UpdateNodeProgress + GetStatus only touch the in-memory statuses
|
||||
// map. A zero-value ApplicationConfig is enough to get past the
|
||||
// LocalModelManager / LocalBackendManager constructors.
|
||||
svc = galleryop.NewGalleryService(&config.ApplicationConfig{}, nil)
|
||||
})
|
||||
|
||||
It("creates a node entry on first call", func() {
|
||||
svc.UpdateNodeProgress("op1", "n1", galleryop.NodeProgress{
|
||||
NodeID: "n1", NodeName: "worker-a", Status: "downloading", Percentage: 12.0,
|
||||
})
|
||||
st := svc.GetStatus("op1")
|
||||
Expect(st).ToNot(BeNil())
|
||||
Expect(st.Nodes).To(HaveLen(1))
|
||||
Expect(st.Nodes[0].NodeID).To(Equal("n1"))
|
||||
Expect(st.Nodes[0].Percentage).To(Equal(12.0))
|
||||
})
|
||||
|
||||
It("merges subsequent updates into the same NodeID entry, not appending", func() {
|
||||
svc.UpdateNodeProgress("op1", "n1", galleryop.NodeProgress{NodeID: "n1", NodeName: "worker-a", Status: "downloading", Percentage: 12.0})
|
||||
svc.UpdateNodeProgress("op1", "n1", galleryop.NodeProgress{NodeID: "n1", NodeName: "worker-a", Status: "downloading", Percentage: 48.0, FileName: "vllm.tar"})
|
||||
st := svc.GetStatus("op1")
|
||||
Expect(st.Nodes).To(HaveLen(1))
|
||||
Expect(st.Nodes[0].Percentage).To(Equal(48.0))
|
||||
Expect(st.Nodes[0].FileName).To(Equal("vllm.tar"))
|
||||
})
|
||||
|
||||
It("appends a new entry for a different NodeID", func() {
|
||||
svc.UpdateNodeProgress("op1", "n1", galleryop.NodeProgress{NodeID: "n1", NodeName: "worker-a", Status: "downloading", Percentage: 12.0})
|
||||
svc.UpdateNodeProgress("op1", "n2", galleryop.NodeProgress{NodeID: "n2", NodeName: "worker-b", Status: "queued"})
|
||||
st := svc.GetStatus("op1")
|
||||
Expect(st.Nodes).To(HaveLen(2))
|
||||
})
|
||||
|
||||
It("mirrors the latest tick into the aggregate OpStatus fields", func() {
|
||||
svc.UpdateNodeProgress("op1", "n1", galleryop.NodeProgress{
|
||||
NodeID: "n1", NodeName: "worker-a", Status: "downloading",
|
||||
Percentage: 33.0, FileName: "vllm.tar", Current: "330 MB", Total: "1 GB",
|
||||
})
|
||||
st := svc.GetStatus("op1")
|
||||
Expect(st.Progress).To(Equal(33.0))
|
||||
Expect(st.FileName).To(Equal("vllm.tar"))
|
||||
Expect(st.DownloadedFileSize).To(Equal("330 MB"))
|
||||
Expect(st.TotalFileSize).To(Equal("1 GB"))
|
||||
})
|
||||
})
|
||||
@@ -2,6 +2,7 @@ package galleryop
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/xsync"
|
||||
@@ -30,6 +31,12 @@ type ManagementOp[T any, E any] struct {
|
||||
ExternalName string // Custom name for the backend
|
||||
ExternalAlias string // Custom alias for the backend
|
||||
|
||||
// TargetNodeID scopes a backend install/upgrade to a single worker node.
|
||||
// Empty means fan out to every healthy backend node (the previous behavior).
|
||||
// Set by InstallBackendOnNodeEndpoint so an admin can install a hardware-specific
|
||||
// build on one node without touching the rest of the cluster.
|
||||
TargetNodeID string
|
||||
|
||||
// Upgrade is true if this is an upgrade operation (not a fresh install)
|
||||
Upgrade bool
|
||||
}
|
||||
@@ -46,6 +53,30 @@ type OpStatus struct {
|
||||
GalleryElementName string `json:"gallery_element_name"`
|
||||
Cancelled bool `json:"cancelled"` // Cancelled is true if the operation was cancelled
|
||||
Cancellable bool `json:"cancellable"` // Cancellable is true if the operation can be cancelled
|
||||
|
||||
// Nodes is the per-node breakdown for a fanned-out backend install.
|
||||
// Populated by DistributedBackendManager (per-node terminal status)
|
||||
// and by the Phase 2 progress bridge (per-byte ticks). The
|
||||
// /api/operations handler surfaces this so the UI can render an
|
||||
// expandable per-node view of an in-flight install.
|
||||
Nodes []NodeProgress `json:"nodes,omitempty"`
|
||||
}
|
||||
|
||||
// NodeProgress is a single node's contribution to a backend install
|
||||
// operation. Populated by DistributedBackendManager (per-node terminal
|
||||
// status) and by the Phase 2 progress bridge (per-byte ticks). Read by
|
||||
// the /api/operations handler so the UI can render an expandable
|
||||
// per-node breakdown.
|
||||
type NodeProgress struct {
|
||||
NodeID string `json:"node_id"`
|
||||
NodeName string `json:"node_name"`
|
||||
Status string `json:"status"` // "queued" | "running_on_worker" | "success" | "error" | "downloading"
|
||||
FileName string `json:"file_name,omitempty"`
|
||||
Current string `json:"current,omitempty"`
|
||||
Total string `json:"total,omitempty"`
|
||||
Percentage float64 `json:"percentage"`
|
||||
Phase string `json:"phase,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type OpCache struct {
|
||||
@@ -115,3 +146,31 @@ func (m *OpCache) GetStatus() (map[string]string, map[string]string) {
|
||||
|
||||
return processingModelsData, taskTypes
|
||||
}
|
||||
|
||||
// NodeScopedKeyPrefix is the opcache key prefix used by InstallBackendOnNodeEndpoint
|
||||
// so per-node installs do not collide on the bare backend name. Format:
|
||||
// "node:<nodeID>:<backend>". Read by /api/operations to extract nodeID for the UI.
|
||||
const NodeScopedKeyPrefix = "node:"
|
||||
|
||||
// NodeScopedKey returns the opcache key for a node-scoped backend operation.
|
||||
// The prefix lets ParseNodeScopedKey detach the nodeID back out so the
|
||||
// operations endpoint can surface it without storing nodeID separately.
|
||||
func NodeScopedKey(nodeID, backend string) string {
|
||||
return NodeScopedKeyPrefix + nodeID + ":" + backend
|
||||
}
|
||||
|
||||
// ParseNodeScopedKey extracts (nodeID, backend) from a key built by NodeScopedKey.
|
||||
// Returns ok=false for keys that lack the prefix or are missing the nodeID or
|
||||
// backend segment. Backend names containing colons are preserved because we
|
||||
// split on the first colon after the prefix only.
|
||||
func ParseNodeScopedKey(key string) (nodeID, backend string, ok bool) {
|
||||
rest, hasPrefix := strings.CutPrefix(key, NodeScopedKeyPrefix)
|
||||
if !hasPrefix {
|
||||
return "", "", false
|
||||
}
|
||||
nodeID, backend, ok = strings.Cut(rest, ":")
|
||||
if !ok || nodeID == "" || backend == "" {
|
||||
return "", "", false
|
||||
}
|
||||
return nodeID, backend, true
|
||||
}
|
||||
|
||||
@@ -135,6 +135,47 @@ func (g *GalleryService) UpdateStatus(s string, op *OpStatus) {
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateNodeProgress merges a per-node progress tick into OpStatus.Nodes,
|
||||
// keyed by nodeID, and mirrors the latest values into the aggregate
|
||||
// Progress / FileName / DownloadedFileSize / TotalFileSize / Message
|
||||
// fields so the legacy single-bar OperationsBar view keeps working
|
||||
// unchanged alongside the new per-node breakdown.
|
||||
//
|
||||
// We deliberately do NOT delegate the aggregate mirror to UpdateStatus
|
||||
// here: UpdateStatus overwrites the entire OpStatus, which would clobber
|
||||
// the Nodes slice we just merged into. Doing the merge + mirror under a
|
||||
// single lock keeps both views consistent and concurrent-safe.
|
||||
func (g *GalleryService) UpdateNodeProgress(opID, nodeID string, np NodeProgress) {
|
||||
g.Lock()
|
||||
defer g.Unlock()
|
||||
status := g.statuses[opID]
|
||||
if status == nil {
|
||||
status = &OpStatus{}
|
||||
g.statuses[opID] = status
|
||||
}
|
||||
merged := false
|
||||
for i := range status.Nodes {
|
||||
if status.Nodes[i].NodeID == nodeID {
|
||||
status.Nodes[i] = np
|
||||
merged = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !merged {
|
||||
status.Nodes = append(status.Nodes, np)
|
||||
}
|
||||
|
||||
// Mirror the latest tick into the legacy aggregate fields so the
|
||||
// existing single-bar UI keeps rendering meaningful progress.
|
||||
status.FileName = np.FileName
|
||||
status.Progress = np.Percentage
|
||||
status.DownloadedFileSize = np.Current
|
||||
status.TotalFileSize = np.Total
|
||||
if np.Phase != "" {
|
||||
status.Message = np.Phase
|
||||
}
|
||||
}
|
||||
|
||||
func (g *GalleryService) GetStatus(s string) *OpStatus {
|
||||
g.Lock()
|
||||
defer g.Unlock()
|
||||
|
||||
23
core/services/messaging/backend_install_progress.go
Normal file
23
core/services/messaging/backend_install_progress.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package messaging
|
||||
|
||||
// BackendInstallProgressEvent is the wire payload published by a worker to
|
||||
// nodes.<nodeID>.backend.install.<opID>.progress while a long-running install
|
||||
// is in flight. Transient: dropped events are acceptable, the master relies
|
||||
// on BackendInstallReply for ground truth on success/failure.
|
||||
type BackendInstallProgressEvent struct {
|
||||
OpID string `json:"op_id"`
|
||||
NodeID string `json:"node_id"`
|
||||
Backend string `json:"backend"`
|
||||
FileName string `json:"file_name,omitempty"`
|
||||
Current string `json:"current,omitempty"` // human-readable size, e.g. "412 MB"
|
||||
Total string `json:"total,omitempty"` // human-readable size, e.g. "2.1 GB"
|
||||
Percentage float64 `json:"percentage"`
|
||||
Phase string `json:"phase,omitempty"` // "resolving" | "downloading" | "extracting" | "starting"
|
||||
}
|
||||
|
||||
// SubjectNodeBackendInstallProgress returns the NATS subject for transient
|
||||
// progress events emitted by a worker during a single backend.install run.
|
||||
// Per-op so multiple concurrent installs on the same node never alias.
|
||||
func SubjectNodeBackendInstallProgress(nodeID, opID string) string {
|
||||
return subjectNodePrefix + sanitizeSubjectToken(nodeID) + ".backend.install." + sanitizeSubjectToken(opID) + ".progress"
|
||||
}
|
||||
51
core/services/messaging/backend_install_progress_test.go
Normal file
51
core/services/messaging/backend_install_progress_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package messaging_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
)
|
||||
|
||||
var _ = Describe("BackendInstallProgress", func() {
|
||||
Context("SubjectNodeBackendInstallProgress", func() {
|
||||
It("composes the per-op progress subject", func() {
|
||||
Expect(messaging.SubjectNodeBackendInstallProgress("node-abc", "op-123")).
|
||||
To(Equal("nodes.node-abc.backend.install.op-123.progress"))
|
||||
})
|
||||
|
||||
It("sanitizes NATS-reserved characters in node and op tokens", func() {
|
||||
// '.' is the NATS hierarchy delimiter, '*' and '>' are wildcards,
|
||||
// and whitespace must be stripped - sanitizeSubjectToken replaces
|
||||
// all of them with '-'. The resulting subject must still parse as
|
||||
// exactly six hierarchy segments: nodes/<node>/backend/install/<op>/progress.
|
||||
subj := messaging.SubjectNodeBackendInstallProgress("a.b c", "x.y z")
|
||||
Expect(subj).ToNot(ContainSubstring(" "))
|
||||
Expect(strings.Count(subj, ".")).To(Equal(5))
|
||||
})
|
||||
})
|
||||
|
||||
Context("BackendInstallProgressEvent", func() {
|
||||
It("JSON round-trips with all known fields", func() {
|
||||
ev := messaging.BackendInstallProgressEvent{
|
||||
OpID: "op-123",
|
||||
NodeID: "node-abc",
|
||||
Backend: "vllm",
|
||||
FileName: "vllm-cpu.tar.zst",
|
||||
Current: "412 MB",
|
||||
Total: "2.1 GB",
|
||||
Percentage: 19.6,
|
||||
Phase: "downloading",
|
||||
}
|
||||
raw, err := json.Marshal(ev)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var got messaging.BackendInstallProgressEvent
|
||||
Expect(json.Unmarshal(raw, &got)).To(Succeed())
|
||||
Expect(got).To(Equal(ev))
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -144,6 +144,12 @@ type BackendInstallRequest struct {
|
||||
// worker still works (the master's install fallback path also uses this
|
||||
// when backend.upgrade returns nats.ErrNoResponders).
|
||||
Force bool `json:"force,omitempty"`
|
||||
// OpID identifies the admin-side operation. When non-empty the worker
|
||||
// publishes BackendInstallProgressEvent values to
|
||||
// SubjectNodeBackendInstallProgress(nodeID, OpID) while the install is
|
||||
// running, debounced to roughly 250ms. Empty means the caller is a
|
||||
// reconciler-driven retry that does not need progress streamed.
|
||||
OpID string `json:"op_id,omitempty"`
|
||||
}
|
||||
|
||||
// BackendInstallReply is the response from a backend.install NATS request.
|
||||
|
||||
120
core/services/nodes/install_progress_publisher.go
Normal file
120
core/services/nodes/install_progress_publisher.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
)
|
||||
|
||||
// DebouncedInstallProgressPublisher buffers backend-install download ticks
|
||||
// and publishes them to the per-op NATS progress subject at most once per
|
||||
// `interval`. Always publishes the final event on Flush so the UI sees the
|
||||
// terminal percentage.
|
||||
//
|
||||
// Behavior: leading-edge debounce. The first OnDownload after a quiet window
|
||||
// publishes immediately; subsequent ticks within `interval` only buffer the
|
||||
// latest event, which is then emitted via a single trailing timer. This
|
||||
// keeps the wire chatter bounded (~4 events per second at 250ms) while
|
||||
// still surfacing every meaningful percentage jump.
|
||||
//
|
||||
// Lock ordering: never hold p.mu across a Publish call. Publish hits the
|
||||
// NATS client which may block on a slow link, and we don't want a stalled
|
||||
// network to stall the underlying gallery download loop.
|
||||
type DebouncedInstallProgressPublisher struct {
|
||||
mu sync.Mutex
|
||||
client messaging.MessagingClient
|
||||
subject string
|
||||
nodeID string
|
||||
opID string
|
||||
backend string
|
||||
interval time.Duration
|
||||
lastPublishedAt time.Time
|
||||
pending *messaging.BackendInstallProgressEvent
|
||||
timer *time.Timer
|
||||
}
|
||||
|
||||
// NewDebouncedInstallProgressPublisher constructs a publisher for one
|
||||
// install operation. interval is the leading-edge debounce window
|
||||
// (~250ms in production).
|
||||
func NewDebouncedInstallProgressPublisher(client messaging.MessagingClient, nodeID, opID, backend string, interval time.Duration) *DebouncedInstallProgressPublisher {
|
||||
return &DebouncedInstallProgressPublisher{
|
||||
client: client,
|
||||
subject: messaging.SubjectNodeBackendInstallProgress(nodeID, opID),
|
||||
nodeID: nodeID,
|
||||
opID: opID,
|
||||
backend: backend,
|
||||
interval: interval,
|
||||
}
|
||||
}
|
||||
|
||||
// OnDownload is the callback shape gallery.InstallBackendFromGallery and
|
||||
// galleryop.InstallExternalBackend pass into the worker. Each invocation
|
||||
// represents a single tick from the underlying io.Reader copy loop.
|
||||
func (p *DebouncedInstallProgressPublisher) OnDownload(file, current, total string, percentage float64) {
|
||||
ev := messaging.BackendInstallProgressEvent{
|
||||
OpID: p.opID,
|
||||
NodeID: p.nodeID,
|
||||
Backend: p.backend,
|
||||
FileName: file,
|
||||
Current: current,
|
||||
Total: total,
|
||||
Percentage: percentage,
|
||||
Phase: "downloading",
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
now := time.Now()
|
||||
if p.lastPublishedAt.IsZero() || now.Sub(p.lastPublishedAt) >= p.interval {
|
||||
// Leading edge: publish immediately.
|
||||
p.lastPublishedAt = now
|
||||
p.pending = nil
|
||||
p.mu.Unlock()
|
||||
_ = p.client.Publish(p.subject, ev)
|
||||
return
|
||||
}
|
||||
// Within the window: buffer the latest event and arm a trailing
|
||||
// publish. If a timer is already armed, we just overwrite p.pending so
|
||||
// the trailing publish carries the freshest data.
|
||||
p.pending = &ev
|
||||
if p.timer == nil {
|
||||
delay := p.interval - now.Sub(p.lastPublishedAt)
|
||||
p.timer = time.AfterFunc(delay, p.flushPending)
|
||||
}
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
// flushPending is the trailing-edge publisher fired by the AfterFunc timer.
|
||||
// It clears the pending slot under the lock, then publishes outside the
|
||||
// lock so Publish never blocks an in-progress OnDownload call.
|
||||
func (p *DebouncedInstallProgressPublisher) flushPending() {
|
||||
p.mu.Lock()
|
||||
p.timer = nil
|
||||
pending := p.pending
|
||||
p.pending = nil
|
||||
if pending != nil {
|
||||
p.lastPublishedAt = time.Now()
|
||||
}
|
||||
p.mu.Unlock()
|
||||
if pending != nil {
|
||||
_ = p.client.Publish(p.subject, *pending)
|
||||
}
|
||||
}
|
||||
|
||||
// Flush publishes any pending buffered event synchronously and stops the
|
||||
// pending timer. Safe to call multiple times. Callers MUST defer Flush
|
||||
// after constructing the publisher so the terminal percentage reaches the
|
||||
// master even on error returns.
|
||||
func (p *DebouncedInstallProgressPublisher) Flush() {
|
||||
p.mu.Lock()
|
||||
if p.timer != nil {
|
||||
p.timer.Stop()
|
||||
p.timer = nil
|
||||
}
|
||||
pending := p.pending
|
||||
p.pending = nil
|
||||
p.mu.Unlock()
|
||||
if pending != nil {
|
||||
_ = p.client.Publish(p.subject, *pending)
|
||||
}
|
||||
}
|
||||
48
core/services/nodes/install_progress_publisher_test.go
Normal file
48
core/services/nodes/install_progress_publisher_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
)
|
||||
|
||||
var _ = Describe("DebouncedInstallProgressPublisher", func() {
|
||||
It("publishes the first event immediately and debounces subsequent ones within the window", func() {
|
||||
mc := newScriptedMessagingClient()
|
||||
pub := NewDebouncedInstallProgressPublisher(mc, "n1", "op1", "vllm", 50*time.Millisecond)
|
||||
|
||||
// Three rapid-fire ticks within the debounce window.
|
||||
pub.OnDownload("vllm.tar.zst", "100 MB", "1 GB", 10.0)
|
||||
pub.OnDownload("vllm.tar.zst", "200 MB", "1 GB", 20.0)
|
||||
pub.OnDownload("vllm.tar.zst", "300 MB", "1 GB", 30.0)
|
||||
pub.Flush()
|
||||
|
||||
// First event publishes immediately; the others coalesce; Flush guarantees a final.
|
||||
// So we expect at least 2 publishes and at most 4 (lead + final + any window-bounded).
|
||||
Eventually(func() int {
|
||||
return len(mc.publishCalls(messaging.SubjectNodeBackendInstallProgress("n1", "op1")))
|
||||
}, "1s").Should(BeNumerically(">=", 2))
|
||||
calls := mc.publishCalls(messaging.SubjectNodeBackendInstallProgress("n1", "op1"))
|
||||
Expect(len(calls)).To(BeNumerically("<=", 4),
|
||||
"three ticks within the debounce window should produce at most ~4 publishes")
|
||||
})
|
||||
|
||||
It("publishes the final event after Flush with the latest percentage", func() {
|
||||
mc := newScriptedMessagingClient()
|
||||
pub := NewDebouncedInstallProgressPublisher(mc, "n1", "op1", "vllm", 50*time.Millisecond)
|
||||
|
||||
pub.OnDownload("vllm.tar.zst", "1 GB", "1 GB", 100.0)
|
||||
pub.Flush()
|
||||
|
||||
Eventually(func() float64 {
|
||||
calls := mc.publishCalls(messaging.SubjectNodeBackendInstallProgress("n1", "op1"))
|
||||
if len(calls) == 0 {
|
||||
return -1
|
||||
}
|
||||
return calls[len(calls)-1].Percentage
|
||||
}, "1s").Should(Equal(100.0))
|
||||
})
|
||||
})
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/mudler/xlog"
|
||||
@@ -48,6 +49,13 @@ func (d *DistributedModelManager) InstallModel(ctx context.Context, op *galleryo
|
||||
return d.local.InstallModel(ctx, op, progressCb)
|
||||
}
|
||||
|
||||
// nodeProgressSink is the narrow interface DistributedBackendManager uses to
|
||||
// publish per-node progress without dragging in the full *GalleryService.
|
||||
// nil means "no sink, skip per-node writes" (used by single-node tests).
|
||||
type nodeProgressSink interface {
|
||||
UpdateNodeProgress(opID, nodeID string, np galleryop.NodeProgress)
|
||||
}
|
||||
|
||||
// DistributedBackendManager wraps a local BackendManager and adds NATS fan-out
|
||||
// for backend deletion so worker nodes clean up stale files.
|
||||
type DistributedBackendManager struct {
|
||||
@@ -56,16 +64,20 @@ type DistributedBackendManager struct {
|
||||
registry *NodeRegistry
|
||||
backendGalleries []config.Gallery
|
||||
systemState *system.SystemState
|
||||
progressSink nodeProgressSink
|
||||
}
|
||||
|
||||
// NewDistributedBackendManager creates a DistributedBackendManager.
|
||||
func NewDistributedBackendManager(appConfig *config.ApplicationConfig, ml *model.ModelLoader, adapter *RemoteUnloaderAdapter, registry *NodeRegistry) *DistributedBackendManager {
|
||||
// progressSink may be nil to disable per-node OpStatus writes (single-node
|
||||
// tests don't need it).
|
||||
func NewDistributedBackendManager(appConfig *config.ApplicationConfig, ml *model.ModelLoader, adapter *RemoteUnloaderAdapter, registry *NodeRegistry, progressSink nodeProgressSink) *DistributedBackendManager {
|
||||
return &DistributedBackendManager{
|
||||
local: galleryop.NewLocalBackendManager(appConfig, ml),
|
||||
adapter: adapter,
|
||||
registry: registry,
|
||||
backendGalleries: appConfig.BackendGalleries,
|
||||
systemState: appConfig.SystemState,
|
||||
progressSink: progressSink,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -75,7 +87,7 @@ func NewDistributedBackendManager(appConfig *config.ApplicationConfig, ml *model
|
||||
type NodeOpStatus struct {
|
||||
NodeID string `json:"node_id"`
|
||||
NodeName string `json:"node_name"`
|
||||
Status string `json:"status"` // "success" | "queued" | "error"
|
||||
Status string `json:"status"` // "success" | "queued" | "error" | "running_on_worker"
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
@@ -116,25 +128,48 @@ func (r BackendOpResult) Err() error {
|
||||
// when the node returns.
|
||||
// targetNodeIDs is an optional allowlist: when non-nil, only nodes whose ID is
|
||||
// in the set are visited. Used by UpgradeBackend to avoid asking nodes that
|
||||
// never had the backend installed to "upgrade" it — such requests fail at the
|
||||
// never had the backend installed to "upgrade" it - such requests fail at the
|
||||
// gallery (no platform variant) and would otherwise leave a forever-retrying
|
||||
// pending_backend_ops row. nil means "fan out to every node" (Install/Delete).
|
||||
func (d *DistributedBackendManager) enqueueAndDrainBackendOp(ctx context.Context, op, backend string, galleriesJSON []byte, targetNodeIDs map[string]bool, apply func(node BackendNode) error) (BackendOpResult, error) {
|
||||
//
|
||||
// opID is the gallery operation identifier; when non-empty and progressSink is
|
||||
// set, every per-node terminal status appended to BackendOpResult is also
|
||||
// mirrored into the sink so the UI's per-node OpStatus.Nodes view stays in
|
||||
// lockstep with the manager's view. opID may be empty for ops that aren't
|
||||
// gallery-tracked (e.g. DeleteBackend's plain code path).
|
||||
func (d *DistributedBackendManager) enqueueAndDrainBackendOp(ctx context.Context, opID, op, backend string, galleriesJSON []byte, targetNodeIDs map[string]bool, apply func(node BackendNode) error) (BackendOpResult, error) {
|
||||
allNodes, err := d.registry.List(ctx)
|
||||
if err != nil {
|
||||
return BackendOpResult{}, err
|
||||
}
|
||||
|
||||
// emitNodeProgress is a small helper that funnels every NodeOpStatus we
|
||||
// append to result.Nodes into the per-node OpStatus sink (when configured
|
||||
// and opID is known). Keeping it inline avoids drift between the
|
||||
// BackendOpResult view and the sink view - they're written from the same
|
||||
// code path on the same terminal statuses.
|
||||
emitNodeProgress := func(node BackendNode, status, errMsg string) {
|
||||
if d.progressSink == nil || opID == "" {
|
||||
return
|
||||
}
|
||||
d.progressSink.UpdateNodeProgress(opID, node.ID, galleryop.NodeProgress{
|
||||
NodeID: node.ID,
|
||||
NodeName: node.Name,
|
||||
Status: status,
|
||||
Error: errMsg,
|
||||
})
|
||||
}
|
||||
|
||||
result := BackendOpResult{Nodes: make([]NodeOpStatus, 0, len(allNodes))}
|
||||
for _, node := range allNodes {
|
||||
// Pending nodes haven't been approved yet — no intent to apply.
|
||||
// Pending nodes haven't been approved yet - no intent to apply.
|
||||
if node.Status == StatusPending {
|
||||
continue
|
||||
}
|
||||
// Backend lifecycle ops only make sense on backend-type workers.
|
||||
// Agent workers don't subscribe to backend.install/delete/list, so
|
||||
// enqueueing for them guarantees a forever-retrying row that the
|
||||
// reconciler can never drain. Silently skip — they aren't consumers.
|
||||
// reconciler can never drain. Silently skip - they aren't consumers.
|
||||
if node.NodeType != "" && node.NodeType != NodeTypeBackend {
|
||||
continue
|
||||
}
|
||||
@@ -143,19 +178,23 @@ func (d *DistributedBackendManager) enqueueAndDrainBackendOp(ctx context.Context
|
||||
}
|
||||
if err := d.registry.UpsertPendingBackendOp(ctx, node.ID, backend, op, galleriesJSON); err != nil {
|
||||
xlog.Warn("Failed to enqueue backend op", "op", op, "node", node.Name, "backend", backend, "error", err)
|
||||
errMsg := fmt.Sprintf("enqueue failed: %v", err)
|
||||
result.Nodes = append(result.Nodes, NodeOpStatus{
|
||||
NodeID: node.ID, NodeName: node.Name, Status: "error",
|
||||
Error: fmt.Sprintf("enqueue failed: %v", err),
|
||||
Error: errMsg,
|
||||
})
|
||||
emitNodeProgress(node, "error", errMsg)
|
||||
continue
|
||||
}
|
||||
|
||||
if node.Status != StatusHealthy {
|
||||
// Intent is recorded; reconciler will retry when the node recovers.
|
||||
errMsg := fmt.Sprintf("node %s, will retry when healthy", node.Status)
|
||||
result.Nodes = append(result.Nodes, NodeOpStatus{
|
||||
NodeID: node.ID, NodeName: node.Name, Status: "queued",
|
||||
Error: fmt.Sprintf("node %s, will retry when healthy", node.Status),
|
||||
Error: errMsg,
|
||||
})
|
||||
emitNodeProgress(node, "queued", errMsg)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -169,12 +208,31 @@ func (d *DistributedBackendManager) enqueueAndDrainBackendOp(ctx context.Context
|
||||
result.Nodes = append(result.Nodes, NodeOpStatus{
|
||||
NodeID: node.ID, NodeName: node.Name, Status: "success",
|
||||
})
|
||||
emitNodeProgress(node, "success", "")
|
||||
continue
|
||||
}
|
||||
|
||||
// Record failure for backoff. If it's an ErrNoResponders, the node's
|
||||
// gone AWOL — mark unhealthy so the router stops picking it too.
|
||||
// gone AWOL - mark unhealthy so the router stops picking it too.
|
||||
errMsg := applyErr.Error()
|
||||
|
||||
// Worker-still-installing is a "soft" failure: the worker is most
|
||||
// likely still pulling the OCI image. Keep the row, push NextRetryAt
|
||||
// out so the reconciler does not immediately re-fire another install
|
||||
// while the worker is still busy, and report the in-progress state
|
||||
// to the caller. The next reconciler pass / backend.list confirms
|
||||
// the actual outcome.
|
||||
if errors.Is(applyErr, galleryop.ErrWorkerStillInstalling) {
|
||||
if id, err := d.findPendingRow(ctx, node.ID, backend, op); err == nil {
|
||||
_ = d.registry.RecordPendingBackendOpInFlight(ctx, id, errMsg, d.adapter.InstallTimeout())
|
||||
}
|
||||
result.Nodes = append(result.Nodes, NodeOpStatus{
|
||||
NodeID: node.ID, NodeName: node.Name, Status: "running_on_worker", Error: errMsg,
|
||||
})
|
||||
emitNodeProgress(node, "running_on_worker", errMsg)
|
||||
continue
|
||||
}
|
||||
|
||||
if errors.Is(applyErr, nats.ErrNoResponders) {
|
||||
xlog.Warn("No NATS responders for node, marking unhealthy", "node", node.Name, "nodeID", node.ID)
|
||||
d.registry.MarkUnhealthy(ctx, node.ID)
|
||||
@@ -185,6 +243,7 @@ func (d *DistributedBackendManager) enqueueAndDrainBackendOp(ctx context.Context
|
||||
result.Nodes = append(result.Nodes, NodeOpStatus{
|
||||
NodeID: node.ID, NodeName: node.Name, Status: "error", Error: errMsg,
|
||||
})
|
||||
emitNodeProgress(node, "error", errMsg)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
@@ -226,7 +285,11 @@ func (d *DistributedBackendManager) DeleteBackend(name string) error {
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := d.enqueueAndDrainBackendOp(ctx, OpBackendDelete, name, nil, nil, func(node BackendNode) error {
|
||||
// Empty opID: plain DeleteBackend isn't gallery-tracked the same way as
|
||||
// Install/Upgrade (no progress dialog), so we skip the per-node sink
|
||||
// writes here. DeleteBackendDetailed is the HTTP path that surfaces
|
||||
// per-node results in its own response.
|
||||
result, err := d.enqueueAndDrainBackendOp(ctx, "", OpBackendDelete, name, nil, nil, func(node BackendNode) error {
|
||||
reply, err := d.adapter.DeleteBackend(node.ID, name)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -249,7 +312,7 @@ func (d *DistributedBackendManager) DeleteBackendDetailed(ctx context.Context, n
|
||||
if err := d.local.DeleteBackend(name); err != nil && !errors.Is(err, gallery.ErrBackendNotFound) {
|
||||
return BackendOpResult{}, err
|
||||
}
|
||||
return d.enqueueAndDrainBackendOp(ctx, OpBackendDelete, name, nil, nil, func(node BackendNode) error {
|
||||
return d.enqueueAndDrainBackendOp(ctx, "", OpBackendDelete, name, nil, nil, func(node BackendNode) error {
|
||||
reply, err := d.adapter.DeleteBackend(node.ID, name)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -324,22 +387,113 @@ func (d *DistributedBackendManager) ListBackends() (gallery.SystemBackends, erro
|
||||
result[b.Name] = entry
|
||||
}
|
||||
}
|
||||
|
||||
// Proactively clear pending_backend_ops install rows whose intent is now
|
||||
// satisfied: the backend is reported installed on its target node. Without
|
||||
// this, the row sits in the queue until next_retry_at expires (up to the
|
||||
// install timeout, default 15m) and the operator UI shows the install as
|
||||
// "still installing in background" for that whole window even though the
|
||||
// worker has actually been ready for minutes. We only clear install rows;
|
||||
// upgrade and delete rows have presence-based semantics that do NOT match
|
||||
// backend.list confirmation.
|
||||
d.clearSatisfiedInstallRows(context.Background(), result)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// clearSatisfiedInstallRows removes pending_backend_ops install rows whose
|
||||
// (nodeID, backend) pair now appears in the cluster-wide backend listing.
|
||||
// Called by ListBackends after fan-out so the proactive clear sees every
|
||||
// node's report. Best-effort: a DB failure is logged and the row stays for
|
||||
// the reconciler to drain via its slower path.
|
||||
func (d *DistributedBackendManager) clearSatisfiedInstallRows(ctx context.Context, backends gallery.SystemBackends) {
|
||||
rows, err := d.registry.ListPendingBackendOps(ctx)
|
||||
if err != nil {
|
||||
xlog.Debug("clearSatisfiedInstallRows: failed to list pending ops", "error", err)
|
||||
return
|
||||
}
|
||||
if len(rows) == 0 {
|
||||
return
|
||||
}
|
||||
// Build a (nodeID, backend) presence set from the listing.
|
||||
present := make(map[string]map[string]bool, len(backends))
|
||||
for name, b := range backends {
|
||||
for _, ref := range b.Nodes {
|
||||
if present[ref.NodeID] == nil {
|
||||
present[ref.NodeID] = make(map[string]bool)
|
||||
}
|
||||
present[ref.NodeID][name] = true
|
||||
}
|
||||
}
|
||||
for _, row := range rows {
|
||||
if row.Op != OpBackendInstall {
|
||||
continue
|
||||
}
|
||||
if !present[row.NodeID][row.Backend] {
|
||||
continue
|
||||
}
|
||||
if err := d.registry.DeletePendingBackendOp(ctx, row.ID); err != nil {
|
||||
xlog.Debug("clearSatisfiedInstallRows: delete failed",
|
||||
"id", row.ID, "node", row.NodeID, "backend", row.Backend, "error", err)
|
||||
continue
|
||||
}
|
||||
xlog.Info("Reconciler: pending install row satisfied by backend.list",
|
||||
"node", row.NodeID, "backend", row.Backend)
|
||||
}
|
||||
}
|
||||
|
||||
// InstallBackend fans out installation through the pending-ops queue so
|
||||
// non-healthy nodes get retried when they come back instead of being silently
|
||||
// skipped. Reply success from the NATS round-trip deletes the queue row;
|
||||
// reply.Success==false is treated as an error so the row stays for retry.
|
||||
//
|
||||
// When op.TargetNodeID is set, only that node is visited - the same allowlist
|
||||
// path UpgradeBackend uses. Empty TargetNodeID preserves the original fan-out
|
||||
// behavior so the periodic reconciler and /api/backends/install/:id keep
|
||||
// working unchanged.
|
||||
func (d *DistributedBackendManager) InstallBackend(ctx context.Context, op *galleryop.ManagementOp[gallery.GalleryBackend, any], progressCb galleryop.ProgressCallback) error {
|
||||
galleriesJSON, _ := json.Marshal(op.Galleries)
|
||||
backendName := op.GalleryElementName
|
||||
|
||||
result, err := d.enqueueAndDrainBackendOp(ctx, OpBackendInstall, backendName, galleriesJSON, nil, func(node BackendNode) error {
|
||||
var targetNodeIDs map[string]bool
|
||||
if op.TargetNodeID != "" {
|
||||
targetNodeIDs = map[string]bool{op.TargetNodeID: true}
|
||||
}
|
||||
|
||||
result, err := d.enqueueAndDrainBackendOp(ctx, op.ID, OpBackendInstall, backendName, galleriesJSON, targetNodeIDs, func(node BackendNode) error {
|
||||
// onProgress fans each BackendInstallProgressEvent into two
|
||||
// observers: the legacy single-bar progressCb (kept so callers
|
||||
// that only consume the aggregate view keep working) and the
|
||||
// per-node sink (so OpStatus.Nodes gets a "downloading" tick
|
||||
// per file/percentage with node attribution). Defined inside the
|
||||
// loop so each node captures its own node.Name into the closure.
|
||||
onProgress := func(ev messaging.BackendInstallProgressEvent) {
|
||||
if progressCb != nil {
|
||||
progressCb(ev.FileName, ev.Current, ev.Total, ev.Percentage)
|
||||
}
|
||||
if d.progressSink != nil && op.ID != "" {
|
||||
d.progressSink.UpdateNodeProgress(op.ID, ev.NodeID, galleryop.NodeProgress{
|
||||
NodeID: ev.NodeID,
|
||||
NodeName: node.Name,
|
||||
Status: "downloading",
|
||||
FileName: ev.FileName,
|
||||
Current: ev.Current,
|
||||
Total: ev.Total,
|
||||
Percentage: ev.Percentage,
|
||||
Phase: ev.Phase,
|
||||
})
|
||||
}
|
||||
}
|
||||
// nil-callback shortcut: when there is nothing to deliver to,
|
||||
// hand the adapter a nil onProgress so it skips the per-op NATS
|
||||
// subscription. Matches the pre-Phase-4 bridgeProgressCb semantics.
|
||||
var onProgressArg func(messaging.BackendInstallProgressEvent)
|
||||
if progressCb != nil || d.progressSink != nil {
|
||||
onProgressArg = onProgress
|
||||
}
|
||||
// Admin-driven backend install: not tied to a specific replica slot.
|
||||
// Pass replica 0 — the worker's processKey is "backend#0" when no
|
||||
// Pass replica 0 - the worker's processKey is "backend#0" when no
|
||||
// modelID is supplied, matching pre-PR4 behavior.
|
||||
reply, err := d.adapter.InstallBackend(node.ID, backendName, "", string(galleriesJSON), op.ExternalURI, op.ExternalName, op.ExternalAlias, 0)
|
||||
reply, err := d.adapter.InstallBackend(node.ID, backendName, "", string(galleriesJSON), op.ExternalURI, op.ExternalName, op.ExternalAlias, 0, op.ID, onProgressArg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -351,7 +505,19 @@ func (d *DistributedBackendManager) InstallBackend(ctx context.Context, op *gall
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return result.Err()
|
||||
if hardErr := result.Err(); hardErr != nil {
|
||||
return hardErr
|
||||
}
|
||||
// No hard failures, but if at least one node reported running_on_worker,
|
||||
// surface a wrapped ErrWorkerStillInstalling so galleryop can render a
|
||||
// yellow in-progress state instead of green success. The reconciler
|
||||
// will confirm the actual outcome on its next pass via backend.list.
|
||||
for _, n := range result.Nodes {
|
||||
if n.Status == "running_on_worker" {
|
||||
return fmt.Errorf("%w: %s", galleryop.ErrWorkerStillInstalling, summarizeRunningOnWorker(result.Nodes))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpgradeBackend uses a separate NATS subject (backend.upgrade) so the slow
|
||||
@@ -382,7 +548,11 @@ func (d *DistributedBackendManager) UpgradeBackend(ctx context.Context, name str
|
||||
targetNodeIDs[n.NodeID] = true
|
||||
}
|
||||
|
||||
result, err := d.enqueueAndDrainBackendOp(ctx, OpBackendUpgrade, name, galleriesJSON, targetNodeIDs, func(node BackendNode) error {
|
||||
// Empty opID: the caller (galleryop) doesn't thread an op ID into
|
||||
// UpgradeBackend today, so we can't tag per-node sink writes with the
|
||||
// right OpStatus key. Until the upgrade path takes a ManagementOp the
|
||||
// way InstallBackend does, the sink stays no-op here.
|
||||
result, err := d.enqueueAndDrainBackendOp(ctx, "", OpBackendUpgrade, name, galleriesJSON, targetNodeIDs, func(node BackendNode) error {
|
||||
reply, err := d.adapter.UpgradeBackend(node.ID, name, string(galleriesJSON), "", "", "", 0)
|
||||
if err != nil {
|
||||
// Rolling-update fallback: an older worker doesn't know
|
||||
@@ -407,7 +577,18 @@ func (d *DistributedBackendManager) UpgradeBackend(ctx context.Context, name str
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return result.Err()
|
||||
if hardErr := result.Err(); hardErr != nil {
|
||||
return hardErr
|
||||
}
|
||||
// Same in-progress surfacing as InstallBackend: a long-running worker
|
||||
// upgrade that timed out the NATS round-trip must not be reported as
|
||||
// green success.
|
||||
for _, n := range result.Nodes {
|
||||
if n.Status == "running_on_worker" {
|
||||
return fmt.Errorf("%w: %s", galleryop.ErrWorkerStillInstalling, summarizeRunningOnWorker(result.Nodes))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsDistributed reports that installs from this manager fan out across the
|
||||
@@ -433,3 +614,16 @@ func (d *DistributedBackendManager) CheckUpgrades(ctx context.Context) (map[stri
|
||||
// it used to come from the empty frontend filesystem.
|
||||
return gallery.CheckUpgradesAgainst(ctx, d.backendGalleries, d.systemState, installed)
|
||||
}
|
||||
|
||||
// summarizeRunningOnWorker builds a short human-readable summary of which
|
||||
// nodes are still installing in the background, for inclusion in the
|
||||
// wrapped ErrWorkerStillInstalling error.
|
||||
func summarizeRunningOnWorker(nodes []NodeOpStatus) string {
|
||||
var names []string
|
||||
for _, n := range nodes {
|
||||
if n.Status == "running_on_worker" {
|
||||
names = append(names, n.NodeName)
|
||||
}
|
||||
}
|
||||
return strings.Join(names, ", ")
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package nodes
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -12,6 +13,7 @@ import (
|
||||
. "github.com/onsi/gomega"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
@@ -22,11 +24,35 @@ import (
|
||||
// (or error). Used so each fan-out request can simulate a different worker
|
||||
// outcome without spinning up real NATS.
|
||||
type scriptedMessagingClient struct {
|
||||
mu sync.Mutex
|
||||
replies map[string][]byte
|
||||
errs map[string]error
|
||||
calls []requestCall
|
||||
matchedReplies map[string][]matchedReply
|
||||
mu sync.Mutex
|
||||
replies map[string][]byte
|
||||
errs map[string]error
|
||||
calls []requestCall
|
||||
matchedReplies map[string][]matchedReply
|
||||
publishes []progressPublishCall
|
||||
scheduledProgressPublishes []scheduledProgressPublish
|
||||
subscribes []string
|
||||
}
|
||||
|
||||
// progressPublishCall records a single Publish invocation. The progress
|
||||
// publisher tests assert on the sequence of BackendInstallProgressEvent
|
||||
// values written to a per-op subject, so we capture both subject and the
|
||||
// decoded event. Named to avoid clashing with the simpler `publishCall`
|
||||
// already defined in unloader_test.go (which stores raw JSON bytes for
|
||||
// non-progress assertions).
|
||||
type progressPublishCall struct {
|
||||
Subject string
|
||||
Event messaging.BackendInstallProgressEvent
|
||||
}
|
||||
|
||||
// scheduledProgressPublish queues a batch of BackendInstallProgressEvent
|
||||
// values to be delivered the next time Subscribe is called with the matching
|
||||
// subject. This lets master-side tests assert that the adapter installs its
|
||||
// handler BEFORE publishing the install request, by scripting events to be
|
||||
// delivered as soon as the subscription appears.
|
||||
type scheduledProgressPublish struct {
|
||||
subject string
|
||||
events []messaging.BackendInstallProgressEvent
|
||||
}
|
||||
|
||||
// matchedReply lets a test script a canned reply that only fires when the
|
||||
@@ -98,10 +124,10 @@ func (s *scriptedMessagingClient) scriptReplyMatching(subject string, pred func(
|
||||
})
|
||||
}
|
||||
|
||||
func (s *scriptedMessagingClient) Request(subject string, data []byte, _ time.Duration) ([]byte, error) {
|
||||
func (s *scriptedMessagingClient) Request(subject string, data []byte, timeout time.Duration) ([]byte, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.calls = append(s.calls, requestCall{Subject: subject, Data: data})
|
||||
s.calls = append(s.calls, requestCall{Subject: subject, Data: data, Timeout: timeout})
|
||||
|
||||
// Predicate-matched replies take precedence over flat scriptReply.
|
||||
if matchers, ok := s.matchedReplies[subject]; ok {
|
||||
@@ -135,8 +161,88 @@ func (s *scriptedMessagingClient) Request(subject string, data []byte, _ time.Du
|
||||
return nil, &fakeNoRespondersErr{}
|
||||
}
|
||||
|
||||
func (s *scriptedMessagingClient) Publish(_ string, _ any) error { return nil }
|
||||
func (s *scriptedMessagingClient) Subscribe(_ string, _ func([]byte)) (messaging.Subscription, error) {
|
||||
// Publish records each call so progress-publisher tests can assert on the
|
||||
// stream of events written to a subject. The real messaging.Client JSON
|
||||
// encodes the payload before sending, but our publisher hands a typed
|
||||
// struct directly, so we handle both shapes.
|
||||
func (s *scriptedMessagingClient) Publish(subject string, data any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
switch ev := data.(type) {
|
||||
case messaging.BackendInstallProgressEvent:
|
||||
s.publishes = append(s.publishes, progressPublishCall{Subject: subject, Event: ev})
|
||||
case []byte:
|
||||
var e messaging.BackendInstallProgressEvent
|
||||
_ = json.Unmarshal(ev, &e)
|
||||
s.publishes = append(s.publishes, progressPublishCall{Subject: subject, Event: e})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// publishCalls returns every BackendInstallProgressEvent that was published
|
||||
// to `subject`, in order. Lets tests assert on debounce behavior without
|
||||
// depending on internal Publish timing.
|
||||
func (s *scriptedMessagingClient) publishCalls(subject string) []messaging.BackendInstallProgressEvent {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
out := make([]messaging.BackendInstallProgressEvent, 0)
|
||||
for _, c := range s.publishes {
|
||||
if c.Subject != subject {
|
||||
continue
|
||||
}
|
||||
out = append(out, c.Event)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// scheduleProgressPublish queues a set of BackendInstallProgressEvent values
|
||||
// to be delivered on the next Subscribe call matching the per-op progress
|
||||
// subject. A short delay before delivery gives the subscriber time to install
|
||||
// its message handler before the events arrive.
|
||||
func (s *scriptedMessagingClient) scheduleProgressPublish(nodeID, opID string, events []messaging.BackendInstallProgressEvent) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.scheduledProgressPublishes = append(s.scheduledProgressPublishes, scheduledProgressPublish{
|
||||
subject: messaging.SubjectNodeBackendInstallProgress(nodeID, opID),
|
||||
events: events,
|
||||
})
|
||||
}
|
||||
|
||||
// subscribeCalls returns the subjects on which Subscribe was invoked.
|
||||
// Used to confirm the master skipped subscription when onProgress was nil.
|
||||
func (s *scriptedMessagingClient) subscribeCalls() []string {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
out := make([]string, len(s.subscribes))
|
||||
copy(out, s.subscribes)
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *scriptedMessagingClient) Subscribe(subject string, handler func([]byte)) (messaging.Subscription, error) {
|
||||
s.mu.Lock()
|
||||
s.subscribes = append(s.subscribes, subject)
|
||||
matched := []scheduledProgressPublish{}
|
||||
remaining := s.scheduledProgressPublishes[:0]
|
||||
for _, sp := range s.scheduledProgressPublishes {
|
||||
if sp.subject == subject {
|
||||
matched = append(matched, sp)
|
||||
} else {
|
||||
remaining = append(remaining, sp)
|
||||
}
|
||||
}
|
||||
s.scheduledProgressPublishes = remaining
|
||||
s.mu.Unlock()
|
||||
|
||||
go func() {
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
for _, sp := range matched {
|
||||
for _, ev := range sp.events {
|
||||
raw, _ := json.Marshal(ev)
|
||||
handler(raw)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return &fakeSubscription{}, nil
|
||||
}
|
||||
func (s *scriptedMessagingClient) QueueSubscribe(_ string, _ string, _ func([]byte)) (messaging.Subscription, error) {
|
||||
@@ -151,8 +257,43 @@ func (s *scriptedMessagingClient) SubscribeReply(_ string, _ func([]byte, func([
|
||||
func (s *scriptedMessagingClient) IsConnected() bool { return true }
|
||||
func (s *scriptedMessagingClient) Close() {}
|
||||
|
||||
// recordingNodeCall captures a single UpdateNodeProgress invocation so
|
||||
// per-node OpStatus tests can assert on the sequence of writes the
|
||||
// DistributedBackendManager fans out into the sink.
|
||||
type recordingNodeCall struct {
|
||||
OpID string
|
||||
NodeID string
|
||||
Progress galleryop.NodeProgress
|
||||
}
|
||||
|
||||
// recordingProgressSink is a test-only nodeProgressSink that just records
|
||||
// every call. Used by the per-node OpStatus specs below to assert the
|
||||
// manager wrote the expected terminal and downloading entries.
|
||||
type recordingProgressSink struct {
|
||||
mu sync.Mutex
|
||||
calls []recordingNodeCall
|
||||
}
|
||||
|
||||
func (r *recordingProgressSink) UpdateNodeProgress(opID, nodeID string, np galleryop.NodeProgress) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.calls = append(r.calls, recordingNodeCall{OpID: opID, NodeID: nodeID, Progress: np})
|
||||
}
|
||||
|
||||
func (r *recordingProgressSink) callsFor(opID, nodeID string) []galleryop.NodeProgress {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
out := []galleryop.NodeProgress{}
|
||||
for _, c := range r.calls {
|
||||
if c.OpID == opID && c.NodeID == nodeID {
|
||||
out = append(out, c.Progress)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// fakeNoRespondersErr is the unscripted-subject default. It matches
|
||||
// nats.ErrNoResponders by string only — used when a test forgets to script
|
||||
// nats.ErrNoResponders by string only - used when a test forgets to script
|
||||
// a node so the failure is loud but doesn't tickle errors.Is(...) sentinel
|
||||
// paths the test wasn't deliberately exercising. Tests that DO want the
|
||||
// real sentinel (e.g. to drive the manager's NoResponders fallback) call
|
||||
@@ -204,7 +345,7 @@ var _ = Describe("DistributedBackendManager", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
mc = newScriptedMessagingClient()
|
||||
adapter = NewRemoteUnloaderAdapter(nil, mc)
|
||||
adapter = NewRemoteUnloaderAdapter(nil, mc, 3*time.Minute, 15*time.Minute)
|
||||
mgr = &DistributedBackendManager{
|
||||
local: stubLocalBackendManager{},
|
||||
adapter: adapter,
|
||||
@@ -311,6 +452,304 @@ var _ = Describe("DistributedBackendManager", func() {
|
||||
Expect(mgr.InstallBackend(ctx, op("vllm-development"), nil)).To(Succeed())
|
||||
})
|
||||
})
|
||||
|
||||
Context("when op.TargetNodeID is set to a healthy node", func() {
|
||||
It("installs only on that node, leaving the others untouched", func() {
|
||||
target := registerHealthyBackend("worker-target", "10.0.0.1:50051")
|
||||
other := registerHealthyBackend("worker-other", "10.0.0.2:50051")
|
||||
|
||||
mc.scriptReply(messaging.SubjectNodeBackendInstall(target.ID),
|
||||
messaging.BackendInstallReply{Success: true, Address: "10.0.0.1:50100"})
|
||||
// No reply scripted for `other`: if InstallBackend fans out
|
||||
// to it, the fakeNoRespondersErr default would surface and
|
||||
// the test would fail.
|
||||
|
||||
targetedOp := &galleryop.ManagementOp[gallery.GalleryBackend, any]{
|
||||
GalleryElementName: "llama-cpp",
|
||||
TargetNodeID: target.ID,
|
||||
}
|
||||
Expect(mgr.InstallBackend(ctx, targetedOp, nil)).To(Succeed())
|
||||
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
Expect(mc.calls).To(HaveLen(1))
|
||||
Expect(mc.calls[0].Subject).To(Equal(messaging.SubjectNodeBackendInstall(target.ID)))
|
||||
Expect(mc.calls[0].Subject).ToNot(Equal(messaging.SubjectNodeBackendInstall(other.ID)))
|
||||
})
|
||||
})
|
||||
|
||||
Context("when op.TargetNodeID is set to a node that does not exist", func() {
|
||||
It("returns nil without sending any NATS request", func() {
|
||||
registerHealthyBackend("worker-a", "10.0.0.1:50051")
|
||||
|
||||
ghostOp := &galleryop.ManagementOp[gallery.GalleryBackend, any]{
|
||||
GalleryElementName: "llama-cpp",
|
||||
TargetNodeID: "this-id-does-not-exist",
|
||||
}
|
||||
Expect(mgr.InstallBackend(ctx, ghostOp, nil)).To(Succeed())
|
||||
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
Expect(mc.calls).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Context("when InstallBackend times out on a worker", func() {
|
||||
It("returns galleryop.ErrWorkerStillInstalling and keeps the queue row with NextRetryAt pushed out", func() {
|
||||
n := registerHealthyBackend("slow", "10.0.0.1:50051")
|
||||
|
||||
// Script a NATS timeout on the install subject. The adapter
|
||||
// wraps this into galleryop.ErrWorkerStillInstalling, which
|
||||
// the manager should treat as a soft failure.
|
||||
mc.scriptErr(messaging.SubjectNodeBackendInstall(n.ID), nats.ErrTimeout)
|
||||
|
||||
err := mgr.InstallBackend(ctx, op("vllm"), nil)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(errors.Is(err, galleryop.ErrWorkerStillInstalling)).To(BeTrue(),
|
||||
"expected wrapped ErrWorkerStillInstalling, got %v", err)
|
||||
|
||||
rows, err := registry.ListPendingBackendOps(ctx)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rows).To(HaveLen(1))
|
||||
Expect(rows[0].Backend).To(Equal("vllm"))
|
||||
// The adapter is configured with a 3m install timeout in this
|
||||
// suite (NewRemoteUnloaderAdapter above). NextRetryAt should
|
||||
// be ~now+3m; a > now+2m bound is safe-but-tight enough to
|
||||
// catch the buggy short default (30s exponential backoff).
|
||||
Expect(rows[0].NextRetryAt).To(BeTemporally(">", time.Now().Add(2*time.Minute)),
|
||||
"NextRetryAt should be pushed to ~now+installTimeout, not the short default")
|
||||
})
|
||||
})
|
||||
|
||||
Context("end-to-end: timeout then successful reconcile via backend.list", func() {
|
||||
It("surfaces the install in ListBackends after the worker finishes", func() {
|
||||
// Use the same node-registration helper the Task 5 test uses
|
||||
// so the test fixture is identical to the prior context.
|
||||
node := registerHealthyBackend("jetson", "10.0.0.2:50051")
|
||||
|
||||
// First install attempt: NATS times out. The adapter wraps
|
||||
// this as galleryop.ErrWorkerStillInstalling and the manager
|
||||
// keeps the pending_backend_ops row alive with NextRetryAt
|
||||
// pushed out (asserted in the previous context).
|
||||
mc.scriptErr(messaging.SubjectNodeBackendInstall(node.ID), nats.ErrTimeout)
|
||||
|
||||
err := mgr.InstallBackend(ctx, op("vllm"), nil)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(errors.Is(err, galleryop.ErrWorkerStillInstalling)).To(BeTrue(),
|
||||
"expected wrapped ErrWorkerStillInstalling, got %v", err)
|
||||
|
||||
rows, listErr := registry.ListPendingBackendOps(ctx)
|
||||
Expect(listErr).ToNot(HaveOccurred())
|
||||
Expect(rows).To(HaveLen(1))
|
||||
|
||||
// The worker finished installing in the background. Script
|
||||
// backend.list on the same scriptedMessagingClient so the
|
||||
// manager's ListBackends fan-out reports the backend.
|
||||
mc.scriptReply(messaging.SubjectNodeBackendList(node.ID), messaging.BackendListReply{
|
||||
Backends: []messaging.NodeBackendInfo{{Name: "vllm"}},
|
||||
})
|
||||
|
||||
backends, listErr := mgr.ListBackends()
|
||||
Expect(listErr).ToNot(HaveOccurred())
|
||||
Expect(backends).To(HaveKey("vllm"))
|
||||
Expect(backends["vllm"].Nodes).To(HaveLen(1))
|
||||
Expect(backends["vllm"].Nodes[0].NodeID).To(Equal(node.ID))
|
||||
|
||||
// Phase 1b shipped: ListBackends proactively clears install rows
|
||||
// whose intent is now satisfied by backend.list confirmation. The
|
||||
// operator UI clears immediately instead of waiting for the next
|
||||
// reconciler tick after NextRetryAt.
|
||||
rowsAfter, _ := registry.ListPendingBackendOps(ctx)
|
||||
Expect(rowsAfter).To(BeEmpty(),
|
||||
"install row should clear once backend.list confirms presence on the target node")
|
||||
})
|
||||
})
|
||||
|
||||
Context("ListBackends clears confirmed install rows", func() {
|
||||
It("deletes the pending_backend_ops install row when the backend is reported installed on its target node", func() {
|
||||
node := registerHealthyBackend("worker-a", "10.0.0.5:50051")
|
||||
|
||||
// Pre-stage: simulate an admin install that timed out at the NATS
|
||||
// round-trip, leaving an install row in the queue.
|
||||
mc.scriptErr(messaging.SubjectNodeBackendInstall(node.ID), nats.ErrTimeout)
|
||||
err := mgr.InstallBackend(ctx, op("vllm"), nil)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(errors.Is(err, galleryop.ErrWorkerStillInstalling)).To(BeTrue())
|
||||
|
||||
rows, _ := registry.ListPendingBackendOps(ctx)
|
||||
Expect(rows).To(HaveLen(1))
|
||||
|
||||
// Worker finishes installing in the background. backend.list now
|
||||
// confirms presence; ListBackends should proactively clear the row.
|
||||
mc.scriptReply(messaging.SubjectNodeBackendList(node.ID), messaging.BackendListReply{
|
||||
Backends: []messaging.NodeBackendInfo{{Name: "vllm"}},
|
||||
})
|
||||
|
||||
backends, listErr := mgr.ListBackends()
|
||||
Expect(listErr).ToNot(HaveOccurred())
|
||||
Expect(backends).To(HaveKey("vllm"))
|
||||
|
||||
rowsAfter, _ := registry.ListPendingBackendOps(ctx)
|
||||
Expect(rowsAfter).To(BeEmpty(),
|
||||
"ListBackends should clear install rows whose intent is now satisfied by backend.list")
|
||||
})
|
||||
|
||||
It("does NOT clear an upgrade row even if the backend is reported installed", func() {
|
||||
node := registerHealthyBackend("worker-b", "10.0.0.6:50051")
|
||||
|
||||
Expect(registry.UpsertPendingBackendOp(ctx, node.ID, "vllm", OpBackendUpgrade, []byte("[]"))).To(Succeed())
|
||||
|
||||
mc.scriptReply(messaging.SubjectNodeBackendList(node.ID), messaging.BackendListReply{
|
||||
Backends: []messaging.NodeBackendInfo{{Name: "vllm"}},
|
||||
})
|
||||
|
||||
_, listErr := mgr.ListBackends()
|
||||
Expect(listErr).ToNot(HaveOccurred())
|
||||
|
||||
rowsAfter, _ := registry.ListPendingBackendOps(ctx)
|
||||
Expect(rowsAfter).To(HaveLen(1), "upgrade rows must not be cleared by backend.list presence")
|
||||
})
|
||||
})
|
||||
|
||||
Context("InstallBackend streams progress events to the caller's progressCb", func() {
|
||||
It("invokes progressCb once per worker-published progress event", func() {
|
||||
node := registerHealthyBackend("worker-prog", "10.0.0.7:50051")
|
||||
|
||||
mc.scriptReply(messaging.SubjectNodeBackendInstall(node.ID), messaging.BackendInstallReply{Success: true, Address: "10.0.0.7:50051"})
|
||||
mc.scheduleProgressPublish(node.ID, "op-prog-1", []messaging.BackendInstallProgressEvent{
|
||||
{OpID: "op-prog-1", NodeID: node.ID, Backend: "vllm", FileName: "vllm.tar", Current: "100 MB", Total: "1 GB", Percentage: 10},
|
||||
{OpID: "op-prog-1", NodeID: node.ID, Backend: "vllm", FileName: "vllm.tar", Current: "1 GB", Total: "1 GB", Percentage: 100},
|
||||
})
|
||||
|
||||
type tick struct {
|
||||
FileName, Current, Total string
|
||||
Percentage float64
|
||||
}
|
||||
var (
|
||||
pcCalls []tick
|
||||
mu sync.Mutex
|
||||
)
|
||||
progressCb := func(file, current, total string, pct float64) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
pcCalls = append(pcCalls, tick{file, current, total, pct})
|
||||
}
|
||||
|
||||
opVal := op("vllm")
|
||||
opVal.ID = "op-prog-1"
|
||||
Expect(mgr.InstallBackend(ctx, opVal, progressCb)).To(Succeed())
|
||||
|
||||
Eventually(func() int {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return len(pcCalls)
|
||||
}, "1s").Should(Equal(2))
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
// The adapter dispatches each progress event to its own goroutine
|
||||
// (see unloader.go: `go onProgress(ev)`) so two events emitted back
|
||||
// to back can land at the bridge in either order. Assert the set of
|
||||
// percentages observed contains both ticks, rather than depending
|
||||
// on goroutine scheduling for ordering.
|
||||
pcts := []float64{pcCalls[0].Percentage, pcCalls[1].Percentage}
|
||||
Expect(pcts).To(ConsistOf(10.0, 100.0))
|
||||
})
|
||||
})
|
||||
|
||||
Context("InstallBackend tolerates silent (pre-Phase-2) workers", func() {
|
||||
It("completes successfully even when no progress events are ever published", func() {
|
||||
node := registerHealthyBackend("worker-silent", "10.0.0.8:50051")
|
||||
mc.scriptReply(messaging.SubjectNodeBackendInstall(node.ID), messaging.BackendInstallReply{Success: true, Address: "10.0.0.8:50051"})
|
||||
// NO scheduleProgressPublish call - silent worker.
|
||||
|
||||
var ticks int
|
||||
var mu sync.Mutex
|
||||
progressCb := func(file, current, total string, pct float64) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
ticks++
|
||||
}
|
||||
|
||||
opVal := op("vllm")
|
||||
opVal.ID = "op-silent-1"
|
||||
Expect(mgr.InstallBackend(ctx, opVal, progressCb)).To(Succeed())
|
||||
|
||||
Consistently(func() int {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return ticks
|
||||
}, "200ms").Should(Equal(0))
|
||||
})
|
||||
})
|
||||
|
||||
Context("populates per-node OpStatus entries", func() {
|
||||
var sink *recordingProgressSink
|
||||
|
||||
BeforeEach(func() {
|
||||
// Reconstruct mgr with the recording sink so the new code
|
||||
// path (per-node OpStatus writes) is exercised. The default
|
||||
// mgr in the outer BeforeEach has progressSink=nil so the
|
||||
// pre-existing specs keep verifying the no-sink behavior.
|
||||
sink = &recordingProgressSink{}
|
||||
appCfg := &config.ApplicationConfig{}
|
||||
mgr = NewDistributedBackendManager(appCfg, nil, adapter, registry, sink)
|
||||
// stubLocalBackendManager mirrors the production behaviour
|
||||
// where the frontend node rarely has the backend installed
|
||||
// locally - the NATS fan-out is what these specs verify.
|
||||
mgr.local = stubLocalBackendManager{}
|
||||
})
|
||||
|
||||
It("emits a success entry for each healthy node visited", func() {
|
||||
node := registerHealthyBackend("worker-ok", "10.0.0.9:50051")
|
||||
mc.scriptReply(messaging.SubjectNodeBackendInstall(node.ID),
|
||||
messaging.BackendInstallReply{Success: true, Address: "10.0.0.9:50051"})
|
||||
|
||||
opVal := op("vllm")
|
||||
opVal.ID = "op-node-success"
|
||||
Expect(mgr.InstallBackend(ctx, opVal, nil)).To(Succeed())
|
||||
|
||||
calls := sink.callsFor("op-node-success", node.ID)
|
||||
Expect(calls).ToNot(BeEmpty())
|
||||
Expect(calls[len(calls)-1].Status).To(Equal("success"))
|
||||
Expect(calls[len(calls)-1].NodeName).To(Equal("worker-ok"))
|
||||
})
|
||||
|
||||
It("emits a running_on_worker entry when NATS times out", func() {
|
||||
node := registerHealthyBackend("worker-slow", "10.0.0.10:50051")
|
||||
mc.scriptErr(messaging.SubjectNodeBackendInstall(node.ID), nats.ErrTimeout)
|
||||
|
||||
opVal := op("vllm")
|
||||
opVal.ID = "op-node-slow"
|
||||
// Soft failure: returns wrapped ErrWorkerStillInstalling.
|
||||
_ = mgr.InstallBackend(ctx, opVal, nil)
|
||||
|
||||
calls := sink.callsFor("op-node-slow", node.ID)
|
||||
Expect(calls).ToNot(BeEmpty())
|
||||
Expect(calls[len(calls)-1].Status).To(Equal("running_on_worker"))
|
||||
})
|
||||
|
||||
It("emits downloading entries from progress events", func() {
|
||||
node := registerHealthyBackend("worker-dl", "10.0.0.11:50051")
|
||||
mc.scriptReply(messaging.SubjectNodeBackendInstall(node.ID),
|
||||
messaging.BackendInstallReply{Success: true})
|
||||
mc.scheduleProgressPublish(node.ID, "op-node-dl", []messaging.BackendInstallProgressEvent{
|
||||
{OpID: "op-node-dl", NodeID: node.ID, Backend: "vllm", FileName: "vllm.tar", Current: "1 GB", Total: "1 GB", Percentage: 100, Phase: "downloading"},
|
||||
})
|
||||
|
||||
opVal := op("vllm")
|
||||
opVal.ID = "op-node-dl"
|
||||
Expect(mgr.InstallBackend(ctx, opVal, nil)).To(Succeed())
|
||||
|
||||
Eventually(func() bool {
|
||||
for _, np := range sink.callsFor("op-node-dl", node.ID) {
|
||||
if np.Status == "downloading" && np.Percentage == 100.0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}, "1s").Should(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("UpgradeBackend", func() {
|
||||
|
||||
@@ -68,9 +68,9 @@ type ModelScheduler interface {
|
||||
|
||||
// ReplicaReconcilerOptions holds configuration for creating a ReplicaReconciler.
|
||||
type ReplicaReconcilerOptions struct {
|
||||
Registry *NodeRegistry
|
||||
Registry *NodeRegistry
|
||||
Scheduler ModelScheduler
|
||||
Unloader NodeCommandSender
|
||||
Unloader NodeCommandSender
|
||||
// Adapter is the NATS sender used to retry pending backend ops. When nil,
|
||||
// the state-reconciler pending-drain pass is a no-op (single-node mode).
|
||||
Adapter *RemoteUnloaderAdapter
|
||||
@@ -78,7 +78,7 @@ type ReplicaReconcilerOptions struct {
|
||||
// addresses. Matches the worker's token so HealthCheck auth succeeds.
|
||||
RegistrationToken string
|
||||
// Prober overrides the default gRPC health probe (used by tests).
|
||||
Prober ModelProber
|
||||
Prober ModelProber
|
||||
DB *gorm.DB
|
||||
Interval time.Duration // default 30s
|
||||
ScaleDownDelay time.Duration // default 5m
|
||||
@@ -191,7 +191,7 @@ func (rc *ReplicaReconciler) drainPendingBackendOps(ctx context.Context) {
|
||||
// Pending-op drain for admin install — not a per-replica load.
|
||||
// Replica 0 is the conventional admin slot. Install is idempotent:
|
||||
// the worker short-circuits if the backend is already running.
|
||||
reply, err := rc.adapter.InstallBackend(op.NodeID, op.Backend, "", string(op.Galleries), "", "", "", 0)
|
||||
reply, err := rc.adapter.InstallBackend(op.NodeID, op.Backend, "", string(op.Galleries), "", "", "", 0, "", nil)
|
||||
if err != nil {
|
||||
applyErr = err
|
||||
} else if !reply.Success {
|
||||
|
||||
@@ -17,24 +17,24 @@ import (
|
||||
// Workers are generic — they don't have a fixed backend type.
|
||||
// The SmartRouter dynamically installs backends via NATS backend.install events.
|
||||
type BackendNode struct {
|
||||
ID string `gorm:"primaryKey;size:36" json:"id"`
|
||||
Name string `gorm:"uniqueIndex;size:255" json:"name"`
|
||||
NodeType string `gorm:"size:32;default:backend" json:"node_type"` // backend, agent
|
||||
Address string `gorm:"size:255" json:"address"` // host:port for gRPC
|
||||
HTTPAddress string `gorm:"size:255" json:"http_address"` // host:port for HTTP file transfer
|
||||
Status string `gorm:"size:32;default:registering" json:"status"` // registering, healthy, unhealthy, draining, pending
|
||||
TokenHash string `gorm:"size:64" json:"-"` // SHA-256 of registration token
|
||||
TotalVRAM uint64 `gorm:"column:total_vram" json:"total_vram"` // Total GPU VRAM in bytes
|
||||
AvailableVRAM uint64 `gorm:"column:available_vram" json:"available_vram"` // Available GPU VRAM in bytes
|
||||
ID string `gorm:"primaryKey;size:36" json:"id"`
|
||||
Name string `gorm:"uniqueIndex;size:255" json:"name"`
|
||||
NodeType string `gorm:"size:32;default:backend" json:"node_type"` // backend, agent
|
||||
Address string `gorm:"size:255" json:"address"` // host:port for gRPC
|
||||
HTTPAddress string `gorm:"size:255" json:"http_address"` // host:port for HTTP file transfer
|
||||
Status string `gorm:"size:32;default:registering" json:"status"` // registering, healthy, unhealthy, draining, pending
|
||||
TokenHash string `gorm:"size:64" json:"-"` // SHA-256 of registration token
|
||||
TotalVRAM uint64 `gorm:"column:total_vram" json:"total_vram"` // Total GPU VRAM in bytes
|
||||
AvailableVRAM uint64 `gorm:"column:available_vram" json:"available_vram"` // Available GPU VRAM in bytes
|
||||
// ReservedVRAM is a soft, in-tick reservation deducted by the scheduler when
|
||||
// it picks this node to load a model. Workers reset it back to 0 on each
|
||||
// heartbeat (the worker is the source of truth for actual free VRAM); the
|
||||
// reservation is only here to keep two scheduling decisions within the
|
||||
// same heartbeat window from over-committing the same node.
|
||||
ReservedVRAM uint64 `gorm:"column:reserved_vram;default:0" json:"reserved_vram"`
|
||||
TotalRAM uint64 `gorm:"column:total_ram" json:"total_ram"` // Total system RAM in bytes (fallback when no GPU)
|
||||
AvailableRAM uint64 `gorm:"column:available_ram" json:"available_ram"` // Available system RAM in bytes
|
||||
GPUVendor string `gorm:"column:gpu_vendor;size:32" json:"gpu_vendor"` // nvidia, amd, intel, vulkan, unknown
|
||||
ReservedVRAM uint64 `gorm:"column:reserved_vram;default:0" json:"reserved_vram"`
|
||||
TotalRAM uint64 `gorm:"column:total_ram" json:"total_ram"` // Total system RAM in bytes (fallback when no GPU)
|
||||
AvailableRAM uint64 `gorm:"column:available_ram" json:"available_ram"` // Available system RAM in bytes
|
||||
GPUVendor string `gorm:"column:gpu_vendor;size:32" json:"gpu_vendor"` // nvidia, amd, intel, vulkan, unknown
|
||||
// MaxReplicasPerModel caps how many replicas of any one model can run on
|
||||
// this node concurrently. Default 1 preserves the historical "one
|
||||
// (node, model)" assumption; set higher (via worker --max-replicas-per-model)
|
||||
@@ -44,12 +44,12 @@ type BackendNode struct {
|
||||
// admin override. When true, the worker's CLI value is ignored on
|
||||
// re-registration so the override survives worker restarts. Cleared
|
||||
// by an explicit "reset to worker default" action.
|
||||
MaxReplicasPerModelManuallySet bool `gorm:"column:max_replicas_per_model_manually_set;default:false" json:"max_replicas_per_model_manually_set"`
|
||||
APIKeyID string `gorm:"size:36" json:"-"` // auto-provisioned API key ID (for cleanup)
|
||||
AuthUserID string `gorm:"size:36" json:"-"` // auto-provisioned user ID (for cleanup)
|
||||
LastHeartbeat time.Time `gorm:"column:last_heartbeat" json:"last_heartbeat"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
MaxReplicasPerModelManuallySet bool `gorm:"column:max_replicas_per_model_manually_set;default:false" json:"max_replicas_per_model_manually_set"`
|
||||
APIKeyID string `gorm:"size:36" json:"-"` // auto-provisioned API key ID (for cleanup)
|
||||
AuthUserID string `gorm:"size:36" json:"-"` // auto-provisioned user ID (for cleanup)
|
||||
LastHeartbeat time.Time `gorm:"column:last_heartbeat" json:"last_heartbeat"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -79,17 +79,17 @@ const (
|
||||
// gRPC Address (each replica is a separate worker process on its own port),
|
||||
// and its own InFlight counter.
|
||||
type NodeModel struct {
|
||||
ID string `gorm:"primaryKey;size:36" json:"id"`
|
||||
NodeID string `gorm:"index;size:36" json:"node_id"`
|
||||
ModelName string `gorm:"index;size:255" json:"model_name"`
|
||||
ReplicaIndex int `gorm:"column:replica_index;default:0;index" json:"replica_index"`
|
||||
Address string `gorm:"size:255" json:"address"` // gRPC address for this replica's backend process
|
||||
State string `gorm:"size:32;default:idle" json:"state"` // loading, loaded, unloading, idle
|
||||
InFlight int `json:"in_flight"` // number of active requests on this replica
|
||||
LastUsed time.Time `json:"last_used"`
|
||||
LoadingBy string `gorm:"size:36" json:"loading_by,omitempty"` // frontend ID that triggered loading
|
||||
BackendType string `gorm:"size:128" json:"backend_type,omitempty"` // e.g. "llama-cpp"; used by reconciler to replicate loads
|
||||
ModelOptsBlob []byte `gorm:"type:bytea" json:"-"` // serialized pb.ModelOptions for replica scale-ups
|
||||
ID string `gorm:"primaryKey;size:36" json:"id"`
|
||||
NodeID string `gorm:"index;size:36" json:"node_id"`
|
||||
ModelName string `gorm:"index;size:255" json:"model_name"`
|
||||
ReplicaIndex int `gorm:"column:replica_index;default:0;index" json:"replica_index"`
|
||||
Address string `gorm:"size:255" json:"address"` // gRPC address for this replica's backend process
|
||||
State string `gorm:"size:32;default:idle" json:"state"` // loading, loaded, unloading, idle
|
||||
InFlight int `json:"in_flight"` // number of active requests on this replica
|
||||
LastUsed time.Time `json:"last_used"`
|
||||
LoadingBy string `gorm:"size:36" json:"loading_by,omitempty"` // frontend ID that triggered loading
|
||||
BackendType string `gorm:"size:128" json:"backend_type,omitempty"` // e.g. "llama-cpp"; used by reconciler to replicate loads
|
||||
ModelOptsBlob []byte `gorm:"type:bytea" json:"-"` // serialized pb.ModelOptions for replica scale-ups
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
@@ -1287,7 +1287,7 @@ func (r *NodeRegistry) UpdateMaxReplicasPerModel(ctx context.Context, nodeID str
|
||||
res := r.db.WithContext(ctx).Model(&BackendNode{}).
|
||||
Where("id = ?", nodeID).
|
||||
Updates(map[string]any{
|
||||
ColMaxReplicasPerModel: n,
|
||||
ColMaxReplicasPerModel: n,
|
||||
"max_replicas_per_model_manually_set": true,
|
||||
})
|
||||
if res.Error != nil {
|
||||
@@ -1460,7 +1460,7 @@ func (r *NodeRegistry) UpsertPendingBackendOp(ctx context.Context, nodeID, backe
|
||||
NextRetryAt: time.Now(),
|
||||
}
|
||||
return r.db.WithContext(ctx).Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "node_id"}, {Name: "backend"}, {Name: "op"}},
|
||||
Columns: []clause.Column{{Name: "node_id"}, {Name: "backend"}, {Name: "op"}},
|
||||
DoUpdates: clause.AssignmentColumns([]string{"galleries", "next_retry_at"}),
|
||||
}).Create(&row).Error
|
||||
}
|
||||
@@ -1515,6 +1515,27 @@ func (r *NodeRegistry) RecordPendingBackendOpFailure(ctx context.Context, id uin
|
||||
})
|
||||
}
|
||||
|
||||
// RecordPendingBackendOpInFlight is the "soft failure" cousin of
|
||||
// RecordPendingBackendOpFailure. Used when a NATS install round-trip timed
|
||||
// out but the worker is still installing in the background. Stores the
|
||||
// message in LastError and pushes NextRetryAt out by `retryDelay` (typically
|
||||
// the install timeout) so the reconciler does not immediately re-fire
|
||||
// another install while the worker is still busy.
|
||||
//
|
||||
// Attempts is intentionally NOT incremented: an in-flight timeout is not a
|
||||
// failed attempt, it is a still-in-progress one. Incrementing it would let a
|
||||
// genuinely-progressing slow install (e.g. 30 GB CUDA image on Wi-Fi) trip
|
||||
// the maxPendingBackendOpAttempts cap in the reconciler and dead-letter the
|
||||
// row while the worker is still legitimately working.
|
||||
func (r *NodeRegistry) RecordPendingBackendOpInFlight(ctx context.Context, id uint, lastError string, retryDelay time.Duration) error {
|
||||
return r.db.WithContext(ctx).Model(&PendingBackendOp{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]any{
|
||||
"last_error": lastError,
|
||||
"next_retry_at": time.Now().Add(retryDelay),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// backoffForAttempt is exponential from 30s doubling up to a 15m cap. The
|
||||
// reconciler tick is 30s so anything shorter would just re-fire immediately.
|
||||
func backoffForAttempt(attempts int) time.Duration {
|
||||
|
||||
@@ -688,7 +688,7 @@ func (r *SmartRouter) installBackendOnNode(ctx context.Context, node *BackendNod
|
||||
|
||||
key := fmt.Sprintf("%s|%s|%s|%d", node.ID, backendType, modelID, replicaIndex)
|
||||
v, err, _ := r.installFlight.Do(key, func() (any, error) {
|
||||
reply, err := r.unloader.InstallBackend(node.ID, backendType, modelID, r.galleriesJSON, "", "", "", replicaIndex)
|
||||
reply, err := r.unloader.InstallBackend(node.ID, backendType, modelID, r.galleriesJSON, "", "", "", replicaIndex, "", nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -330,7 +330,7 @@ type upgradeCall struct {
|
||||
replica int
|
||||
}
|
||||
|
||||
func (f *fakeUnloader) InstallBackend(nodeID, backend, modelID, _, _, _, _ string, replica int) (*messaging.BackendInstallReply, error) {
|
||||
func (f *fakeUnloader) InstallBackend(nodeID, backend, modelID, _, _, _, _ string, replica int, _ string, _ func(messaging.BackendInstallProgressEvent)) (*messaging.BackendInstallReply, error) {
|
||||
// installHook intentionally runs OUTSIDE the mutex: the hook may block
|
||||
// on a channel and we don't want to serialize concurrent callers,
|
||||
// which would defeat the singleflight-overlap test.
|
||||
|
||||
@@ -2,9 +2,15 @@ package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/nats-io/nats.go"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
@@ -28,7 +34,7 @@ type backendStopRequest struct {
|
||||
// nats.ErrNoResponders for old workers that don't subscribe to the new
|
||||
// backend.upgrade subject.
|
||||
type NodeCommandSender interface {
|
||||
InstallBackend(nodeID, backendType, modelID, galleriesJSON, uri, name, alias string, replicaIndex int) (*messaging.BackendInstallReply, error)
|
||||
InstallBackend(nodeID, backendType, modelID, galleriesJSON, uri, name, alias string, replicaIndex int, opID string, onProgress func(messaging.BackendInstallProgressEvent)) (*messaging.BackendInstallReply, error)
|
||||
UpgradeBackend(nodeID, backendType, galleriesJSON, uri, name, alias string, replicaIndex int) (*messaging.BackendUpgradeReply, error)
|
||||
DeleteBackend(nodeID, backendName string) (*messaging.BackendDeleteReply, error)
|
||||
ListBackends(nodeID string) (*messaging.BackendListReply, error)
|
||||
@@ -43,18 +49,33 @@ type NodeCommandSender interface {
|
||||
// This mirrors the local ModelLoader's startProcess()/deleteProcess() but
|
||||
// over NATS for remote nodes.
|
||||
type RemoteUnloaderAdapter struct {
|
||||
registry ModelLocator
|
||||
nats messaging.MessagingClient
|
||||
registry ModelLocator
|
||||
nats messaging.MessagingClient
|
||||
installTimeout time.Duration
|
||||
upgradeTimeout time.Duration
|
||||
}
|
||||
|
||||
// NewRemoteUnloaderAdapter creates a new adapter.
|
||||
func NewRemoteUnloaderAdapter(registry ModelLocator, nats messaging.MessagingClient) *RemoteUnloaderAdapter {
|
||||
// NewRemoteUnloaderAdapter creates a new adapter. installTimeout and
|
||||
// upgradeTimeout govern the NATS request-reply deadlines for backend.install
|
||||
// and backend.upgrade respectively. Use
|
||||
// DistributedConfig.BackendInstallTimeoutOrDefault() /
|
||||
// BackendUpgradeTimeoutOrDefault() at construction.
|
||||
func NewRemoteUnloaderAdapter(registry ModelLocator, nats messaging.MessagingClient, installTimeout, upgradeTimeout time.Duration) *RemoteUnloaderAdapter {
|
||||
return &RemoteUnloaderAdapter{
|
||||
registry: registry,
|
||||
nats: nats,
|
||||
registry: registry,
|
||||
nats: nats,
|
||||
installTimeout: installTimeout,
|
||||
upgradeTimeout: upgradeTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// InstallTimeout returns the configured backend.install round-trip timeout.
|
||||
// Used by DistributedBackendManager to push NextRetryAt out by this duration
|
||||
// when a worker times out replying but is still installing in the background.
|
||||
func (a *RemoteUnloaderAdapter) InstallTimeout() time.Duration {
|
||||
return a.installTimeout
|
||||
}
|
||||
|
||||
// UnloadRemoteModel finds the node(s) hosting the given model and tells them
|
||||
// to stop their backend process via NATS backend.stop event.
|
||||
// The worker process handles: Free() → kill process.
|
||||
@@ -87,18 +108,59 @@ func (a *RemoteUnloaderAdapter) UnloadRemoteModel(modelName string) error {
|
||||
// is on disk, the worker just spawns a process; only a missing binary
|
||||
// triggers a full gallery pull.
|
||||
//
|
||||
// Timeout: 3 minutes. Most calls return in under 2 seconds (process already
|
||||
// running). The 3-minute ceiling covers the cold-binary spawn-after-download
|
||||
// case while still failing fast enough to surface real worker hangs.
|
||||
// Timeout: configured via DistributedConfig.BackendInstallTimeoutOrDefault
|
||||
// (default 15m). Most calls return in under 2 seconds (process already
|
||||
// running). The 15-minute ceiling covers the cold-binary spawn-after-download
|
||||
// case on slow links (Jetson Wi-Fi, multi-GB CUDA images) while still
|
||||
// failing fast enough to surface real worker hangs.
|
||||
//
|
||||
// For force-reinstall (admin-driven Upgrade), use UpgradeBackend instead —
|
||||
// For force-reinstall (admin-driven Upgrade), use UpgradeBackend instead -
|
||||
// it lives on a different NATS subject so it cannot head-of-line-block
|
||||
// routine load traffic on the same worker.
|
||||
func (a *RemoteUnloaderAdapter) InstallBackend(nodeID, backendType, modelID, galleriesJSON, uri, name, alias string, replicaIndex int) (*messaging.BackendInstallReply, error) {
|
||||
func (a *RemoteUnloaderAdapter) InstallBackend(
|
||||
nodeID, backendType, modelID, galleriesJSON, uri, name, alias string,
|
||||
replicaIndex int,
|
||||
opID string,
|
||||
onProgress func(messaging.BackendInstallProgressEvent),
|
||||
) (*messaging.BackendInstallReply, error) {
|
||||
subject := messaging.SubjectNodeBackendInstall(nodeID)
|
||||
xlog.Info("Sending NATS backend.install", "nodeID", nodeID, "backend", backendType, "modelID", modelID, "replica", replicaIndex)
|
||||
xlog.Info("Sending NATS backend.install", "nodeID", nodeID, "backend", backendType, "modelID", modelID, "replica", replicaIndex, "opID", opID)
|
||||
|
||||
return messaging.RequestJSON[messaging.BackendInstallRequest, messaging.BackendInstallReply](a.nats, subject, messaging.BackendInstallRequest{
|
||||
// Subscribe to the per-op progress subject BEFORE publishing the install
|
||||
// request so we don't miss early events. When onProgress is nil OR opID
|
||||
// is empty (the reconciler-driven retry path), skip subscription entirely:
|
||||
// silent installs cost nothing extra.
|
||||
var sub messaging.Subscription
|
||||
if onProgress != nil && opID != "" {
|
||||
progressSubject := messaging.SubjectNodeBackendInstallProgress(nodeID, opID)
|
||||
s, subErr := a.nats.Subscribe(progressSubject, func(raw []byte) {
|
||||
var ev messaging.BackendInstallProgressEvent
|
||||
if err := json.Unmarshal(raw, &ev); err != nil {
|
||||
xlog.Debug("malformed install progress event", "subject", progressSubject, "error", err)
|
||||
return
|
||||
}
|
||||
// Goroutine guard: a slow onProgress callback must not stall
|
||||
// the NATS reader thread.
|
||||
//
|
||||
// NOTE: events spawn one goroutine each, so ordering at the
|
||||
// consumer is best-effort. In practice the worker debounces to
|
||||
// ~250ms which is far larger than goroutine scheduling jitter,
|
||||
// so reordering is rare. The worker's final Flush() event is
|
||||
// intended to win as the terminal tick. A future hardening pass
|
||||
// could add a Seq uint64 field to BackendInstallProgressEvent
|
||||
// and drop stale-by-seq at the bridge if reordering becomes a
|
||||
// real UX issue.
|
||||
go onProgress(ev)
|
||||
})
|
||||
if subErr != nil {
|
||||
xlog.Warn("Failed to subscribe to install progress subject; proceeding without progress streaming",
|
||||
"subject", progressSubject, "error", subErr)
|
||||
} else {
|
||||
sub = s
|
||||
}
|
||||
}
|
||||
|
||||
reply, err := messaging.RequestJSON[messaging.BackendInstallRequest, messaging.BackendInstallReply](a.nats, subject, messaging.BackendInstallRequest{
|
||||
Backend: backendType,
|
||||
ModelID: modelID,
|
||||
BackendGalleries: galleriesJSON,
|
||||
@@ -106,29 +168,46 @@ func (a *RemoteUnloaderAdapter) InstallBackend(nodeID, backendType, modelID, gal
|
||||
Name: name,
|
||||
Alias: alias,
|
||||
ReplicaIndex: int32(replicaIndex),
|
||||
}, 3*time.Minute)
|
||||
OpID: opID,
|
||||
}, a.installTimeout)
|
||||
|
||||
if sub != nil {
|
||||
_ = sub.Unsubscribe()
|
||||
}
|
||||
|
||||
if err != nil && isNATSTimeout(err) {
|
||||
return nil, fmt.Errorf("%w (subject=%s nodeID=%s backend=%s): %v",
|
||||
galleryop.ErrWorkerStillInstalling, subject, nodeID, backendType, err)
|
||||
}
|
||||
return reply, err
|
||||
}
|
||||
|
||||
// UpgradeBackend sends a backend.upgrade request-reply to a worker node.
|
||||
// The worker stops every live process for this backend, force-reinstalls
|
||||
// from the gallery (overwriting the on-disk artifact), and replies. The
|
||||
// next routine InstallBackend call spawns a fresh process with the new
|
||||
// binary — upgrade itself does not start a process.
|
||||
// binary - upgrade itself does not start a process.
|
||||
//
|
||||
// Timeout: 15 minutes. Real-world worst case observed: 8–10 minutes for
|
||||
// large CUDA-l4t backend images on Jetson over WiFi.
|
||||
// Timeout: configured via DistributedConfig.BackendUpgradeTimeoutOrDefault
|
||||
// (default 15m). Real-world worst case observed: 8-10 minutes for large
|
||||
// CUDA-l4t backend images on Jetson over WiFi.
|
||||
func (a *RemoteUnloaderAdapter) UpgradeBackend(nodeID, backendType, galleriesJSON, uri, name, alias string, replicaIndex int) (*messaging.BackendUpgradeReply, error) {
|
||||
subject := messaging.SubjectNodeBackendUpgrade(nodeID)
|
||||
xlog.Info("Sending NATS backend.upgrade", "nodeID", nodeID, "backend", backendType, "replica", replicaIndex)
|
||||
|
||||
return messaging.RequestJSON[messaging.BackendUpgradeRequest, messaging.BackendUpgradeReply](a.nats, subject, messaging.BackendUpgradeRequest{
|
||||
reply, err := messaging.RequestJSON[messaging.BackendUpgradeRequest, messaging.BackendUpgradeReply](a.nats, subject, messaging.BackendUpgradeRequest{
|
||||
Backend: backendType,
|
||||
BackendGalleries: galleriesJSON,
|
||||
URI: uri,
|
||||
Name: name,
|
||||
Alias: alias,
|
||||
ReplicaIndex: int32(replicaIndex),
|
||||
}, 15*time.Minute)
|
||||
}, a.upgradeTimeout)
|
||||
if err != nil && isNATSTimeout(err) {
|
||||
return nil, fmt.Errorf("%w (subject=%s nodeID=%s backend=%s): %v",
|
||||
galleryop.ErrWorkerStillInstalling, subject, nodeID, backendType, err)
|
||||
}
|
||||
return reply, err
|
||||
}
|
||||
|
||||
// installWithForceFallback is the rolling-update fallback used by
|
||||
@@ -141,7 +220,7 @@ func (a *RemoteUnloaderAdapter) installWithForceFallback(nodeID, backendType, ga
|
||||
subject := messaging.SubjectNodeBackendInstall(nodeID)
|
||||
xlog.Warn("Falling back to legacy backend.install Force=true (old worker)", "nodeID", nodeID, "backend", backendType)
|
||||
|
||||
return messaging.RequestJSON[messaging.BackendInstallRequest, messaging.BackendInstallReply](a.nats, subject, messaging.BackendInstallRequest{
|
||||
reply, err := messaging.RequestJSON[messaging.BackendInstallRequest, messaging.BackendInstallReply](a.nats, subject, messaging.BackendInstallRequest{
|
||||
Backend: backendType,
|
||||
BackendGalleries: galleriesJSON,
|
||||
URI: uri,
|
||||
@@ -149,7 +228,12 @@ func (a *RemoteUnloaderAdapter) installWithForceFallback(nodeID, backendType, ga
|
||||
Alias: alias,
|
||||
ReplicaIndex: int32(replicaIndex),
|
||||
Force: true,
|
||||
}, 15*time.Minute)
|
||||
}, a.upgradeTimeout)
|
||||
if err != nil && isNATSTimeout(err) {
|
||||
return nil, fmt.Errorf("%w (subject=%s nodeID=%s backend=%s): %v",
|
||||
galleryop.ErrWorkerStillInstalling, subject, nodeID, backendType, err)
|
||||
}
|
||||
return reply, err
|
||||
}
|
||||
|
||||
// ListBackends queries a worker node for its installed backends via NATS request-reply.
|
||||
@@ -228,3 +312,14 @@ func (a *RemoteUnloaderAdapter) StopNode(nodeID string) error {
|
||||
subject := messaging.SubjectNodeStop(nodeID)
|
||||
return a.nats.Publish(subject, nil)
|
||||
}
|
||||
|
||||
// isNATSTimeout returns true if err looks like a NATS request-reply timeout.
|
||||
// nats.ErrTimeout is the canonical sentinel; context.DeadlineExceeded can
|
||||
// also surface depending on the client's path; we accept both, plus a
|
||||
// string-match fallback for clients that return a bare error.
|
||||
func isNATSTimeout(err error) bool {
|
||||
if errors.Is(err, nats.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return true
|
||||
}
|
||||
return err != nil && strings.Contains(err.Error(), "nats: timeout")
|
||||
}
|
||||
|
||||
@@ -3,13 +3,16 @@ package nodes
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/nats-io/nats.go"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
)
|
||||
|
||||
@@ -60,6 +63,7 @@ type publishCall struct {
|
||||
type requestCall struct {
|
||||
Subject string
|
||||
Data []byte
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
func (f *fakeMessagingClient) Publish(subject string, data any) error {
|
||||
@@ -93,10 +97,10 @@ func (f *fakeMessagingClient) SubscribeReply(_ string, _ func(data []byte, reply
|
||||
return &fakeSubscription{}, nil
|
||||
}
|
||||
|
||||
func (f *fakeMessagingClient) Request(subject string, data []byte, _ time.Duration) ([]byte, error) {
|
||||
func (f *fakeMessagingClient) Request(subject string, data []byte, timeout time.Duration) ([]byte, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.requestCalls = append(f.requestCalls, requestCall{Subject: subject, Data: data})
|
||||
f.requestCalls = append(f.requestCalls, requestCall{Subject: subject, Data: data, Timeout: timeout})
|
||||
return f.requestReply, f.requestErr
|
||||
}
|
||||
|
||||
@@ -119,7 +123,7 @@ var _ = Describe("RemoteUnloaderAdapter", func() {
|
||||
BeforeEach(func() {
|
||||
locator = &fakeModelLocator{}
|
||||
mc = &fakeMessagingClient{}
|
||||
adapter = NewRemoteUnloaderAdapter(locator, mc)
|
||||
adapter = NewRemoteUnloaderAdapter(locator, mc, 3*time.Minute, 15*time.Minute)
|
||||
})
|
||||
|
||||
Describe("UnloadRemoteModel", func() {
|
||||
@@ -154,7 +158,7 @@ var _ = Describe("RemoteUnloaderAdapter", func() {
|
||||
}
|
||||
// Use a messaging client that fails the first Publish call only.
|
||||
failOnce := &failOnceMessagingClient{inner: mc, failOn: 0}
|
||||
adapter = NewRemoteUnloaderAdapter(locator, failOnce)
|
||||
adapter = NewRemoteUnloaderAdapter(locator, failOnce, 3*time.Minute, 15*time.Minute)
|
||||
|
||||
Expect(adapter.UnloadRemoteModel("llama")).To(Succeed())
|
||||
|
||||
@@ -259,3 +263,96 @@ func (f *failOnceMessagingClient) Request(subject string, data []byte, timeout t
|
||||
|
||||
func (f *failOnceMessagingClient) IsConnected() bool { return true }
|
||||
func (f *failOnceMessagingClient) Close() {}
|
||||
|
||||
var _ = Describe("RemoteUnloaderAdapter timeout configuration", func() {
|
||||
It("passes the configured install timeout to the messaging client", func() {
|
||||
mc := newScriptedMessagingClient()
|
||||
mc.scriptReply(messaging.SubjectNodeBackendInstall("n1"), messaging.BackendInstallReply{Success: true, Address: "127.0.0.1:0"})
|
||||
adapter := NewRemoteUnloaderAdapter(nil, mc, 7*time.Minute, 11*time.Minute)
|
||||
|
||||
_, err := adapter.InstallBackend("n1", "llama-cpp", "", "[]", "", "", "", 0, "", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(mc.calls).To(HaveLen(1))
|
||||
Expect(mc.calls[0].Timeout).To(Equal(7 * time.Minute))
|
||||
})
|
||||
|
||||
It("passes the configured upgrade timeout to the messaging client", func() {
|
||||
mc := newScriptedMessagingClient()
|
||||
mc.scriptReply(messaging.SubjectNodeBackendUpgrade("n1"), messaging.BackendUpgradeReply{Success: true})
|
||||
adapter := NewRemoteUnloaderAdapter(nil, mc, 7*time.Minute, 11*time.Minute)
|
||||
|
||||
_, err := adapter.UpgradeBackend("n1", "llama-cpp", "[]", "", "", "", 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(mc.calls).To(HaveLen(1))
|
||||
Expect(mc.calls[0].Timeout).To(Equal(11 * time.Minute))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("RemoteUnloaderAdapter NATS timeout handling", func() {
|
||||
It("wraps nats.ErrTimeout from InstallBackend in galleryop.ErrWorkerStillInstalling", func() {
|
||||
mc := newScriptedMessagingClient()
|
||||
mc.scriptErr(messaging.SubjectNodeBackendInstall("n1"), nats.ErrTimeout)
|
||||
adapter := NewRemoteUnloaderAdapter(nil, mc, 100*time.Millisecond, 1*time.Second)
|
||||
|
||||
_, err := adapter.InstallBackend("n1", "vllm", "", "[]", "", "", "", 0, "", nil)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(errors.Is(err, galleryop.ErrWorkerStillInstalling)).To(BeTrue(),
|
||||
"expected wrapped ErrWorkerStillInstalling, got %v", err)
|
||||
})
|
||||
|
||||
It("does NOT wrap non-timeout errors", func() {
|
||||
mc := newScriptedMessagingClient()
|
||||
mc.scriptErr(messaging.SubjectNodeBackendInstall("n1"), nats.ErrNoResponders)
|
||||
adapter := NewRemoteUnloaderAdapter(nil, mc, 100*time.Millisecond, 1*time.Second)
|
||||
|
||||
_, err := adapter.InstallBackend("n1", "vllm", "", "[]", "", "", "", 0, "", nil)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(errors.Is(err, galleryop.ErrWorkerStillInstalling)).To(BeFalse())
|
||||
Expect(errors.Is(err, nats.ErrNoResponders)).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("RemoteUnloaderAdapter install progress streaming", func() {
|
||||
It("forwards BackendInstallProgressEvent values into the onProgress callback when the worker publishes them", func() {
|
||||
mc := newScriptedMessagingClient()
|
||||
mc.scriptReply(messaging.SubjectNodeBackendInstall("n1"), messaging.BackendInstallReply{Success: true, Address: "127.0.0.1:0"})
|
||||
mc.scheduleProgressPublish("n1", "op-abc", []messaging.BackendInstallProgressEvent{
|
||||
{OpID: "op-abc", NodeID: "n1", Backend: "vllm", FileName: "vllm.tar.zst", Current: "100 MB", Total: "1 GB", Percentage: 10},
|
||||
{OpID: "op-abc", NodeID: "n1", Backend: "vllm", FileName: "vllm.tar.zst", Current: "500 MB", Total: "1 GB", Percentage: 50},
|
||||
})
|
||||
|
||||
adapter := NewRemoteUnloaderAdapter(nil, mc, 1*time.Second, 1*time.Second)
|
||||
var (
|
||||
received []messaging.BackendInstallProgressEvent
|
||||
mu sync.Mutex
|
||||
)
|
||||
onProgress := func(ev messaging.BackendInstallProgressEvent) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
received = append(received, ev)
|
||||
}
|
||||
|
||||
_, err := adapter.InstallBackend("n1", "vllm", "", "[]", "", "", "", 0, "op-abc", onProgress)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Eventually(func() int {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return len(received)
|
||||
}, "1s").Should(Equal(2))
|
||||
})
|
||||
|
||||
It("does NOT subscribe when onProgress is nil (reconciler retry path)", func() {
|
||||
mc := newScriptedMessagingClient()
|
||||
mc.scriptReply(messaging.SubjectNodeBackendInstall("n1"), messaging.BackendInstallReply{Success: true})
|
||||
|
||||
adapter := NewRemoteUnloaderAdapter(nil, mc, 1*time.Second, 1*time.Second)
|
||||
_, err := adapter.InstallBackend("n1", "vllm", "", "[]", "", "", "", 0, "", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(mc.subscribeCalls()).To(BeEmpty(),
|
||||
"reconciler-driven retries must not subscribe to the per-op progress subject")
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
@@ -15,7 +17,7 @@ var _ = Describe("RemoteUnloaderAdapter.UpgradeBackend", func() {
|
||||
mc.scriptReply(messaging.SubjectNodeBackendUpgrade(nodeID),
|
||||
messaging.BackendUpgradeReply{Success: true})
|
||||
|
||||
adapter := NewRemoteUnloaderAdapter(nil, mc)
|
||||
adapter := NewRemoteUnloaderAdapter(nil, mc, 3*time.Minute, 15*time.Minute)
|
||||
reply, err := adapter.UpgradeBackend(nodeID, "llama-cpp", `[{"name":"x"}]`, "", "", "", 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(reply.Success).To(BeTrue())
|
||||
@@ -24,7 +26,7 @@ var _ = Describe("RemoteUnloaderAdapter.UpgradeBackend", func() {
|
||||
It("returns the underlying error when the subject has no responders", func() {
|
||||
mc := newScriptedMessagingClient() // unscripted subject => fakeNoRespondersErr by harness convention
|
||||
|
||||
adapter := NewRemoteUnloaderAdapter(nil, mc)
|
||||
adapter := NewRemoteUnloaderAdapter(nil, mc, 3*time.Minute, 15*time.Minute)
|
||||
_, err := adapter.UpgradeBackend("missing-node", "llama-cpp", "", "", "", "", 0)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
115
core/services/skills/skills_mcp_test.go
Normal file
115
core/services/skills/skills_mcp_test.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package skills_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
agiSkills "github.com/mudler/LocalAGI/services/skills"
|
||||
localskills "github.com/mudler/LocalAI/core/services/skills"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestSkillsMCP(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Skills MCP test")
|
||||
}
|
||||
|
||||
// listSkillsResult mirrors the output struct of skillserver's list_skills tool.
|
||||
type listSkillsResult struct {
|
||||
Skills []struct {
|
||||
ID string `json:"id"`
|
||||
Description string `json:"description,omitempty"`
|
||||
} `json:"skills"`
|
||||
}
|
||||
|
||||
// Exercises the same wire the agent uses at runtime: open an in-process
|
||||
// MCP session via LocalAGI's skills.Service, create a skill through the
|
||||
// LocalAI FilesystemManager, then list_skills on the still-open session.
|
||||
// Guards against regressions in the manager <-> MCP session lifecycle
|
||||
// (e.g. cached manager not picking up newly-created skills).
|
||||
var _ = Describe("Skills exposed to agent via MCP", func() {
|
||||
var (
|
||||
stateDir string
|
||||
svc *agiSkills.Service
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
stateDir, err = os.MkdirTemp("", "skills-mcp-test")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// Create the LocalAGI skills service (this is what AgentPoolService wires
|
||||
// into LocalAGI's state.NewAgentPool for MCP session exposure).
|
||||
svc, err = agiSkills.NewService(stateDir)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 30*time.Second)
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
cancel()
|
||||
Expect(os.RemoveAll(stateDir)).To(Succeed())
|
||||
})
|
||||
|
||||
It("returns a skill created after the MCP session was established", func() {
|
||||
// Open the MCP session first — this is what the agent does at startup
|
||||
// with EnableSkills=true, before any skill might exist.
|
||||
session, err := svc.GetMCPSession(ctx)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(session).NotTo(BeNil())
|
||||
|
||||
res, err := session.CallTool(ctx, &mcp.CallToolParams{Name: "list_skills"})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.IsError).To(BeFalse())
|
||||
var initial listSkillsResult
|
||||
Expect(decodeMCPText(res, &initial)).To(Succeed())
|
||||
Expect(initial.Skills).To(BeEmpty(), "no skills should exist initially")
|
||||
|
||||
// Create a skill via the LocalAI FilesystemManager — same code path the
|
||||
// /api/agents/skills POST endpoint takes.
|
||||
mgr := localskills.NewFilesystemManager(svc)
|
||||
_, err = mgr.Create("talk-like-pirate", "Talk like a pirate", "Speak in pirate-style.", "", "", "", nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// Re-list via the SAME already-open session: the manager is shared,
|
||||
// so a freshly-created skill must be visible without re-attaching.
|
||||
res, err = session.CallTool(ctx, &mcp.CallToolParams{Name: "list_skills"})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.IsError).To(BeFalse())
|
||||
|
||||
var got listSkillsResult
|
||||
Expect(decodeMCPText(res, &got)).To(Succeed())
|
||||
|
||||
ids := make([]string, 0, len(got.Skills))
|
||||
for _, s := range got.Skills {
|
||||
ids = append(ids, s.ID)
|
||||
}
|
||||
Expect(ids).To(ContainElement("talk-like-pirate"))
|
||||
})
|
||||
})
|
||||
|
||||
func mcpText(res *mcp.CallToolResult) string {
|
||||
text := ""
|
||||
for _, c := range res.Content {
|
||||
if tc, ok := c.(*mcp.TextContent); ok {
|
||||
text += tc.Text
|
||||
}
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
func decodeMCPText(res *mcp.CallToolResult, out any) error {
|
||||
text := mcpText(res)
|
||||
if text == "" {
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal([]byte(text), out)
|
||||
}
|
||||
@@ -7,14 +7,22 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// installProgressDebounce is the leading-edge window the worker uses when
|
||||
// streaming download progress to the master. 250ms caps wire chatter at
|
||||
// ~4 events/sec per in-flight install while still surfacing every
|
||||
// meaningful percentage jump.
|
||||
const installProgressDebounce = 250 * time.Millisecond
|
||||
|
||||
// buildProcessKey is the supervisor's stable identifier for a backend gRPC
|
||||
// process. It includes the replica index so the same model can run multiple
|
||||
// processes on a worker simultaneously without colliding on the same map slot
|
||||
@@ -100,6 +108,20 @@ func (s *backendSupervisor) installBackend(req messaging.BackendInstallRequest,
|
||||
}
|
||||
}
|
||||
|
||||
// When the master tagged this install with an OpID, stream the
|
||||
// gallery download progress back to it on the per-op NATS subject.
|
||||
// Old masters that omit OpID stay on the silent path so they keep
|
||||
// working without changes. The publisher releases its mutex before
|
||||
// every Publish so a slow link never stalls the download loop, and
|
||||
// the deferred Flush guarantees a terminal-percentage event reaches
|
||||
// the master even when the install errors out.
|
||||
var downloadCb func(file, current, total string, percentage float64)
|
||||
if req.OpID != "" && s.nats != nil {
|
||||
publisher := nodes.NewDebouncedInstallProgressPublisher(s.nats, s.nodeID, req.OpID, req.Backend, installProgressDebounce)
|
||||
downloadCb = publisher.OnDownload
|
||||
defer publisher.Flush()
|
||||
}
|
||||
|
||||
// On upgrade, run the gallery install path even if the binary already
|
||||
// exists on disk: findBackend would otherwise short-circuit and we'd
|
||||
// restart the same stale binary. The force flag passed to
|
||||
@@ -112,14 +134,14 @@ func (s *backendSupervisor) installBackend(req messaging.BackendInstallRequest,
|
||||
if req.URI != "" {
|
||||
xlog.Info("Installing backend from external URI", "backend", req.Backend, "uri", req.URI, "force", force)
|
||||
if err := galleryop.InstallExternalBackend(
|
||||
context.Background(), galleries, s.systemState, s.ml, nil, req.URI, req.Name, req.Alias, s.cfg.RequireBackendIntegrity,
|
||||
context.Background(), galleries, s.systemState, s.ml, downloadCb, req.URI, req.Name, req.Alias, s.cfg.RequireBackendIntegrity,
|
||||
); err != nil {
|
||||
return "", fmt.Errorf("installing backend from gallery: %w", err)
|
||||
}
|
||||
} else {
|
||||
xlog.Info("Installing backend from gallery", "backend", req.Backend, "force", force)
|
||||
if err := gallery.InstallBackendFromGallery(
|
||||
context.Background(), galleries, s.systemState, s.ml, req.Backend, nil, force, s.cfg.RequireBackendIntegrity,
|
||||
context.Background(), galleries, s.systemState, s.ml, req.Backend, downloadCb, force, s.cfg.RequireBackendIntegrity,
|
||||
); err != nil {
|
||||
return "", fmt.Errorf("installing backend from gallery: %w", err)
|
||||
}
|
||||
|
||||
@@ -444,11 +444,15 @@ These llama.cpp options are passed through the `options:` array.
|
||||
|
||||
### Prompt Caching
|
||||
|
||||
The recommended way to enable prompt caching for the `llama-cpp` backend is the **server-side prompt cache** controlled by `cache_ram` / `kv_unified` / `cache_idle_slots` in the `options:` array (see [llama.cpp backend options]({{%relref "features/text-generation#server-side-prompt-cache-repeated-system-prompts" %}})). It's on by default since LocalAI v4.3 and is what gives repeated system prompts a near-zero prefill on the second call.
|
||||
|
||||
The fields below come from upstream llama.cpp's **CLI completion tool** and are passed through to the gRPC backend for compatibility, but the gRPC server itself does not consume them: keep them empty unless you're targeting a non-llama-cpp backend that reads them.
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `prompt_cache_path` | string | Path to store prompt cache (relative to models directory) |
|
||||
| `prompt_cache_all` | bool | Cache all prompts automatically |
|
||||
| `prompt_cache_ro` | bool | Read-only prompt cache |
|
||||
| `prompt_cache_path` | string | (legacy / unused by llama-cpp gRPC server) Path to a file-backed prompt cache for upstream's CLI completion tool. |
|
||||
| `prompt_cache_all` | bool | (legacy / unused by llama-cpp gRPC server) |
|
||||
| `prompt_cache_ro` | bool | (legacy / unused by llama-cpp gRPC server) |
|
||||
|
||||
### Text Processing
|
||||
|
||||
|
||||
@@ -253,10 +253,12 @@ User API keys inherit the creating user's role. Admin keys grant admin access; u
|
||||
| `GET` | `/api/auth/api-keys` | List user's API keys | Yes |
|
||||
| `DELETE` | `/api/auth/api-keys/:id` | Revoke API key | Yes |
|
||||
| `GET` | `/api/auth/usage` | User's own usage stats | Yes |
|
||||
| `GET` | `/api/auth/usage/sources` | User's own per-API-key / per-source breakdown | Yes |
|
||||
| `GET` | `/api/auth/admin/users` | List all users | Admin |
|
||||
| `PUT` | `/api/auth/admin/users/:id/role` | Change user role | Admin |
|
||||
| `DELETE` | `/api/auth/admin/users/:id` | Delete user | Admin |
|
||||
| `GET` | `/api/auth/admin/usage` | All users' usage stats | Admin |
|
||||
| `GET` | `/api/auth/admin/usage/sources` | All users' per-API-key / per-source breakdown | Admin |
|
||||
| `POST` | `/api/auth/admin/invites` | Create invite link | Admin |
|
||||
| `GET` | `/api/auth/admin/invites` | List all invites | Admin |
|
||||
| `DELETE` | `/api/auth/admin/invites/:id` | Revoke unused invite | Admin |
|
||||
@@ -327,10 +329,79 @@ curl "http://localhost:8080/api/auth/admin/usage?period=month&user_id=<user-id>"
|
||||
### Usage Dashboard
|
||||
|
||||
The web UI Usage page provides:
|
||||
- **Period selector** — switch between day, week, month, and all-time views
|
||||
- **Summary cards** — total requests, prompt tokens, completion tokens, total tokens
|
||||
- **By Model table** — per-model breakdown with visual usage bars
|
||||
- **By User table** (admin only) — per-user breakdown across all models
|
||||
- **Period selector** - switch between day, week, month, and all-time views
|
||||
- **Summary cards** - total requests, prompt tokens, completion tokens, total tokens
|
||||
- **By Model table** - per-model breakdown with visual usage bars
|
||||
- **By User table** (admin only) - per-user breakdown across all models
|
||||
- **Sources tab** - per-API-key and per-source breakdown (described below)
|
||||
|
||||
### Per-API-key Breakdown
|
||||
|
||||
The **Sources** tab on the Usage page surfaces a third dimension of the same data: traffic broken down by API key and by request source. Three source classes are tracked:
|
||||
|
||||
- **API key** - request authenticated with a named user API key (`Authorization: Bearer lai-...`, `x-api-key`, or `token` cookie). Each key shows up with its label (snapshotted at write time, so revoked keys still display the original name).
|
||||
- **Web UI** - request authenticated with a browser session cookie.
|
||||
- **Legacy** - request authenticated with an env-configured `LOCALAI_API_KEY`. Visible to admins only.
|
||||
|
||||
The Sources tab is visible to every authenticated user. Non-admins see only their own keys plus their own Web UI traffic (legacy is filtered server-side). Admins see every key from every user.
|
||||
|
||||
The tab is laid out as:
|
||||
|
||||
- A **source mix ribbon** showing the percentage split across the three classes.
|
||||
- A **top-N + Other stacked time chart** (top 7 sources by total tokens; the rest roll up).
|
||||
- A **searchable, sortable table** of every key plus the Web UI and Legacy pseudo-rows. Click a row to filter the chart to that source.
|
||||
|
||||
#### Endpoints
|
||||
|
||||
| Method | Path | Auth | Description |
|
||||
|--------|------|------|-------------|
|
||||
| `GET` | `/api/auth/usage/sources` | Self | Caller's per-source breakdown. Excludes legacy. |
|
||||
| `GET` | `/api/auth/admin/usage/sources` | Admin | All users' per-source breakdown. Accepts `user_id` and `api_key_id` filters. Includes legacy. |
|
||||
|
||||
Both endpoints accept the same `period` parameter (`day`, `week`, `month`, `all`) as `/api/auth/usage`.
|
||||
|
||||
```bash
|
||||
# Your own per-source usage for the last week
|
||||
curl "http://localhost:8080/api/auth/usage/sources?period=week" \
|
||||
-H "Authorization: Bearer <key>"
|
||||
|
||||
# Admin: filter to a single API key across all users
|
||||
curl "http://localhost:8080/api/auth/admin/usage/sources?period=month&api_key_id=<key-id>" \
|
||||
-H "Authorization: Bearer <admin-key>"
|
||||
```
|
||||
|
||||
**Response shape:**
|
||||
|
||||
```json
|
||||
{
|
||||
"buckets": [
|
||||
{ "bucket": "2026-05-19", "source": "apikey",
|
||||
"api_key_id": "uuid", "api_key_name": "ci-runner",
|
||||
"total_tokens": 20000, "request_count": 142, "...": "..." },
|
||||
{ "bucket": "2026-05-19", "source": "web",
|
||||
"total_tokens": 300, "request_count": 11, "...": "..." }
|
||||
],
|
||||
"totals": {
|
||||
"by_source": {
|
||||
"apikey": { "tokens": 1234567, "requests": 8420 },
|
||||
"web": { "tokens": 92000, "requests": 211 }
|
||||
},
|
||||
"by_key": [
|
||||
{ "api_key_id": "uuid", "api_key_name": "ci-runner",
|
||||
"tokens": 2100000, "requests": 8420,
|
||||
"last_used": "2026-05-20T12:34:56Z" }
|
||||
],
|
||||
"grand_total": { "tokens": 1334777, "requests": 8645 }
|
||||
},
|
||||
"truncated": false
|
||||
}
|
||||
```
|
||||
|
||||
The `by_key` list is server-sorted by tokens descending and capped at 200 entries. When more keys would qualify, the response sets `"truncated": true` so the UI can show a notice.
|
||||
|
||||
#### Migration of pre-feature data
|
||||
|
||||
Usage rows recorded before this feature have no `source` column. On startup, `InitDB` backfills them as `legacy` when the synthetic `legacy-api-key` user_id was used, and `web` for everything else. The migration is idempotent; existing aggregations remain correct after the upgrade.
|
||||
|
||||
## Combining Auth Modes
|
||||
|
||||
|
||||
@@ -86,6 +86,8 @@ The frontend is a standard LocalAI instance with distributed mode enabled. These
|
||||
| `--auto-approve-nodes` | `LOCALAI_AUTO_APPROVE_NODES` | `false` | Auto-approve new worker nodes (skip admin approval) |
|
||||
| `--auth` | `LOCALAI_AUTH` | `false` | **Must be `true`** for distributed mode |
|
||||
| `--auth-database-url` | `LOCALAI_AUTH_DATABASE_URL` | *(required)* | PostgreSQL connection URL |
|
||||
| `--backend-install-timeout` | `LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT` | `15m` | NATS round-trip timeout for `backend.install` requests sent to worker nodes. Raise on slow links pulling multi-GB OCI images (e.g. Jetson over Wi-Fi). If the round-trip times out but the worker is still installing in the background, the admin UI shows the operation as `still installing in background` rather than failed, and the reconciler confirms completion via the next `backend.list` poll. |
|
||||
| `--backend-upgrade-timeout` | `LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT` | `15m` | NATS round-trip timeout for `backend.upgrade` requests (force-reinstall path). Same semantics as install timeout. |
|
||||
|
||||
### Optional: S3 Object Storage
|
||||
|
||||
@@ -103,6 +105,46 @@ When S3 is not configured, model files are transferred directly from the fronten
|
||||
|
||||
For high-throughput or very large model files, S3 can be more efficient since it avoids streaming through the frontend.
|
||||
|
||||
### Install Progress Streaming
|
||||
|
||||
While a worker is pulling an OCI image for a backend install, it publishes
|
||||
debounced progress events (~250ms) on `nodes.<nodeID>.backend.install.<opID>.progress`.
|
||||
The frontend subscribes for the duration of the install request and forwards each
|
||||
event into the operation status so the admin UI surfaces per-file byte progress
|
||||
and percentage in real time, the same way local-mode installs already do.
|
||||
|
||||
The NATS reply for `backend.install` is still the source of truth for the
|
||||
final success/failure; dropped progress events are acceptable and the install
|
||||
completes regardless.
|
||||
|
||||
**Mixed-version clusters:** Workers running pre-2026-05-22 code do not publish
|
||||
on the new progress subject. New frontends tolerate that silently. The install
|
||||
still completes via the reply; the UI keeps showing the message from the
|
||||
install-timeout fallback path (`still installing in background`) until the
|
||||
pending operation row clears.
|
||||
|
||||
#### Per-Node Operations Breakdown
|
||||
|
||||
When an admin backend install fans out to more than one worker, the
|
||||
**Operations Bar** at the top of the admin UI shows a per-node breakdown.
|
||||
Click the "N nodes" chevron on the operation row to expand a list with one
|
||||
entry per target node, each carrying:
|
||||
|
||||
- A color-coded status pill: queued (gray), downloading (blue), worker
|
||||
busy / running on worker (yellow), done (green), failed (red).
|
||||
- The current file being pulled, current/total bytes, percentage.
|
||||
- A thin per-node progress bar.
|
||||
- Any error message returned by the worker for that node.
|
||||
|
||||
The yellow "Worker busy" pill appears when the NATS round-trip to a node
|
||||
timed out but the worker is still installing in the background. Hover the
|
||||
pill for the full tooltip.
|
||||
|
||||
The breakdown is driven by the `nodes` array on the `/api/operations`
|
||||
response, which the frontend polls every second. Single-node installs and
|
||||
model installs render the same single-line card as before: the per-node
|
||||
section only appears when more than one node is involved.
|
||||
|
||||
## Worker Configuration
|
||||
|
||||
Workers are started with the `worker` subcommand. Each worker is generic — it doesn't need a backend type at startup:
|
||||
|
||||
@@ -499,7 +499,7 @@ The `llama.cpp` backend supports additional configuration options that can be sp
|
||||
|--------|------|-------------|---------|
|
||||
| `use_jinja` or `jinja` | boolean | Enable Jinja2 template processing for chat templates. When enabled, the backend uses Jinja2-based chat templates from the model for formatting messages. | `use_jinja:true` |
|
||||
| `context_shift` | boolean | Enable context shifting, which allows the model to dynamically adjust context window usage. | `context_shift:true` |
|
||||
| `cache_ram` | integer | Set the maximum RAM cache size in MiB for KV cache. Use `-1` for unlimited (default). | `cache_ram:2048` |
|
||||
| `cache_ram` | integer | Size budget in MiB for the **server-side prompt cache** (a host-RAM store of idle slot KV states that's reloaded on a prompt-prefix hit, see [upstream PR #16391](https://github.com/ggml-org/llama.cpp/pull/16391)). Default: `-1` (no limit). `0` disables the prompt cache entirely. Together with `kv_unified` and `cache_idle_slots` this is what makes a repeated system prompt skip prefill on subsequent calls. | `cache_ram:4096` |
|
||||
| `parallel` or `n_parallel` | integer | Enable parallel request processing. When set to a value greater than 1, enables continuous batching for handling multiple requests concurrently. | `parallel:4` |
|
||||
| `grpc_servers` or `rpc_servers` | string | Comma-separated list of gRPC server addresses for distributed inference. Allows distributing workload across multiple llama.cpp workers. | `grpc_servers:localhost:50051,localhost:50052` |
|
||||
| `fit_params` or `fit` | boolean | Enable auto-adjustment of model/context parameters to fit available device memory. Default: `true`. | `fit_params:true` |
|
||||
@@ -512,8 +512,10 @@ The `llama.cpp` backend supports additional configuration options that can be sp
|
||||
| `check_tensors` | boolean | Validate tensor data for invalid values during model loading. Default: `false`. | `check_tensors:true` |
|
||||
| `warmup` | boolean | Enable warmup run after model loading. Default: `true`. | `warmup:false` |
|
||||
| `no_op_offload` | boolean | Disable offloading host tensor operations to device. Default: `false`. | `no_op_offload:true` |
|
||||
| `kv_unified` or `unified_kv` | boolean | Enable unified KV cache. Default: `false`. | `kv_unified:true` |
|
||||
| `n_ctx_checkpoints` or `ctx_checkpoints` | integer | Maximum number of context checkpoints per slot. Default: `8`. | `ctx_checkpoints:4` |
|
||||
| `kv_unified` or `unified_kv` | boolean | Use a single unified KV buffer shared across all sequences. Default: `true` (LocalAI override; upstream defaults to `false` but auto-enables it when slot count is auto). **Required for `cache_idle_slots` to work**: without it the server force-disables idle-slot saving at init, and the prompt cache is never written across requests. | `kv_unified:false` |
|
||||
| `cache_idle_slots` or `idle_slots_cache` | boolean | On a new task, save the previous slot's KV state into the prompt cache (and clear the slot) so a later request with the same prefix can warm-load it. Default: `true`. Auto-disabled by the server if `kv_unified=false` or `cache_ram=0`. | `cache_idle_slots:false` |
|
||||
| `n_ctx_checkpoints` or `ctx_checkpoints` | integer | Maximum number of context checkpoints per slot (used for partial-prefix recovery, e.g. SWA). Default: `32`. | `ctx_checkpoints:16` |
|
||||
| `checkpoint_every_nt` or `checkpoint_every_n_tokens` | integer | Create a context checkpoint every N tokens during prefill. `-1` disables checkpointing. Default: `8192`. | `checkpoint_every_nt:4096` |
|
||||
| `split_mode` or `sm` | string | How to split the model across multiple GPUs: `none` (single GPU only), `layer` (default — split layers and KV across GPUs), `row` (split rows across GPUs), `tensor` (experimental tensor parallelism — requires `flash_attention: true`, no KV-cache quantization, manually set `context_size`, and a llama.cpp build that includes [#19378](https://github.com/ggml-org/llama.cpp/pull/19378)). | `split_mode:tensor` |
|
||||
|
||||
**Example configuration with options:**
|
||||
@@ -535,6 +537,27 @@ options:
|
||||
|
||||
**Note:** The `parallel` option can also be set via the `LLAMACPP_PARALLEL` environment variable, and `grpc_servers` can be set via the `LLAMACPP_GRPC_SERVERS` environment variable. Options specified in the YAML file take precedence over environment variables.
|
||||
|
||||
##### Server-side prompt cache (repeated system prompts)
|
||||
|
||||
Agents, coding assistants, and Anthropic/OpenAI-compatible CLIs typically resend the same large system prompt on every turn. The llama.cpp server can short-circuit prefill for the matching prefix by stashing idle slot KV states in host RAM and reloading them on a hit. Three settings interact:
|
||||
|
||||
| Setting | Default | Role |
|
||||
|---|---|---|
|
||||
| `cache_ram:N` | `-1` (no limit) | Allocates the host-side prompt cache. `0` disables it. |
|
||||
| `kv_unified:true` | `true` | Single unified KV buffer (**prerequisite** for idle-slot saving). |
|
||||
| `cache_idle_slots:true` | `true` | Persists the idle slot's KV into the prompt cache on task switch. |
|
||||
|
||||
All three are on by default since LocalAI v4.3, so the prompt cache works out of the box for the common single-slot setup. If you're on an older release, or you've explicitly disabled one of them, add the following to recover the behaviour:
|
||||
|
||||
```yaml
|
||||
options:
|
||||
- cache_ram:4096 # or -1 for no limit
|
||||
- kv_unified:true
|
||||
- cache_idle_slots:true
|
||||
```
|
||||
|
||||
Set `cache_ram:0` to opt out of the prompt cache entirely (saves host RAM at the cost of re-prefilling repeated prompts).
|
||||
|
||||
#### Reference
|
||||
|
||||
- [llama](https://github.com/ggerganov/llama.cpp)
|
||||
|
||||
2
go.mod
2
go.mod
@@ -220,7 +220,7 @@ require (
|
||||
github.com/mschoch/smat v0.2.0 // indirect
|
||||
github.com/mudler/LocalAGI v0.0.0-20260508125235-37810d918a87
|
||||
github.com/mudler/localrecall v0.6.1-0.20260507074622-a7724fef6f81 // indirect
|
||||
github.com/mudler/skillserver v0.0.6
|
||||
github.com/mudler/skillserver v0.0.7-0.20260520220837-a7317cbf9145
|
||||
github.com/olekukonko/tablewriter v0.0.5 // indirect
|
||||
github.com/oxffaa/gopher-parse-sitemap v0.0.0-20191021113419-005d2eb1def4 // indirect
|
||||
github.com/philippgille/chromem-go v0.7.0 // indirect
|
||||
|
||||
4
go.sum
4
go.sum
@@ -984,6 +984,10 @@ github.com/mudler/memory v0.0.0-20260406210934-424c1ecf2cf8 h1:Ry8RiWy8fZ6Ff4E7d
|
||||
github.com/mudler/memory v0.0.0-20260406210934-424c1ecf2cf8/go.mod h1:EA8Ashhd56o32qN7ouPKFSRUs/Z+LrRCF4v6R2Oarm8=
|
||||
github.com/mudler/skillserver v0.0.6 h1:ixz6wUekLdTmbnpAavCkTydDF6UdXAG3ncYufSPK9G0=
|
||||
github.com/mudler/skillserver v0.0.6/go.mod h1:z3yFhcL9bSykmmh6xgGu0hyoItd4CnxgtWMEWw8uFJU=
|
||||
github.com/mudler/skillserver v0.0.7-0.20260520212528-3dae7f041b1e h1:ryXE1UEzGhLkDFYuaxJ0fZ6fg4l++TWfMCTJ1E7bYS8=
|
||||
github.com/mudler/skillserver v0.0.7-0.20260520212528-3dae7f041b1e/go.mod h1:z3yFhcL9bSykmmh6xgGu0hyoItd4CnxgtWMEWw8uFJU=
|
||||
github.com/mudler/skillserver v0.0.7-0.20260520220837-a7317cbf9145 h1:z59tA3IDYPt71nzH1jpxeaA1LuDw8aZfpTQFNU43Zb8=
|
||||
github.com/mudler/skillserver v0.0.7-0.20260520220837-a7317cbf9145/go.mod h1:z3yFhcL9bSykmmh6xgGu0hyoItd4CnxgtWMEWw8uFJU=
|
||||
github.com/mudler/water v0.0.0-20250808092830-dd90dcf09025 h1:WFLP5FHInarYGXi6B/Ze204x7Xy6q/I4nCZnWEyPHK0=
|
||||
github.com/mudler/water v0.0.0-20250808092830-dd90dcf09025/go.mod h1:QuIFdRstyGJt+MTTkWY+mtD7U6xwjOR6SwKUjmLZtR4=
|
||||
github.com/mudler/xlog v0.0.6 h1:3nBV4THK8kY0Y8FDXXvWAnuAJoOyO7EAXteJeAoHUC0=
|
||||
|
||||
@@ -36,7 +36,7 @@ func ExtractArchive(archive, dst string) error {
|
||||
OverwriteExisting: true,
|
||||
MkdirAll: true,
|
||||
ImplicitTopLevelFolder: false,
|
||||
ContinueOnError: true,
|
||||
ContinueOnError: false,
|
||||
}
|
||||
|
||||
switch v := uaIface.(type) {
|
||||
|
||||
@@ -225,7 +225,7 @@ var _ = Describe("Full Distributed Inference Flow", Label("Distributed"), func()
|
||||
// newTestSmartRouter creates a SmartRouter with NATS wired up and a mock
|
||||
// backend.install handler that always replies success for all registered nodes.
|
||||
newTestSmartRouter := func(reg *nodes.NodeRegistry, extraOpts ...nodes.SmartRouterOptions) *nodes.SmartRouter {
|
||||
unloader := nodes.NewRemoteUnloaderAdapter(reg, infra.NC)
|
||||
unloader := nodes.NewRemoteUnloaderAdapter(reg, infra.NC, 3*time.Minute, 15*time.Minute)
|
||||
|
||||
opts := nodes.SmartRouterOptions{
|
||||
Unloader: unloader,
|
||||
@@ -395,7 +395,7 @@ var _ = Describe("Full Distributed Inference Flow", Label("Distributed"), func()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Create RemoteUnloaderAdapter and unload model
|
||||
unloader := nodes.NewRemoteUnloaderAdapter(registry, infra.NC)
|
||||
unloader := nodes.NewRemoteUnloaderAdapter(registry, infra.NC, 3*time.Minute, 15*time.Minute)
|
||||
err = unloader.UnloadRemoteModel("old-model")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
@@ -175,7 +176,7 @@ var _ = Describe("Model and Backend Managers", Label("Distributed"), func() {
|
||||
appCfg := config.NewApplicationConfig()
|
||||
appCfg.SystemState = ss
|
||||
|
||||
adapter := nodes.NewRemoteUnloaderAdapter(registry, infra.NC)
|
||||
adapter := nodes.NewRemoteUnloaderAdapter(registry, infra.NC, 3*time.Minute, 15*time.Minute)
|
||||
distMgr := nodes.NewDistributedModelManager(appCfg, ml, adapter)
|
||||
|
||||
err = distMgr.DeleteModel("big-model")
|
||||
@@ -251,8 +252,8 @@ var _ = Describe("Model and Backend Managers", Label("Distributed"), func() {
|
||||
appCfg := config.NewApplicationConfig()
|
||||
appCfg.SystemState = ss
|
||||
|
||||
adapter := nodes.NewRemoteUnloaderAdapter(registry, infra.NC)
|
||||
distMgr := nodes.NewDistributedBackendManager(appCfg, ml, adapter, registry)
|
||||
adapter := nodes.NewRemoteUnloaderAdapter(registry, infra.NC, 3*time.Minute, 15*time.Minute)
|
||||
distMgr := nodes.NewDistributedBackendManager(appCfg, ml, adapter, registry, nil)
|
||||
|
||||
err = distMgr.DeleteBackend("my-backend")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -298,8 +299,8 @@ var _ = Describe("Model and Backend Managers", Label("Distributed"), func() {
|
||||
appCfg := config.NewApplicationConfig()
|
||||
appCfg.SystemState = ss
|
||||
|
||||
adapter := nodes.NewRemoteUnloaderAdapter(registry, infra.NC)
|
||||
distMgr := nodes.NewDistributedBackendManager(appCfg, ml, adapter, registry)
|
||||
adapter := nodes.NewRemoteUnloaderAdapter(registry, infra.NC, 3*time.Minute, 15*time.Minute)
|
||||
distMgr := nodes.NewDistributedBackendManager(appCfg, ml, adapter, registry, nil)
|
||||
|
||||
// Should NOT return an error even though the backend doesn't exist locally
|
||||
err = distMgr.DeleteBackend("remote-only-backend")
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
@@ -56,8 +57,8 @@ var _ = Describe("Node Backend Lifecycle (NATS-driven)", Label("Distributed"), f
|
||||
|
||||
FlushNATS(infra.NC)
|
||||
|
||||
adapter := nodes.NewRemoteUnloaderAdapter(registry, infra.NC)
|
||||
installReply, err := adapter.InstallBackend(node.ID, "llama-cpp", "", "", "", "", "", 0)
|
||||
adapter := nodes.NewRemoteUnloaderAdapter(registry, infra.NC, 3*time.Minute, 15*time.Minute)
|
||||
installReply, err := adapter.InstallBackend(node.ID, "llama-cpp", "", "", "", "", "", 0, "", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(installReply.Success).To(BeTrue())
|
||||
})
|
||||
@@ -77,8 +78,8 @@ var _ = Describe("Node Backend Lifecycle (NATS-driven)", Label("Distributed"), f
|
||||
|
||||
FlushNATS(infra.NC)
|
||||
|
||||
adapter := nodes.NewRemoteUnloaderAdapter(registry, infra.NC)
|
||||
installReply, err := adapter.InstallBackend(node.ID, "nonexistent", "", "", "", "", "", 0)
|
||||
adapter := nodes.NewRemoteUnloaderAdapter(registry, infra.NC, 3*time.Minute, 15*time.Minute)
|
||||
installReply, err := adapter.InstallBackend(node.ID, "nonexistent", "", "", "", "", "", 0, "", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(installReply.Success).To(BeFalse())
|
||||
Expect(installReply.Error).To(ContainSubstring("backend not found"))
|
||||
@@ -103,7 +104,7 @@ var _ = Describe("Node Backend Lifecycle (NATS-driven)", Label("Distributed"), f
|
||||
FlushNATS(infra.NC)
|
||||
|
||||
// Frontend calls UnloadRemoteModel (triggered by UI "Stop" or WatchDog)
|
||||
adapter := nodes.NewRemoteUnloaderAdapter(registry, infra.NC)
|
||||
adapter := nodes.NewRemoteUnloaderAdapter(registry, infra.NC, 3*time.Minute, 15*time.Minute)
|
||||
Expect(adapter.UnloadRemoteModel("whisper-large")).To(Succeed())
|
||||
|
||||
Eventually(func() int32 { return stopReceived.Load() }, "5s").Should(Equal(int32(1)))
|
||||
@@ -133,14 +134,14 @@ var _ = Describe("Node Backend Lifecycle (NATS-driven)", Label("Distributed"), f
|
||||
|
||||
FlushNATS(infra.NC)
|
||||
|
||||
adapter := nodes.NewRemoteUnloaderAdapter(registry, infra.NC)
|
||||
adapter := nodes.NewRemoteUnloaderAdapter(registry, infra.NC, 3*time.Minute, 15*time.Minute)
|
||||
adapter.UnloadRemoteModel("shared-model")
|
||||
|
||||
Eventually(func() int32 { return count.Load() }, "5s").Should(Equal(int32(2)))
|
||||
})
|
||||
|
||||
It("should be no-op for models not on any node", func() {
|
||||
adapter := nodes.NewRemoteUnloaderAdapter(registry, infra.NC)
|
||||
adapter := nodes.NewRemoteUnloaderAdapter(registry, infra.NC, 3*time.Minute, 15*time.Minute)
|
||||
Expect(adapter.UnloadRemoteModel("nonexistent-model")).To(Succeed())
|
||||
})
|
||||
})
|
||||
@@ -161,7 +162,7 @@ var _ = Describe("Node Backend Lifecycle (NATS-driven)", Label("Distributed"), f
|
||||
|
||||
FlushNATS(infra.NC)
|
||||
|
||||
adapter := nodes.NewRemoteUnloaderAdapter(registry, infra.NC)
|
||||
adapter := nodes.NewRemoteUnloaderAdapter(registry, infra.NC, 3*time.Minute, 15*time.Minute)
|
||||
Expect(adapter.StopNode(node.ID)).To(Succeed())
|
||||
|
||||
Eventually(func() int32 { return stopped.Load() }, "5s").Should(Equal(int32(1)))
|
||||
|
||||
@@ -3,6 +3,7 @@ package distributed_test
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
@@ -78,7 +79,7 @@ var _ = Describe("SmartRouter trackingKey", Label("Distributed"), func() {
|
||||
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
||||
nodeID = node.ID
|
||||
|
||||
unloader := nodes.NewRemoteUnloaderAdapter(registry, infra.NC)
|
||||
unloader := nodes.NewRemoteUnloaderAdapter(registry, infra.NC, 3*time.Minute, 15*time.Minute)
|
||||
router = nodes.NewSmartRouter(registry, nodes.SmartRouterOptions{
|
||||
Unloader: unloader,
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user