mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-25 01:02:05 -04:00
Compare commits
37 Commits
dependabot
...
feat/expos
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c778ad0f6d | ||
|
|
8b2697f39a | ||
|
|
1af79c1b0f | ||
|
|
40ec4ffc94 | ||
|
|
04407d24f3 | ||
|
|
1c4bdfd1d6 | ||
|
|
799215cdc6 | ||
|
|
88306d562d | ||
|
|
df8418cb2d | ||
|
|
42d6e52fd7 | ||
|
|
a867b3d2a8 | ||
|
|
63448826b1 | ||
|
|
be1041de0c | ||
|
|
b85b7e29df | ||
|
|
17791fb741 | ||
|
|
1a30020a82 | ||
|
|
8bbe89a537 | ||
|
|
dcc5599f89 | ||
|
|
a95f4e63e0 | ||
|
|
dfd19a3f88 | ||
|
|
d7387c725c | ||
|
|
63d84a5705 | ||
|
|
1198d10b58 | ||
|
|
a0f3e26245 | ||
|
|
e4cc1f11f3 | ||
|
|
6ed269d0b9 | ||
|
|
5756fb046d | ||
|
|
7980629bc5 | ||
|
|
d0a59be9de | ||
|
|
5cda4f1ccf | ||
|
|
c500461c69 | ||
|
|
834ecc36bf | ||
|
|
61bf34ea2f | ||
|
|
0b2ae3c6ca | ||
|
|
4735345105 | ||
|
|
7384fd800b | ||
|
|
6942713d85 |
@@ -16,7 +16,8 @@ side (`pkg/oci/cosignverify` plus the gallery YAML).
|
||||
per-arch manifest before checking signatures.
|
||||
- **Storage:** Signatures are written as OCI 1.1 referrers
|
||||
(`--registry-referrers-mode=oci-1-1`) in the new Sigstore bundle format
|
||||
(`--new-bundle-format`). No `:sha256-<hex>.sig` tag clutter.
|
||||
(current cosign releases do this by default; no `--new-bundle-format`
|
||||
flag). No `:sha256-<hex>.sig` tag clutter.
|
||||
- **Consumer:** `pkg/oci/cosignverify` discovers the bundle via the
|
||||
referrers API, hands it to `sigstore-go`, and verifies it against the
|
||||
policy declared in the gallery YAML (`Gallery.Verification`).
|
||||
@@ -33,15 +34,14 @@ to sign. The job needs:
|
||||
|
||||
- `permissions: { id-token: write, contents: read }` at the job level so
|
||||
the runner can exchange its GitHub OIDC token for a Fulcio cert.
|
||||
- `sigstore/cosign-installer@v3` step (cosign ≥ 2.2 for
|
||||
`--new-bundle-format`).
|
||||
- `sigstore/cosign-installer@v3` step (current cosign releases already
|
||||
default to the new bundle format).
|
||||
- After each `docker buildx imagetools create`, resolve the resulting
|
||||
list digest with `docker buildx imagetools inspect <tag> --format
|
||||
'{{.Manifest.Digest}}'` and sign:
|
||||
|
||||
```sh
|
||||
cosign sign --yes --recursive \
|
||||
--new-bundle-format \
|
||||
--registry-referrers-mode=oci-1-1 \
|
||||
"${REGISTRY_REPO}@${DIGEST}"
|
||||
```
|
||||
@@ -49,6 +49,12 @@ cosign sign --yes --recursive \
|
||||
Sign by digest, never by tag — signing by tag binds the signature to
|
||||
whatever the tag points at *now*, and a subsequent tag push orphans it.
|
||||
|
||||
`--registry-referrers-mode=oci-1-1` is still gated behind
|
||||
`COSIGN_EXPERIMENTAL=1` in cosign v2.4.x (set at the job env level in
|
||||
`backend_merge.yml`). Re-evaluate when bumping the pinned cosign release
|
||||
— newer versions are expected to graduate this flag and the env var can
|
||||
then be dropped.
|
||||
|
||||
`backend_build_darwin.yml` builds and pushes single-arch darwin images
|
||||
that bypass the manifest-list merge. If/when those entries get a gallery
|
||||
`verification:` policy, the equivalent cosign step has to land there
|
||||
|
||||
10
.github/workflows/backend_merge.yml
vendored
10
.github/workflows/backend_merge.yml
vendored
@@ -40,6 +40,11 @@ jobs:
|
||||
id-token: write
|
||||
env:
|
||||
quay_username: ${{ secrets.quayUsername }}
|
||||
# cosign v2.4.x still gates --registry-referrers-mode=oci-1-1 behind
|
||||
# this flag. Without it, signing fails with:
|
||||
# invalid argument "oci-1-1" for "--registry-referrers-mode" flag:
|
||||
# in order to use mode "oci-1-1", you must set COSIGN_EXPERIMENTAL=1
|
||||
COSIGN_EXPERIMENTAL: '1'
|
||||
steps:
|
||||
# Sparse checkout: the merge job needs `.github/scripts/` (for the
|
||||
# keepalive cleanup script) but none of the source tree.
|
||||
@@ -66,7 +71,8 @@ jobs:
|
||||
|
||||
# cosign signs each pushed manifest list with --recursive so the
|
||||
# index and every per-arch entry get an attached Sigstore bundle.
|
||||
# 2.2+ is required for --new-bundle-format.
|
||||
# Recent cosign releases always emit the new bundle format, so
|
||||
# there's no extra CLI flag to opt into it.
|
||||
- name: Install cosign
|
||||
if: github.event_name != 'pull_request'
|
||||
uses: sigstore/cosign-installer@v3
|
||||
@@ -153,7 +159,6 @@ jobs:
|
||||
# manifest before checking signatures need the per-arch
|
||||
# signatures, not just the list-level one.
|
||||
cosign sign --yes --recursive \
|
||||
--new-bundle-format \
|
||||
--registry-referrers-mode=oci-1-1 \
|
||||
"quay.io/go-skynet/local-ai-backends@${digest}"
|
||||
|
||||
@@ -180,7 +185,6 @@ jobs:
|
||||
' <<< "$DOCKER_METADATA_OUTPUT_JSON")
|
||||
digest=$(docker buildx imagetools inspect "$first_tag" --format '{{.Manifest.Digest}}')
|
||||
cosign sign --yes --recursive \
|
||||
--new-bundle-format \
|
||||
--registry-referrers-mode=oci-1-1 \
|
||||
"localai/localai-backends@${digest}"
|
||||
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -77,3 +77,6 @@ local-backends/
|
||||
tests/e2e-ui/ui-test-server
|
||||
core/http/react-ui/playwright-report/
|
||||
core/http/react-ui/test-results/
|
||||
|
||||
# Local worktrees
|
||||
.worktrees/
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
# ds4 backend Makefile.
|
||||
#
|
||||
# Upstream pin lives below as DS4_VERSION?=2606543be7a8c125a32cee37f5d1d85dc78f2fcf
|
||||
# Upstream pin lives below as DS4_VERSION?=444afce822057d87f14c4dec307dce24fd49b3ee
|
||||
# (.github/bump_deps.sh) can find and update it - matches the
|
||||
# llama-cpp / ik-llama-cpp / turboquant convention.
|
||||
|
||||
DS4_VERSION?=2606543be7a8c125a32cee37f5d1d85dc78f2fcf
|
||||
DS4_VERSION?=444afce822057d87f14c4dec307dce24fd49b3ee
|
||||
DS4_REPO?=https://github.com/antirez/ds4
|
||||
|
||||
CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
IK_LLAMA_VERSION?=48a55f74e4c6e2aeda363dd386c1ac9170a0af71
|
||||
IK_LLAMA_VERSION?=642c038ccdf3dd08e6d9ac6fdc3b1c311ebd8a02
|
||||
LLAMA_REPO?=https://github.com/ikawrakow/ik_llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
LLAMA_VERSION?=ad277572619fcfb6ddd38f4c6437283a4b2b8636
|
||||
LLAMA_VERSION?=c0c7e147e7efa6c5858754b47259ba4880f8a906
|
||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# stablediffusion.cpp (ggml)
|
||||
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
||||
STABLEDIFFUSION_GGML_VERSION?=5b0267e941cade15bd80089d89838795d9f4baa6
|
||||
STABLEDIFFUSION_GGML_VERSION?=a397e03488cc27e1a42da646b82dfce9f50741c0
|
||||
|
||||
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# whisper.cpp version
|
||||
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
||||
WHISPER_CPP_VERSION?=8443cf05e3fa8ce1b32348e1bcbcf8fc31f7f3ae
|
||||
WHISPER_CPP_VERSION?=0ccd896f5b882628e1c077f9769735ef4ce52860
|
||||
SO_TARGET?=libgowhisper.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -36,15 +36,11 @@ fi
|
||||
# flash-attn-4 4.0 stable lands.
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --prerelease=allow"
|
||||
|
||||
# JetPack 7 / L4T arm64 wheels are built for cp312 and shipped via
|
||||
# pypi.jetson-ai-lab.io. Bump the venv Python so the prebuilt sglang
|
||||
# wheel resolves cleanly. The actual install on l4t13 goes through
|
||||
# pyproject.toml (see the elif branch below) so [tool.uv.sources] can
|
||||
# pin only torch/torchvision/torchaudio/sglang to the jetson-ai-lab
|
||||
# index — leaving PyPI as the path for transitive deps like
|
||||
# markdown-it-py / anthropic / propcache that the L4T mirror's proxy
|
||||
# 503s on. No --index-strategy flag here: the explicit index keeps the
|
||||
# scoping clean.
|
||||
# JetPack 7 / L4T arm64 sglang + torch wheels come straight from PyPI now
|
||||
# (torch 2.11+ ships aarch64 + cu130 manylinux wheels and sglang 0.5.11+
|
||||
# ships a cp312 aarch64 wheel pinned to that torch). They're cp312-only,
|
||||
# so bump the venv Python accordingly.
|
||||
# https://pytorch.org/blog/vllm-and-pytorch-work-together-to-improve-the-developer-experience-on-aarch64/
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
PYTHON_VERSION="3.12"
|
||||
PYTHON_PATCH="12"
|
||||
@@ -110,27 +106,6 @@ if [ "x${BUILD_TYPE}" == "x" ] || [ "x${FROM_SOURCE:-}" == "xtrue" ]; then
|
||||
fi
|
||||
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} .
|
||||
popd
|
||||
# L4T arm64 (JetPack 7): drive the install through pyproject.toml so that
|
||||
# [tool.uv.sources] can pin torch/torchvision/torchaudio/sglang to the
|
||||
# jetson-ai-lab index, while everything else (transitive deps and
|
||||
# PyPI-resolvable packages like transformers / accelerate) comes from
|
||||
# PyPI. Bypasses installRequirements because uv pip install -r
|
||||
# requirements.txt does not honor sources — see
|
||||
# backend/python/sglang/pyproject.toml for the rationale. Mirrors the
|
||||
# equivalent path in backend/python/vllm/install.sh.
|
||||
elif [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
ensureVenv
|
||||
if [ "x${PORTABLE_PYTHON}" == "xtrue" ]; then
|
||||
export C_INCLUDE_PATH="${C_INCLUDE_PATH:-}:$(_portable_dir)/include/python${PYTHON_VERSION}"
|
||||
fi
|
||||
pushd "${backend_dir}"
|
||||
# Build deps first (matches installRequirements' requirements-install.txt
|
||||
# pass — sglang/sgl-kernel sdists need packaging/setuptools-scm in the
|
||||
# venv before they can build under --no-build-isolation).
|
||||
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} -r requirements-install.txt
|
||||
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} --requirement pyproject.toml
|
||||
popd
|
||||
runProtogen
|
||||
else
|
||||
installRequirements
|
||||
fi
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
# L4T arm64 (JetPack 7 / sbsa cu130) install spec for the sglang backend.
|
||||
#
|
||||
# Why this file exists, and why only the l4t13 BUILD_PROFILE consumes it:
|
||||
#
|
||||
# pypi.jetson-ai-lab.io hosts the L4T-specific torch / sglang / sgl-kernel
|
||||
# wheels we need on aarch64 + cuda13, but it ALSO transparently proxies the
|
||||
# rest of PyPI through `/+f/<sha>/<filename>` URLs that 503 frequently.
|
||||
# With `--extra-index-url` + `--index-strategy=unsafe-best-match` (the
|
||||
# historical fix in install.sh) uv would pick those proxy URLs for ordinary
|
||||
# PyPI packages — markdown-it-py, anthropic, propcache, etc. — and trip on
|
||||
# the 503s. See e.g. CI run 25439791228 (markdown-it-py-4.0.0).
|
||||
#
|
||||
# `explicit = true` on the index makes uv consult the L4T mirror ONLY for
|
||||
# packages mapped under [tool.uv.sources]. Everything else goes to PyPI.
|
||||
# This breaks the historical 503 path without losing access to the L4T
|
||||
# wheels we actually need from there. Mirrors the equivalent fix already
|
||||
# in backend/python/vllm/pyproject.toml.
|
||||
#
|
||||
# `uv pip install -r requirements.txt` does NOT honor [tool.uv.sources]
|
||||
# (sources are project-mode only, not pip-compat mode), so install.sh's
|
||||
# l4t13 branch invokes `uv pip install --requirement pyproject.toml`
|
||||
# directly. Other BUILD_PROFILEs continue to use the requirements-*.txt
|
||||
# pipeline through libbackend.sh's installRequirements and never read
|
||||
# this file.
|
||||
[project]
|
||||
name = "localai-sglang-l4t13"
|
||||
version = "0.0.0"
|
||||
requires-python = ">=3.12,<3.13"
|
||||
dependencies = [
|
||||
# Mirror of requirements.txt — kept in sync manually for now since the
|
||||
# l4t13 path bypasses installRequirements (see install.sh).
|
||||
"grpcio==1.80.0",
|
||||
"protobuf",
|
||||
"certifi",
|
||||
"setuptools",
|
||||
"pillow",
|
||||
# L4T-specific accelerator stack (sourced from jetson-ai-lab below).
|
||||
"torch",
|
||||
"torchvision",
|
||||
"torchaudio",
|
||||
# sglang on jetson — the [all] extra is deliberately omitted because it
|
||||
# pulls outlines/decord, and decord has no aarch64 cp312 wheel anywhere
|
||||
# (PyPI nor the jetson-ai-lab index ships only legacy cp35-cp37). With
|
||||
# [all] uv backtracks through versions trying to satisfy decord and
|
||||
# lands on sglang==0.1.16. The 0.5.0 floor matches the only major
|
||||
# series the jetson-ai-lab sbsa/cu130 mirror currently publishes
|
||||
# (sglang==0.5.1.post2 as of 2026-05-06). Bumping to >=0.5.11 here
|
||||
# would make the build unsatisfiable until the mirror catches up.
|
||||
# Gemma 4 / MTP recipes are therefore not supported on l4t13 — those
|
||||
# features land on cublas12/cublas13 hosts that pull the newer wheel
|
||||
# from PyPI. backend.py keeps backward compat with the 0.5.x SamplingParams
|
||||
# field rename via runtime detection.
|
||||
"sglang>=0.5.0",
|
||||
# PyPI-resolvable packages that complete the runtime.
|
||||
"accelerate",
|
||||
"transformers",
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "jetson-ai-lab"
|
||||
url = "https://pypi.jetson-ai-lab.io/sbsa/cu130"
|
||||
explicit = true
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = { index = "jetson-ai-lab" }
|
||||
torchvision = { index = "jetson-ai-lab" }
|
||||
torchaudio = { index = "jetson-ai-lab" }
|
||||
sglang = { index = "jetson-ai-lab" }
|
||||
15
backend/python/sglang/requirements-l4t13-after.txt
Normal file
15
backend/python/sglang/requirements-l4t13-after.txt
Normal file
@@ -0,0 +1,15 @@
|
||||
# sglang 0.5.11+ ships an aarch64 manylinux wheel on PyPI whose Requires-Dist
|
||||
# pins torch==2.11.0 / torchaudio==2.11.0, locking an ABI-consistent set with
|
||||
# the cu130 torch wheel installed above. 0.5.11 is the floor for Gemma 4
|
||||
# support (sgl-project/sglang#21952).
|
||||
#
|
||||
# The [all] extra is deliberately NOT used on aarch64: it pulls the
|
||||
# [diffusion] sub-extra which requires `xatlas`, and xatlas ships no
|
||||
# aarch64 wheel and its sdist depends on scikit_build_core without
|
||||
# declaring it in build-system.requires — so under --no-build-isolation
|
||||
# uv can't build it. Upstream sglang gates st_attn and vsa on
|
||||
# platform_machine != aarch64 in the diffusion extra but forgot xatlas.
|
||||
# Plain `sglang` carries everything backend.py uses (Engine, ServerArgs,
|
||||
# FunctionCallParser, ReasoningParser); the [all] extras are optional
|
||||
# accelerators not required at import time.
|
||||
sglang>=0.5.11
|
||||
9
backend/python/sglang/requirements-l4t13.txt
Normal file
9
backend/python/sglang/requirements-l4t13.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
# JetPack 7 / L4T arm64 + CUDA 13. Since PyTorch 2.11 (April 2026), PyPI ships
|
||||
# aarch64 + cu130 manylinux wheels for torch/torchvision/torchaudio directly,
|
||||
# so we no longer need a custom --extra-index-url for the L4T mirror.
|
||||
# https://pytorch.org/blog/vllm-and-pytorch-work-together-to-improve-the-developer-experience-on-aarch64/
|
||||
accelerate
|
||||
torch
|
||||
torchvision
|
||||
torchaudio
|
||||
transformers
|
||||
@@ -13,14 +13,14 @@ else
|
||||
fi
|
||||
|
||||
# Handle l4t build profiles (Python 3.12, pip fallback) if needed.
|
||||
# unsafe-best-match is required on l4t13 because the jetson-ai-lab index
|
||||
# lists transitive deps at limited versions — without it uv pins to the
|
||||
# first matching index and fails to resolve a compatible wheel from PyPI.
|
||||
# Since PyTorch 2.11 (April 2026) PyPI ships aarch64 + cu130 manylinux wheels
|
||||
# directly for torch/torchvision/torchaudio and an aarch64 vllm wheel pinned
|
||||
# to that torch, so the jetson-ai-lab mirror is no longer needed.
|
||||
# https://pytorch.org/blog/vllm-and-pytorch-work-together-to-improve-the-developer-experience-on-aarch64/
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
PYTHON_VERSION="3.12"
|
||||
PYTHON_PATCH="12"
|
||||
PY_STANDALONE_TAG="20251120"
|
||||
EXTRA_PIP_INSTALL_FLAGS="${EXTRA_PIP_INSTALL_FLAGS:-} --index-strategy=unsafe-best-match"
|
||||
fi
|
||||
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
|
||||
@@ -42,18 +42,11 @@ if [ "x${BUILD_TYPE}" == "xhipblas" ]; then
|
||||
else
|
||||
uv pip install vllm==0.14.0 --extra-index-url https://wheels.vllm.ai/rocm/0.14.0/rocm700
|
||||
fi
|
||||
elif [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
# JetPack 7 / L4T arm64 cu130 — vllm comes from the prebuilt SBSA wheel
|
||||
# at jetson-ai-lab. Version is unpinned: the index ships whatever build
|
||||
# matches the cu130/cp312 ABI. unsafe-best-match lets uv fall through
|
||||
# to PyPI for transitive deps not present on the jetson-ai-lab index.
|
||||
if [ "x${USE_PIP}" == "xtrue" ]; then
|
||||
pip install vllm --extra-index-url https://pypi.jetson-ai-lab.io/sbsa/cu130
|
||||
else
|
||||
uv pip install --index-strategy=unsafe-best-match vllm --extra-index-url https://pypi.jetson-ai-lab.io/sbsa/cu130
|
||||
fi
|
||||
elif [ "x${BUILD_PROFILE}" == "xcublas13" ]; then
|
||||
# vllm 0.19+ defaults to cu130 wheels on PyPI, no extra index needed.
|
||||
elif [ "x${BUILD_PROFILE}" == "xcublas13" ] || [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
# cublas13 (x86_64) and l4t13 (aarch64) both pull vllm from PyPI now:
|
||||
# vllm 0.19+ defaults to cu130 wheels on x86_64 and vllm 0.20+ ships an
|
||||
# aarch64 manylinux wheel pinned to torch==2.11.0. No extra index needed
|
||||
# in either case.
|
||||
if [ "x${USE_PIP}" == "xtrue" ]; then
|
||||
pip install vllm --torch-backend=auto
|
||||
else
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
--extra-index-url https://pypi.jetson-ai-lab.io/sbsa/cu130
|
||||
# JetPack 7 / L4T arm64 + CUDA 13. PyPI ships aarch64 + cu130 manylinux wheels
|
||||
# for torch/torchvision/torchaudio directly since PyTorch 2.11 (April 2026),
|
||||
# so no custom index is needed. flash-attn is dropped here: PyPI has no
|
||||
# aarch64 wheel for it, but vLLM 0.20+ bundles its own vllm_flash_attn
|
||||
# (fa2 + fa3) inside the main wheel, so it is not required at runtime.
|
||||
# https://pytorch.org/blog/vllm-and-pytorch-work-together-to-improve-the-developer-experience-on-aarch64/
|
||||
accelerate
|
||||
torch
|
||||
torchvision
|
||||
torchaudio
|
||||
transformers
|
||||
bitsandbytes
|
||||
flash-attn
|
||||
diffusers
|
||||
librosa
|
||||
soundfile
|
||||
|
||||
@@ -43,14 +43,11 @@ if [ "x${BUILD_PROFILE}" == "xcublas13" ]; then
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --index-strategy=unsafe-best-match"
|
||||
fi
|
||||
|
||||
# JetPack 7 / L4T arm64 wheels (torch, vllm, flash-attn) live on
|
||||
# pypi.jetson-ai-lab.io and are built for cp312, so bump the venv Python
|
||||
# accordingly. JetPack 6 keeps cp310 + USE_PIP=true.
|
||||
#
|
||||
# l4t13 uses pyproject.toml (see the elif branch below) to pin only the
|
||||
# L4T-specific wheels to the jetson-ai-lab index via [tool.uv.sources].
|
||||
# That keeps PyPI as the resolution path for transitive deps like
|
||||
# anthropic/openai/propcache, which the L4T mirror's proxy 503s on.
|
||||
# JetPack 7 / L4T arm64 vllm + torch wheels come straight from PyPI now
|
||||
# (torch 2.11+ ships aarch64 + cu130 manylinux wheels and vllm 0.20+ ships
|
||||
# an aarch64 wheel pinned to that torch). They're cp312-only, so bump the
|
||||
# venv Python accordingly. JetPack 6 keeps cp310 + USE_PIP=true.
|
||||
# https://pytorch.org/blog/vllm-and-pytorch-work-together-to-improve-the-developer-experience-on-aarch64/
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
|
||||
USE_PIP=true
|
||||
fi
|
||||
@@ -103,25 +100,6 @@ if [ "x${BUILD_TYPE}" == "xintel" ]; then
|
||||
export CMAKE_PREFIX_PATH="$(python -c 'import site; print(site.getsitepackages()[0])'):${CMAKE_PREFIX_PATH:-}"
|
||||
VLLM_TARGET_DEVICE=xpu uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} --no-deps .
|
||||
popd
|
||||
# L4T arm64 (JetPack 7): drive the install through pyproject.toml so that
|
||||
# [tool.uv.sources] can pin torch/vllm/flash-attn/torchvision/torchaudio
|
||||
# to the jetson-ai-lab index, while everything else (transitive deps and
|
||||
# PyPI-resolvable packages like transformers) comes from PyPI. Bypasses
|
||||
# installRequirements because uv pip install -r requirements.txt does not
|
||||
# honor sources — see backend/python/vllm/pyproject.toml for the rationale.
|
||||
elif [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
ensureVenv
|
||||
if [ "x${PORTABLE_PYTHON}" == "xtrue" ]; then
|
||||
export C_INCLUDE_PATH="${C_INCLUDE_PATH:-}:$(_portable_dir)/include/python${PYTHON_VERSION}"
|
||||
fi
|
||||
pushd "${backend_dir}"
|
||||
# Build deps first (matches installRequirements' requirements-install.txt
|
||||
# pass — fastsafetensors and friends need pybind11 in the venv before
|
||||
# their sdists can build under --no-build-isolation).
|
||||
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} -r requirements-install.txt
|
||||
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} --requirement pyproject.toml
|
||||
popd
|
||||
runProtogen
|
||||
# FROM_SOURCE=true on a CPU build skips the prebuilt vllm wheel in
|
||||
# requirements-cpu-after.txt and compiles vllm locally against the host's
|
||||
# actual CPU. Not used by default because it takes ~30-40 minutes, but
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
# L4T arm64 (JetPack 7 / sbsa cu130) install spec for the vllm backend.
|
||||
#
|
||||
# Why this file exists, and why only the l4t13 BUILD_PROFILE consumes it:
|
||||
#
|
||||
# pypi.jetson-ai-lab.io hosts the L4T-specific torch / vllm / flash-attn
|
||||
# wheels we need on aarch64 + cuda13, but it ALSO transparently proxies the
|
||||
# rest of PyPI through `/+f/<sha>/<filename>` URLs that 503 frequently. With
|
||||
# `--extra-index-url` + `--index-strategy=unsafe-best-match` (the historical
|
||||
# fix in install.sh) uv would pick those proxy URLs for ordinary PyPI
|
||||
# packages — `anthropic`, `openai`, `propcache`, `annotated-types` — and
|
||||
# trip on the 503s. See e.g. CI run 25212201349 (anthropic-0.97.0).
|
||||
#
|
||||
# `explicit = true` on the index makes uv consult the L4T mirror ONLY for
|
||||
# packages mapped under [tool.uv.sources]. Everything else goes to PyPI.
|
||||
# This breaks the historical 503 path without losing access to the L4T
|
||||
# wheels we actually need from there.
|
||||
#
|
||||
# `uv pip install -r requirements.txt` does NOT honor [tool.uv.sources]
|
||||
# (sources are project-mode only, not pip-compat mode), so install.sh's
|
||||
# l4t13 branch invokes `uv pip install --requirement pyproject.toml`
|
||||
# directly. Other BUILD_PROFILEs continue to use the requirements-*.txt
|
||||
# pipeline through libbackend.sh's installRequirements and never read
|
||||
# this file.
|
||||
[project]
|
||||
name = "localai-vllm-l4t13"
|
||||
version = "0.0.0"
|
||||
requires-python = ">=3.12,<3.13"
|
||||
dependencies = [
|
||||
# Mirror of requirements.txt — kept in sync manually for now since the
|
||||
# l4t13 path bypasses installRequirements (see install.sh).
|
||||
"grpcio==1.80.0",
|
||||
"protobuf",
|
||||
"certifi",
|
||||
"setuptools",
|
||||
"pillow",
|
||||
"charset-normalizer>=3.4.7",
|
||||
"chardet",
|
||||
# L4T-specific accelerator stack (sourced from jetson-ai-lab below).
|
||||
"torch",
|
||||
"torchvision",
|
||||
"torchaudio",
|
||||
"flash-attn",
|
||||
"vllm",
|
||||
# PyPI-resolvable packages that complete the runtime — accelerate,
|
||||
# transformers, bitsandbytes carry their own wheels for aarch64.
|
||||
"accelerate",
|
||||
"transformers",
|
||||
"bitsandbytes",
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "jetson-ai-lab"
|
||||
url = "https://pypi.jetson-ai-lab.io/sbsa/cu130"
|
||||
explicit = true
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = { index = "jetson-ai-lab" }
|
||||
torchvision = { index = "jetson-ai-lab" }
|
||||
torchaudio = { index = "jetson-ai-lab" }
|
||||
flash-attn = { index = "jetson-ai-lab" }
|
||||
vllm = { index = "jetson-ai-lab" }
|
||||
4
backend/python/vllm/requirements-l4t13-after.txt
Normal file
4
backend/python/vllm/requirements-l4t13-after.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
# vLLM 0.20+ ships an aarch64 manylinux wheel on PyPI whose Requires-Dist pins
|
||||
# torch==2.11.0 / torchvision==0.26.0 / torchaudio==2.11.0, locking an ABI-
|
||||
# consistent set with the cu130 torch wheel installed above.
|
||||
vllm
|
||||
8
backend/python/vllm/requirements-l4t13.txt
Normal file
8
backend/python/vllm/requirements-l4t13.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
# JetPack 7 / L4T arm64 + CUDA 13. Since PyTorch 2.11 (April 2026), PyPI ships
|
||||
# aarch64 + cu130 manylinux wheels for torch/torchvision/torchaudio directly,
|
||||
# so we no longer need a custom --extra-index-url for the L4T mirror.
|
||||
# https://pytorch.org/blog/vllm-and-pytorch-work-together-to-improve-the-developer-experience-on-aarch64/
|
||||
accelerate
|
||||
torch
|
||||
transformers
|
||||
bitsandbytes
|
||||
8
backend/rust/kokoros/Cargo.lock
generated
8
backend/rust/kokoros/Cargo.lock
generated
@@ -1392,9 +1392,9 @@ checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe"
|
||||
|
||||
[[package]]
|
||||
name = "openssl"
|
||||
version = "0.10.80"
|
||||
version = "0.10.79"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a45fa2aa886c42762255da344f0a0d313e254066c46aad76f300c3d3da62d967"
|
||||
checksum = "bf0b434746ee2832f4f0baf10137e1cabb18cbe6912c69e2e33263c45250f542"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"cfg-if",
|
||||
@@ -1423,9 +1423,9 @@ checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe"
|
||||
|
||||
[[package]]
|
||||
name = "openssl-sys"
|
||||
version = "0.9.116"
|
||||
version = "0.9.115"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f28a22dc7140cda5f096e5e7724a6962ca81a7f8bfd2979f9b18c11af56318c4"
|
||||
checksum = "158fe5b292746440aa6e7a7e690e55aeb72d41505e2804c23c6973ad0e9c9781"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"libc",
|
||||
|
||||
@@ -233,7 +233,12 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
|
||||
xlog.Info("File stager initialized (HTTP direct transfer)")
|
||||
}
|
||||
// Create RemoteUnloaderAdapter — needed by SmartRouter and startup.go
|
||||
remoteUnloader := nodes.NewRemoteUnloaderAdapter(registry, natsClient)
|
||||
remoteUnloader := nodes.NewRemoteUnloaderAdapter(
|
||||
registry,
|
||||
natsClient,
|
||||
cfg.Distributed.BackendInstallTimeoutOrDefault(),
|
||||
cfg.Distributed.BackendUpgradeTimeoutOrDefault(),
|
||||
)
|
||||
|
||||
// All dependencies ready — build SmartRouter with all options at once
|
||||
var conflictResolver nodes.ConcurrencyConflictResolver
|
||||
|
||||
@@ -17,9 +17,9 @@ import (
|
||||
"github.com/mudler/LocalAI/core/services/jobs"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/storage"
|
||||
"github.com/mudler/LocalAI/pkg/vram"
|
||||
coreStartup "github.com/mudler/LocalAI/core/startup"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
"github.com/mudler/LocalAI/pkg/vram"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/sanitize"
|
||||
@@ -200,7 +200,7 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
nodes.NewDistributedModelManager(options, application.modelLoader, distSvc.Unloader),
|
||||
)
|
||||
application.galleryService.SetBackendManager(
|
||||
nodes.NewDistributedBackendManager(options, application.modelLoader, distSvc.Unloader, distSvc.Registry),
|
||||
nodes.NewDistributedBackendManager(options, application.modelLoader, distSvc.Unloader, distSvc.Registry, application.galleryService),
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -552,6 +552,13 @@ func loadRuntimeSettingsFromFile(options *config.ApplicationConfig) {
|
||||
options.TracingMaxItems = *settings.TracingMaxItems
|
||||
}
|
||||
}
|
||||
if settings.TracingMaxBodyBytes != nil {
|
||||
// Allow the on-disk setting to override the CLI/env default. The
|
||||
// startup default is non-zero (see NewApplicationConfig), so a plain
|
||||
// `== 0` guard like the others would never trigger; we instead respect
|
||||
// any value the file specifies. 0 in the file means "uncapped".
|
||||
options.TracingMaxBodyBytes = *settings.TracingMaxBodyBytes
|
||||
}
|
||||
|
||||
// Branding / whitelabeling. There are no env vars for these — the file is
|
||||
// the only source — so apply unconditionally. Without this block a server
|
||||
|
||||
@@ -78,7 +78,7 @@ func ModelAudioTransform(
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
@@ -104,7 +104,7 @@ func ModelAudioTransform(
|
||||
data["sample_rate"] = res.SampleRate
|
||||
data["samples"] = res.Samples
|
||||
data["reference_provided"] = res.ReferenceProvided
|
||||
if snippet := trace.AudioSnippet(dst); snippet != nil {
|
||||
if snippet := trace.AudioSnippet(dst, appConfig.TracingMaxBodyBytes); snippet != nil {
|
||||
maps.Copy(data, snippet)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,7 +35,7 @@ func Detection(
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, modelConf
|
||||
}
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
|
||||
traceData := map[string]any{
|
||||
"input_text": trace.TruncateString(s, 1000),
|
||||
|
||||
@@ -32,7 +32,7 @@ func FaceAnalyze(
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ func FaceVerify(
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ func ImageGeneration(height, width, step, seed int, positive_prompt, negative_pr
|
||||
}
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
|
||||
traceData := map[string]any{
|
||||
"positive_prompt": positive_prompt,
|
||||
|
||||
@@ -305,7 +305,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
}
|
||||
|
||||
if o.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(o.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(o.TracingMaxItems, o.TracingMaxBodyBytes)
|
||||
|
||||
traceData := map[string]any{
|
||||
"chat_template": c.TemplateConfig.Chat,
|
||||
@@ -316,9 +316,13 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
"audios_count": len(audios),
|
||||
}
|
||||
|
||||
// Cap the captured fields up front: agent-pool LLM calls embed the
|
||||
// full augmented chat history in messages and the full reply in
|
||||
// response, so without a per-field cap a single trace can dwarf the
|
||||
// rest of the buffer. The cap matches the API-trace body cap.
|
||||
if len(messages) > 0 {
|
||||
if msgJSON, err := json.Marshal(messages); err == nil {
|
||||
traceData["messages"] = string(msgJSON)
|
||||
traceData["messages"] = trace.TruncateToBytes(string(msgJSON), o.TracingMaxBodyBytes)
|
||||
}
|
||||
}
|
||||
if reasoningJSON, err := json.Marshal(c.ReasoningConfig); err == nil {
|
||||
@@ -337,7 +341,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
resp, err := originalFn()
|
||||
duration := time.Since(startTime)
|
||||
|
||||
traceData["response"] = resp.Response
|
||||
traceData["response"] = trace.TruncateToBytes(resp.Response, o.TracingMaxBodyBytes)
|
||||
traceData["token_usage"] = map[string]any{
|
||||
"prompt": resp.Usage.Prompt,
|
||||
"completion": resp.Usage.Completion,
|
||||
@@ -359,10 +363,10 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
toolCallCount += len(d.ToolCalls)
|
||||
}
|
||||
if len(contentParts) > 0 {
|
||||
chatDeltasInfo["content"] = strings.Join(contentParts, "")
|
||||
chatDeltasInfo["content"] = trace.TruncateToBytes(strings.Join(contentParts, ""), o.TracingMaxBodyBytes)
|
||||
}
|
||||
if len(reasoningParts) > 0 {
|
||||
chatDeltasInfo["reasoning_content"] = strings.Join(reasoningParts, "")
|
||||
chatDeltasInfo["reasoning_content"] = trace.TruncateToBytes(strings.Join(reasoningParts, ""), o.TracingMaxBodyBytes)
|
||||
}
|
||||
if toolCallCount > 0 {
|
||||
chatDeltasInfo["tool_call_count"] = toolCallCount
|
||||
|
||||
@@ -21,7 +21,7 @@ func recordModelLoadFailure(appConfig *config.ApplicationConfig, modelName, back
|
||||
if !appConfig.EnableTracing {
|
||||
return
|
||||
}
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: time.Now(),
|
||||
Type: trace.BackendTraceModelLoad,
|
||||
@@ -277,7 +277,7 @@ func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions
|
||||
MinP: float32(*c.MinP),
|
||||
Tokens: int32(*c.Maxtokens),
|
||||
Threads: int32(*c.Threads),
|
||||
PromptCacheAll: c.PromptCacheAll,
|
||||
PromptCacheAll: *c.PromptCacheAll,
|
||||
PromptCacheRO: c.PromptCacheRO,
|
||||
PromptCachePath: promptCachePath,
|
||||
F16KV: *c.F16,
|
||||
|
||||
@@ -25,7 +25,7 @@ func Rerank(ctx context.Context, request *proto.RerankRequest, loader *model.Mod
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
|
||||
@@ -98,7 +98,7 @@ func SoundGeneration(
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ func ModelTokenize(s string, loader *model.ModelLoader, modelConfig config.Model
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
|
||||
@@ -76,10 +76,10 @@ func ModelTranscriptionWithOptions(ctx context.Context, req TranscriptionRequest
|
||||
var startTime time.Time
|
||||
var audioSnippet map[string]any
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
startTime = time.Now()
|
||||
// Capture audio before the backend call — the backend may delete the file.
|
||||
audioSnippet = trace.AudioSnippet(req.Audio)
|
||||
audioSnippet = trace.AudioSnippet(req.Audio, appConfig.TracingMaxBodyBytes)
|
||||
}
|
||||
|
||||
r, err := transcriptionModel.AudioTranscription(ctx, req.toProto(uint32(*modelConfig.Threads)))
|
||||
|
||||
@@ -67,7 +67,7 @@ func ModelTTS(
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
@@ -93,7 +93,7 @@ func ModelTTS(
|
||||
"language": language,
|
||||
}
|
||||
if err == nil && res.Success {
|
||||
if snippet := trace.AudioSnippet(filePath); snippet != nil {
|
||||
if snippet := trace.AudioSnippet(filePath, appConfig.TracingMaxBodyBytes); snippet != nil {
|
||||
maps.Copy(data, snippet)
|
||||
}
|
||||
}
|
||||
@@ -161,7 +161,7 @@ func ModelTTSStream(
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
@@ -260,7 +260,7 @@ func ModelTTSStream(
|
||||
"streaming": true,
|
||||
}
|
||||
if resultErr == nil && len(snippetPCM) > 0 {
|
||||
if snippet := trace.AudioSnippetFromPCM(snippetPCM, int(sampleRate), totalPCMBytes); snippet != nil {
|
||||
if snippet := trace.AudioSnippetFromPCM(snippetPCM, int(sampleRate), totalPCMBytes, appConfig.TracingMaxBodyBytes); snippet != nil {
|
||||
maps.Copy(data, snippet)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ func VideoGeneration(height, width int32, prompt, negativePrompt, startImage, en
|
||||
}
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
|
||||
traceData := map[string]any{
|
||||
"prompt": prompt,
|
||||
|
||||
@@ -31,7 +31,7 @@ func VoiceAnalyze(
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ func VoiceEmbed(
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ func VoiceVerify(
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
|
||||
@@ -39,19 +39,19 @@ type RunCMD struct {
|
||||
LocalaiConfigDir string `env:"LOCALAI_CONFIG_DIR" type:"path" default:"${basepath}/configuration" help:"Directory for dynamic loading of certain configuration files (currently api_keys.json and external_backends.json)" group:"storage"`
|
||||
LocalaiConfigDirPollInterval time.Duration `env:"LOCALAI_CONFIG_DIR_POLL_INTERVAL" help:"Typically the config path picks up changes automatically, but if your system has broken fsnotify events, set this to an interval to poll the LocalAI Config Dir (example: 1m)" group:"storage"`
|
||||
// The alias on this option is there to preserve functionality with the old `--config-file` parameter
|
||||
ModelsConfigFile string `env:"LOCALAI_MODELS_CONFIG_FILE,CONFIG_FILE" aliases:"config-file" help:"YAML file containing a list of model backend configs" group:"storage"`
|
||||
BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"`
|
||||
Galleries string `env:"LOCALAI_GALLERIES,GALLERIES" help:"JSON list of galleries" group:"models" default:"${galleries}"`
|
||||
AutoloadGalleries bool `env:"LOCALAI_AUTOLOAD_GALLERIES,AUTOLOAD_GALLERIES" group:"models" default:"true"`
|
||||
AutoloadBackendGalleries bool `env:"LOCALAI_AUTOLOAD_BACKEND_GALLERIES,AUTOLOAD_BACKEND_GALLERIES" group:"backends" default:"true"`
|
||||
BackendImagesReleaseTag string `env:"LOCALAI_BACKEND_IMAGES_RELEASE_TAG,BACKEND_IMAGES_RELEASE_TAG" help:"Fallback release tag for backend images" group:"backends" default:"latest"`
|
||||
BackendImagesBranchTag string `env:"LOCALAI_BACKEND_IMAGES_BRANCH_TAG,BACKEND_IMAGES_BRANCH_TAG" help:"Fallback branch tag for backend images" group:"backends" default:"master"`
|
||||
BackendDevSuffix string `env:"LOCALAI_BACKEND_DEV_SUFFIX,BACKEND_DEV_SUFFIX" help:"Development suffix for backend images" group:"backends" default:"development"`
|
||||
ModelsConfigFile string `env:"LOCALAI_MODELS_CONFIG_FILE,CONFIG_FILE" aliases:"config-file" help:"YAML file containing a list of model backend configs" group:"storage"`
|
||||
BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"`
|
||||
Galleries string `env:"LOCALAI_GALLERIES,GALLERIES" help:"JSON list of galleries" group:"models" default:"${galleries}"`
|
||||
AutoloadGalleries bool `env:"LOCALAI_AUTOLOAD_GALLERIES,AUTOLOAD_GALLERIES" group:"models" default:"true"`
|
||||
AutoloadBackendGalleries bool `env:"LOCALAI_AUTOLOAD_BACKEND_GALLERIES,AUTOLOAD_BACKEND_GALLERIES" group:"backends" default:"true"`
|
||||
BackendImagesReleaseTag string `env:"LOCALAI_BACKEND_IMAGES_RELEASE_TAG,BACKEND_IMAGES_RELEASE_TAG" help:"Fallback release tag for backend images" group:"backends" default:"latest"`
|
||||
BackendImagesBranchTag string `env:"LOCALAI_BACKEND_IMAGES_BRANCH_TAG,BACKEND_IMAGES_BRANCH_TAG" help:"Fallback branch tag for backend images" group:"backends" default:"master"`
|
||||
BackendDevSuffix string `env:"LOCALAI_BACKEND_DEV_SUFFIX,BACKEND_DEV_SUFFIX" help:"Development suffix for backend images" group:"backends" default:"development"`
|
||||
AutoUpgradeBackends bool `env:"LOCALAI_AUTO_UPGRADE_BACKENDS,AUTO_UPGRADE_BACKENDS" help:"Automatically upgrade backends when new versions are detected" group:"backends" default:"false"`
|
||||
PreferDevelopmentBackends bool `env:"LOCALAI_PREFER_DEV_BACKENDS,PREFER_DEV_BACKENDS" help:"Prefer development backend versions (shows development backends by default in UI)" group:"backends" default:"false"`
|
||||
PreloadModels string `env:"LOCALAI_PRELOAD_MODELS,PRELOAD_MODELS" help:"A List of models to apply in JSON at start" group:"models"`
|
||||
Models []string `env:"LOCALAI_MODELS,MODELS" help:"A List of model configuration URLs to load" group:"models"`
|
||||
PreloadModelsConfig string `env:"LOCALAI_PRELOAD_MODELS_CONFIG,PRELOAD_MODELS_CONFIG" help:"A List of models to apply at startup. Path to a YAML config file" group:"models"`
|
||||
Models []string `env:"LOCALAI_MODELS,MODELS" help:"A List of model configuration URLs to load" group:"models"`
|
||||
PreloadModelsConfig string `env:"LOCALAI_PRELOAD_MODELS_CONFIG,PRELOAD_MODELS_CONFIG" help:"A List of models to apply at startup. Path to a YAML config file" group:"models"`
|
||||
|
||||
F16 bool `name:"f16" env:"LOCALAI_F16,F16" help:"Enable GPU acceleration" group:"performance"`
|
||||
Threads int `env:"LOCALAI_THREADS,THREADS" short:"t" help:"Number of threads used for parallel computation. Usage of the number of physical cores in the system is suggested" group:"performance"`
|
||||
@@ -100,6 +100,7 @@ type RunCMD struct {
|
||||
LoadToMemory []string `env:"LOCALAI_LOAD_TO_MEMORY,LOAD_TO_MEMORY" help:"A list of models to load into memory at startup" group:"models"`
|
||||
EnableTracing bool `env:"LOCALAI_ENABLE_TRACING,ENABLE_TRACING" help:"Enable API tracing" group:"api"`
|
||||
TracingMaxItems int `env:"LOCALAI_TRACING_MAX_ITEMS" default:"1024" help:"Maximum number of traces to keep" group:"api"`
|
||||
TracingMaxBodyBytes int `env:"LOCALAI_TRACING_MAX_BODY_BYTES" default:"65536" help:"Maximum bytes captured per request/response body in the trace buffer (0 = uncapped). Caps memory growth from chatty endpoints like /embeddings." group:"api"`
|
||||
AgentJobRetentionDays int `env:"LOCALAI_AGENT_JOB_RETENTION_DAYS,AGENT_JOB_RETENTION_DAYS" default:"30" help:"Number of days to keep agent job history (default: 30)" group:"api"`
|
||||
OpenResponsesStoreTTL string `env:"LOCALAI_OPEN_RESPONSES_STORE_TTL,OPEN_RESPONSES_STORE_TTL" default:"0" help:"TTL for Open Responses store (e.g., 1h, 30m, 0 = no expiration)" group:"api"`
|
||||
|
||||
@@ -144,16 +145,19 @@ type RunCMD struct {
|
||||
DefaultAPIKeyExpiry string `env:"LOCALAI_DEFAULT_API_KEY_EXPIRY" help:"Default expiry for API keys (e.g. 90d, 1y; empty = no expiry)" group:"auth"`
|
||||
|
||||
// Distributed / Horizontal Scaling
|
||||
Distributed bool `env:"LOCALAI_DISTRIBUTED" default:"false" help:"Enable distributed mode (requires PostgreSQL + NATS)" group:"distributed"`
|
||||
InstanceID string `env:"LOCALAI_INSTANCE_ID" help:"Unique instance ID for distributed mode (auto-generated UUID if empty)" group:"distributed"`
|
||||
NatsURL string `env:"LOCALAI_NATS_URL" help:"NATS server URL (e.g., nats://localhost:4222)" group:"distributed"`
|
||||
StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3-compatible storage endpoint URL (e.g., http://minio:9000)" group:"distributed"`
|
||||
StorageBucket string `env:"LOCALAI_STORAGE_BUCKET" default:"localai" help:"S3 bucket name for object storage" group:"distributed"`
|
||||
StorageRegion string `env:"LOCALAI_STORAGE_REGION" default:"us-east-1" help:"S3 region" group:"distributed"`
|
||||
StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key ID" group:"distributed"`
|
||||
StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret access key" group:"distributed"`
|
||||
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token that backend nodes must provide to register (empty = no auth required)" group:"distributed"`
|
||||
AutoApproveNodes bool `env:"LOCALAI_AUTO_APPROVE_NODES" default:"false" help:"Auto-approve new worker nodes (skip admin approval)" group:"distributed"`
|
||||
Distributed bool `env:"LOCALAI_DISTRIBUTED" default:"false" help:"Enable distributed mode (requires PostgreSQL + NATS)" group:"distributed"`
|
||||
InstanceID string `env:"LOCALAI_INSTANCE_ID" help:"Unique instance ID for distributed mode (auto-generated UUID if empty)" group:"distributed"`
|
||||
NatsURL string `env:"LOCALAI_NATS_URL" help:"NATS server URL (e.g., nats://localhost:4222)" group:"distributed"`
|
||||
StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3-compatible storage endpoint URL (e.g., http://minio:9000)" group:"distributed"`
|
||||
StorageBucket string `env:"LOCALAI_STORAGE_BUCKET" default:"localai" help:"S3 bucket name for object storage" group:"distributed"`
|
||||
StorageRegion string `env:"LOCALAI_STORAGE_REGION" default:"us-east-1" help:"S3 region" group:"distributed"`
|
||||
StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key ID" group:"distributed"`
|
||||
StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret access key" group:"distributed"`
|
||||
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token that backend nodes must provide to register (empty = no auth required)" group:"distributed"`
|
||||
AutoApproveNodes bool `env:"LOCALAI_AUTO_APPROVE_NODES" default:"false" help:"Auto-approve new worker nodes (skip admin approval)" group:"distributed"`
|
||||
BackendInstallTimeout string `env:"LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT" help:"NATS round-trip timeout for backend.install requests sent to worker nodes (default 15m). Increase for slow links pulling multi-GB images." group:"distributed"`
|
||||
BackendUpgradeTimeout string `env:"LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT" help:"NATS round-trip timeout for backend.upgrade requests (default 15m)." group:"distributed"`
|
||||
ExposeNodeHeader bool `env:"LOCALAI_EXPOSE_NODE_HEADER" default:"false" help:"Set the X-LocalAI-Node response header on inference responses (OpenAI chat/completions/embeddings, Anthropic /v1/messages, Ollama /api/chat,/api/generate,/api/embed) with the ID of the worker that served the request. Disabled by default: the node ID reveals internal topology and should not be exposed on a public endpoint. Best-effort: under heavy concurrency the header may reflect a recent routing decision rather than this exact request's." group:"distributed"`
|
||||
|
||||
Version bool
|
||||
}
|
||||
@@ -254,12 +258,29 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
if r.StorageSecretKey != "" {
|
||||
opts = append(opts, config.WithStorageSecretKey(r.StorageSecretKey))
|
||||
}
|
||||
if r.BackendInstallTimeout != "" {
|
||||
d, err := time.ParseDuration(r.BackendInstallTimeout)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT %q: %w", r.BackendInstallTimeout, err)
|
||||
}
|
||||
opts = append(opts, config.WithBackendInstallTimeout(d))
|
||||
}
|
||||
if r.BackendUpgradeTimeout != "" {
|
||||
d, err := time.ParseDuration(r.BackendUpgradeTimeout)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT %q: %w", r.BackendUpgradeTimeout, err)
|
||||
}
|
||||
opts = append(opts, config.WithBackendUpgradeTimeout(d))
|
||||
}
|
||||
if r.RegistrationToken != "" {
|
||||
opts = append(opts, config.WithRegistrationToken(r.RegistrationToken))
|
||||
}
|
||||
if r.AutoApproveNodes {
|
||||
opts = append(opts, config.EnableAutoApproveNodes)
|
||||
}
|
||||
if r.ExposeNodeHeader {
|
||||
opts = append(opts, config.WithExposeNodeHeader(true))
|
||||
}
|
||||
|
||||
if r.DisableMetricsEndpoint {
|
||||
opts = append(opts, config.DisableMetricsEndpoint)
|
||||
@@ -273,6 +294,7 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
opts = append(opts, config.EnableTracing)
|
||||
}
|
||||
opts = append(opts, config.WithTracingMaxItems(r.TracingMaxItems))
|
||||
opts = append(opts, config.WithTracingMaxBodyBytes(r.TracingMaxBodyBytes))
|
||||
|
||||
token := ""
|
||||
if r.Peer2Peer || r.Peer2PeerToken != "" {
|
||||
|
||||
@@ -21,6 +21,7 @@ type ApplicationConfig struct {
|
||||
Debug bool
|
||||
EnableTracing bool
|
||||
TracingMaxItems int
|
||||
TracingMaxBodyBytes int // Per-body cap for captured request/response bodies; 0 disables the cap
|
||||
EnableBackendLogging bool
|
||||
GeneratedContentDir string
|
||||
|
||||
@@ -111,6 +112,18 @@ type ApplicationConfig struct {
|
||||
// Distributed / Horizontal Scaling
|
||||
Distributed DistributedConfig
|
||||
|
||||
// ExposeNodeHeader, when true, activates middleware.ExposeNodeHeader on
|
||||
// the inference routes (OpenAI chat/completions/embeddings, Anthropic
|
||||
// /v1/messages, Ollama /api/chat,/api/generate,/api/embed). The
|
||||
// middleware wraps the response writer and attaches an "X-LocalAI-Node"
|
||||
// response header carrying the ID of the distributed-mode worker node
|
||||
// that served the request. Off by default because the node ID is
|
||||
// internal topology that can aid attacker reconnaissance if surfaced on
|
||||
// a public endpoint; operators opt in explicitly via
|
||||
// --expose-node-header / LOCALAI_EXPOSE_NODE_HEADER for debugging,
|
||||
// observability and load-balancer attribution.
|
||||
ExposeNodeHeader bool
|
||||
|
||||
// LocalAI Assistant chat modality. Hard-disable the in-process admin MCP
|
||||
// server with this flag; runtime-toggleable via /api/settings.
|
||||
DisableLocalAIAssistant bool
|
||||
@@ -187,6 +200,7 @@ func NewApplicationConfig(o ...AppOption) *ApplicationConfig {
|
||||
LRUEvictionRetryInterval: 1 * time.Second, // Default: 1 second
|
||||
WatchDogInterval: 500 * time.Millisecond, // Default: 500ms
|
||||
TracingMaxItems: 1024,
|
||||
TracingMaxBodyBytes: 64 * 1024, // 64 KiB - caps each request/response body in the trace buffer
|
||||
AgentPool: AgentPoolConfig{
|
||||
Enabled: true,
|
||||
Timeout: "5m",
|
||||
@@ -578,6 +592,12 @@ func WithTracingMaxItems(items int) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
func WithTracingMaxBodyBytes(bytes int) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.TracingMaxBodyBytes = bytes
|
||||
}
|
||||
}
|
||||
|
||||
func WithGeneratedContentDir(generatedContentDir string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.GeneratedContentDir = generatedContentDir
|
||||
@@ -885,6 +905,15 @@ func WithDisableLocalAIAssistant(disabled bool) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithExposeNodeHeader enables the X-LocalAI-Node response header on
|
||||
// inference endpoints. Default off; the node ID reveals internal cluster
|
||||
// topology and is opt-in for that reason.
|
||||
func WithExposeNodeHeader(enabled bool) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.ExposeNodeHeader = enabled
|
||||
}
|
||||
}
|
||||
|
||||
// ToConfigLoaderOptions returns a slice of ConfigLoader Option.
|
||||
// Some options defined at the application level are going to be passed as defaults for
|
||||
// all the configuration for the models.
|
||||
@@ -920,6 +949,7 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
|
||||
f16 := o.F16
|
||||
debug := o.Debug
|
||||
tracingMaxItems := o.TracingMaxItems
|
||||
tracingMaxBodyBytes := o.TracingMaxBodyBytes
|
||||
enableTracing := o.EnableTracing
|
||||
enableBackendLogging := o.EnableBackendLogging
|
||||
cors := o.CORS
|
||||
@@ -1008,6 +1038,7 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
|
||||
F16: &f16,
|
||||
Debug: &debug,
|
||||
TracingMaxItems: &tracingMaxItems,
|
||||
TracingMaxBodyBytes: &tracingMaxBodyBytes,
|
||||
EnableTracing: &enableTracing,
|
||||
EnableBackendLogging: &enableBackendLogging,
|
||||
CORS: &cors,
|
||||
@@ -1146,6 +1177,9 @@ func (o *ApplicationConfig) ApplyRuntimeSettings(settings *RuntimeSettings) (req
|
||||
if settings.TracingMaxItems != nil {
|
||||
o.TracingMaxItems = *settings.TracingMaxItems
|
||||
}
|
||||
if settings.TracingMaxBodyBytes != nil {
|
||||
o.TracingMaxBodyBytes = *settings.TracingMaxBodyBytes
|
||||
}
|
||||
if settings.EnableBackendLogging != nil {
|
||||
o.EnableBackendLogging = *settings.EnableBackendLogging
|
||||
}
|
||||
|
||||
@@ -40,7 +40,10 @@ type DistributedConfig struct {
|
||||
// model-row cleanup on MarkUnhealthy / MarkDraining).
|
||||
DisablePerModelHealthCheck bool
|
||||
|
||||
MCPCIJobTimeout time.Duration // MCP CI job execution timeout (default 10m)
|
||||
MCPCIJobTimeout time.Duration // MCP CI job execution timeout (default 10m)
|
||||
|
||||
BackendInstallTimeout time.Duration // NATS round-trip timeout for backend.install (default 15m)
|
||||
BackendUpgradeTimeout time.Duration // NATS round-trip timeout for backend.upgrade (default 15m)
|
||||
|
||||
MaxUploadSize int64 // Maximum upload body size in bytes (default 50 GB)
|
||||
|
||||
@@ -68,13 +71,15 @@ func (c DistributedConfig) Validate() error {
|
||||
}
|
||||
// Check for negative durations
|
||||
for name, d := range map[string]time.Duration{
|
||||
"mcp-tool-timeout": c.MCPToolTimeout,
|
||||
"mcp-discovery-timeout": c.MCPDiscoveryTimeout,
|
||||
"worker-wait-timeout": c.WorkerWaitTimeout,
|
||||
"drain-timeout": c.DrainTimeout,
|
||||
"health-check-interval": c.HealthCheckInterval,
|
||||
"stale-node-threshold": c.StaleNodeThreshold,
|
||||
"mcp-ci-job-timeout": c.MCPCIJobTimeout,
|
||||
FlagMCPToolTimeout: c.MCPToolTimeout,
|
||||
FlagMCPDiscoveryTimeout: c.MCPDiscoveryTimeout,
|
||||
FlagWorkerWaitTimeout: c.WorkerWaitTimeout,
|
||||
FlagDrainTimeout: c.DrainTimeout,
|
||||
FlagHealthCheckInterval: c.HealthCheckInterval,
|
||||
FlagStaleNodeThreshold: c.StaleNodeThreshold,
|
||||
FlagMCPCIJobTimeout: c.MCPCIJobTimeout,
|
||||
FlagBackendInstallTimeout: c.BackendInstallTimeout,
|
||||
FlagBackendUpgradeTimeout: c.BackendUpgradeTimeout,
|
||||
} {
|
||||
if d < 0 {
|
||||
return fmt.Errorf("%s must not be negative", name)
|
||||
@@ -137,24 +142,66 @@ func WithStorageSecretKey(key string) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
func WithBackendInstallTimeout(d time.Duration) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.BackendInstallTimeout = d
|
||||
}
|
||||
}
|
||||
|
||||
func WithBackendUpgradeTimeout(d time.Duration) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.BackendUpgradeTimeout = d
|
||||
}
|
||||
}
|
||||
|
||||
var EnableAutoApproveNodes = func(o *ApplicationConfig) {
|
||||
o.Distributed.AutoApproveNodes = true
|
||||
}
|
||||
|
||||
// Flag names for distributed timeout / interval configuration. These are
|
||||
// the kebab-case identifiers kong derives from the matching RunCMD struct
|
||||
// fields; they appear in Validate error messages and any other operator-
|
||||
// facing surface that needs to reference a specific knob by name. Keeping
|
||||
// them as constants prevents the string from drifting from the actual
|
||||
// flag a future rename would produce.
|
||||
const (
|
||||
FlagMCPToolTimeout = "mcp-tool-timeout"
|
||||
FlagMCPDiscoveryTimeout = "mcp-discovery-timeout"
|
||||
FlagWorkerWaitTimeout = "worker-wait-timeout"
|
||||
FlagDrainTimeout = "drain-timeout"
|
||||
FlagHealthCheckInterval = "health-check-interval"
|
||||
FlagStaleNodeThreshold = "stale-node-threshold"
|
||||
FlagMCPCIJobTimeout = "mcp-ci-job-timeout"
|
||||
FlagBackendInstallTimeout = "backend-install-timeout"
|
||||
FlagBackendUpgradeTimeout = "backend-upgrade-timeout"
|
||||
)
|
||||
|
||||
// Defaults for distributed timeouts.
|
||||
const (
|
||||
DefaultMCPToolTimeout = 360 * time.Second
|
||||
DefaultMCPDiscoveryTimeout = 60 * time.Second
|
||||
DefaultWorkerWaitTimeout = 5 * time.Minute
|
||||
DefaultDrainTimeout = 30 * time.Second
|
||||
DefaultHealthCheckInterval = 15 * time.Second
|
||||
DefaultStaleNodeThreshold = 60 * time.Second
|
||||
DefaultMCPCIJobTimeout = 10 * time.Minute
|
||||
DefaultMCPToolTimeout = 360 * time.Second
|
||||
DefaultMCPDiscoveryTimeout = 60 * time.Second
|
||||
DefaultWorkerWaitTimeout = 5 * time.Minute
|
||||
DefaultDrainTimeout = 30 * time.Second
|
||||
DefaultHealthCheckInterval = 15 * time.Second
|
||||
DefaultStaleNodeThreshold = 60 * time.Second
|
||||
DefaultMCPCIJobTimeout = 10 * time.Minute
|
||||
DefaultBackendInstallTimeout = 15 * time.Minute
|
||||
DefaultBackendUpgradeTimeout = 15 * time.Minute
|
||||
)
|
||||
|
||||
// DefaultMaxUploadSize is the default maximum upload body size (50 GB).
|
||||
const DefaultMaxUploadSize int64 = 50 << 30
|
||||
|
||||
// BackendInstallTimeoutOrDefault returns the configured timeout or the default.
|
||||
func (c DistributedConfig) BackendInstallTimeoutOrDefault() time.Duration {
|
||||
return cmp.Or(c.BackendInstallTimeout, DefaultBackendInstallTimeout)
|
||||
}
|
||||
|
||||
// BackendUpgradeTimeoutOrDefault returns the configured timeout or the default.
|
||||
func (c DistributedConfig) BackendUpgradeTimeoutOrDefault() time.Duration {
|
||||
return cmp.Or(c.BackendUpgradeTimeout, DefaultBackendUpgradeTimeout)
|
||||
}
|
||||
|
||||
// MCPToolTimeoutOrDefault returns the configured timeout or the default.
|
||||
func (c DistributedConfig) MCPToolTimeoutOrDefault() time.Duration {
|
||||
return cmp.Or(c.MCPToolTimeout, DefaultMCPToolTimeout)
|
||||
|
||||
90
core/config/distributed_config_test.go
Normal file
90
core/config/distributed_config_test.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package config_test
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
)
|
||||
|
||||
var _ = Describe("DistributedConfig backend NATS timeouts", func() {
|
||||
Context("BackendInstallTimeoutOrDefault", func() {
|
||||
It("returns 15 minutes when unset", func() {
|
||||
c := config.DistributedConfig{}
|
||||
Expect(c.BackendInstallTimeoutOrDefault()).To(Equal(15 * time.Minute))
|
||||
})
|
||||
|
||||
It("returns the configured value when set", func() {
|
||||
c := config.DistributedConfig{BackendInstallTimeout: 42 * time.Minute}
|
||||
Expect(c.BackendInstallTimeoutOrDefault()).To(Equal(42 * time.Minute))
|
||||
})
|
||||
})
|
||||
|
||||
Context("BackendUpgradeTimeoutOrDefault", func() {
|
||||
It("returns 15 minutes when unset", func() {
|
||||
c := config.DistributedConfig{}
|
||||
Expect(c.BackendUpgradeTimeoutOrDefault()).To(Equal(15 * time.Minute))
|
||||
})
|
||||
|
||||
It("returns the configured value when set", func() {
|
||||
c := config.DistributedConfig{BackendUpgradeTimeout: 30 * time.Minute}
|
||||
Expect(c.BackendUpgradeTimeoutOrDefault()).To(Equal(30 * time.Minute))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("DistributedConfig flag-name constants", func() {
|
||||
// Pin the kebab-case strings so a rename of the Go field name (or a
|
||||
// CLI flag naming convention change) forces the constant to update,
|
||||
// keeping the Validate error messages and any future operator-facing
|
||||
// surface in sync with the actual CLI flag.
|
||||
DescribeTable("flag name constants",
|
||||
func(actual, expected string) {
|
||||
Expect(actual).To(Equal(expected))
|
||||
},
|
||||
Entry("MCP tool timeout", config.FlagMCPToolTimeout, "mcp-tool-timeout"),
|
||||
Entry("MCP discovery timeout", config.FlagMCPDiscoveryTimeout, "mcp-discovery-timeout"),
|
||||
Entry("worker wait timeout", config.FlagWorkerWaitTimeout, "worker-wait-timeout"),
|
||||
Entry("drain timeout", config.FlagDrainTimeout, "drain-timeout"),
|
||||
Entry("health check interval", config.FlagHealthCheckInterval, "health-check-interval"),
|
||||
Entry("stale node threshold", config.FlagStaleNodeThreshold, "stale-node-threshold"),
|
||||
Entry("MCP CI job timeout", config.FlagMCPCIJobTimeout, "mcp-ci-job-timeout"),
|
||||
Entry("backend install timeout", config.FlagBackendInstallTimeout, "backend-install-timeout"),
|
||||
Entry("backend upgrade timeout", config.FlagBackendUpgradeTimeout, "backend-upgrade-timeout"),
|
||||
)
|
||||
})
|
||||
|
||||
var _ = Describe("DistributedConfig.Validate negative-duration errors", func() {
|
||||
It("rejects a negative BackendInstallTimeout with the flag name in the error", func() {
|
||||
c := config.DistributedConfig{
|
||||
Enabled: true,
|
||||
NatsURL: "nats://localhost:4222",
|
||||
BackendInstallTimeout: -1 * time.Second,
|
||||
}
|
||||
err := c.Validate()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring(config.FlagBackendInstallTimeout))
|
||||
Expect(err.Error()).To(ContainSubstring("must not be negative"))
|
||||
})
|
||||
|
||||
It("rejects a negative BackendUpgradeTimeout with the flag name in the error", func() {
|
||||
c := config.DistributedConfig{
|
||||
Enabled: true,
|
||||
NatsURL: "nats://localhost:4222",
|
||||
BackendUpgradeTimeout: -1 * time.Second,
|
||||
}
|
||||
err := c.Validate()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring(config.FlagBackendUpgradeTimeout))
|
||||
})
|
||||
|
||||
It("accepts all-zero durations as valid (defaults apply)", func() {
|
||||
c := config.DistributedConfig{
|
||||
Enabled: true,
|
||||
NatsURL: "nats://localhost:4222",
|
||||
}
|
||||
Expect(c.Validate()).To(Succeed())
|
||||
})
|
||||
})
|
||||
@@ -136,4 +136,36 @@ var _ = Describe("Backend hooks and parser defaults", func() {
|
||||
Expect(cfg.EngineArgs["enable_chunked_prefill"]).To(Equal(true))
|
||||
})
|
||||
})
|
||||
|
||||
Context("PromptCacheAll default", func() {
|
||||
It("defaults to true when omitted from YAML", func() {
|
||||
cfg := &ModelConfig{}
|
||||
cfg.SetDefaults()
|
||||
|
||||
Expect(cfg.PromptCacheAll).NotTo(BeNil())
|
||||
Expect(*cfg.PromptCacheAll).To(BeTrue())
|
||||
})
|
||||
|
||||
It("preserves an explicit false from YAML", func() {
|
||||
falseV := false
|
||||
cfg := &ModelConfig{
|
||||
LLMConfig: LLMConfig{PromptCacheAll: &falseV},
|
||||
}
|
||||
cfg.SetDefaults()
|
||||
|
||||
Expect(cfg.PromptCacheAll).NotTo(BeNil())
|
||||
Expect(*cfg.PromptCacheAll).To(BeFalse())
|
||||
})
|
||||
|
||||
It("preserves an explicit true from YAML", func() {
|
||||
trueV := true
|
||||
cfg := &ModelConfig{
|
||||
LLMConfig: LLMConfig{PromptCacheAll: &trueV},
|
||||
}
|
||||
cfg.SetDefaults()
|
||||
|
||||
Expect(cfg.PromptCacheAll).NotTo(BeNil())
|
||||
Expect(*cfg.PromptCacheAll).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -209,7 +209,7 @@ type LLMConfig struct {
|
||||
RMSNormEps float32 `yaml:"rms_norm_eps,omitempty" json:"rms_norm_eps,omitempty"`
|
||||
NGQA int32 `yaml:"ngqa,omitempty" json:"ngqa,omitempty"`
|
||||
PromptCachePath string `yaml:"prompt_cache_path,omitempty" json:"prompt_cache_path,omitempty"`
|
||||
PromptCacheAll bool `yaml:"prompt_cache_all,omitempty" json:"prompt_cache_all,omitempty"`
|
||||
PromptCacheAll *bool `yaml:"prompt_cache_all,omitempty" json:"prompt_cache_all,omitempty"`
|
||||
PromptCacheRO bool `yaml:"prompt_cache_ro,omitempty" json:"prompt_cache_ro,omitempty"`
|
||||
MirostatETA *float64 `yaml:"mirostat_eta,omitempty" json:"mirostat_eta,omitempty"`
|
||||
MirostatTAU *float64 `yaml:"mirostat_tau,omitempty" json:"mirostat_tau,omitempty"`
|
||||
@@ -494,6 +494,13 @@ func (cfg *ModelConfig) SetDefaults(opts ...ConfigLoaderOption) {
|
||||
cfg.Reranking = &falseV
|
||||
}
|
||||
|
||||
if cfg.PromptCacheAll == nil {
|
||||
// Match upstream llama.cpp's default (common/common.h: cache_prompt = true)
|
||||
// and let cache_idle_slots / kv_unified actually do useful work; users can
|
||||
// opt out with an explicit `prompt_cache_all: false` in the model YAML.
|
||||
cfg.PromptCacheAll = &trueV
|
||||
}
|
||||
|
||||
if threads == 0 {
|
||||
// Threads can't be 0
|
||||
threads = 4
|
||||
|
||||
@@ -38,6 +38,7 @@ type RuntimeSettings struct {
|
||||
Debug *bool `json:"debug,omitempty"`
|
||||
EnableTracing *bool `json:"enable_tracing,omitempty"`
|
||||
TracingMaxItems *int `json:"tracing_max_items,omitempty"`
|
||||
TracingMaxBodyBytes *int `json:"tracing_max_body_bytes,omitempty"` // Per-body cap in bytes; 0 disables the cap
|
||||
EnableBackendLogging *bool `json:"enable_backend_logging,omitempty"`
|
||||
|
||||
// Security/CORS settings
|
||||
|
||||
@@ -73,363 +73,6 @@ func mergeToolCallDeltas(existing []schema.ToolCall, deltas []schema.ToolCall) [
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/chat/completions [post]
|
||||
func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig, natsClient mcpTools.MCPNATSClient, assistantHolder *mcpTools.LocalAIAssistantHolder) echo.HandlerFunc {
|
||||
process := func(s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool, id string, created int) error {
|
||||
initialMessage := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0, FinishReason: nil}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
responses <- initialMessage
|
||||
|
||||
// Detect if thinking token is already in prompt or template
|
||||
// When UseTokenizerTemplate is enabled, predInput is empty, so we check the template
|
||||
var template string
|
||||
if config.TemplateConfig.UseTokenizerTemplate {
|
||||
template = config.GetModelTemplate()
|
||||
} else {
|
||||
template = s
|
||||
}
|
||||
thinkingStartToken := reason.DetectThinkingStartToken(template, &config.ReasoningConfig)
|
||||
extractor := reason.NewReasoningExtractor(thinkingStartToken, config.ReasoningConfig)
|
||||
|
||||
_, _, _, err := ComputeChoices(req, s, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool {
|
||||
var reasoningDelta, contentDelta string
|
||||
|
||||
// Always keep the Go-side extractor in sync with raw tokens so it
|
||||
// can serve as fallback for backends without an autoparser (e.g. vLLM).
|
||||
goReasoning, goContent := extractor.ProcessToken(s)
|
||||
|
||||
// When C++ autoparser chat deltas are available, prefer them — they
|
||||
// handle model-specific formats (Gemma 4, etc.) without Go-side tags.
|
||||
// Otherwise fall back to Go-side extraction.
|
||||
if tokenUsage.HasChatDeltaContent() {
|
||||
rawReasoning, cd := tokenUsage.ChatDeltaReasoningAndContent()
|
||||
contentDelta = cd
|
||||
reasoningDelta = extractor.ProcessChatDeltaReasoning(rawReasoning)
|
||||
} else {
|
||||
reasoningDelta = goReasoning
|
||||
contentDelta = goContent
|
||||
}
|
||||
|
||||
usage := schema.OpenAIUsage{
|
||||
PromptTokens: tokenUsage.Prompt,
|
||||
CompletionTokens: tokenUsage.Completion,
|
||||
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
|
||||
}
|
||||
if extraUsage {
|
||||
usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration
|
||||
usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing
|
||||
}
|
||||
|
||||
delta := &schema.Message{}
|
||||
if contentDelta != "" {
|
||||
delta.Content = &contentDelta
|
||||
}
|
||||
if reasoningDelta != "" {
|
||||
delta.Reasoning = &reasoningDelta
|
||||
}
|
||||
|
||||
// Usage rides as a struct field for the consumer to track the
|
||||
// running cumulative — it is stripped before JSON marshal so the
|
||||
// wire chunk stays spec-compliant (no `usage` on intermediate
|
||||
// chunks). The dedicated trailer chunk (when include_usage=true)
|
||||
// carries the final totals.
|
||||
usageForChunk := usage
|
||||
resp := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{Delta: delta, Index: 0, FinishReason: nil}},
|
||||
Object: "chat.completion.chunk",
|
||||
Usage: &usageForChunk,
|
||||
}
|
||||
|
||||
responses <- resp
|
||||
return true
|
||||
})
|
||||
close(responses)
|
||||
return err
|
||||
}
|
||||
processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool, id string, created int, textContentToReturn *string) error {
|
||||
// Detect if thinking token is already in prompt or template
|
||||
var template string
|
||||
if config.TemplateConfig.UseTokenizerTemplate {
|
||||
template = config.GetModelTemplate()
|
||||
} else {
|
||||
template = prompt
|
||||
}
|
||||
thinkingStartToken := reason.DetectThinkingStartToken(template, &config.ReasoningConfig)
|
||||
extractor := reason.NewReasoningExtractor(thinkingStartToken, config.ReasoningConfig)
|
||||
|
||||
result := ""
|
||||
lastEmittedCount := 0
|
||||
sentInitialRole := false
|
||||
sentReasoning := false
|
||||
hasChatDeltaToolCalls := false
|
||||
hasChatDeltaContent := false
|
||||
|
||||
_, _, chatDeltas, err := ComputeChoices(req, prompt, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
||||
result += s
|
||||
|
||||
// Track whether ChatDeltas from the C++ autoparser contain
|
||||
// tool calls or content, so the retry decision can account for them.
|
||||
for _, d := range usage.ChatDeltas {
|
||||
if len(d.ToolCalls) > 0 {
|
||||
hasChatDeltaToolCalls = true
|
||||
}
|
||||
if d.Content != "" {
|
||||
hasChatDeltaContent = true
|
||||
}
|
||||
}
|
||||
|
||||
var reasoningDelta, contentDelta string
|
||||
|
||||
goReasoning, goContent := extractor.ProcessToken(s)
|
||||
|
||||
if usage.HasChatDeltaContent() {
|
||||
rawReasoning, cd := usage.ChatDeltaReasoningAndContent()
|
||||
contentDelta = cd
|
||||
reasoningDelta = extractor.ProcessChatDeltaReasoning(rawReasoning)
|
||||
} else {
|
||||
reasoningDelta = goReasoning
|
||||
contentDelta = goContent
|
||||
}
|
||||
|
||||
// Emit reasoning deltas in their own SSE chunks before any tool-call chunks
|
||||
// (OpenAI spec: reasoning and tool_calls never share a delta)
|
||||
if reasoningDelta != "" {
|
||||
responses <- schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model,
|
||||
Choices: []schema.Choice{{
|
||||
Delta: &schema.Message{Reasoning: &reasoningDelta},
|
||||
Index: 0,
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
sentReasoning = true
|
||||
}
|
||||
|
||||
// Stream content deltas (cleaned of reasoning tags) while no tool calls
|
||||
// have been detected. Once the incremental parser finds tool calls,
|
||||
// content stops — per OpenAI spec, content and tool_calls don't mix.
|
||||
if lastEmittedCount == 0 && contentDelta != "" {
|
||||
if !sentInitialRole {
|
||||
responses <- schema.OpenAIResponse{
|
||||
ID: id, Created: created, Model: req.Model,
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
sentInitialRole = true
|
||||
}
|
||||
responses <- schema.OpenAIResponse{
|
||||
ID: id, Created: created, Model: req.Model,
|
||||
Choices: []schema.Choice{{
|
||||
Delta: &schema.Message{Content: &contentDelta},
|
||||
Index: 0,
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
}
|
||||
|
||||
// Try incremental XML parsing for streaming support using iterative parser
|
||||
// This allows emitting partial tool calls as they're being generated
|
||||
cleanedResult := functions.CleanupLLMResult(result, config.FunctionsConfig)
|
||||
|
||||
// Determine XML format from config
|
||||
var xmlFormat *functions.XMLToolCallFormat
|
||||
if config.FunctionsConfig.XMLFormat != nil {
|
||||
xmlFormat = config.FunctionsConfig.XMLFormat
|
||||
} else if config.FunctionsConfig.XMLFormatPreset != "" {
|
||||
xmlFormat = functions.GetXMLFormatPreset(config.FunctionsConfig.XMLFormatPreset)
|
||||
}
|
||||
|
||||
// Use iterative parser for streaming (partial parsing enabled)
|
||||
// Try XML parsing first
|
||||
partialResults, parseErr := functions.ParseXMLIterative(cleanedResult, xmlFormat, true)
|
||||
if parseErr == nil && len(partialResults) > 0 {
|
||||
// Emit new XML tool calls that weren't emitted before
|
||||
if len(partialResults) > lastEmittedCount {
|
||||
for i := lastEmittedCount; i < len(partialResults); i++ {
|
||||
toolCall := partialResults[i]
|
||||
initialMessage := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model,
|
||||
Choices: []schema.Choice{{
|
||||
Delta: &schema.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []schema.ToolCall{
|
||||
{
|
||||
Index: i,
|
||||
ID: id,
|
||||
Type: "function",
|
||||
FunctionCall: schema.FunctionCall{
|
||||
Name: toolCall.Name,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Index: 0,
|
||||
FinishReason: nil,
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
select {
|
||||
case responses <- initialMessage:
|
||||
default:
|
||||
}
|
||||
}
|
||||
lastEmittedCount = len(partialResults)
|
||||
}
|
||||
} else {
|
||||
// Try JSON tool call parsing for streaming.
|
||||
// Only emit NEW tool calls (same guard as XML parser above).
|
||||
jsonResults, jsonErr := functions.ParseJSONIterative(cleanedResult, true)
|
||||
if jsonErr == nil && len(jsonResults) > lastEmittedCount {
|
||||
for i := lastEmittedCount; i < len(jsonResults); i++ {
|
||||
jsonObj := jsonResults[i]
|
||||
name, ok := jsonObj["name"].(string)
|
||||
if !ok || name == "" {
|
||||
continue
|
||||
}
|
||||
args := "{}"
|
||||
if argsVal, ok := jsonObj["arguments"]; ok {
|
||||
if argsStr, ok := argsVal.(string); ok {
|
||||
args = argsStr
|
||||
} else {
|
||||
argsBytes, _ := json.Marshal(argsVal)
|
||||
args = string(argsBytes)
|
||||
}
|
||||
}
|
||||
initialMessage := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model,
|
||||
Choices: []schema.Choice{{
|
||||
Delta: &schema.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []schema.ToolCall{
|
||||
{
|
||||
Index: i,
|
||||
ID: id,
|
||||
Type: "function",
|
||||
FunctionCall: schema.FunctionCall{
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Index: 0,
|
||||
FinishReason: nil,
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
responses <- initialMessage
|
||||
}
|
||||
lastEmittedCount = len(jsonResults)
|
||||
}
|
||||
}
|
||||
return true
|
||||
},
|
||||
func(attempt int) bool {
|
||||
// After streaming completes: check if we got actionable content
|
||||
cleaned := extractor.CleanedContent()
|
||||
// Check for tool calls from chat deltas (will be re-checked after ComputeChoices,
|
||||
// but we need to know here whether to retry).
|
||||
// Also check ChatDelta flags — when the C++ autoparser is active,
|
||||
// tool calls and content are delivered via ChatDeltas while the
|
||||
// raw message is cleared. Without this check, we'd retry
|
||||
// unnecessarily, losing valid results and concatenating output.
|
||||
hasToolCalls := lastEmittedCount > 0 || hasChatDeltaToolCalls
|
||||
hasContent := cleaned != "" || hasChatDeltaContent
|
||||
if !hasContent && !hasToolCalls {
|
||||
xlog.Warn("Streaming: backend produced only reasoning, retrying",
|
||||
"reasoning_len", len(extractor.Reasoning()), "attempt", attempt+1)
|
||||
extractor.ResetAndSuppressReasoning()
|
||||
result = ""
|
||||
lastEmittedCount = 0
|
||||
sentInitialRole = false
|
||||
hasChatDeltaToolCalls = false
|
||||
hasChatDeltaContent = false
|
||||
return true
|
||||
}
|
||||
return false
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Try using pre-parsed tool calls from C++ autoparser (chat deltas)
|
||||
var functionResults []functions.FuncCallResults
|
||||
var reasoning string
|
||||
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 {
|
||||
xlog.Debug("[ChatDeltas] Using pre-parsed tool calls from C++ autoparser", "count", len(deltaToolCalls))
|
||||
functionResults = deltaToolCalls
|
||||
// Use content/reasoning from deltas too
|
||||
*textContentToReturn = functions.ContentFromChatDeltas(chatDeltas)
|
||||
reasoning = functions.ReasoningFromChatDeltas(chatDeltas)
|
||||
} else {
|
||||
// Fallback: parse tool calls from raw text (no chat deltas from backend)
|
||||
xlog.Debug("[ChatDeltas] no pre-parsed tool calls, falling back to Go-side text parsing")
|
||||
reasoning = extractor.Reasoning()
|
||||
cleanedResult := extractor.CleanedContent()
|
||||
*textContentToReturn = functions.ParseTextContent(cleanedResult, config.FunctionsConfig)
|
||||
cleanedResult = functions.CleanupLLMResult(cleanedResult, config.FunctionsConfig)
|
||||
functionResults = functions.ParseFunctionCall(cleanedResult, config.FunctionsConfig)
|
||||
}
|
||||
xlog.Debug("[ChatDeltas] final tool call decision", "tool_calls", len(functionResults), "text_content", *textContentToReturn)
|
||||
// noAction is a sentinel "just answer" pseudo-function — not a real
|
||||
// tool call. Scan the whole slice rather than only index 0 so we
|
||||
// don't drop a real tool call that happens to follow a noAction
|
||||
// entry, and so the default branch isn't entered with only noAction
|
||||
// entries to emit as tool_calls.
|
||||
noActionToRun := !hasRealCall(functionResults, noAction)
|
||||
|
||||
switch {
|
||||
case noActionToRun:
|
||||
// Token-cumulative usage is communicated to the streaming
|
||||
// consumer via the per-token callback's chunk struct (stripped
|
||||
// before wire marshal). The final usage trailer — when the
|
||||
// caller opted in with stream_options.include_usage — is built
|
||||
// by the outer streaming loop, not here.
|
||||
var result string
|
||||
if !sentInitialRole {
|
||||
var hqErr error
|
||||
result, hqErr = handleQuestion(config, functionResults, extractor.CleanedContent(), prompt)
|
||||
if hqErr != nil {
|
||||
xlog.Error("error handling question", "error", hqErr)
|
||||
return hqErr
|
||||
}
|
||||
}
|
||||
for _, chunk := range buildNoActionFinalChunks(
|
||||
id, req.Model, created,
|
||||
sentInitialRole, sentReasoning,
|
||||
result, reasoning,
|
||||
) {
|
||||
responses <- chunk
|
||||
}
|
||||
|
||||
default:
|
||||
for _, chunk := range buildDeferredToolCallChunks(
|
||||
id, req.Model, created,
|
||||
functionResults, lastEmittedCount,
|
||||
sentInitialRole, *textContentToReturn,
|
||||
sentReasoning, reasoning,
|
||||
) {
|
||||
responses <- chunk
|
||||
}
|
||||
}
|
||||
|
||||
close(responses)
|
||||
return err
|
||||
}
|
||||
|
||||
return func(c echo.Context) error {
|
||||
var textContentToReturn string
|
||||
id := uuid.New().String()
|
||||
@@ -682,6 +325,12 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
c.Response().Header().Set("Cache-Control", "no-cache")
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
c.Response().Header().Set("X-Correlation-ID", id)
|
||||
// X-LocalAI-Node attribution (when --expose-node-header is on)
|
||||
// is handled by middleware.ExposeNodeHeader at the wrapper
|
||||
// layer: the first c.Response().Write / Flush lazily reads the
|
||||
// node ID from the loader (post-ml.Load) and stamps the header
|
||||
// before the byte hits the underlying writer. No per-request
|
||||
// chan / per-handler plumbing needed here.
|
||||
|
||||
mcpStreamMaxIterations := 10
|
||||
if config.Agent.MaxIterations > 0 {
|
||||
@@ -697,17 +346,19 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
}
|
||||
|
||||
responses := make(chan schema.OpenAIResponse)
|
||||
ended := make(chan error, 1)
|
||||
ended := make(chan streamWorkerResult, 1)
|
||||
|
||||
go func() {
|
||||
if !shouldUseFn {
|
||||
ended <- process(predInput, input, config, ml, responses, extraUsage, id, created)
|
||||
u, err := processStream(predInput, input, config, cl, startupOptions, ml, responses, id, created)
|
||||
ended <- streamWorkerResult{usage: u, err: err}
|
||||
} else {
|
||||
ended <- processTools(noActionName, predInput, input, config, ml, responses, extraUsage, id, created, &textContentToReturn)
|
||||
u, err := processStreamWithTools(noActionName, predInput, input, config, cl, startupOptions, ml, responses, id, created, &textContentToReturn)
|
||||
ended <- streamWorkerResult{usage: u, err: err}
|
||||
}
|
||||
}()
|
||||
|
||||
usage := &schema.OpenAIUsage{}
|
||||
var finalUsage backend.TokenUsage
|
||||
toolsCalled := false
|
||||
var collectedToolCalls []schema.ToolCall
|
||||
var collectedContent string
|
||||
@@ -725,13 +376,6 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
xlog.Debug("No choices in the response, skipping")
|
||||
continue
|
||||
}
|
||||
// Capture the running cumulative usage from this chunk
|
||||
// (when present) so the include_usage trailer can carry
|
||||
// the final totals. Usage is stripped before marshal
|
||||
// below so the wire chunk stays spec-compliant.
|
||||
if ev.Usage != nil {
|
||||
usage = ev.Usage
|
||||
}
|
||||
if len(ev.Choices[0].Delta.ToolCalls) > 0 {
|
||||
toolsCalled = true
|
||||
// Collect and merge tool call deltas for MCP execution
|
||||
@@ -747,11 +391,6 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
collectedContent += *sp
|
||||
}
|
||||
}
|
||||
// OpenAI streaming spec: intermediate chunks must NOT
|
||||
// carry a `usage` field. Strip the tracking copy
|
||||
// before marshalling — usage is delivered via the
|
||||
// dedicated trailer chunk when include_usage=true.
|
||||
ev.Usage = nil
|
||||
respData, err := json.Marshal(ev)
|
||||
if err != nil {
|
||||
xlog.Debug("Failed to marshal response", "error", err)
|
||||
@@ -766,15 +405,16 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
return err
|
||||
}
|
||||
c.Response().Flush()
|
||||
case err := <-ended:
|
||||
if err == nil {
|
||||
case res := <-ended:
|
||||
if res.err == nil {
|
||||
finalUsage = res.usage
|
||||
break LOOP
|
||||
}
|
||||
xlog.Error("Stream ended with error", "error", err)
|
||||
xlog.Error("Stream ended with error", "error", res.err)
|
||||
|
||||
errorResp := schema.ErrorResponse{
|
||||
Error: &schema.APIError{
|
||||
Message: err.Error(),
|
||||
Message: res.err.Error(),
|
||||
Type: "server_error",
|
||||
Code: "server_error",
|
||||
},
|
||||
@@ -797,7 +437,10 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
// still trying to send (e.g., after client disconnect). The goroutine
|
||||
// calls close(responses) when done, which terminates the drain.
|
||||
if input.Context.Err() != nil {
|
||||
go func() { for range responses {} }()
|
||||
go func() {
|
||||
for range responses {
|
||||
}
|
||||
}()
|
||||
<-ended
|
||||
}
|
||||
|
||||
@@ -921,8 +564,16 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
// Trailing usage chunk per OpenAI spec: emit only when the
|
||||
// caller opted in via stream_options.include_usage. Shape:
|
||||
// {"choices":[],"usage":{...},"object":"chat.completion.chunk",...}
|
||||
if input.StreamOptions != nil && input.StreamOptions.IncludeUsage && usage != nil {
|
||||
trailer := streamUsageTrailerJSON(id, input.Model, created, *usage)
|
||||
//
|
||||
// finalUsage is the authoritative TokenUsage returned by the
|
||||
// worker function (process / processTools) via the `ended`
|
||||
// channel. The worker reads it from ComputeChoices' return
|
||||
// value, which is the cumulative count produced by the backend
|
||||
// over the whole prediction. Issue #9927 was caused by the
|
||||
// tools-path worker not surfacing this value at all.
|
||||
if input.StreamOptions != nil && input.StreamOptions.IncludeUsage {
|
||||
trailerUsage := streamUsageFromTokenUsage(finalUsage, extraUsage)
|
||||
trailer := streamUsageTrailerJSON(id, input.Model, created, trailerUsage)
|
||||
_, _ = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", trailer)
|
||||
}
|
||||
|
||||
@@ -932,7 +583,10 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
return nil
|
||||
} // end MCP stream iteration loop
|
||||
|
||||
// Safety fallback
|
||||
// Safety fallback. The MCP iteration loop above always returns,
|
||||
// so this is structurally unreachable; if we ever reach it the
|
||||
// stream is closed cleanly. The middleware-installed wrapper
|
||||
// still stamps X-LocalAI-Node on this final write if applicable.
|
||||
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
|
||||
c.Response().Flush()
|
||||
return nil
|
||||
@@ -1290,6 +944,10 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
respData, _ := json.Marshal(resp)
|
||||
xlog.Debug("Response", "response", string(respData))
|
||||
|
||||
// X-LocalAI-Node attribution (when --expose-node-header is on)
|
||||
// is handled by middleware.ExposeNodeHeader at the wrapper
|
||||
// layer; c.JSON's writes trigger the lazy stamp.
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(200, resp)
|
||||
} // end MCP iteration loop
|
||||
|
||||
@@ -4,10 +4,39 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
)
|
||||
|
||||
// streamWorkerResult is what the streaming workers (process / processTools)
|
||||
// hand back to the outer ChatEndpoint loop through the `ended` channel.
|
||||
// Threading the final TokenUsage here, instead of piggy-backing it on the
|
||||
// `responses` SSE channel, keeps the SSE channel single-purpose (wire chunks)
|
||||
// and gives the trailer emitter a plain Go value to read after LOOP exits.
|
||||
// Fix for issue #9927: the previous tools-path worker never surfaced the
|
||||
// cumulative token counts at all, so the include_usage trailer reported zeros.
|
||||
type streamWorkerResult struct {
|
||||
usage backend.TokenUsage
|
||||
err error
|
||||
}
|
||||
|
||||
// streamUsageFromTokenUsage converts the backend's cumulative TokenUsage into
|
||||
// the OpenAI-spec OpenAIUsage shape used on the wire. `extraUsage` controls
|
||||
// whether the non-standard timing fields are forwarded.
|
||||
func streamUsageFromTokenUsage(usage backend.TokenUsage, extraUsage bool) schema.OpenAIUsage {
|
||||
out := schema.OpenAIUsage{
|
||||
PromptTokens: usage.Prompt,
|
||||
CompletionTokens: usage.Completion,
|
||||
TotalTokens: usage.Prompt + usage.Completion,
|
||||
}
|
||||
if extraUsage {
|
||||
out.TimingTokenGeneration = usage.TimingTokenGeneration
|
||||
out.TimingPromptProcessing = usage.TimingPromptProcessing
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// streamUsageTrailerJSON returns the bytes of the OpenAI-spec trailing usage
|
||||
// chunk emitted in streaming completions when the request opts in via
|
||||
// `stream_options.include_usage: true`. The shape is:
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
@@ -152,6 +156,28 @@ var _ = Describe("streaming usage spec compliance", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Describe("streamUsageFromTokenUsage", func() {
|
||||
It("converts backend TokenUsage to schema OpenAIUsage", func() {
|
||||
tu := backend.TokenUsage{Prompt: 18, Completion: 213}
|
||||
u := streamUsageFromTokenUsage(tu, false)
|
||||
Expect(u.PromptTokens).To(Equal(18))
|
||||
Expect(u.CompletionTokens).To(Equal(213))
|
||||
Expect(u.TotalTokens).To(Equal(231))
|
||||
Expect(u.TimingTokenGeneration).To(BeZero())
|
||||
Expect(u.TimingPromptProcessing).To(BeZero())
|
||||
})
|
||||
It("includes timings when extraUsage is true", func() {
|
||||
tu := backend.TokenUsage{
|
||||
Prompt: 10, Completion: 20,
|
||||
TimingPromptProcessing: 0.5,
|
||||
TimingTokenGeneration: 1.5,
|
||||
}
|
||||
u := streamUsageFromTokenUsage(tu, true)
|
||||
Expect(u.TimingPromptProcessing).To(Equal(0.5))
|
||||
Expect(u.TimingTokenGeneration).To(Equal(1.5))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("OpenAIRequest.StreamOptions", func() {
|
||||
It("parses stream_options.include_usage=true", func() {
|
||||
body := []byte(`{
|
||||
@@ -177,3 +203,160 @@ var _ = Describe("streaming usage spec compliance", func() {
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// Functional regression coverage for issue #9927: the streaming workers
|
||||
// must surface the cumulative TokenUsage returned by ComputeChoices to
|
||||
// their caller. The earlier broken implementations discarded that value
|
||||
// (`_, _, chatDeltas, err := ComputeChoices(...)`) and threw away the
|
||||
// counts on the floor, so the include_usage trailer always reported
|
||||
// zeros when tools were enabled.
|
||||
//
|
||||
// These tests stub backend.ModelInferenceFunc so the worker exercises the
|
||||
// real ComputeChoices → predFunc → LLMResponse pipeline. If a future change
|
||||
// drops the TokenUsage somewhere along that path, the assertions on the
|
||||
// returned value fail with a concrete count mismatch (e.g. 0 vs 213),
|
||||
// not with a "function undefined" compile error.
|
||||
var _ = Describe("streaming workers surface final TokenUsage (issue #9927)", func() {
|
||||
var (
|
||||
origInference modelInferenceFunc
|
||||
appCfg *config.ApplicationConfig
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
origInference = backend.ModelInferenceFunc
|
||||
appCfg = config.NewApplicationConfig()
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
backend.ModelInferenceFunc = origInference
|
||||
})
|
||||
|
||||
// mockBackendUsage installs a stub backend that yields one LLMResponse
|
||||
// carrying the supplied TokenUsage. ComputeChoices' single-attempt path
|
||||
// copies these counts into the value it returns to the worker.
|
||||
mockBackendUsage := func(usage backend.TokenUsage, response string) {
|
||||
backend.ModelInferenceFunc = func(
|
||||
ctx context.Context, s string, messages schema.Messages,
|
||||
images, videos, audios []string,
|
||||
loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader,
|
||||
o *config.ApplicationConfig,
|
||||
tokenCallback func(string, backend.TokenUsage) bool,
|
||||
tools, toolChoice string,
|
||||
logprobs, topLogprobs *int,
|
||||
logitBias map[string]float64,
|
||||
metadata map[string]string,
|
||||
) (func() (backend.LLMResponse, error), error) {
|
||||
return func() (backend.LLMResponse, error) {
|
||||
return backend.LLMResponse{
|
||||
Response: response,
|
||||
Usage: usage,
|
||||
}, nil
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
makeReq := func() *schema.OpenAIRequest {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
req := &schema.OpenAIRequest{
|
||||
Context: ctx,
|
||||
Cancel: cancel,
|
||||
}
|
||||
req.Model = "test-model" // promoted from BasicModelRequest
|
||||
return req
|
||||
}
|
||||
|
||||
// drainResponses consumes everything the worker pushes onto the channel
|
||||
// so the worker is never blocked on its send. The channel is unbuffered
|
||||
// (matching production), so the drain goroutine must be running before
|
||||
// the worker is called.
|
||||
drainResponses := func(ch <-chan schema.OpenAIResponse) <-chan struct{} {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
for range ch {
|
||||
}
|
||||
close(done)
|
||||
}()
|
||||
return done
|
||||
}
|
||||
|
||||
Describe("processStream (no-tools path)", func() {
|
||||
It("returns the cumulative TokenUsage produced by the backend", func() {
|
||||
mockBackendUsage(backend.TokenUsage{Prompt: 18, Completion: 213}, "Hello there")
|
||||
|
||||
req := makeReq()
|
||||
cfg := &config.ModelConfig{}
|
||||
responses := make(chan schema.OpenAIResponse)
|
||||
done := drainResponses(responses)
|
||||
|
||||
actual, err := processStream("prompt", req, cfg, nil, appCfg, nil, responses, "req-1", 0)
|
||||
<-done
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(actual.Prompt).To(Equal(18),
|
||||
"prompt tokens must round-trip from backend through processStream")
|
||||
Expect(actual.Completion).To(Equal(213),
|
||||
"completion tokens must round-trip from backend through processStream")
|
||||
})
|
||||
|
||||
It("returns zero TokenUsage when the backend reports zero (negative control)", func() {
|
||||
mockBackendUsage(backend.TokenUsage{}, "x")
|
||||
|
||||
req := makeReq()
|
||||
cfg := &config.ModelConfig{}
|
||||
responses := make(chan schema.OpenAIResponse)
|
||||
done := drainResponses(responses)
|
||||
|
||||
actual, err := processStream("prompt", req, cfg, nil, appCfg, nil, responses, "req-1", 0)
|
||||
<-done
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(actual.Prompt).To(BeZero())
|
||||
Expect(actual.Completion).To(BeZero())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("processStreamWithTools (tools path)", func() {
|
||||
It("returns the cumulative TokenUsage produced by the backend", func() {
|
||||
// This is the direct regression check for issue #9927: with tools
|
||||
// enabled, the trailer was reporting {0,0,0} because the worker
|
||||
// discarded ComputeChoices' second return value.
|
||||
mockBackendUsage(backend.TokenUsage{Prompt: 18, Completion: 213}, "answer")
|
||||
|
||||
req := makeReq()
|
||||
cfg := &config.ModelConfig{}
|
||||
responses := make(chan schema.OpenAIResponse)
|
||||
done := drainResponses(responses)
|
||||
var textContent string
|
||||
|
||||
actual, err := processStreamWithTools("none", "prompt", req, cfg, nil, appCfg, nil, responses, "req-1", 0, &textContent)
|
||||
<-done
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(actual.Prompt).To(Equal(18),
|
||||
"prompt tokens must round-trip from backend through processStreamWithTools (issue #9927)")
|
||||
Expect(actual.Completion).To(Equal(213),
|
||||
"completion tokens must round-trip from backend through processStreamWithTools (issue #9927)")
|
||||
})
|
||||
|
||||
It("forwards timing fields when the backend supplies them", func() {
|
||||
mockBackendUsage(backend.TokenUsage{
|
||||
Prompt: 10, Completion: 20,
|
||||
TimingPromptProcessing: 0.5,
|
||||
TimingTokenGeneration: 1.5,
|
||||
}, "answer")
|
||||
|
||||
req := makeReq()
|
||||
cfg := &config.ModelConfig{}
|
||||
responses := make(chan schema.OpenAIResponse)
|
||||
done := drainResponses(responses)
|
||||
var textContent string
|
||||
|
||||
actual, err := processStreamWithTools("none", "prompt", req, cfg, nil, appCfg, nil, responses, "req-1", 0, &textContent)
|
||||
<-done
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(actual.TimingPromptProcessing).To(Equal(0.5))
|
||||
Expect(actual.TimingTokenGeneration).To(Equal(1.5))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
405
core/http/endpoints/openai/chat_stream_workers.go
Normal file
405
core/http/endpoints/openai/chat_stream_workers.go
Normal file
@@ -0,0 +1,405 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
reason "github.com/mudler/LocalAI/pkg/reasoning"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// processStream is the streaming worker for chat completions with no
|
||||
// tool/function calling involved. It pushes SSE-shaped chunks onto
|
||||
// `responses` and returns the authoritative cumulative TokenUsage from
|
||||
// the prediction so the caller can populate the include_usage trailer
|
||||
// without having to peek inside the chunks.
|
||||
//
|
||||
// The caller owns the `responses` channel and is expected to read from
|
||||
// it while this function runs; processStream closes the channel before
|
||||
// returning.
|
||||
//
|
||||
// X-LocalAI-Node attribution (when --expose-node-header is on) is
|
||||
// handled by middleware.ExposeNodeHeader at the response writer wrapper
|
||||
// layer; no in-band signal from the worker is needed. The initial
|
||||
// role=assistant chunk is still emitted from the first token callback
|
||||
// rather than eagerly here, so the wrapper's lazy lookup against the
|
||||
// loader runs AFTER ml.Load has stamped the per-modelID node ID.
|
||||
func processStream(
|
||||
s string,
|
||||
req *schema.OpenAIRequest,
|
||||
cfg *config.ModelConfig,
|
||||
cl *config.ModelConfigLoader,
|
||||
startupOptions *config.ApplicationConfig,
|
||||
loader *model.ModelLoader,
|
||||
responses chan schema.OpenAIResponse,
|
||||
id string,
|
||||
created int,
|
||||
) (backend.TokenUsage, error) {
|
||||
sentInitialRole := false
|
||||
|
||||
// Detect if thinking token is already in prompt or template
|
||||
// When UseTokenizerTemplate is enabled, predInput is empty, so we check the template
|
||||
var template string
|
||||
if cfg.TemplateConfig.UseTokenizerTemplate {
|
||||
template = cfg.GetModelTemplate()
|
||||
} else {
|
||||
template = s
|
||||
}
|
||||
thinkingStartToken := reason.DetectThinkingStartToken(template, &cfg.ReasoningConfig)
|
||||
extractor := reason.NewReasoningExtractor(thinkingStartToken, cfg.ReasoningConfig)
|
||||
|
||||
_, finalUsage, _, err := ComputeChoices(req, s, cfg, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool {
|
||||
var reasoningDelta, contentDelta string
|
||||
|
||||
// Always keep the Go-side extractor in sync with raw tokens so it
|
||||
// can serve as fallback for backends without an autoparser (e.g. vLLM).
|
||||
goReasoning, goContent := extractor.ProcessToken(s)
|
||||
|
||||
// When C++ autoparser chat deltas are available, prefer them: they
|
||||
// handle model-specific formats (Gemma 4, etc.) without Go-side tags.
|
||||
// Otherwise fall back to Go-side extraction.
|
||||
if tokenUsage.HasChatDeltaContent() {
|
||||
rawReasoning, cd := tokenUsage.ChatDeltaReasoningAndContent()
|
||||
contentDelta = cd
|
||||
reasoningDelta = extractor.ProcessChatDeltaReasoning(rawReasoning)
|
||||
} else {
|
||||
reasoningDelta = goReasoning
|
||||
contentDelta = goContent
|
||||
}
|
||||
|
||||
if !sentInitialRole {
|
||||
sentInitialRole = true
|
||||
responses <- schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0, FinishReason: nil}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
}
|
||||
|
||||
delta := &schema.Message{}
|
||||
if contentDelta != "" {
|
||||
delta.Content = &contentDelta
|
||||
}
|
||||
if reasoningDelta != "" {
|
||||
delta.Reasoning = &reasoningDelta
|
||||
}
|
||||
|
||||
responses <- schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{Delta: delta, Index: 0, FinishReason: nil}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
return true
|
||||
})
|
||||
close(responses)
|
||||
return finalUsage, err
|
||||
}
|
||||
|
||||
// processStreamWithTools is the streaming worker for chat completions
|
||||
// with tools / function calling. Same contract as processStream: pushes
|
||||
// chunks onto `responses`, closes the channel, returns the cumulative
|
||||
// TokenUsage.
|
||||
//
|
||||
// Returning the TokenUsage as a normal Go value (rather than smuggling
|
||||
// it on a sentinel chunk) is the fix for issue #9927 — the previous
|
||||
// implementation discarded the value from ComputeChoices, so the
|
||||
// include_usage trailer reported zeros whenever `tools` was in play.
|
||||
func processStreamWithTools(
|
||||
noAction string,
|
||||
prompt string,
|
||||
req *schema.OpenAIRequest,
|
||||
cfg *config.ModelConfig,
|
||||
cl *config.ModelConfigLoader,
|
||||
startupOptions *config.ApplicationConfig,
|
||||
loader *model.ModelLoader,
|
||||
responses chan schema.OpenAIResponse,
|
||||
id string,
|
||||
created int,
|
||||
textContentToReturn *string,
|
||||
) (backend.TokenUsage, error) {
|
||||
// Detect if thinking token is already in prompt or template
|
||||
var template string
|
||||
if cfg.TemplateConfig.UseTokenizerTemplate {
|
||||
template = cfg.GetModelTemplate()
|
||||
} else {
|
||||
template = prompt
|
||||
}
|
||||
thinkingStartToken := reason.DetectThinkingStartToken(template, &cfg.ReasoningConfig)
|
||||
extractor := reason.NewReasoningExtractor(thinkingStartToken, cfg.ReasoningConfig)
|
||||
|
||||
result := ""
|
||||
lastEmittedCount := 0
|
||||
sentInitialRole := false
|
||||
sentReasoning := false
|
||||
hasChatDeltaToolCalls := false
|
||||
hasChatDeltaContent := false
|
||||
|
||||
// X-LocalAI-Node attribution is handled by middleware.ExposeNodeHeader
|
||||
// at the wrapper layer; no in-band signalling from this worker.
|
||||
|
||||
_, finalUsage, chatDeltas, err := ComputeChoices(req, prompt, cfg, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
||||
result += s
|
||||
|
||||
// Track whether ChatDeltas from the C++ autoparser contain
|
||||
// tool calls or content, so the retry decision can account for them.
|
||||
for _, d := range usage.ChatDeltas {
|
||||
if len(d.ToolCalls) > 0 {
|
||||
hasChatDeltaToolCalls = true
|
||||
}
|
||||
if d.Content != "" {
|
||||
hasChatDeltaContent = true
|
||||
}
|
||||
}
|
||||
|
||||
var reasoningDelta, contentDelta string
|
||||
|
||||
goReasoning, goContent := extractor.ProcessToken(s)
|
||||
|
||||
if usage.HasChatDeltaContent() {
|
||||
rawReasoning, cd := usage.ChatDeltaReasoningAndContent()
|
||||
contentDelta = cd
|
||||
reasoningDelta = extractor.ProcessChatDeltaReasoning(rawReasoning)
|
||||
} else {
|
||||
reasoningDelta = goReasoning
|
||||
contentDelta = goContent
|
||||
}
|
||||
|
||||
// Emit reasoning deltas in their own SSE chunks before any tool-call chunks
|
||||
// (OpenAI spec: reasoning and tool_calls never share a delta)
|
||||
if reasoningDelta != "" {
|
||||
responses <- schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model,
|
||||
Choices: []schema.Choice{{
|
||||
Delta: &schema.Message{Reasoning: &reasoningDelta},
|
||||
Index: 0,
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
sentReasoning = true
|
||||
}
|
||||
|
||||
// Stream content deltas (cleaned of reasoning tags) while no tool calls
|
||||
// have been detected. Once the incremental parser finds tool calls,
|
||||
// content stops: per OpenAI spec, content and tool_calls don't mix.
|
||||
if lastEmittedCount == 0 && contentDelta != "" {
|
||||
if !sentInitialRole {
|
||||
responses <- schema.OpenAIResponse{
|
||||
ID: id, Created: created, Model: req.Model,
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
sentInitialRole = true
|
||||
}
|
||||
responses <- schema.OpenAIResponse{
|
||||
ID: id, Created: created, Model: req.Model,
|
||||
Choices: []schema.Choice{{
|
||||
Delta: &schema.Message{Content: &contentDelta},
|
||||
Index: 0,
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
}
|
||||
|
||||
// Try incremental XML parsing for streaming support using iterative parser
|
||||
// This allows emitting partial tool calls as they're being generated
|
||||
cleanedResult := functions.CleanupLLMResult(result, cfg.FunctionsConfig)
|
||||
|
||||
// Determine XML format from config
|
||||
var xmlFormat *functions.XMLToolCallFormat
|
||||
if cfg.FunctionsConfig.XMLFormat != nil {
|
||||
xmlFormat = cfg.FunctionsConfig.XMLFormat
|
||||
} else if cfg.FunctionsConfig.XMLFormatPreset != "" {
|
||||
xmlFormat = functions.GetXMLFormatPreset(cfg.FunctionsConfig.XMLFormatPreset)
|
||||
}
|
||||
|
||||
// Use iterative parser for streaming (partial parsing enabled)
|
||||
// Try XML parsing first
|
||||
partialResults, parseErr := functions.ParseXMLIterative(cleanedResult, xmlFormat, true)
|
||||
if parseErr == nil && len(partialResults) > 0 {
|
||||
// Emit new XML tool calls that weren't emitted before
|
||||
if len(partialResults) > lastEmittedCount {
|
||||
for i := lastEmittedCount; i < len(partialResults); i++ {
|
||||
toolCall := partialResults[i]
|
||||
initialMessage := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model,
|
||||
Choices: []schema.Choice{{
|
||||
Delta: &schema.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []schema.ToolCall{
|
||||
{
|
||||
Index: i,
|
||||
ID: id,
|
||||
Type: "function",
|
||||
FunctionCall: schema.FunctionCall{
|
||||
Name: toolCall.Name,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Index: 0,
|
||||
FinishReason: nil,
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
select {
|
||||
case responses <- initialMessage:
|
||||
default:
|
||||
}
|
||||
}
|
||||
lastEmittedCount = len(partialResults)
|
||||
}
|
||||
} else {
|
||||
// Try JSON tool call parsing for streaming.
|
||||
// Only emit NEW tool calls (same guard as XML parser above).
|
||||
jsonResults, jsonErr := functions.ParseJSONIterative(cleanedResult, true)
|
||||
if jsonErr == nil && len(jsonResults) > lastEmittedCount {
|
||||
for i := lastEmittedCount; i < len(jsonResults); i++ {
|
||||
jsonObj := jsonResults[i]
|
||||
name, ok := jsonObj["name"].(string)
|
||||
if !ok || name == "" {
|
||||
continue
|
||||
}
|
||||
args := "{}"
|
||||
if argsVal, ok := jsonObj["arguments"]; ok {
|
||||
if argsStr, ok := argsVal.(string); ok {
|
||||
args = argsStr
|
||||
} else {
|
||||
argsBytes, _ := json.Marshal(argsVal)
|
||||
args = string(argsBytes)
|
||||
}
|
||||
}
|
||||
initialMessage := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model,
|
||||
Choices: []schema.Choice{{
|
||||
Delta: &schema.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []schema.ToolCall{
|
||||
{
|
||||
Index: i,
|
||||
ID: id,
|
||||
Type: "function",
|
||||
FunctionCall: schema.FunctionCall{
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Index: 0,
|
||||
FinishReason: nil,
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
responses <- initialMessage
|
||||
}
|
||||
lastEmittedCount = len(jsonResults)
|
||||
}
|
||||
}
|
||||
return true
|
||||
},
|
||||
func(attempt int) bool {
|
||||
// After streaming completes: check if we got actionable content
|
||||
cleaned := extractor.CleanedContent()
|
||||
// Check for tool calls from chat deltas (will be re-checked after ComputeChoices,
|
||||
// but we need to know here whether to retry).
|
||||
// Also check ChatDelta flags: when the C++ autoparser is active,
|
||||
// tool calls and content are delivered via ChatDeltas while the
|
||||
// raw message is cleared. Without this check, we'd retry
|
||||
// unnecessarily, losing valid results and concatenating output.
|
||||
hasToolCalls := lastEmittedCount > 0 || hasChatDeltaToolCalls
|
||||
hasContent := cleaned != "" || hasChatDeltaContent
|
||||
if !hasContent && !hasToolCalls {
|
||||
xlog.Warn("Streaming: backend produced only reasoning, retrying",
|
||||
"reasoning_len", len(extractor.Reasoning()), "attempt", attempt+1)
|
||||
extractor.ResetAndSuppressReasoning()
|
||||
result = ""
|
||||
lastEmittedCount = 0
|
||||
sentInitialRole = false
|
||||
hasChatDeltaToolCalls = false
|
||||
hasChatDeltaContent = false
|
||||
return true
|
||||
}
|
||||
return false
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return finalUsage, err
|
||||
}
|
||||
// Try using pre-parsed tool calls from C++ autoparser (chat deltas)
|
||||
var functionResults []functions.FuncCallResults
|
||||
var reasoning string
|
||||
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 {
|
||||
xlog.Debug("[ChatDeltas] Using pre-parsed tool calls from C++ autoparser", "count", len(deltaToolCalls))
|
||||
functionResults = deltaToolCalls
|
||||
// Use content/reasoning from deltas too
|
||||
*textContentToReturn = functions.ContentFromChatDeltas(chatDeltas)
|
||||
reasoning = functions.ReasoningFromChatDeltas(chatDeltas)
|
||||
} else {
|
||||
// Fallback: parse tool calls from raw text (no chat deltas from backend)
|
||||
xlog.Debug("[ChatDeltas] no pre-parsed tool calls, falling back to Go-side text parsing")
|
||||
reasoning = extractor.Reasoning()
|
||||
cleanedResult := extractor.CleanedContent()
|
||||
*textContentToReturn = functions.ParseTextContent(cleanedResult, cfg.FunctionsConfig)
|
||||
cleanedResult = functions.CleanupLLMResult(cleanedResult, cfg.FunctionsConfig)
|
||||
functionResults = functions.ParseFunctionCall(cleanedResult, cfg.FunctionsConfig)
|
||||
}
|
||||
xlog.Debug("[ChatDeltas] final tool call decision", "tool_calls", len(functionResults), "text_content", *textContentToReturn)
|
||||
// noAction is a sentinel "just answer" pseudo-function: not a real
|
||||
// tool call. Scan the whole slice rather than only index 0 so we
|
||||
// don't drop a real tool call that happens to follow a noAction
|
||||
// entry, and so the default branch isn't entered with only noAction
|
||||
// entries to emit as tool_calls.
|
||||
noActionToRun := !hasRealCall(functionResults, noAction)
|
||||
|
||||
switch {
|
||||
case noActionToRun:
|
||||
// The final usage trailer (when the caller opted in with
|
||||
// stream_options.include_usage) is built by the outer streaming
|
||||
// loop from the TokenUsage this function returns, not from any
|
||||
// chunk on the responses channel.
|
||||
var result string
|
||||
if !sentInitialRole {
|
||||
var hqErr error
|
||||
result, hqErr = handleQuestion(cfg, functionResults, extractor.CleanedContent(), prompt)
|
||||
if hqErr != nil {
|
||||
xlog.Error("error handling question", "error", hqErr)
|
||||
return finalUsage, hqErr
|
||||
}
|
||||
}
|
||||
for _, chunk := range buildNoActionFinalChunks(
|
||||
id, req.Model, created,
|
||||
sentInitialRole, sentReasoning,
|
||||
result, reasoning,
|
||||
) {
|
||||
responses <- chunk
|
||||
}
|
||||
|
||||
default:
|
||||
for _, chunk := range buildDeferredToolCallChunks(
|
||||
id, req.Model, created,
|
||||
functionResults, lastEmittedCount,
|
||||
sentInitialRole, *textContentToReturn,
|
||||
sentReasoning, reasoning,
|
||||
) {
|
||||
responses <- chunk
|
||||
}
|
||||
}
|
||||
|
||||
close(responses)
|
||||
return finalUsage, err
|
||||
}
|
||||
@@ -26,6 +26,11 @@ import (
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/completions [post]
|
||||
func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
// process runs the streaming inference. X-LocalAI-Node attribution
|
||||
// (when --expose-node-header is on) is handled by
|
||||
// middleware.ExposeNodeHeader at the response writer wrapper layer:
|
||||
// the first SSE write triggers a lazy lookup against the loader, so
|
||||
// no in-band signalling is needed here.
|
||||
process := func(id string, s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error {
|
||||
tokenCallback := func(s string, tokenUsage backend.TokenUsage) bool {
|
||||
created := int(time.Now().Unix())
|
||||
@@ -106,6 +111,11 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
c.Response().Header().Set("Content-Type", "text/event-stream")
|
||||
c.Response().Header().Set("Cache-Control", "no-cache")
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
// X-LocalAI-Node attribution (when --expose-node-header is on)
|
||||
// is handled by middleware.ExposeNodeHeader at the wrapper
|
||||
// layer: the first c.Response().Write / Flush lazily reads the
|
||||
// node ID from the loader (post-ml.Load) and stamps the header
|
||||
// before the byte hits the underlying writer.
|
||||
|
||||
if len(config.PromptStrings) > 1 {
|
||||
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming")
|
||||
@@ -274,6 +284,9 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
jsonResult, _ := json.Marshal(resp)
|
||||
xlog.Debug("Response", "response", string(jsonResult))
|
||||
|
||||
// X-LocalAI-Node attribution (when --expose-node-header is on) is
|
||||
// handled by middleware.ExposeNodeHeader at the wrapper layer.
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
|
||||
@@ -102,6 +102,9 @@ func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
||||
jsonResult, _ := json.Marshal(resp)
|
||||
xlog.Debug("Response", "response", string(jsonResult))
|
||||
|
||||
// X-LocalAI-Node attribution (when --expose-node-header is on) is
|
||||
// handled by middleware.ExposeNodeHeader at the wrapper layer.
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
|
||||
131
core/http/middleware/node_header.go
Normal file
131
core/http/middleware/node_header.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
// NodeHeaderName is the HTTP response header that, when --expose-node-header
|
||||
// is enabled, carries the ID of the distributed-mode worker node that served
|
||||
// the inference request. Off by default: node IDs reveal internal topology
|
||||
// and should not be exposed on a public endpoint.
|
||||
const NodeHeaderName = "X-LocalAI-Node"
|
||||
|
||||
// nodeHeaderWriter wraps an http.ResponseWriter and stamps the X-LocalAI-Node
|
||||
// header lazily on the first Write / WriteHeader / Flush call. The lazy
|
||||
// resolve is what makes this work for streaming: the picked node ID is only
|
||||
// known AFTER ml.Load runs (i.e. on the first SSE chunk), so resolving at
|
||||
// request entry would attach the previous request's routing decision (or
|
||||
// nothing on a cold cache).
|
||||
type nodeHeaderWriter struct {
|
||||
http.ResponseWriter
|
||||
resolve func() string
|
||||
set bool
|
||||
}
|
||||
|
||||
func (w *nodeHeaderWriter) maybeSet() {
|
||||
if w.set {
|
||||
return
|
||||
}
|
||||
w.set = true
|
||||
if id := w.resolve(); id != "" {
|
||||
w.Header().Set(NodeHeaderName, id)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *nodeHeaderWriter) Write(b []byte) (int, error) {
|
||||
w.maybeSet()
|
||||
return w.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
func (w *nodeHeaderWriter) WriteHeader(code int) {
|
||||
w.maybeSet()
|
||||
w.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
// Flush keeps SSE handlers working: Echo's Response.Flush goes through
|
||||
// http.NewResponseController which walks Unwrap() chains and invokes Flush
|
||||
// on the first wrapper that implements http.Flusher. By implementing it
|
||||
// here we both stamp the header before the underlying writer flushes AND
|
||||
// keep the streaming path alive.
|
||||
func (w *nodeHeaderWriter) Flush() {
|
||||
w.maybeSet()
|
||||
if f, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// Hijack preserves WebSocket / raw-conn handlers that need to take over the
|
||||
// underlying TCP connection (e.g. /v1/realtime). Without this the wrapper
|
||||
// would silently break those endpoints.
|
||||
//
|
||||
// When the underlying writer does not implement http.Hijacker we return
|
||||
// http.ErrNotSupported so callers using errors.Is (notably
|
||||
// http.NewResponseController.Hijack) detect the condition through the
|
||||
// standard sentinel rather than a string-matched custom error.
|
||||
func (w *nodeHeaderWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if h, ok := w.ResponseWriter.(http.Hijacker); ok {
|
||||
return h.Hijack()
|
||||
}
|
||||
return nil, nil, fmt.Errorf("hijack not supported: %w", http.ErrNotSupported)
|
||||
}
|
||||
|
||||
// Unwrap lets http.NewResponseController reach through us to find optional
|
||||
// interfaces (CloseNotifier, SetReadDeadline, etc.) on the real writer.
|
||||
func (w *nodeHeaderWriter) Unwrap() http.ResponseWriter {
|
||||
return w.ResponseWriter
|
||||
}
|
||||
|
||||
// ExposeNodeHeader installs a per-request response writer wrapper that
|
||||
// stamps the X-LocalAI-Node header from the currently-loaded model's node
|
||||
// ID on the first write. Off by default; opted in via --expose-node-header
|
||||
// / LOCALAI_EXPOSE_NODE_HEADER. The model name is read from the standard
|
||||
// per-request context key set by the request-extractor middleware chain
|
||||
// (CONTEXT_LOCALS_KEY_MODEL_NAME), so any handler that goes through the
|
||||
// usual SetModelAndConfig wiring is automatically covered.
|
||||
//
|
||||
// Best-effort: under heavy concurrency for the same model across multiple
|
||||
// replicas, the header may reflect a recent routing decision rather than
|
||||
// this exact request's, because the model loader's per-modelID store entry
|
||||
// is overwritten on every routing decision. Acceptable for observability
|
||||
// and debugging.
|
||||
func ExposeNodeHeader(appCfg *config.ApplicationConfig, ml *model.ModelLoader) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if appCfg == nil || !appCfg.ExposeNodeHeader || ml == nil {
|
||||
return next(c)
|
||||
}
|
||||
orig := c.Response().Writer
|
||||
wrapper := &nodeHeaderWriter{
|
||||
ResponseWriter: orig,
|
||||
resolve: func() string {
|
||||
modelName, _ := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
if modelName == "" {
|
||||
return ""
|
||||
}
|
||||
// Pure store read - never invokes HealthCheck and
|
||||
// never acquires ml.mu, so the wrapper cannot stall
|
||||
// the response writer for the 2-minute gRPC
|
||||
// HealthCheck timeout that CheckIsLoaded can pay
|
||||
// when the recently-healthy cache window has
|
||||
// expired. The X-LocalAI-Node header is
|
||||
// best-effort observability; a stale value is
|
||||
// preferable to blocking the byte stream.
|
||||
return ml.LookupNodeID(modelName)
|
||||
},
|
||||
}
|
||||
c.Response().Writer = wrapper
|
||||
defer func() {
|
||||
c.Response().Writer = orig
|
||||
}()
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
225
core/http/middleware/node_header_integration_test.go
Normal file
225
core/http/middleware/node_header_integration_test.go
Normal file
@@ -0,0 +1,225 @@
|
||||
package middleware_test
|
||||
|
||||
// Route-level integration coverage for the X-LocalAI-Node middleware.
|
||||
//
|
||||
// What this file pins (and why a separate spec on top of the unit tests
|
||||
// in node_header_test.go):
|
||||
//
|
||||
// - The unit tests in node_header_test.go exercise the wrapper by
|
||||
// invoking `mw(handler)(c)` directly against a hand-built
|
||||
// echo.Context. That misses regressions where the contract between
|
||||
// the real Echo router and the wrapper breaks: e.g. middleware
|
||||
// installation via e.Use() loses the wrapper because the framework
|
||||
// re-decorates c.Response().Writer after middleware setup, or a
|
||||
// handler that bypasses c.Response().Writer (writing to some other
|
||||
// captured surface).
|
||||
//
|
||||
// - This spec dispatches a real HTTP request through e.ServeHTTP into
|
||||
// a streaming handler shaped like chat.go's streaming branch: set
|
||||
// SSE headers, write chunks via c.Response().Write, Flush. It
|
||||
// proves that:
|
||||
// 1. Middleware installed via e.Use() is on the writer chain
|
||||
// when the handler runs.
|
||||
// 2. The wrapper's lazy maybeSet fires on the first underlying
|
||||
// Write/Flush, so X-LocalAI-Node lands on the response map
|
||||
// BEFORE the first body byte is committed.
|
||||
// 3. The header is present in the recorded response (i.e. it
|
||||
// isn't dropped because we tried to set it post-WriteHeader).
|
||||
//
|
||||
// Out of scope (and why):
|
||||
//
|
||||
// - We do NOT wire core/http/endpoints/openai.ChatEndpoint
|
||||
// end-to-end. ChatEndpoint depends on templates.Evaluator, the
|
||||
// MCP NATS client, and the LocalAI Assistant holder; standing
|
||||
// those up just to assert header ordering is out of proportion to
|
||||
// the property under test. The handler used here mirrors
|
||||
// chat.go's streaming branch and exercises the SAME middleware →
|
||||
// c.Response().Writer → SSE write path as production. If
|
||||
// chat.go's streaming branch ever stops going through
|
||||
// c.Response().Writer (e.g. it starts using a captured raw
|
||||
// http.ResponseWriter from a different seam), this test will not
|
||||
// notice; guard that with a code review checklist on chat.go.
|
||||
//
|
||||
// - We do NOT exercise the real processStream worker here.
|
||||
// processStream lives in core/http/endpoints/openai, which itself
|
||||
// imports core/http/middleware - a regular import from middleware
|
||||
// into openai would create a cycle. processStream is independently
|
||||
// covered in core/http/endpoints/openai/chat_stream_usage_test.go;
|
||||
// the only behaviour we need at this layer is the writer-contract
|
||||
// check above, which the synthetic SSE handler reproduces faithfully.
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
)
|
||||
|
||||
// orderRecorder snapshots the X-LocalAI-Node header value AT THE MOMENT
|
||||
// the underlying writer is asked to commit each event. Any header set on
|
||||
// the response map AFTER the first write/flush is dropped on the wire,
|
||||
// so this is the ground-truth observation a real SSE client would see.
|
||||
type orderRecorder struct {
|
||||
http.ResponseWriter
|
||||
mu sync.Mutex
|
||||
events []string
|
||||
}
|
||||
|
||||
func (o *orderRecorder) record(ev string) {
|
||||
o.mu.Lock()
|
||||
defer o.mu.Unlock()
|
||||
o.events = append(o.events, ev)
|
||||
}
|
||||
|
||||
func (o *orderRecorder) snapshot() []string {
|
||||
o.mu.Lock()
|
||||
defer o.mu.Unlock()
|
||||
out := make([]string, len(o.events))
|
||||
copy(out, o.events)
|
||||
return out
|
||||
}
|
||||
|
||||
func (o *orderRecorder) WriteHeader(code int) {
|
||||
o.record(fmt.Sprintf("header:%d:node=%s", code, o.Header().Get(middleware.NodeHeaderName)))
|
||||
o.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (o *orderRecorder) Write(b []byte) (int, error) {
|
||||
o.record(fmt.Sprintf("write:node=%s", o.Header().Get(middleware.NodeHeaderName)))
|
||||
return o.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
func (o *orderRecorder) Flush() {
|
||||
o.record(fmt.Sprintf("flush:node=%s", o.Header().Get(middleware.NodeHeaderName)))
|
||||
if f, ok := o.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
var _ = Describe("ExposeNodeHeader middleware (route-level integration)", func() {
|
||||
const (
|
||||
modelID = "integration-model"
|
||||
fakeNodeID = "node-route-7"
|
||||
)
|
||||
|
||||
var (
|
||||
ml *model.ModelLoader
|
||||
appCfg *config.ApplicationConfig
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithModelPath(GinkgoT().TempDir()),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ml = model.NewModelLoader(systemState)
|
||||
|
||||
// Stamp the loader with a model entry that already has the
|
||||
// node ID set. In production the SmartRouter stamps this
|
||||
// during ml.Load before the first chunk is emitted; here we
|
||||
// pre-stamp it because the assertion is about wire ordering
|
||||
// (header-before-first-byte), not about ml.Load timing
|
||||
// (which is covered separately in pkg/model/lookup_node_id_test.go).
|
||||
m := model.NewModelWithClient(modelID, "10.0.0.1:50051", nil)
|
||||
m.SetNodeID(fakeNodeID)
|
||||
m.MarkHealthy()
|
||||
store := model.NewInMemoryModelStore()
|
||||
store.Set(modelID, m)
|
||||
ml.SetModelStore(store)
|
||||
|
||||
appCfg = config.NewApplicationConfig()
|
||||
appCfg.ExposeNodeHeader = true
|
||||
})
|
||||
|
||||
It("stamps X-LocalAI-Node before the first SSE byte via the real router + middleware chain", func() {
|
||||
// Build a real Echo router. We need the tracker to sit BELOW
|
||||
// the ExposeNodeHeader wrapper in the writer chain (so its
|
||||
// recorded snapshot reflects what bytes-on-the-wire see AFTER
|
||||
// the wrapper has had a chance to stamp the header). Install
|
||||
// the tracker via a middleware that runs BEFORE
|
||||
// ExposeNodeHeader; Echo's middleware execution order matches
|
||||
// e.Use() call order, so the first Use() wraps the OUTER
|
||||
// layer of the writer chain (i.e. the wrapper installed by
|
||||
// the second Use() wraps the tracker installed by the first).
|
||||
var (
|
||||
recorderMu sync.Mutex
|
||||
tracker *orderRecorder
|
||||
)
|
||||
e := echo.New()
|
||||
e.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
recorderMu.Lock()
|
||||
tracker = &orderRecorder{ResponseWriter: c.Response().Writer}
|
||||
c.Response().Writer = tracker
|
||||
recorderMu.Unlock()
|
||||
return next(c)
|
||||
}
|
||||
})
|
||||
e.Use(middleware.ExposeNodeHeader(appCfg, ml))
|
||||
|
||||
e.POST("/v1/chat/completions", func(c echo.Context) error {
|
||||
// Mirror SetModelAndConfig: stash the model name on the
|
||||
// per-request locals so the middleware's resolve closure
|
||||
// can pick it up. Every real chat / completion handler
|
||||
// goes through this contract.
|
||||
c.Set(middleware.CONTEXT_LOCALS_KEY_MODEL_NAME, modelID)
|
||||
|
||||
// SSE response prelude (same shape as chat.go).
|
||||
c.Response().Header().Set("Content-Type", "text/event-stream")
|
||||
c.Response().Header().Set("Cache-Control", "no-cache")
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
|
||||
// Emit a handful of SSE chunks. The very first
|
||||
// Write/Flush is what triggers the middleware
|
||||
// wrapper's maybeSet, so the X-LocalAI-Node header
|
||||
// MUST already be on the response map by the time the
|
||||
// byte is committed.
|
||||
for i := 0; i < 3; i++ {
|
||||
_, err := c.Response().Write([]byte(fmt.Sprintf("data: chunk %d\n\n", i)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.Response().Flush()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(""))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
e.ServeHTTP(rec, req)
|
||||
|
||||
recorderMu.Lock()
|
||||
Expect(tracker).ToNot(BeNil(), "handler must run and install the order recorder")
|
||||
events := tracker.snapshot()
|
||||
recorderMu.Unlock()
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(rec.Header().Get(middleware.NodeHeaderName)).To(Equal(fakeNodeID),
|
||||
"production contract: header must reach the wire on a streamed response")
|
||||
|
||||
Expect(events).ToNot(BeEmpty(),
|
||||
"expected at least one underlying-writer event from the streaming handler")
|
||||
|
||||
// The very first observed event is the moment the wrapper
|
||||
// commits to the wire. Its recorded node= value is what a
|
||||
// real HTTP client would actually see; anything that lands
|
||||
// AFTER this byte is invisible.
|
||||
first := events[0]
|
||||
Expect(first).To(ContainSubstring("node="+fakeNodeID),
|
||||
"first writer event must carry the X-LocalAI-Node header (chain: middleware.Use -> e.POST -> handler.Write/Flush); got events: %v", events)
|
||||
|
||||
// Body sanity: SSE chunks made it to the recorder.
|
||||
Expect(rec.Body.String()).To(ContainSubstring("data: chunk 0"))
|
||||
})
|
||||
})
|
||||
291
core/http/middleware/node_header_test.go
Normal file
291
core/http/middleware/node_header_test.go
Normal file
@@ -0,0 +1,291 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
)
|
||||
|
||||
// orderedWriter records the order in which header-snapshot vs body-byte
|
||||
// events happen. Used by the streaming spec to assert that the X-LocalAI-Node
|
||||
// header lands on the response BEFORE the first body byte is committed to
|
||||
// the underlying writer.
|
||||
type orderedWriter struct {
|
||||
http.ResponseWriter
|
||||
events []string
|
||||
}
|
||||
|
||||
func (o *orderedWriter) WriteHeader(code int) {
|
||||
o.events = append(o.events, "header:"+http.StatusText(code))
|
||||
o.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (o *orderedWriter) Write(b []byte) (int, error) {
|
||||
// Snapshot the X-LocalAI-Node header value AT THE INSTANT the underlying
|
||||
// writer is asked to commit bytes. This is what real HTTP clients
|
||||
// effectively observe: anything set on the header map AFTER this point
|
||||
// would be silently dropped.
|
||||
o.events = append(o.events, "write:node="+o.Header().Get(NodeHeaderName))
|
||||
return o.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
func (o *orderedWriter) Flush() {
|
||||
o.events = append(o.events, "flush:node="+o.Header().Get(NodeHeaderName))
|
||||
if f, ok := o.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
var _ = Describe("ExposeNodeHeader middleware", func() {
|
||||
const (
|
||||
modelID = "qwen3.5-0.8b"
|
||||
fakeNodeID = "node-abcdef"
|
||||
)
|
||||
|
||||
var (
|
||||
e *echo.Echo
|
||||
ml *model.ModelLoader
|
||||
appCfg *config.ApplicationConfig
|
||||
)
|
||||
|
||||
// loadModel pre-populates the loader's in-memory store with a model
|
||||
// entry whose NodeID is set to `nodeID` (or left empty). Marking the
|
||||
// model recently-healthy short-circuits the gRPC HealthCheck inside
|
||||
// CheckIsLoaded so the test does not try to dial a bogus address.
|
||||
loadModel := func(id, nodeID string) {
|
||||
m := model.NewModelWithClient(id, "10.0.0.1:50051", nil)
|
||||
if nodeID != "" {
|
||||
m.SetNodeID(nodeID)
|
||||
}
|
||||
m.MarkHealthy()
|
||||
store := model.NewInMemoryModelStore()
|
||||
store.Set(id, m)
|
||||
ml.SetModelStore(store)
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
e = echo.New()
|
||||
ml = model.NewModelLoader(&system.SystemState{})
|
||||
appCfg = &config.ApplicationConfig{}
|
||||
})
|
||||
|
||||
// run executes the middleware against a fake handler that stashes the
|
||||
// model name on the request context (the same way the
|
||||
// request-extractor middleware does in production) and then writes a
|
||||
// trivial body to trigger the wrapper. Returns the recorded response.
|
||||
run := func(handler echo.HandlerFunc) *httptest.ResponseRecorder {
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
mw := ExposeNodeHeader(appCfg, ml)
|
||||
Expect(mw(handler)(c)).To(Succeed())
|
||||
return rec
|
||||
}
|
||||
|
||||
When("ExposeNodeHeader is false", func() {
|
||||
It("does not set the X-LocalAI-Node header", func() {
|
||||
appCfg.ExposeNodeHeader = false
|
||||
loadModel(modelID, fakeNodeID)
|
||||
|
||||
rec := run(func(c echo.Context) error {
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelID)
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
Expect(rec.Header().Get(NodeHeaderName)).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("does not even install the wrapper (writer is unchanged)", func() {
|
||||
appCfg.ExposeNodeHeader = false
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
origWriter := c.Response().Writer
|
||||
|
||||
handler := func(c echo.Context) error {
|
||||
// Pass-through must leave the writer identity intact so
|
||||
// no overhead is added on the hot path when the feature
|
||||
// is off.
|
||||
Expect(c.Response().Writer).To(BeIdenticalTo(origWriter))
|
||||
return c.String(http.StatusOK, "ok")
|
||||
}
|
||||
mw := ExposeNodeHeader(appCfg, ml)
|
||||
Expect(mw(handler)(c)).To(Succeed())
|
||||
})
|
||||
})
|
||||
|
||||
When("ExposeNodeHeader is true and the model is loaded with a node ID", func() {
|
||||
It("sets the X-LocalAI-Node header on a buffered response", func() {
|
||||
appCfg.ExposeNodeHeader = true
|
||||
loadModel(modelID, fakeNodeID)
|
||||
|
||||
rec := run(func(c echo.Context) error {
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelID)
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
Expect(rec.Header().Get(NodeHeaderName)).To(Equal(fakeNodeID))
|
||||
})
|
||||
|
||||
It("sets the header even on a 500 error response (Write still triggers maybeSet)", func() {
|
||||
appCfg.ExposeNodeHeader = true
|
||||
loadModel(modelID, fakeNodeID)
|
||||
|
||||
rec := run(func(c echo.Context) error {
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelID)
|
||||
return c.String(http.StatusInternalServerError, "boom")
|
||||
})
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusInternalServerError))
|
||||
Expect(rec.Header().Get(NodeHeaderName)).To(Equal(fakeNodeID))
|
||||
})
|
||||
})
|
||||
|
||||
When("ExposeNodeHeader is true but no model is loaded for the request", func() {
|
||||
It("does not set the header (cold cache stays silent)", func() {
|
||||
appCfg.ExposeNodeHeader = true
|
||||
// No model loaded.
|
||||
|
||||
rec := run(func(c echo.Context) error {
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelID)
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
Expect(rec.Header().Get(NodeHeaderName)).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
When("ExposeNodeHeader is true and the model is loaded but has no node ID", func() {
|
||||
It("does not set the header (in-process model, not distributed)", func() {
|
||||
appCfg.ExposeNodeHeader = true
|
||||
loadModel(modelID, "") // local model: no node ID stamped
|
||||
|
||||
rec := run(func(c echo.Context) error {
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelID)
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
Expect(rec.Header().Get(NodeHeaderName)).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
When("ExposeNodeHeader is true but no model name is stashed on the context", func() {
|
||||
It("does not set the header (handler did not opt in)", func() {
|
||||
appCfg.ExposeNodeHeader = true
|
||||
loadModel(modelID, fakeNodeID)
|
||||
|
||||
rec := run(func(c echo.Context) error {
|
||||
// Intentionally skip the c.Set call.
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
Expect(rec.Header().Get(NodeHeaderName)).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
When("the handler streams via Flush before any Write", func() {
|
||||
It("sets the header BEFORE the first byte hits the underlying writer", func() {
|
||||
appCfg.ExposeNodeHeader = true
|
||||
loadModel(modelID, fakeNodeID)
|
||||
|
||||
// Wrap the recorder with an order-tracking writer so we can
|
||||
// assert that the header is on the response map by the time
|
||||
// the first body byte is committed. This is the property
|
||||
// that protected the pre-refactor streaming bug: if the
|
||||
// wrapper stamped lazily but AFTER the byte commit, real
|
||||
// SSE clients would see the body without the header.
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
tracker := &orderedWriter{ResponseWriter: rec}
|
||||
c := e.NewContext(req, rec)
|
||||
c.Response().Writer = tracker
|
||||
|
||||
handler := func(c echo.Context) error {
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelID)
|
||||
// Simulate an SSE handler: flush headers, then write a
|
||||
// chunk and flush again. The wrapper must stamp the
|
||||
// node ID on the first call - either Flush or Write,
|
||||
// whichever comes first.
|
||||
c.Response().Header().Set("Content-Type", "text/event-stream")
|
||||
c.Response().Flush()
|
||||
_, err := c.Response().Write([]byte("data: chunk\n\n"))
|
||||
return err
|
||||
}
|
||||
|
||||
mw := ExposeNodeHeader(appCfg, ml)
|
||||
Expect(mw(handler)(c)).To(Succeed())
|
||||
|
||||
// First recorded event on the underlying writer must show
|
||||
// the header already populated. The first event is either
|
||||
// flush or write; either way the node ID must be on it.
|
||||
Expect(tracker.events).ToNot(BeEmpty())
|
||||
Expect(tracker.events[0]).To(HavePrefix("flush:node=" + fakeNodeID))
|
||||
Expect(rec.Header().Get(NodeHeaderName)).To(Equal(fakeNodeID))
|
||||
})
|
||||
})
|
||||
|
||||
When("the handler writes a body without an explicit WriteHeader", func() {
|
||||
It("still stamps the header before the implicit 200 commit", func() {
|
||||
appCfg.ExposeNodeHeader = true
|
||||
loadModel(modelID, fakeNodeID)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
tracker := &orderedWriter{ResponseWriter: rec}
|
||||
c := e.NewContext(req, rec)
|
||||
c.Response().Writer = tracker
|
||||
|
||||
handler := func(c echo.Context) error {
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelID)
|
||||
_, err := c.Response().Write([]byte("body"))
|
||||
return err
|
||||
}
|
||||
|
||||
mw := ExposeNodeHeader(appCfg, ml)
|
||||
Expect(mw(handler)(c)).To(Succeed())
|
||||
|
||||
// Echo's Response.Write calls WriteHeader on the underlying
|
||||
// writer first, then Write. Both must see the header
|
||||
// already populated (the wrapper's maybeSet ran inside both
|
||||
// WriteHeader and Write before they hit `tracker`).
|
||||
Expect(len(tracker.events)).To(BeNumerically(">=", 2))
|
||||
Expect(tracker.events[0]).To(HavePrefix("header:"))
|
||||
Expect(tracker.events[1]).To(Equal("write:node=" + fakeNodeID))
|
||||
Expect(rec.Header().Get(NodeHeaderName)).To(Equal(fakeNodeID))
|
||||
})
|
||||
})
|
||||
|
||||
When("the model's node ID changes between request entry and first write", func() {
|
||||
It("uses the value present AT the first write (late binding)", func() {
|
||||
appCfg.ExposeNodeHeader = true
|
||||
loadModel(modelID, "stale-node-A")
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := func(c echo.Context) error {
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelID)
|
||||
// Simulate ml.Load running mid-request and re-stamping
|
||||
// the model with this request's actual routing decision.
|
||||
m := ml.CheckIsLoaded(modelID)
|
||||
Expect(m).ToNot(BeNil())
|
||||
m.SetNodeID("fresh-node-B")
|
||||
return c.String(http.StatusOK, "ok")
|
||||
}
|
||||
|
||||
mw := ExposeNodeHeader(appCfg, ml)
|
||||
Expect(mw(handler)(c)).To(Succeed())
|
||||
|
||||
Expect(rec.Header().Get(NodeHeaderName)).To(Equal("fresh-node-B"),
|
||||
"the wrapper must read the node ID lazily at first write, not eagerly at entry")
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -17,16 +17,20 @@ import (
|
||||
)
|
||||
|
||||
type APIExchangeRequest struct {
|
||||
Method string `json:"method"`
|
||||
Path string `json:"path"`
|
||||
Headers *http.Header `json:"headers"`
|
||||
Body *[]byte `json:"body"`
|
||||
Method string `json:"method"`
|
||||
Path string `json:"path"`
|
||||
Headers *http.Header `json:"headers"`
|
||||
Body *[]byte `json:"body"`
|
||||
BodyTruncated bool `json:"body_truncated,omitempty"`
|
||||
BodyBytes int `json:"body_bytes,omitempty"` // original size before truncation
|
||||
}
|
||||
|
||||
type APIExchangeResponse struct {
|
||||
Status int `json:"status"`
|
||||
Headers *http.Header `json:"headers"`
|
||||
Body *[]byte `json:"body"`
|
||||
Status int `json:"status"`
|
||||
Headers *http.Header `json:"headers"`
|
||||
Body *[]byte `json:"body"`
|
||||
BodyTruncated bool `json:"body_truncated,omitempty"`
|
||||
BodyBytes int `json:"body_bytes,omitempty"` // original size before truncation
|
||||
}
|
||||
|
||||
type APIExchange struct {
|
||||
@@ -66,11 +70,29 @@ var doInitializeTracing = sync.OnceFunc(func() {
|
||||
|
||||
type bodyWriter struct {
|
||||
http.ResponseWriter
|
||||
body *bytes.Buffer
|
||||
body *bytes.Buffer
|
||||
maxBytes int // 0 = unlimited capture
|
||||
truncated bool
|
||||
totalBytes int // bytes the upstream handler wrote, even past the cap
|
||||
}
|
||||
|
||||
func (w *bodyWriter) Write(b []byte) (int, error) {
|
||||
w.body.Write(b)
|
||||
// Capture into the trace buffer up to maxBytes, then drop the overflow
|
||||
// so a chatty endpoint can't grow the buffer without bound. The full
|
||||
// payload still flows through to the real client below.
|
||||
w.totalBytes += len(b)
|
||||
if w.maxBytes <= 0 {
|
||||
w.body.Write(b)
|
||||
} else if remain := w.maxBytes - w.body.Len(); remain > 0 {
|
||||
if remain >= len(b) {
|
||||
w.body.Write(b)
|
||||
} else {
|
||||
w.body.Write(b[:remain])
|
||||
w.truncated = true
|
||||
}
|
||||
} else {
|
||||
w.truncated = true
|
||||
}
|
||||
return w.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
@@ -80,6 +102,20 @@ func (w *bodyWriter) Flush() {
|
||||
}
|
||||
}
|
||||
|
||||
// truncateForTrace returns a defensive copy of body capped at maxBytes,
|
||||
// and a flag indicating whether the cap forced truncation. maxBytes <= 0
|
||||
// disables the cap.
|
||||
func truncateForTrace(body []byte, maxBytes int) ([]byte, bool) {
|
||||
if maxBytes <= 0 || len(body) <= maxBytes {
|
||||
out := make([]byte, len(body))
|
||||
copy(out, body)
|
||||
return out, false
|
||||
}
|
||||
out := make([]byte, maxBytes)
|
||||
copy(out, body[:maxBytes])
|
||||
return out, true
|
||||
}
|
||||
|
||||
func initializeTracing(maxItems int) {
|
||||
tracingMaxItems = maxItems
|
||||
doInitializeTracing()
|
||||
@@ -134,11 +170,18 @@ func TraceMiddleware(app *application.Application) echo.MiddlewareFunc {
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
// Cap captured payload size. Without this, /embeddings and
|
||||
// streaming /chat/completions blow the in-memory buffer into the
|
||||
// tens of MB, which then locks the admin Traces UI fetching the
|
||||
// JSON dump faster than the 5s auto-refresh.
|
||||
maxBodyBytes := app.ApplicationConfig().TracingMaxBodyBytes
|
||||
|
||||
// Wrap response writer to capture body
|
||||
resBody := new(bytes.Buffer)
|
||||
mw := &bodyWriter{
|
||||
ResponseWriter: c.Response().Writer,
|
||||
body: resBody,
|
||||
maxBytes: maxBodyBytes,
|
||||
}
|
||||
c.Response().Writer = mw
|
||||
|
||||
@@ -159,8 +202,7 @@ func TraceMiddleware(app *application.Application) echo.MiddlewareFunc {
|
||||
// via any heap-dump-style introspection, and tokens shouldn't
|
||||
// outlive the request that carried them.
|
||||
requestHeaders := redactSensitiveHeaders(c.Request().Header)
|
||||
requestBody := make([]byte, len(body))
|
||||
copy(requestBody, body)
|
||||
requestBody, requestTruncated := truncateForTrace(body, maxBodyBytes)
|
||||
responseHeaders := redactSensitiveHeaders(c.Response().Header())
|
||||
responseBody := make([]byte, resBody.Len())
|
||||
copy(responseBody, resBody.Bytes())
|
||||
@@ -168,15 +210,19 @@ func TraceMiddleware(app *application.Application) echo.MiddlewareFunc {
|
||||
Timestamp: startTime,
|
||||
Duration: time.Since(startTime),
|
||||
Request: APIExchangeRequest{
|
||||
Method: c.Request().Method,
|
||||
Path: c.Path(),
|
||||
Headers: &requestHeaders,
|
||||
Body: &requestBody,
|
||||
Method: c.Request().Method,
|
||||
Path: c.Path(),
|
||||
Headers: &requestHeaders,
|
||||
Body: &requestBody,
|
||||
BodyTruncated: requestTruncated,
|
||||
BodyBytes: len(body),
|
||||
},
|
||||
Response: APIExchangeResponse{
|
||||
Status: status,
|
||||
Headers: &responseHeaders,
|
||||
Body: &responseBody,
|
||||
Status: status,
|
||||
Headers: &responseHeaders,
|
||||
Body: &responseBody,
|
||||
BodyTruncated: mw.truncated,
|
||||
BodyBytes: mw.totalBytes,
|
||||
},
|
||||
}
|
||||
if handlerErr != nil {
|
||||
|
||||
116
core/http/middleware/trace_body_cap_test.go
Normal file
116
core/http/middleware/trace_body_cap_test.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// The trace middleware copies request and response bodies into an in-memory
|
||||
// buffer that backs the admin /api/traces endpoint. With no upper bound a
|
||||
// chatty workload (embeddings, large completions) trivially produces a
|
||||
// multi-MB response that locks the Traces UI in a loading state — fetching
|
||||
// and parsing the payload outruns the 5-second auto-refresh. These specs
|
||||
// pin the capping contract so future refactors keep both the cap and the
|
||||
// passthrough to the real client intact.
|
||||
|
||||
var _ = Describe("bodyWriter capping", func() {
|
||||
It("captures the full body when maxBytes is 0 (unlimited)", func() {
|
||||
downstream := httptest.NewRecorder()
|
||||
buf := &bytes.Buffer{}
|
||||
bw := &bodyWriter{ResponseWriter: downstream, body: buf, maxBytes: 0}
|
||||
|
||||
payload := []byte(strings.Repeat("x", 4096))
|
||||
n, err := bw.Write(payload)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(len(payload)))
|
||||
Expect(buf.Len()).To(Equal(len(payload)))
|
||||
Expect(downstream.Body.Len()).To(Equal(len(payload)))
|
||||
Expect(bw.truncated).To(BeFalse())
|
||||
})
|
||||
|
||||
It("stops appending to the trace buffer once maxBytes is reached but still forwards to the client", func() {
|
||||
downstream := httptest.NewRecorder()
|
||||
buf := &bytes.Buffer{}
|
||||
bw := &bodyWriter{ResponseWriter: downstream, body: buf, maxBytes: 100}
|
||||
|
||||
payload := []byte(strings.Repeat("a", 250))
|
||||
n, err := bw.Write(payload)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(len(payload)), "Write must return the full byte count so callers see no short write")
|
||||
Expect(buf.Len()).To(Equal(100), "trace buffer should hold exactly maxBytes")
|
||||
Expect(downstream.Body.Len()).To(Equal(len(payload)), "client must still receive every byte")
|
||||
Expect(bw.truncated).To(BeTrue())
|
||||
})
|
||||
|
||||
It("handles a write that straddles the cap by keeping only the leading slice", func() {
|
||||
downstream := httptest.NewRecorder()
|
||||
buf := &bytes.Buffer{}
|
||||
bw := &bodyWriter{ResponseWriter: downstream, body: buf, maxBytes: 10}
|
||||
|
||||
_, err := bw.Write([]byte("12345"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(bw.truncated).To(BeFalse())
|
||||
|
||||
_, err = bw.Write([]byte("67890ABCDE"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(buf.String()).To(Equal("1234567890"))
|
||||
Expect(downstream.Body.String()).To(Equal("1234567890ABCDE"))
|
||||
Expect(bw.truncated).To(BeTrue())
|
||||
})
|
||||
|
||||
It("ignores further writes after the cap was already hit", func() {
|
||||
downstream := httptest.NewRecorder()
|
||||
buf := &bytes.Buffer{}
|
||||
bw := &bodyWriter{ResponseWriter: downstream, body: buf, maxBytes: 4}
|
||||
|
||||
_, _ = bw.Write([]byte("AAAA"))
|
||||
_, _ = bw.Write([]byte("BBBB"))
|
||||
_, _ = bw.Write([]byte("CCCC"))
|
||||
|
||||
Expect(buf.String()).To(Equal("AAAA"))
|
||||
Expect(downstream.Body.String()).To(Equal("AAAABBBBCCCC"))
|
||||
Expect(bw.truncated).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("truncateForTrace", func() {
|
||||
It("returns the input unchanged when below the cap", func() {
|
||||
in := []byte("hello")
|
||||
out, truncated := truncateForTrace(in, 1024)
|
||||
Expect(truncated).To(BeFalse())
|
||||
Expect(out).To(Equal(in))
|
||||
})
|
||||
|
||||
It("truncates when the input exceeds the cap and signals truncation", func() {
|
||||
in := []byte(strings.Repeat("z", 200))
|
||||
out, truncated := truncateForTrace(in, 64)
|
||||
Expect(truncated).To(BeTrue())
|
||||
Expect(out).To(HaveLen(64))
|
||||
Expect(string(out)).To(Equal(strings.Repeat("z", 64)))
|
||||
})
|
||||
|
||||
It("treats maxBytes <= 0 as unlimited (back-compat with current default)", func() {
|
||||
in := []byte(strings.Repeat("q", 10_000))
|
||||
out, truncated := truncateForTrace(in, 0)
|
||||
Expect(truncated).To(BeFalse())
|
||||
Expect(out).To(HaveLen(len(in)))
|
||||
})
|
||||
|
||||
It("does not retain the caller's backing array (defensive copy)", func() {
|
||||
in := []byte("abcdefghij")
|
||||
out, truncated := truncateForTrace(in, 4)
|
||||
Expect(truncated).To(BeTrue())
|
||||
Expect(string(out)).To(Equal("abcd"))
|
||||
|
||||
// Mutating the source must not corrupt the trace copy.
|
||||
in[0] = 'Z'
|
||||
Expect(string(out)).To(Equal("abcd"))
|
||||
})
|
||||
})
|
||||
@@ -52,11 +52,22 @@ test.describe('Traces Settings', () => {
|
||||
await page.locator('button', { hasText: 'Tracing is' }).click()
|
||||
await expect(page.locator('text=Enable Tracing')).toBeVisible()
|
||||
|
||||
const maxItemsInput = page.locator('input[type="number"]')
|
||||
// The Tracing panel has two numeric inputs (Max Items and Max Body Bytes).
|
||||
// Disambiguate by placeholder so adding a third field later doesn't break this.
|
||||
const maxItemsInput = page.getByPlaceholder('100')
|
||||
await maxItemsInput.fill('500')
|
||||
await expect(maxItemsInput).toHaveValue('500')
|
||||
})
|
||||
|
||||
test('set max body bytes value', async ({ page }) => {
|
||||
await page.locator('button', { hasText: 'Tracing is' }).click()
|
||||
await expect(page.locator('text=Enable Tracing')).toBeVisible()
|
||||
|
||||
const maxBodyBytesInput = page.getByPlaceholder('65536')
|
||||
await maxBodyBytesInput.fill('16384')
|
||||
await expect(maxBodyBytesInput).toHaveValue('16384')
|
||||
})
|
||||
|
||||
test('save shows toast', async ({ page }) => {
|
||||
// Expand settings
|
||||
await page.locator('button', { hasText: 'Tracing is' }).click()
|
||||
|
||||
@@ -649,6 +649,7 @@
|
||||
align-items: center;
|
||||
gap: var(--spacing-md);
|
||||
padding: var(--spacing-xs) 0;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.operation-info {
|
||||
@@ -739,6 +740,110 @@
|
||||
color: var(--color-error);
|
||||
}
|
||||
|
||||
/* Operations bar: per-node breakdown (multi-worker installs) */
|
||||
.operation-expand {
|
||||
background: none;
|
||||
border: none;
|
||||
color: var(--color-text-muted);
|
||||
cursor: pointer;
|
||||
padding: 0 var(--spacing-xs);
|
||||
font-size: var(--text-xs);
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 0.25rem;
|
||||
}
|
||||
.operation-expand:hover {
|
||||
color: var(--color-text-primary);
|
||||
}
|
||||
.operation-expand-label {
|
||||
font-size: var(--text-xs);
|
||||
}
|
||||
|
||||
.operation-nodes-list {
|
||||
list-style: none;
|
||||
margin: var(--spacing-xs) 0 0;
|
||||
padding: var(--spacing-xs) 0 0;
|
||||
border-top: 1px solid var(--color-border-subtle);
|
||||
flex-basis: 100%;
|
||||
width: 100%;
|
||||
}
|
||||
.operation-node {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: var(--spacing-sm);
|
||||
padding: var(--spacing-xs) 0;
|
||||
font-size: var(--text-xs);
|
||||
color: var(--color-text-muted);
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
.operation-node-status {
|
||||
padding: 2px 6px;
|
||||
border-radius: var(--radius-md);
|
||||
font-size: 0.65rem;
|
||||
font-weight: 600;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.025em;
|
||||
white-space: nowrap;
|
||||
}
|
||||
.operation-node-status-success {
|
||||
background: var(--color-success-light);
|
||||
color: var(--color-success);
|
||||
}
|
||||
.operation-node-status-error {
|
||||
background: var(--color-error-light);
|
||||
color: var(--color-error);
|
||||
}
|
||||
.operation-node-status-queued {
|
||||
background: var(--color-bg-tertiary);
|
||||
color: var(--color-text-muted);
|
||||
}
|
||||
.operation-node-status-running_on_worker {
|
||||
background: var(--color-warning-light);
|
||||
color: var(--color-warning);
|
||||
}
|
||||
.operation-node-status-downloading {
|
||||
background: var(--color-primary-light);
|
||||
color: var(--color-primary);
|
||||
}
|
||||
.operation-node-name {
|
||||
font-weight: 500;
|
||||
color: var(--color-text-secondary);
|
||||
}
|
||||
.operation-node-file {
|
||||
font-family: var(--font-mono);
|
||||
color: var(--color-text-tertiary);
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
max-width: 30ch;
|
||||
white-space: nowrap;
|
||||
}
|
||||
.operation-node-bytes {
|
||||
font-variant-numeric: tabular-nums;
|
||||
color: var(--color-text-tertiary);
|
||||
}
|
||||
.operation-node-pct {
|
||||
font-variant-numeric: tabular-nums;
|
||||
color: var(--color-primary);
|
||||
font-weight: 500;
|
||||
}
|
||||
.operation-node-error {
|
||||
color: var(--color-error);
|
||||
}
|
||||
.operation-node-bar-container {
|
||||
flex-basis: 100%;
|
||||
height: 3px;
|
||||
background: var(--color-surface-sunken);
|
||||
border-radius: var(--radius-full);
|
||||
overflow: hidden;
|
||||
margin-top: 0.25rem;
|
||||
}
|
||||
.operation-node-bar {
|
||||
height: 100%;
|
||||
background: var(--color-primary);
|
||||
border-radius: var(--radius-full);
|
||||
transition: width var(--duration-slow, 0.3s) var(--ease-spring, ease);
|
||||
}
|
||||
|
||||
/* Toast */
|
||||
.toast-container {
|
||||
position: fixed;
|
||||
|
||||
@@ -1,14 +1,33 @@
|
||||
import { useState } from 'react'
|
||||
import { useOperations } from '../hooks/useOperations'
|
||||
|
||||
const nodeStatusLabels = {
|
||||
success: 'Done',
|
||||
error: 'Failed',
|
||||
queued: 'Queued',
|
||||
running_on_worker: 'Worker busy',
|
||||
downloading: 'Downloading',
|
||||
}
|
||||
|
||||
const runningOnWorkerTooltip = 'NATS round-trip timed out, but the worker is still installing in the background. The reconciler will confirm completion.'
|
||||
|
||||
export default function OperationsBar() {
|
||||
const { operations, cancelOperation, dismissFailedOp } = useOperations()
|
||||
const [expanded, setExpanded] = useState({})
|
||||
|
||||
if (operations.length === 0) return null
|
||||
|
||||
const toggle = (key) => setExpanded((m) => ({ ...m, [key]: !m[key] }))
|
||||
|
||||
return (
|
||||
<div className="operations-bar">
|
||||
{operations.map(op => (
|
||||
<div key={op.jobID || op.id} className="operation-item">
|
||||
{operations.map(op => {
|
||||
const key = op.jobID || op.id
|
||||
const nodes = Array.isArray(op.nodes) ? op.nodes : []
|
||||
const canExpand = nodes.length > 1
|
||||
const isOpen = !!expanded[key]
|
||||
return (
|
||||
<div key={key} className="operation-item">
|
||||
<div className="operation-info">
|
||||
{op.error ? (
|
||||
<i className="fas fa-circle-exclamation" style={{ color: 'var(--color-error)', marginRight: 'var(--spacing-xs)' }} />
|
||||
@@ -80,8 +99,55 @@ export default function OperationsBar() {
|
||||
<i className="fas fa-xmark" />
|
||||
</button>
|
||||
) : null}
|
||||
{canExpand && (
|
||||
<button
|
||||
type="button"
|
||||
className="operation-expand"
|
||||
onClick={() => toggle(key)}
|
||||
aria-expanded={isOpen}
|
||||
title={isOpen ? 'Hide per-node detail' : `Show ${nodes.length} nodes`}
|
||||
>
|
||||
<i className={`fas fa-chevron-${isOpen ? 'up' : 'down'}`} />
|
||||
<span className="operation-expand-label">{nodes.length} nodes</span>
|
||||
</button>
|
||||
)}
|
||||
{canExpand && isOpen && (
|
||||
<ul className="operation-nodes-list">
|
||||
{nodes.map((n) => (
|
||||
<li key={n.node_id} className={`operation-node operation-node-${n.status}`}>
|
||||
<span
|
||||
className={`operation-node-status operation-node-status-${n.status}`}
|
||||
title={n.status === 'running_on_worker' ? runningOnWorkerTooltip : undefined}
|
||||
>
|
||||
{nodeStatusLabels[n.status] || n.status}
|
||||
</span>
|
||||
<span className="operation-node-name">{n.node_name || n.node_id}</span>
|
||||
{n.file_name && <span className="operation-node-file">{n.file_name}</span>}
|
||||
{(n.current || n.total) && (
|
||||
<span className="operation-node-bytes">
|
||||
{n.current || '?'} / {n.total || '?'}
|
||||
</span>
|
||||
)}
|
||||
{n.percentage > 0 && (
|
||||
<span className="operation-node-pct">{Math.round(n.percentage)}%</span>
|
||||
)}
|
||||
{n.error && (
|
||||
<span className="operation-node-error" title={n.error}>
|
||||
{n.error.length > 80 ? n.error.slice(0, 80) + '...' : n.error}
|
||||
</span>
|
||||
)}
|
||||
{n.percentage > 0 && n.percentage < 100 && (
|
||||
<div className="operation-node-bar-container">
|
||||
<div className="operation-node-bar" style={{ width: `${n.percentage}%` }} />
|
||||
</div>
|
||||
)}
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
)}
|
||||
</div>
|
||||
))}
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import { useState, useEffect, useCallback, useRef, useMemo } from 'react'
|
||||
import { useParams, useSearchParams, useOutletContext, Link } from 'react-router-dom'
|
||||
import { backendLogsApi } from '../utils/api'
|
||||
import { useParams, useSearchParams, useOutletContext, Link, Navigate } from 'react-router-dom'
|
||||
import { backendLogsApi, nodesApi } from '../utils/api'
|
||||
import { formatTimestamp } from '../utils/format'
|
||||
import { apiUrl } from '../utils/basePath'
|
||||
import LoadingSpinner from '../components/LoadingSpinner'
|
||||
import { useDistributedMode } from '../hooks/useDistributedMode'
|
||||
|
||||
function wsUrl(path) {
|
||||
const proto = window.location.protocol === 'https:' ? 'wss:' : 'ws:'
|
||||
@@ -274,11 +275,158 @@ function BackendLogsDetail({ modelId }) {
|
||||
)
|
||||
}
|
||||
|
||||
// DistributedBackendLogsResolver runs only in distributed mode. The local
|
||||
// /api/backend-logs WebSocket has no backend behind it here (inference lives
|
||||
// on workers), so we resolve modelId → hosting node(s) and forward to the
|
||||
// per-node logs page. One hit redirects automatically; multiple hits render
|
||||
// a picker so the operator can pick which worker's logs to inspect.
|
||||
function DistributedBackendLogsResolver({ modelId, fromTimestamp }) {
|
||||
const [hits, setHits] = useState(null) // [{ node, model }] once resolved
|
||||
const [error, setError] = useState(null)
|
||||
|
||||
useEffect(() => {
|
||||
let cancelled = false
|
||||
;(async () => {
|
||||
try {
|
||||
const nodes = await nodesApi.list()
|
||||
const nodeList = Array.isArray(nodes) ? nodes : []
|
||||
// Fan out to each node and collect entries that match this model.
|
||||
// Per-node failures are tolerated — a single offline worker shouldn't
|
||||
// hide logs available on its peers.
|
||||
const perNode = await Promise.all(nodeList.map(async (node) => {
|
||||
try {
|
||||
const models = await nodesApi.getModels(node.id)
|
||||
const matches = (Array.isArray(models) ? models : []).filter(m => m.model_name === modelId)
|
||||
return matches.map(m => ({ node, model: m }))
|
||||
} catch {
|
||||
return []
|
||||
}
|
||||
}))
|
||||
if (cancelled) return
|
||||
setHits(perNode.flat())
|
||||
} catch (err) {
|
||||
if (!cancelled) setError(err)
|
||||
}
|
||||
})()
|
||||
return () => { cancelled = true }
|
||||
}, [modelId])
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<div className="page page--wide">
|
||||
<div className="empty-state">
|
||||
<div className="empty-state-icon"><i className="fas fa-exclamation-triangle" /></div>
|
||||
<h2 className="empty-state-title">Failed to resolve hosting nodes</h2>
|
||||
<p className="empty-state-text">{error.message}</p>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
if (hits === null) {
|
||||
return (
|
||||
<div style={{ display: 'flex', justifyContent: 'center', padding: 'var(--spacing-xl)' }}>
|
||||
<LoadingSpinner size="lg" />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
if (hits.length === 0) {
|
||||
return (
|
||||
<div className="page page--wide">
|
||||
<div className="empty-state">
|
||||
<div className="empty-state-icon"><i className="fas fa-terminal" /></div>
|
||||
<h2 className="empty-state-title">Model not loaded on any worker</h2>
|
||||
<p className="empty-state-text">
|
||||
<span style={{ fontFamily: 'var(--font-mono)' }}>{modelId}</span> isn't currently loaded on any node in the cluster.
|
||||
Check the <Link to="/app/nodes" style={{ color: 'var(--color-primary)' }}>Nodes page</Link> to see which models are running where.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// Bare model name aggregates this node's replicas via the worker's log
|
||||
// store; preserve ?from= so the deep-link from a trace still scrolls to
|
||||
// the right line on arrival.
|
||||
const buildHref = (nodeId) => {
|
||||
const base = `/app/node-backend-logs/${nodeId}/${encodeURIComponent(modelId)}`
|
||||
return fromTimestamp ? `${base}?from=${encodeURIComponent(fromTimestamp)}` : base
|
||||
}
|
||||
|
||||
if (hits.length === 1) {
|
||||
return <Navigate to={buildHref(hits[0].node.id)} replace />
|
||||
}
|
||||
|
||||
// Multiple workers host this model — let the operator pick.
|
||||
return (
|
||||
<div className="page page--wide">
|
||||
<div className="page-header">
|
||||
<div>
|
||||
<h1 className="page-title" style={{ marginBottom: 0 }}>
|
||||
<i className="fas fa-terminal" style={{ fontSize: '0.8em', marginRight: 'var(--spacing-sm)' }} />
|
||||
{modelId}
|
||||
</h1>
|
||||
<p className="page-subtitle" style={{ marginTop: 'var(--spacing-xs)' }}>
|
||||
Hosted on {hits.length} workers — pick one to view its logs.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div style={{ display: 'flex', flexDirection: 'column', gap: 'var(--spacing-xs)' }}>
|
||||
{hits.map(({ node, model }) => (
|
||||
<Link
|
||||
key={`${node.id}#${model.replica_index ?? 0}`}
|
||||
to={buildHref(node.id)}
|
||||
style={{
|
||||
display: 'flex', alignItems: 'center', justifyContent: 'space-between',
|
||||
padding: 'var(--spacing-sm) var(--spacing-md)',
|
||||
background: 'var(--color-bg-primary)', border: '1px solid var(--color-border)',
|
||||
borderRadius: 'var(--radius-md)', textDecoration: 'none', color: 'inherit',
|
||||
}}
|
||||
>
|
||||
<div>
|
||||
<div style={{ fontWeight: 500 }}>{node.name || node.id}</div>
|
||||
<div style={{ fontSize: '0.75rem', color: 'var(--color-text-secondary)', fontFamily: 'var(--font-mono)' }}>
|
||||
{node.id}{model.replica_index ? ` · replica ${model.replica_index}` : ''} · {model.state}
|
||||
</div>
|
||||
</div>
|
||||
<i className="fas fa-chevron-right" style={{ color: 'var(--color-text-muted)' }} />
|
||||
</Link>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// BackendLogsRouter picks between the local WebSocket view (standalone) and
|
||||
// the distributed resolver. The probe runs once via useDistributedMode so a
|
||||
// 503 from /api/nodes (the canonical "distributed disabled" signal) keeps the
|
||||
// existing standalone path intact.
|
||||
function BackendLogsRouter({ modelId }) {
|
||||
const [searchParams] = useSearchParams()
|
||||
const fromTimestamp = searchParams.get('from')
|
||||
const { enabled: distributedMode, loading } = useDistributedMode()
|
||||
|
||||
if (loading) {
|
||||
return (
|
||||
<div style={{ display: 'flex', justifyContent: 'center', padding: 'var(--spacing-xl)' }}>
|
||||
<LoadingSpinner size="lg" />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
if (distributedMode) {
|
||||
return <DistributedBackendLogsResolver modelId={modelId} fromTimestamp={fromTimestamp} />
|
||||
}
|
||||
|
||||
return <BackendLogsDetail modelId={modelId} />
|
||||
}
|
||||
|
||||
export default function BackendLogs() {
|
||||
const { modelId } = useParams()
|
||||
|
||||
if (modelId) {
|
||||
return <BackendLogsDetail modelId={decodeURIComponent(modelId)} />
|
||||
return <BackendLogsRouter modelId={decodeURIComponent(modelId)} />
|
||||
}
|
||||
|
||||
// No model specified — redirect to System page
|
||||
|
||||
@@ -660,8 +660,7 @@ export default function Manage() {
|
||||
{ key: 'edit', icon: 'fa-pen-to-square', label: 'Edit configuration',
|
||||
onClick: () => navigate(`/app/model-editor/${encodeURIComponent(model.id)}`) },
|
||||
{ key: 'logs', icon: 'fa-terminal', label: 'Backend logs',
|
||||
onClick: () => navigate(`/app/backend-logs/${encodeURIComponent(model.id)}`),
|
||||
hidden: distributedMode },
|
||||
onClick: () => navigate(`/app/backend-logs/${encodeURIComponent(model.id)}`) },
|
||||
{ divider: true },
|
||||
{ key: 'delete', icon: 'fa-trash', label: 'Delete model', danger: true,
|
||||
onClick: () => handleDeleteModel(model.id) },
|
||||
|
||||
@@ -435,6 +435,9 @@ export default function Settings() {
|
||||
<SettingRow label="Max Items" description="Maximum number of trace items to retain (0 = unlimited)">
|
||||
<input className="input" type="number" style={{ width: 120 }} value={settings.tracing_max_items ?? ''} onChange={(e) => update('tracing_max_items', parseInt(e.target.value) || 0)} placeholder="100" disabled={!settings.enable_tracing} />
|
||||
</SettingRow>
|
||||
<SettingRow label="Max Body Bytes" description="Per-field cap (bytes) for captured request/response bodies and backend trace Data fields. Prevents large LLM histories or TTS audio snippets from locking the Traces UI. 0 = uncapped.">
|
||||
<input className="input" type="number" style={{ width: 120 }} value={settings.tracing_max_body_bytes ?? ''} onChange={(e) => update('tracing_max_body_bytes', parseInt(e.target.value) || 0)} placeholder="65536" disabled={!settings.enable_tracing} />
|
||||
</SettingRow>
|
||||
<SettingRow label="Enable Backend Logging" description="Capture backend process output per model (without requiring debug mode)">
|
||||
<Toggle checked={settings.enable_backend_logging} onChange={(v) => update('enable_backend_logging', v)} />
|
||||
</SettingRow>
|
||||
|
||||
@@ -220,7 +220,10 @@ function BackendTraceDetail({ trace }) {
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Backend logs link */}
|
||||
{/* Backend logs link — /app/backend-logs/:modelId is the unified entry
|
||||
point: in standalone mode it streams local logs, in distributed mode
|
||||
it resolves the model to the host worker(s) and either redirects to
|
||||
/app/node-backend-logs/<nodeId>/<modelId> or shows a node picker. */}
|
||||
{trace.model_name && (
|
||||
<div style={{ marginBottom: 'var(--spacing-md)' }}>
|
||||
<a
|
||||
@@ -406,7 +409,15 @@ export default function Traces() {
|
||||
<button className="btn btn-secondary btn-sm" onClick={fetchTraces}><i className="fas fa-rotate" /> Refresh</button>
|
||||
<button className="btn btn-secondary btn-sm" onClick={handleExport} disabled={traces.length === 0}><i className="fas fa-download" /> Export</button>
|
||||
<div style={{ flex: 1 }} />
|
||||
<button className="btn btn-danger btn-sm" onClick={handleClear} disabled={traces.length === 0}><i className="fas fa-trash" /> Clear</button>
|
||||
<button
|
||||
className="btn btn-danger btn-sm"
|
||||
onClick={handleClear}
|
||||
/* Stay enabled while loading: a massive in-memory trace buffer is
|
||||
precisely the case where the user can't see the table yet and
|
||||
needs Clear to recover. Clearing an already-empty server-side
|
||||
buffer is a harmless no-op. */
|
||||
disabled={!loading && traces.length === 0}
|
||||
><i className="fas fa-trash" /> Clear</button>
|
||||
</div>
|
||||
|
||||
{settings && (() => {
|
||||
@@ -459,6 +470,17 @@ export default function Traces() {
|
||||
disabled={!settings.enable_tracing}
|
||||
/>
|
||||
</SettingRow>
|
||||
<SettingRow label="Max Body Bytes" description="Per-field cap for captured bodies and backend trace Data (0 = uncapped). Prevents oversized LLM histories or TTS snippets from locking this page in loading.">
|
||||
<input
|
||||
className="input"
|
||||
type="number"
|
||||
style={{ width: 120 }}
|
||||
value={settings.tracing_max_body_bytes ?? ''}
|
||||
onChange={(e) => setSettings(prev => ({ ...prev, tracing_max_body_bytes: parseInt(e.target.value) || 0 }))}
|
||||
placeholder="65536"
|
||||
disabled={!settings.enable_tracing}
|
||||
/>
|
||||
</SettingRow>
|
||||
<SettingRow label="Enable Backend Logging" description="Capture backend process output per model (without requiring debug mode)">
|
||||
<Toggle
|
||||
checked={settings.enable_backend_logging}
|
||||
|
||||
@@ -35,6 +35,7 @@ func RegisterAnthropicRoutes(app *echo.Echo,
|
||||
)
|
||||
|
||||
messagesMiddleware := []echo.MiddlewareFunc{
|
||||
middleware.ExposeNodeHeader(application.ApplicationConfig(), application.ModelLoader()),
|
||||
middleware.UsageMiddleware(application.AuthDB()),
|
||||
middleware.TraceMiddleware(application),
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
|
||||
|
||||
@@ -18,6 +18,7 @@ func RegisterOllamaRoutes(app *echo.Echo,
|
||||
|
||||
traceMiddleware := middleware.TraceMiddleware(application)
|
||||
usageMiddleware := middleware.UsageMiddleware(application.AuthDB())
|
||||
nodeHeaderMiddleware := middleware.ExposeNodeHeader(application.ApplicationConfig(), application.ModelLoader())
|
||||
|
||||
// Chat endpoint: POST /api/chat
|
||||
chatHandler := ollama.ChatEndpoint(
|
||||
@@ -27,6 +28,7 @@ func RegisterOllamaRoutes(app *echo.Echo,
|
||||
application.ApplicationConfig(),
|
||||
)
|
||||
chatMiddleware := []echo.MiddlewareFunc{
|
||||
nodeHeaderMiddleware,
|
||||
usageMiddleware,
|
||||
traceMiddleware,
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
|
||||
@@ -43,6 +45,7 @@ func RegisterOllamaRoutes(app *echo.Echo,
|
||||
application.ApplicationConfig(),
|
||||
)
|
||||
generateMiddleware := []echo.MiddlewareFunc{
|
||||
nodeHeaderMiddleware,
|
||||
usageMiddleware,
|
||||
traceMiddleware,
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
|
||||
@@ -58,6 +61,7 @@ func RegisterOllamaRoutes(app *echo.Echo,
|
||||
application.ApplicationConfig(),
|
||||
)
|
||||
embedMiddleware := []echo.MiddlewareFunc{
|
||||
nodeHeaderMiddleware,
|
||||
usageMiddleware,
|
||||
traceMiddleware,
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EMBEDDINGS)),
|
||||
|
||||
@@ -17,6 +17,12 @@ func RegisterOpenAIRoutes(app *echo.Echo,
|
||||
// openAI compatible API endpoint
|
||||
traceMiddleware := middleware.TraceMiddleware(application)
|
||||
usageMiddleware := middleware.UsageMiddleware(application.AuthDB())
|
||||
// X-LocalAI-Node attribution middleware: wraps the response writer and
|
||||
// stamps the header on first write when --expose-node-header is on. No-op
|
||||
// otherwise. Applied to every inference path that routes through
|
||||
// ml.Load (chat, completion, embeddings) so distributed-mode operators
|
||||
// can observe which worker served each request.
|
||||
nodeHeaderMiddleware := middleware.ExposeNodeHeader(application.ApplicationConfig(), application.ModelLoader())
|
||||
|
||||
// realtime
|
||||
// TODO: Modify/disable the API key middleware for this endpoint to allow ephemeral keys created by sessions
|
||||
@@ -34,6 +40,7 @@ func RegisterOpenAIRoutes(app *echo.Echo,
|
||||
// chat
|
||||
chatHandler := openai.ChatEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig(), natsClient, application.LocalAIAssistant())
|
||||
chatMiddleware := []echo.MiddlewareFunc{
|
||||
nodeHeaderMiddleware,
|
||||
usageMiddleware,
|
||||
traceMiddleware,
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
|
||||
@@ -73,6 +80,7 @@ func RegisterOpenAIRoutes(app *echo.Echo,
|
||||
// completion
|
||||
completionHandler := openai.CompletionEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig())
|
||||
completionMiddleware := []echo.MiddlewareFunc{
|
||||
nodeHeaderMiddleware,
|
||||
usageMiddleware,
|
||||
traceMiddleware,
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_COMPLETION)),
|
||||
@@ -94,6 +102,7 @@ func RegisterOpenAIRoutes(app *echo.Echo,
|
||||
// embeddings
|
||||
embeddingHandler := openai.EmbeddingsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
embeddingMiddleware := []echo.MiddlewareFunc{
|
||||
nodeHeaderMiddleware,
|
||||
usageMiddleware,
|
||||
traceMiddleware,
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EMBEDDINGS)),
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -57,7 +58,6 @@ var usecaseFilters = map[string]config.ModelConfigUsecase{
|
||||
config.UsecaseRealtimeAudio: config.FLAG_REALTIME_AUDIO,
|
||||
}
|
||||
|
||||
|
||||
// extractHFRepo tries to find a HuggingFace repo ID from model overrides or URLs.
|
||||
func extractHFRepo(overrides map[string]any, urls []string) string {
|
||||
if overrides != nil {
|
||||
@@ -257,6 +257,44 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model
|
||||
if status != nil && status.Error != nil {
|
||||
opData["error"] = status.Error.Error()
|
||||
}
|
||||
// Expose the per-node breakdown when the Phase 4 progress sink
|
||||
// has populated OpStatus.Nodes (distributed backend installs).
|
||||
// We sort by node_name for stable UI rendering across polls;
|
||||
// the underlying slice is order-dependent on UpdateNodeProgress
|
||||
// arrival order, which the UI must not depend on. Single-node
|
||||
// ops and model installs leave Nodes empty so this block emits
|
||||
// no key, preserving the legacy payload shape.
|
||||
if status != nil && len(status.Nodes) > 0 {
|
||||
nodes := make([]map[string]any, 0, len(status.Nodes))
|
||||
for _, n := range status.Nodes {
|
||||
entry := map[string]any{
|
||||
"node_id": n.NodeID,
|
||||
"node_name": n.NodeName,
|
||||
"status": n.Status,
|
||||
"percentage": n.Percentage,
|
||||
}
|
||||
if n.FileName != "" {
|
||||
entry["file_name"] = n.FileName
|
||||
}
|
||||
if n.Current != "" {
|
||||
entry["current"] = n.Current
|
||||
}
|
||||
if n.Total != "" {
|
||||
entry["total"] = n.Total
|
||||
}
|
||||
if n.Phase != "" {
|
||||
entry["phase"] = n.Phase
|
||||
}
|
||||
if n.Error != "" {
|
||||
entry["error"] = n.Error
|
||||
}
|
||||
nodes = append(nodes, entry)
|
||||
}
|
||||
sort.SliceStable(nodes, func(i, j int) bool {
|
||||
return fmt.Sprintf("%v", nodes[i]["node_name"]) < fmt.Sprintf("%v", nodes[j]["node_name"])
|
||||
})
|
||||
opData["nodes"] = nodes
|
||||
}
|
||||
operations = append(operations, opData)
|
||||
}
|
||||
|
||||
@@ -557,11 +595,11 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model
|
||||
NodeStatus string `json:"node_status"`
|
||||
}
|
||||
type modelCapability struct {
|
||||
ID string `json:"id"`
|
||||
Capabilities []string `json:"capabilities"`
|
||||
Backend string `json:"backend"`
|
||||
Disabled bool `json:"disabled"`
|
||||
Pinned bool `json:"pinned"`
|
||||
ID string `json:"id"`
|
||||
Capabilities []string `json:"capabilities"`
|
||||
Backend string `json:"backend"`
|
||||
Disabled bool `json:"disabled"`
|
||||
Pinned bool `json:"pinned"`
|
||||
// LoadedOn is populated only when the node registry is active
|
||||
// (distributed mode). Lets the UI show "loaded on worker-1" without
|
||||
// the operator having to expand every node manually. An empty slice
|
||||
@@ -1159,17 +1197,17 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model
|
||||
}
|
||||
|
||||
return c.JSON(200, map[string]any{
|
||||
"backends": backendsJSON,
|
||||
"repositories": appConfig.BackendGalleries,
|
||||
"allTags": tags,
|
||||
"processingBackends": processingBackendsData,
|
||||
"taskTypes": taskTypes,
|
||||
"availableBackends": totalBackends,
|
||||
"installedBackends": installedBackendsCount,
|
||||
"currentPage": pageNum,
|
||||
"totalPages": totalPages,
|
||||
"prevPage": prevPage,
|
||||
"nextPage": nextPage,
|
||||
"backends": backendsJSON,
|
||||
"repositories": appConfig.BackendGalleries,
|
||||
"allTags": tags,
|
||||
"processingBackends": processingBackendsData,
|
||||
"taskTypes": taskTypes,
|
||||
"availableBackends": totalBackends,
|
||||
"installedBackends": installedBackendsCount,
|
||||
"currentPage": pageNum,
|
||||
"totalPages": totalPages,
|
||||
"prevPage": prevPage,
|
||||
"nextPage": nextPage,
|
||||
"systemCapability": detectedCapability,
|
||||
"preferDevelopmentBackends": appConfig.PreferDevelopmentBackends,
|
||||
})
|
||||
@@ -1599,4 +1637,3 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model
|
||||
app.DELETE("/api/branding/asset/:kind", localai.DeleteBrandingAssetEndpoint(appConfig), adminMiddleware)
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -62,6 +62,63 @@ var _ = Describe("/api/operations with node-scoped backend ops", func() {
|
||||
Expect(found["isBackend"]).To(Equal(true))
|
||||
})
|
||||
|
||||
It("surfaces per-node OpStatus entries on /api/operations", func() {
|
||||
appCfg := &config.ApplicationConfig{}
|
||||
galleryService := galleryop.NewGalleryService(appCfg, nil)
|
||||
opcache := galleryop.NewOpCache(galleryService)
|
||||
|
||||
jobID := "test-op-nodes-1"
|
||||
// Register a backend op so the handler treats this as a backend
|
||||
// install (no need to consult the gallery during the test).
|
||||
opcache.SetBackend("vllm", jobID)
|
||||
|
||||
// Populate per-node entries via the P4.2 helper. The helper also
|
||||
// allocates an OpStatus under jobID, which the handler will read.
|
||||
galleryService.UpdateNodeProgress(jobID, "node-b", galleryop.NodeProgress{
|
||||
NodeID: "node-b", NodeName: "worker-b", Status: galleryop.NodeStatusRunningOnWorker,
|
||||
})
|
||||
galleryService.UpdateNodeProgress(jobID, "node-a", galleryop.NodeProgress{
|
||||
NodeID: "node-a", NodeName: "worker-a", Status: galleryop.NodeStatusDownloading, Percentage: 30, FileName: "vllm.tar",
|
||||
})
|
||||
|
||||
e := echo.New()
|
||||
routes.RegisterUIAPIRoutes(e, nil, nil, appCfg, galleryService, opcache, &application.Application{}, noopMw)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/operations", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var envelope struct {
|
||||
Operations []map[string]any `json:"operations"`
|
||||
}
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &envelope)).To(Succeed())
|
||||
|
||||
var found map[string]any
|
||||
for _, op := range envelope.Operations {
|
||||
if op["jobID"] == jobID {
|
||||
found = op
|
||||
break
|
||||
}
|
||||
}
|
||||
Expect(found).ToNot(BeNil(), "operation should appear in /api/operations")
|
||||
nodes, ok := found["nodes"].([]any)
|
||||
Expect(ok).To(BeTrue(), "operation should have a nodes array")
|
||||
Expect(nodes).To(HaveLen(2))
|
||||
|
||||
// Stable sort by node_name: "worker-a" comes before "worker-b"
|
||||
// even though UpdateNodeProgress was called in reverse order.
|
||||
first := nodes[0].(map[string]any)
|
||||
Expect(first["node_name"]).To(Equal("worker-a"))
|
||||
Expect(first["status"]).To(Equal("downloading"))
|
||||
Expect(first["file_name"]).To(Equal("vllm.tar"))
|
||||
Expect(first["percentage"]).To(Equal(30.0))
|
||||
|
||||
second := nodes[1].(map[string]any)
|
||||
Expect(second["node_name"]).To(Equal("worker-b"))
|
||||
Expect(second["status"]).To(Equal("running_on_worker"))
|
||||
})
|
||||
|
||||
It("does not emit nodeID for non-node-scoped backend ops", func() {
|
||||
appCfg := &config.ApplicationConfig{}
|
||||
galleryService := galleryop.NewGalleryService(appCfg, nil)
|
||||
|
||||
@@ -91,6 +91,21 @@ func (g *GalleryService) backendHandler(op *ManagementOp[gallery.GalleryBackend,
|
||||
})
|
||||
return err
|
||||
}
|
||||
if errors.Is(err, ErrWorkerStillInstalling) {
|
||||
// Soft failure: at least one worker timed out replying but is
|
||||
// still running the install in the background. Mark the op as
|
||||
// processed with a non-error message so the admin UI shows a
|
||||
// yellow in-progress state rather than red. The reconciler's
|
||||
// next pass will reconcile the actual outcome via backend.list.
|
||||
xlog.Info("worker still installing in background", "backend", op.GalleryElementName, "error", err)
|
||||
g.UpdateStatus(op.ID, &OpStatus{
|
||||
Processed: true,
|
||||
GalleryElementName: op.GalleryElementName,
|
||||
Message: fmt.Sprintf("backend %s: worker still installing in background; reconciler will confirm completion (%v)", op.GalleryElementName, err),
|
||||
Cancellable: false,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
xlog.Error("error installing backend", "error", err, "backend", op.GalleryElementName)
|
||||
if !op.Delete {
|
||||
// If we didn't install the backend, we need to make sure we don't have a leftover directory
|
||||
|
||||
13
core/services/galleryop/errors.go
Normal file
13
core/services/galleryop/errors.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package galleryop
|
||||
|
||||
import "errors"
|
||||
|
||||
// ErrWorkerStillInstalling indicates a distributed backend install
|
||||
// timed out at the NATS round-trip layer but the worker is most likely
|
||||
// still pulling the OCI image in the background. Producers
|
||||
// (DistributedBackendManager) wrap this when the round-trip times out;
|
||||
// consumers (backendHandler) use errors.Is(err, ErrWorkerStillInstalling)
|
||||
// to surface a yellow "in progress" OpStatus instead of a red error,
|
||||
// leaving the pending_backend_ops row in place for the reconciler to
|
||||
// confirm via backend.list.
|
||||
var ErrWorkerStillInstalling = errors.New("worker did not reply in time; install may still be running in the background")
|
||||
149
core/services/galleryop/node_progress_test.go
Normal file
149
core/services/galleryop/node_progress_test.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package galleryop_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
)
|
||||
|
||||
var _ = Describe("NodeStatus constants", func() {
|
||||
// Pin the wire-format string values. A future refactor that renames
|
||||
// a constant must NOT silently change the JSON value the UI receives
|
||||
// (or the cross-package contract with the nodes package, which
|
||||
// reuses these constants for NodeOpStatus.Status).
|
||||
DescribeTable("status constant",
|
||||
func(actual, expected string) {
|
||||
Expect(actual).To(Equal(expected))
|
||||
},
|
||||
Entry("queued", galleryop.NodeStatusQueued, "queued"),
|
||||
Entry("downloading", galleryop.NodeStatusDownloading, "downloading"),
|
||||
Entry("running on worker", galleryop.NodeStatusRunningOnWorker, "running_on_worker"),
|
||||
Entry("success", galleryop.NodeStatusSuccess, "success"),
|
||||
Entry("error", galleryop.NodeStatusError, "error"),
|
||||
)
|
||||
})
|
||||
|
||||
var _ = Describe("OpStatus.Nodes", func() {
|
||||
It("defaults to empty on a fresh OpStatus", func() {
|
||||
os := &galleryop.OpStatus{}
|
||||
Expect(os.Nodes).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("JSON round-trips with all NodeProgress fields", func() {
|
||||
os := &galleryop.OpStatus{
|
||||
Nodes: []galleryop.NodeProgress{
|
||||
{
|
||||
NodeID: "node-1",
|
||||
NodeName: "worker-a",
|
||||
Status: galleryop.NodeStatusRunningOnWorker,
|
||||
FileName: "vllm.tar.zst",
|
||||
Current: "412 MB",
|
||||
Total: "2.1 GB",
|
||||
Percentage: 19.6,
|
||||
Phase: "downloading", // literal pins the wire-format value
|
||||
Error: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
raw, err := json.Marshal(os)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
got := &galleryop.OpStatus{}
|
||||
Expect(json.Unmarshal(raw, got)).To(Succeed())
|
||||
Expect(got.Nodes).To(HaveLen(1))
|
||||
Expect(got.Nodes[0]).To(Equal(os.Nodes[0]))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("GalleryService.UpdateNodeProgress", func() {
|
||||
var svc *galleryop.GalleryService
|
||||
|
||||
BeforeEach(func() {
|
||||
// UpdateNodeProgress + GetStatus only touch the in-memory statuses
|
||||
// map. A zero-value ApplicationConfig is enough to get past the
|
||||
// LocalModelManager / LocalBackendManager constructors.
|
||||
svc = galleryop.NewGalleryService(&config.ApplicationConfig{}, nil)
|
||||
})
|
||||
|
||||
It("creates a node entry on first call", func() {
|
||||
svc.UpdateNodeProgress("op1", "n1", galleryop.NodeProgress{
|
||||
NodeID: "n1", NodeName: "worker-a", Status: galleryop.NodeStatusDownloading, Percentage: 12.0,
|
||||
})
|
||||
st := svc.GetStatus("op1")
|
||||
Expect(st).ToNot(BeNil())
|
||||
Expect(st.Nodes).To(HaveLen(1))
|
||||
Expect(st.Nodes[0].NodeID).To(Equal("n1"))
|
||||
Expect(st.Nodes[0].Percentage).To(Equal(12.0))
|
||||
})
|
||||
|
||||
It("merges subsequent updates into the same NodeID entry, not appending", func() {
|
||||
svc.UpdateNodeProgress("op1", "n1", galleryop.NodeProgress{NodeID: "n1", NodeName: "worker-a", Status: galleryop.NodeStatusDownloading, Percentage: 12.0})
|
||||
svc.UpdateNodeProgress("op1", "n1", galleryop.NodeProgress{NodeID: "n1", NodeName: "worker-a", Status: galleryop.NodeStatusDownloading, Percentage: 48.0, FileName: "vllm.tar"})
|
||||
st := svc.GetStatus("op1")
|
||||
Expect(st.Nodes).To(HaveLen(1))
|
||||
Expect(st.Nodes[0].Percentage).To(Equal(48.0))
|
||||
Expect(st.Nodes[0].FileName).To(Equal("vllm.tar"))
|
||||
})
|
||||
|
||||
It("appends a new entry for a different NodeID", func() {
|
||||
svc.UpdateNodeProgress("op1", "n1", galleryop.NodeProgress{NodeID: "n1", NodeName: "worker-a", Status: galleryop.NodeStatusDownloading, Percentage: 12.0})
|
||||
svc.UpdateNodeProgress("op1", "n2", galleryop.NodeProgress{NodeID: "n2", NodeName: "worker-b", Status: galleryop.NodeStatusQueued})
|
||||
st := svc.GetStatus("op1")
|
||||
Expect(st.Nodes).To(HaveLen(2))
|
||||
})
|
||||
|
||||
It("mirrors the latest tick into the aggregate OpStatus fields", func() {
|
||||
svc.UpdateNodeProgress("op1", "n1", galleryop.NodeProgress{
|
||||
NodeID: "n1", NodeName: "worker-a", Status: galleryop.NodeStatusDownloading,
|
||||
Percentage: 33.0, FileName: "vllm.tar", Current: "330 MB", Total: "1 GB",
|
||||
})
|
||||
st := svc.GetStatus("op1")
|
||||
Expect(st.Progress).To(Equal(33.0))
|
||||
Expect(st.FileName).To(Equal("vllm.tar"))
|
||||
Expect(st.DownloadedFileSize).To(Equal("330 MB"))
|
||||
Expect(st.TotalFileSize).To(Equal("1 GB"))
|
||||
})
|
||||
|
||||
It("preserves accumulated Nodes when a subsequent UpdateStatus comes through the legacy path", func() {
|
||||
// Regression: the Phase 2 progress bridge also calls the legacy
|
||||
// progressCb -> UpdateStatus(opID, &OpStatus{...}) on every tick.
|
||||
// Without preservation that overwrite would wipe the Nodes slice
|
||||
// and the UI would flicker between one node and another on a
|
||||
// multi-worker install. UpdateStatus must carry forward existing
|
||||
// Nodes when the incoming op has none.
|
||||
svc.UpdateNodeProgress("op1", "n1", galleryop.NodeProgress{NodeID: "n1", NodeName: "worker-a", Status: galleryop.NodeStatusSuccess})
|
||||
svc.UpdateNodeProgress("op1", "n2", galleryop.NodeProgress{NodeID: "n2", NodeName: "worker-b", Status: galleryop.NodeStatusDownloading, Percentage: 30.0})
|
||||
|
||||
// Now simulate the legacy progressCb path: a fresh OpStatus
|
||||
// pointer with no Nodes set, carrying only aggregate fields.
|
||||
svc.UpdateStatus("op1", &galleryop.OpStatus{
|
||||
Progress: 30.0,
|
||||
Message: "downloading",
|
||||
})
|
||||
|
||||
st := svc.GetStatus("op1")
|
||||
Expect(st.Nodes).To(HaveLen(2), "Nodes accumulated before the legacy UpdateStatus must be preserved")
|
||||
ids := []string{st.Nodes[0].NodeID, st.Nodes[1].NodeID}
|
||||
Expect(ids).To(ConsistOf("n1", "n2"))
|
||||
})
|
||||
|
||||
It("allows an explicit empty-then-populated Nodes transition to win when caller sets Nodes", func() {
|
||||
// If a caller explicitly passes a non-empty Nodes slice on the
|
||||
// incoming op, that should replace the existing slice (no merge).
|
||||
// Only an EMPTY incoming slice triggers the carry-forward.
|
||||
svc.UpdateNodeProgress("op1", "n1", galleryop.NodeProgress{NodeID: "n1", NodeName: "worker-a", Status: galleryop.NodeStatusSuccess})
|
||||
svc.UpdateStatus("op1", &galleryop.OpStatus{
|
||||
Progress: 100.0,
|
||||
Nodes: []galleryop.NodeProgress{
|
||||
{NodeID: "n9", NodeName: "worker-final", Status: galleryop.NodeStatusSuccess},
|
||||
},
|
||||
})
|
||||
st := svc.GetStatus("op1")
|
||||
Expect(st.Nodes).To(HaveLen(1))
|
||||
Expect(st.Nodes[0].NodeID).To(Equal("n9"))
|
||||
})
|
||||
})
|
||||
@@ -53,6 +53,45 @@ type OpStatus struct {
|
||||
GalleryElementName string `json:"gallery_element_name"`
|
||||
Cancelled bool `json:"cancelled"` // Cancelled is true if the operation was cancelled
|
||||
Cancellable bool `json:"cancellable"` // Cancellable is true if the operation can be cancelled
|
||||
|
||||
// Nodes is the per-node breakdown for a fanned-out backend install.
|
||||
// Populated by DistributedBackendManager (per-node terminal status)
|
||||
// and by the Phase 2 progress bridge (per-byte ticks). The
|
||||
// /api/operations handler surfaces this so the UI can render an
|
||||
// expandable per-node view of an in-flight install.
|
||||
Nodes []NodeProgress `json:"nodes,omitempty"`
|
||||
}
|
||||
|
||||
// NodeStatus values shared between NodeProgress (per-node tick) and the
|
||||
// NodeOpStatus surfaced by DistributedBackendManager's fan-out. Defined
|
||||
// as exported constants so producers (the manager, the progress bridge)
|
||||
// and consumers (the /api/operations handler, the React OperationsBar
|
||||
// through its JSON contract) stay in sync via a single source of truth.
|
||||
const (
|
||||
NodeStatusQueued = "queued" // node accepted the intent but install has not started
|
||||
NodeStatusDownloading = "downloading" // worker is actively pulling the OCI image
|
||||
NodeStatusRunningOnWorker = "running_on_worker" // NATS round-trip timed out but worker is still installing
|
||||
NodeStatusSuccess = "success" // install completed on this node
|
||||
NodeStatusError = "error" // install failed on this node
|
||||
)
|
||||
|
||||
// NodeProgress is a single node's contribution to a backend install
|
||||
// operation. Populated by DistributedBackendManager (per-node terminal
|
||||
// status) and by the Phase 2 progress bridge (per-byte ticks). Read by
|
||||
// the /api/operations handler so the UI can render an expandable
|
||||
// per-node breakdown.
|
||||
//
|
||||
// Status holds one of the NodeStatus* constants above.
|
||||
type NodeProgress struct {
|
||||
NodeID string `json:"node_id"`
|
||||
NodeName string `json:"node_name"`
|
||||
Status string `json:"status"`
|
||||
FileName string `json:"file_name,omitempty"`
|
||||
Current string `json:"current,omitempty"`
|
||||
Total string `json:"total,omitempty"`
|
||||
Percentage float64 `json:"percentage"`
|
||||
Phase string `json:"phase,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type OpCache struct {
|
||||
|
||||
@@ -110,6 +110,18 @@ func (g *GalleryService) DeleteBackend(name string) error {
|
||||
func (g *GalleryService) UpdateStatus(s string, op *OpStatus) {
|
||||
g.Lock()
|
||||
defer g.Unlock()
|
||||
// Preserve any per-node entries already accumulated by UpdateNodeProgress:
|
||||
// the legacy progressCb path (used by the Phase 2 install bridge) calls
|
||||
// UpdateStatus with a fresh *OpStatus on every tick, which would otherwise
|
||||
// wipe the Nodes slice and leave the UI flickering between one node and
|
||||
// another. If the caller explicitly populates Nodes on the incoming op,
|
||||
// that wins; an empty Nodes slice on the incoming op is treated as "no
|
||||
// new per-node data" and the previous Nodes are carried forward.
|
||||
if op != nil && len(op.Nodes) == 0 {
|
||||
if prev := g.statuses[s]; prev != nil && len(prev.Nodes) > 0 {
|
||||
op.Nodes = prev.Nodes
|
||||
}
|
||||
}
|
||||
g.statuses[s] = op
|
||||
|
||||
// Persist to PostgreSQL in distributed mode
|
||||
@@ -135,6 +147,47 @@ func (g *GalleryService) UpdateStatus(s string, op *OpStatus) {
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateNodeProgress merges a per-node progress tick into OpStatus.Nodes,
|
||||
// keyed by nodeID, and mirrors the latest values into the aggregate
|
||||
// Progress / FileName / DownloadedFileSize / TotalFileSize / Message
|
||||
// fields so the legacy single-bar OperationsBar view keeps working
|
||||
// unchanged alongside the new per-node breakdown.
|
||||
//
|
||||
// We deliberately do NOT delegate the aggregate mirror to UpdateStatus
|
||||
// here: UpdateStatus overwrites the entire OpStatus, which would clobber
|
||||
// the Nodes slice we just merged into. Doing the merge + mirror under a
|
||||
// single lock keeps both views consistent and concurrent-safe.
|
||||
func (g *GalleryService) UpdateNodeProgress(opID, nodeID string, np NodeProgress) {
|
||||
g.Lock()
|
||||
defer g.Unlock()
|
||||
status := g.statuses[opID]
|
||||
if status == nil {
|
||||
status = &OpStatus{}
|
||||
g.statuses[opID] = status
|
||||
}
|
||||
merged := false
|
||||
for i := range status.Nodes {
|
||||
if status.Nodes[i].NodeID == nodeID {
|
||||
status.Nodes[i] = np
|
||||
merged = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !merged {
|
||||
status.Nodes = append(status.Nodes, np)
|
||||
}
|
||||
|
||||
// Mirror the latest tick into the legacy aggregate fields so the
|
||||
// existing single-bar UI keeps rendering meaningful progress.
|
||||
status.FileName = np.FileName
|
||||
status.Progress = np.Percentage
|
||||
status.DownloadedFileSize = np.Current
|
||||
status.TotalFileSize = np.Total
|
||||
if np.Phase != "" {
|
||||
status.Message = np.Phase
|
||||
}
|
||||
}
|
||||
|
||||
func (g *GalleryService) GetStatus(s string) *OpStatus {
|
||||
g.Lock()
|
||||
defer g.Unlock()
|
||||
|
||||
36
core/services/messaging/backend_install_progress.go
Normal file
36
core/services/messaging/backend_install_progress.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package messaging
|
||||
|
||||
// Phase values published on the BackendInstallProgressEvent.Phase field.
|
||||
// Defined as exported constants so producer (worker install handler) and
|
||||
// consumer (master bridge into OpStatus) share a single source of truth
|
||||
// instead of two copies of the literal string.
|
||||
const (
|
||||
PhaseResolving = "resolving" // worker is locating the gallery / image manifest
|
||||
PhaseDownloading = "downloading" // worker is actively pulling layers
|
||||
PhaseExtracting = "extracting" // worker is unpacking the downloaded archive
|
||||
PhaseStarting = "starting" // worker is spawning the gRPC backend process
|
||||
)
|
||||
|
||||
// BackendInstallProgressEvent is the wire payload published by a worker to
|
||||
// nodes.<nodeID>.backend.install.<opID>.progress while a long-running install
|
||||
// is in flight. Transient: dropped events are acceptable, the master relies
|
||||
// on BackendInstallReply for ground truth on success/failure.
|
||||
//
|
||||
// Phase holds one of the Phase* constants above.
|
||||
type BackendInstallProgressEvent struct {
|
||||
OpID string `json:"op_id"`
|
||||
NodeID string `json:"node_id"`
|
||||
Backend string `json:"backend"`
|
||||
FileName string `json:"file_name,omitempty"`
|
||||
Current string `json:"current,omitempty"` // human-readable size, e.g. "412 MB"
|
||||
Total string `json:"total,omitempty"` // human-readable size, e.g. "2.1 GB"
|
||||
Percentage float64 `json:"percentage"`
|
||||
Phase string `json:"phase,omitempty"`
|
||||
}
|
||||
|
||||
// SubjectNodeBackendInstallProgress returns the NATS subject for transient
|
||||
// progress events emitted by a worker during a single backend.install run.
|
||||
// Per-op so multiple concurrent installs on the same node never alias.
|
||||
func SubjectNodeBackendInstallProgress(nodeID, opID string) string {
|
||||
return subjectNodePrefix + sanitizeSubjectToken(nodeID) + ".backend.install." + sanitizeSubjectToken(opID) + ".progress"
|
||||
}
|
||||
66
core/services/messaging/backend_install_progress_test.go
Normal file
66
core/services/messaging/backend_install_progress_test.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package messaging_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
)
|
||||
|
||||
var _ = Describe("Phase constants", func() {
|
||||
// Pin the wire-format string values. A future refactor that renames
|
||||
// a constant must NOT silently change the JSON value the master
|
||||
// receives or break consumers that switch on Phase.
|
||||
DescribeTable("phase constant",
|
||||
func(actual, expected string) {
|
||||
Expect(actual).To(Equal(expected))
|
||||
},
|
||||
Entry("resolving", messaging.PhaseResolving, "resolving"),
|
||||
Entry("downloading", messaging.PhaseDownloading, "downloading"),
|
||||
Entry("extracting", messaging.PhaseExtracting, "extracting"),
|
||||
Entry("starting", messaging.PhaseStarting, "starting"),
|
||||
)
|
||||
})
|
||||
|
||||
var _ = Describe("BackendInstallProgress", func() {
|
||||
Context("SubjectNodeBackendInstallProgress", func() {
|
||||
It("composes the per-op progress subject", func() {
|
||||
Expect(messaging.SubjectNodeBackendInstallProgress("node-abc", "op-123")).
|
||||
To(Equal("nodes.node-abc.backend.install.op-123.progress"))
|
||||
})
|
||||
|
||||
It("sanitizes NATS-reserved characters in node and op tokens", func() {
|
||||
// '.' is the NATS hierarchy delimiter, '*' and '>' are wildcards,
|
||||
// and whitespace must be stripped - sanitizeSubjectToken replaces
|
||||
// all of them with '-'. The resulting subject must still parse as
|
||||
// exactly six hierarchy segments: nodes/<node>/backend/install/<op>/progress.
|
||||
subj := messaging.SubjectNodeBackendInstallProgress("a.b c", "x.y z")
|
||||
Expect(subj).ToNot(ContainSubstring(" "))
|
||||
Expect(strings.Count(subj, ".")).To(Equal(5))
|
||||
})
|
||||
})
|
||||
|
||||
Context("BackendInstallProgressEvent", func() {
|
||||
It("JSON round-trips with all known fields", func() {
|
||||
ev := messaging.BackendInstallProgressEvent{
|
||||
OpID: "op-123",
|
||||
NodeID: "node-abc",
|
||||
Backend: "vllm",
|
||||
FileName: "vllm-cpu.tar.zst",
|
||||
Current: "412 MB",
|
||||
Total: "2.1 GB",
|
||||
Percentage: 19.6,
|
||||
Phase: "downloading",
|
||||
}
|
||||
raw, err := json.Marshal(ev)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var got messaging.BackendInstallProgressEvent
|
||||
Expect(json.Unmarshal(raw, &got)).To(Succeed())
|
||||
Expect(got).To(Equal(ev))
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -144,6 +144,12 @@ type BackendInstallRequest struct {
|
||||
// worker still works (the master's install fallback path also uses this
|
||||
// when backend.upgrade returns nats.ErrNoResponders).
|
||||
Force bool `json:"force,omitempty"`
|
||||
// OpID identifies the admin-side operation. When non-empty the worker
|
||||
// publishes BackendInstallProgressEvent values to
|
||||
// SubjectNodeBackendInstallProgress(nodeID, OpID) while the install is
|
||||
// running, debounced to roughly 250ms. Empty means the caller is a
|
||||
// reconciler-driven retry that does not need progress streamed.
|
||||
OpID string `json:"op_id,omitempty"`
|
||||
}
|
||||
|
||||
// BackendInstallReply is the response from a backend.install NATS request.
|
||||
|
||||
120
core/services/nodes/install_progress_publisher.go
Normal file
120
core/services/nodes/install_progress_publisher.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
)
|
||||
|
||||
// DebouncedInstallProgressPublisher buffers backend-install download ticks
|
||||
// and publishes them to the per-op NATS progress subject at most once per
|
||||
// `interval`. Always publishes the final event on Flush so the UI sees the
|
||||
// terminal percentage.
|
||||
//
|
||||
// Behavior: leading-edge debounce. The first OnDownload after a quiet window
|
||||
// publishes immediately; subsequent ticks within `interval` only buffer the
|
||||
// latest event, which is then emitted via a single trailing timer. This
|
||||
// keeps the wire chatter bounded (~4 events per second at 250ms) while
|
||||
// still surfacing every meaningful percentage jump.
|
||||
//
|
||||
// Lock ordering: never hold p.mu across a Publish call. Publish hits the
|
||||
// NATS client which may block on a slow link, and we don't want a stalled
|
||||
// network to stall the underlying gallery download loop.
|
||||
type DebouncedInstallProgressPublisher struct {
|
||||
mu sync.Mutex
|
||||
client messaging.MessagingClient
|
||||
subject string
|
||||
nodeID string
|
||||
opID string
|
||||
backend string
|
||||
interval time.Duration
|
||||
lastPublishedAt time.Time
|
||||
pending *messaging.BackendInstallProgressEvent
|
||||
timer *time.Timer
|
||||
}
|
||||
|
||||
// NewDebouncedInstallProgressPublisher constructs a publisher for one
|
||||
// install operation. interval is the leading-edge debounce window
|
||||
// (~250ms in production).
|
||||
func NewDebouncedInstallProgressPublisher(client messaging.MessagingClient, nodeID, opID, backend string, interval time.Duration) *DebouncedInstallProgressPublisher {
|
||||
return &DebouncedInstallProgressPublisher{
|
||||
client: client,
|
||||
subject: messaging.SubjectNodeBackendInstallProgress(nodeID, opID),
|
||||
nodeID: nodeID,
|
||||
opID: opID,
|
||||
backend: backend,
|
||||
interval: interval,
|
||||
}
|
||||
}
|
||||
|
||||
// OnDownload is the callback shape gallery.InstallBackendFromGallery and
|
||||
// galleryop.InstallExternalBackend pass into the worker. Each invocation
|
||||
// represents a single tick from the underlying io.Reader copy loop.
|
||||
func (p *DebouncedInstallProgressPublisher) OnDownload(file, current, total string, percentage float64) {
|
||||
ev := messaging.BackendInstallProgressEvent{
|
||||
OpID: p.opID,
|
||||
NodeID: p.nodeID,
|
||||
Backend: p.backend,
|
||||
FileName: file,
|
||||
Current: current,
|
||||
Total: total,
|
||||
Percentage: percentage,
|
||||
Phase: messaging.PhaseDownloading,
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
now := time.Now()
|
||||
if p.lastPublishedAt.IsZero() || now.Sub(p.lastPublishedAt) >= p.interval {
|
||||
// Leading edge: publish immediately.
|
||||
p.lastPublishedAt = now
|
||||
p.pending = nil
|
||||
p.mu.Unlock()
|
||||
_ = p.client.Publish(p.subject, ev)
|
||||
return
|
||||
}
|
||||
// Within the window: buffer the latest event and arm a trailing
|
||||
// publish. If a timer is already armed, we just overwrite p.pending so
|
||||
// the trailing publish carries the freshest data.
|
||||
p.pending = &ev
|
||||
if p.timer == nil {
|
||||
delay := p.interval - now.Sub(p.lastPublishedAt)
|
||||
p.timer = time.AfterFunc(delay, p.flushPending)
|
||||
}
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
// flushPending is the trailing-edge publisher fired by the AfterFunc timer.
|
||||
// It clears the pending slot under the lock, then publishes outside the
|
||||
// lock so Publish never blocks an in-progress OnDownload call.
|
||||
func (p *DebouncedInstallProgressPublisher) flushPending() {
|
||||
p.mu.Lock()
|
||||
p.timer = nil
|
||||
pending := p.pending
|
||||
p.pending = nil
|
||||
if pending != nil {
|
||||
p.lastPublishedAt = time.Now()
|
||||
}
|
||||
p.mu.Unlock()
|
||||
if pending != nil {
|
||||
_ = p.client.Publish(p.subject, *pending)
|
||||
}
|
||||
}
|
||||
|
||||
// Flush publishes any pending buffered event synchronously and stops the
|
||||
// pending timer. Safe to call multiple times. Callers MUST defer Flush
|
||||
// after constructing the publisher so the terminal percentage reaches the
|
||||
// master even on error returns.
|
||||
func (p *DebouncedInstallProgressPublisher) Flush() {
|
||||
p.mu.Lock()
|
||||
if p.timer != nil {
|
||||
p.timer.Stop()
|
||||
p.timer = nil
|
||||
}
|
||||
pending := p.pending
|
||||
p.pending = nil
|
||||
p.mu.Unlock()
|
||||
if pending != nil {
|
||||
_ = p.client.Publish(p.subject, *pending)
|
||||
}
|
||||
}
|
||||
48
core/services/nodes/install_progress_publisher_test.go
Normal file
48
core/services/nodes/install_progress_publisher_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
)
|
||||
|
||||
var _ = Describe("DebouncedInstallProgressPublisher", func() {
|
||||
It("publishes the first event immediately and debounces subsequent ones within the window", func() {
|
||||
mc := newScriptedMessagingClient()
|
||||
pub := NewDebouncedInstallProgressPublisher(mc, "n1", "op1", "vllm", 50*time.Millisecond)
|
||||
|
||||
// Three rapid-fire ticks within the debounce window.
|
||||
pub.OnDownload("vllm.tar.zst", "100 MB", "1 GB", 10.0)
|
||||
pub.OnDownload("vllm.tar.zst", "200 MB", "1 GB", 20.0)
|
||||
pub.OnDownload("vllm.tar.zst", "300 MB", "1 GB", 30.0)
|
||||
pub.Flush()
|
||||
|
||||
// First event publishes immediately; the others coalesce; Flush guarantees a final.
|
||||
// So we expect at least 2 publishes and at most 4 (lead + final + any window-bounded).
|
||||
Eventually(func() int {
|
||||
return len(mc.publishCalls(messaging.SubjectNodeBackendInstallProgress("n1", "op1")))
|
||||
}, "1s").Should(BeNumerically(">=", 2))
|
||||
calls := mc.publishCalls(messaging.SubjectNodeBackendInstallProgress("n1", "op1"))
|
||||
Expect(len(calls)).To(BeNumerically("<=", 4),
|
||||
"three ticks within the debounce window should produce at most ~4 publishes")
|
||||
})
|
||||
|
||||
It("publishes the final event after Flush with the latest percentage", func() {
|
||||
mc := newScriptedMessagingClient()
|
||||
pub := NewDebouncedInstallProgressPublisher(mc, "n1", "op1", "vllm", 50*time.Millisecond)
|
||||
|
||||
pub.OnDownload("vllm.tar.zst", "1 GB", "1 GB", 100.0)
|
||||
pub.Flush()
|
||||
|
||||
Eventually(func() float64 {
|
||||
calls := mc.publishCalls(messaging.SubjectNodeBackendInstallProgress("n1", "op1"))
|
||||
if len(calls) == 0 {
|
||||
return -1
|
||||
}
|
||||
return calls[len(calls)-1].Percentage
|
||||
}, "1s").Should(Equal(100.0))
|
||||
})
|
||||
})
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/mudler/xlog"
|
||||
@@ -48,6 +49,13 @@ func (d *DistributedModelManager) InstallModel(ctx context.Context, op *galleryo
|
||||
return d.local.InstallModel(ctx, op, progressCb)
|
||||
}
|
||||
|
||||
// nodeProgressSink is the narrow interface DistributedBackendManager uses to
|
||||
// publish per-node progress without dragging in the full *GalleryService.
|
||||
// nil means "no sink, skip per-node writes" (used by single-node tests).
|
||||
type nodeProgressSink interface {
|
||||
UpdateNodeProgress(opID, nodeID string, np galleryop.NodeProgress)
|
||||
}
|
||||
|
||||
// DistributedBackendManager wraps a local BackendManager and adds NATS fan-out
|
||||
// for backend deletion so worker nodes clean up stale files.
|
||||
type DistributedBackendManager struct {
|
||||
@@ -56,26 +64,31 @@ type DistributedBackendManager struct {
|
||||
registry *NodeRegistry
|
||||
backendGalleries []config.Gallery
|
||||
systemState *system.SystemState
|
||||
progressSink nodeProgressSink
|
||||
}
|
||||
|
||||
// NewDistributedBackendManager creates a DistributedBackendManager.
|
||||
func NewDistributedBackendManager(appConfig *config.ApplicationConfig, ml *model.ModelLoader, adapter *RemoteUnloaderAdapter, registry *NodeRegistry) *DistributedBackendManager {
|
||||
// progressSink may be nil to disable per-node OpStatus writes (single-node
|
||||
// tests don't need it).
|
||||
func NewDistributedBackendManager(appConfig *config.ApplicationConfig, ml *model.ModelLoader, adapter *RemoteUnloaderAdapter, registry *NodeRegistry, progressSink nodeProgressSink) *DistributedBackendManager {
|
||||
return &DistributedBackendManager{
|
||||
local: galleryop.NewLocalBackendManager(appConfig, ml),
|
||||
adapter: adapter,
|
||||
registry: registry,
|
||||
backendGalleries: appConfig.BackendGalleries,
|
||||
systemState: appConfig.SystemState,
|
||||
progressSink: progressSink,
|
||||
}
|
||||
}
|
||||
|
||||
// NodeOpStatus is the per-node outcome of a backend lifecycle operation.
|
||||
// Returned as part of BackendOpResult so the frontend can surface exactly
|
||||
// what happened on each worker instead of a single joined error string.
|
||||
// Status holds one of the galleryop.NodeStatus* constants.
|
||||
type NodeOpStatus struct {
|
||||
NodeID string `json:"node_id"`
|
||||
NodeName string `json:"node_name"`
|
||||
Status string `json:"status"` // "success" | "queued" | "error"
|
||||
Status string `json:"status"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
@@ -93,7 +106,7 @@ type BackendOpResult struct {
|
||||
func (r BackendOpResult) Err() error {
|
||||
var failures []string
|
||||
for _, n := range r.Nodes {
|
||||
if n.Status == "error" {
|
||||
if n.Status == galleryop.NodeStatusError {
|
||||
failures = append(failures, fmt.Sprintf("%s: %s", n.NodeName, n.Error))
|
||||
}
|
||||
}
|
||||
@@ -116,25 +129,48 @@ func (r BackendOpResult) Err() error {
|
||||
// when the node returns.
|
||||
// targetNodeIDs is an optional allowlist: when non-nil, only nodes whose ID is
|
||||
// in the set are visited. Used by UpgradeBackend to avoid asking nodes that
|
||||
// never had the backend installed to "upgrade" it — such requests fail at the
|
||||
// never had the backend installed to "upgrade" it - such requests fail at the
|
||||
// gallery (no platform variant) and would otherwise leave a forever-retrying
|
||||
// pending_backend_ops row. nil means "fan out to every node" (Install/Delete).
|
||||
func (d *DistributedBackendManager) enqueueAndDrainBackendOp(ctx context.Context, op, backend string, galleriesJSON []byte, targetNodeIDs map[string]bool, apply func(node BackendNode) error) (BackendOpResult, error) {
|
||||
//
|
||||
// opID is the gallery operation identifier; when non-empty and progressSink is
|
||||
// set, every per-node terminal status appended to BackendOpResult is also
|
||||
// mirrored into the sink so the UI's per-node OpStatus.Nodes view stays in
|
||||
// lockstep with the manager's view. opID may be empty for ops that aren't
|
||||
// gallery-tracked (e.g. DeleteBackend's plain code path).
|
||||
func (d *DistributedBackendManager) enqueueAndDrainBackendOp(ctx context.Context, opID, op, backend string, galleriesJSON []byte, targetNodeIDs map[string]bool, apply func(node BackendNode) error) (BackendOpResult, error) {
|
||||
allNodes, err := d.registry.List(ctx)
|
||||
if err != nil {
|
||||
return BackendOpResult{}, err
|
||||
}
|
||||
|
||||
// emitNodeProgress is a small helper that funnels every NodeOpStatus we
|
||||
// append to result.Nodes into the per-node OpStatus sink (when configured
|
||||
// and opID is known). Keeping it inline avoids drift between the
|
||||
// BackendOpResult view and the sink view - they're written from the same
|
||||
// code path on the same terminal statuses.
|
||||
emitNodeProgress := func(node BackendNode, status, errMsg string) {
|
||||
if d.progressSink == nil || opID == "" {
|
||||
return
|
||||
}
|
||||
d.progressSink.UpdateNodeProgress(opID, node.ID, galleryop.NodeProgress{
|
||||
NodeID: node.ID,
|
||||
NodeName: node.Name,
|
||||
Status: status,
|
||||
Error: errMsg,
|
||||
})
|
||||
}
|
||||
|
||||
result := BackendOpResult{Nodes: make([]NodeOpStatus, 0, len(allNodes))}
|
||||
for _, node := range allNodes {
|
||||
// Pending nodes haven't been approved yet — no intent to apply.
|
||||
// Pending nodes haven't been approved yet - no intent to apply.
|
||||
if node.Status == StatusPending {
|
||||
continue
|
||||
}
|
||||
// Backend lifecycle ops only make sense on backend-type workers.
|
||||
// Agent workers don't subscribe to backend.install/delete/list, so
|
||||
// enqueueing for them guarantees a forever-retrying row that the
|
||||
// reconciler can never drain. Silently skip — they aren't consumers.
|
||||
// reconciler can never drain. Silently skip - they aren't consumers.
|
||||
if node.NodeType != "" && node.NodeType != NodeTypeBackend {
|
||||
continue
|
||||
}
|
||||
@@ -143,19 +179,23 @@ func (d *DistributedBackendManager) enqueueAndDrainBackendOp(ctx context.Context
|
||||
}
|
||||
if err := d.registry.UpsertPendingBackendOp(ctx, node.ID, backend, op, galleriesJSON); err != nil {
|
||||
xlog.Warn("Failed to enqueue backend op", "op", op, "node", node.Name, "backend", backend, "error", err)
|
||||
errMsg := fmt.Sprintf("enqueue failed: %v", err)
|
||||
result.Nodes = append(result.Nodes, NodeOpStatus{
|
||||
NodeID: node.ID, NodeName: node.Name, Status: "error",
|
||||
Error: fmt.Sprintf("enqueue failed: %v", err),
|
||||
NodeID: node.ID, NodeName: node.Name, Status: galleryop.NodeStatusError,
|
||||
Error: errMsg,
|
||||
})
|
||||
emitNodeProgress(node, galleryop.NodeStatusError, errMsg)
|
||||
continue
|
||||
}
|
||||
|
||||
if node.Status != StatusHealthy {
|
||||
// Intent is recorded; reconciler will retry when the node recovers.
|
||||
errMsg := fmt.Sprintf("node %s, will retry when healthy", node.Status)
|
||||
result.Nodes = append(result.Nodes, NodeOpStatus{
|
||||
NodeID: node.ID, NodeName: node.Name, Status: "queued",
|
||||
Error: fmt.Sprintf("node %s, will retry when healthy", node.Status),
|
||||
NodeID: node.ID, NodeName: node.Name, Status: galleryop.NodeStatusQueued,
|
||||
Error: errMsg,
|
||||
})
|
||||
emitNodeProgress(node, galleryop.NodeStatusQueued, errMsg)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -167,14 +207,33 @@ func (d *DistributedBackendManager) enqueueAndDrainBackendOp(ctx context.Context
|
||||
xlog.Debug("Failed to clear pending backend op after success", "error", err)
|
||||
}
|
||||
result.Nodes = append(result.Nodes, NodeOpStatus{
|
||||
NodeID: node.ID, NodeName: node.Name, Status: "success",
|
||||
NodeID: node.ID, NodeName: node.Name, Status: galleryop.NodeStatusSuccess,
|
||||
})
|
||||
emitNodeProgress(node, galleryop.NodeStatusSuccess, "")
|
||||
continue
|
||||
}
|
||||
|
||||
// Record failure for backoff. If it's an ErrNoResponders, the node's
|
||||
// gone AWOL — mark unhealthy so the router stops picking it too.
|
||||
// gone AWOL - mark unhealthy so the router stops picking it too.
|
||||
errMsg := applyErr.Error()
|
||||
|
||||
// Worker-still-installing is a "soft" failure: the worker is most
|
||||
// likely still pulling the OCI image. Keep the row, push NextRetryAt
|
||||
// out so the reconciler does not immediately re-fire another install
|
||||
// while the worker is still busy, and report the in-progress state
|
||||
// to the caller. The next reconciler pass / backend.list confirms
|
||||
// the actual outcome.
|
||||
if errors.Is(applyErr, galleryop.ErrWorkerStillInstalling) {
|
||||
if id, err := d.findPendingRow(ctx, node.ID, backend, op); err == nil {
|
||||
_ = d.registry.RecordPendingBackendOpInFlight(ctx, id, errMsg, d.adapter.InstallTimeout())
|
||||
}
|
||||
result.Nodes = append(result.Nodes, NodeOpStatus{
|
||||
NodeID: node.ID, NodeName: node.Name, Status: galleryop.NodeStatusRunningOnWorker, Error: errMsg,
|
||||
})
|
||||
emitNodeProgress(node, galleryop.NodeStatusRunningOnWorker, errMsg)
|
||||
continue
|
||||
}
|
||||
|
||||
if errors.Is(applyErr, nats.ErrNoResponders) {
|
||||
xlog.Warn("No NATS responders for node, marking unhealthy", "node", node.Name, "nodeID", node.ID)
|
||||
d.registry.MarkUnhealthy(ctx, node.ID)
|
||||
@@ -183,8 +242,9 @@ func (d *DistributedBackendManager) enqueueAndDrainBackendOp(ctx context.Context
|
||||
_ = d.registry.RecordPendingBackendOpFailure(ctx, id, errMsg)
|
||||
}
|
||||
result.Nodes = append(result.Nodes, NodeOpStatus{
|
||||
NodeID: node.ID, NodeName: node.Name, Status: "error", Error: errMsg,
|
||||
NodeID: node.ID, NodeName: node.Name, Status: galleryop.NodeStatusError, Error: errMsg,
|
||||
})
|
||||
emitNodeProgress(node, galleryop.NodeStatusError, errMsg)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
@@ -226,7 +286,11 @@ func (d *DistributedBackendManager) DeleteBackend(name string) error {
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := d.enqueueAndDrainBackendOp(ctx, OpBackendDelete, name, nil, nil, func(node BackendNode) error {
|
||||
// Empty opID: plain DeleteBackend isn't gallery-tracked the same way as
|
||||
// Install/Upgrade (no progress dialog), so we skip the per-node sink
|
||||
// writes here. DeleteBackendDetailed is the HTTP path that surfaces
|
||||
// per-node results in its own response.
|
||||
result, err := d.enqueueAndDrainBackendOp(ctx, "", OpBackendDelete, name, nil, nil, func(node BackendNode) error {
|
||||
reply, err := d.adapter.DeleteBackend(node.ID, name)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -249,7 +313,7 @@ func (d *DistributedBackendManager) DeleteBackendDetailed(ctx context.Context, n
|
||||
if err := d.local.DeleteBackend(name); err != nil && !errors.Is(err, gallery.ErrBackendNotFound) {
|
||||
return BackendOpResult{}, err
|
||||
}
|
||||
return d.enqueueAndDrainBackendOp(ctx, OpBackendDelete, name, nil, nil, func(node BackendNode) error {
|
||||
return d.enqueueAndDrainBackendOp(ctx, "", OpBackendDelete, name, nil, nil, func(node BackendNode) error {
|
||||
reply, err := d.adapter.DeleteBackend(node.ID, name)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -324,9 +388,60 @@ func (d *DistributedBackendManager) ListBackends() (gallery.SystemBackends, erro
|
||||
result[b.Name] = entry
|
||||
}
|
||||
}
|
||||
|
||||
// Proactively clear pending_backend_ops install rows whose intent is now
|
||||
// satisfied: the backend is reported installed on its target node. Without
|
||||
// this, the row sits in the queue until next_retry_at expires (up to the
|
||||
// install timeout, default 15m) and the operator UI shows the install as
|
||||
// "still installing in background" for that whole window even though the
|
||||
// worker has actually been ready for minutes. We only clear install rows;
|
||||
// upgrade and delete rows have presence-based semantics that do NOT match
|
||||
// backend.list confirmation.
|
||||
d.clearSatisfiedInstallRows(context.Background(), result)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// clearSatisfiedInstallRows removes pending_backend_ops install rows whose
|
||||
// (nodeID, backend) pair now appears in the cluster-wide backend listing.
|
||||
// Called by ListBackends after fan-out so the proactive clear sees every
|
||||
// node's report. Best-effort: a DB failure is logged and the row stays for
|
||||
// the reconciler to drain via its slower path.
|
||||
func (d *DistributedBackendManager) clearSatisfiedInstallRows(ctx context.Context, backends gallery.SystemBackends) {
|
||||
rows, err := d.registry.ListPendingBackendOps(ctx)
|
||||
if err != nil {
|
||||
xlog.Debug("clearSatisfiedInstallRows: failed to list pending ops", "error", err)
|
||||
return
|
||||
}
|
||||
if len(rows) == 0 {
|
||||
return
|
||||
}
|
||||
// Build a (nodeID, backend) presence set from the listing.
|
||||
present := make(map[string]map[string]bool, len(backends))
|
||||
for name, b := range backends {
|
||||
for _, ref := range b.Nodes {
|
||||
if present[ref.NodeID] == nil {
|
||||
present[ref.NodeID] = make(map[string]bool)
|
||||
}
|
||||
present[ref.NodeID][name] = true
|
||||
}
|
||||
}
|
||||
for _, row := range rows {
|
||||
if row.Op != OpBackendInstall {
|
||||
continue
|
||||
}
|
||||
if !present[row.NodeID][row.Backend] {
|
||||
continue
|
||||
}
|
||||
if err := d.registry.DeletePendingBackendOp(ctx, row.ID); err != nil {
|
||||
xlog.Debug("clearSatisfiedInstallRows: delete failed",
|
||||
"id", row.ID, "node", row.NodeID, "backend", row.Backend, "error", err)
|
||||
continue
|
||||
}
|
||||
xlog.Info("Reconciler: pending install row satisfied by backend.list",
|
||||
"node", row.NodeID, "backend", row.Backend)
|
||||
}
|
||||
}
|
||||
|
||||
// InstallBackend fans out installation through the pending-ops queue so
|
||||
// non-healthy nodes get retried when they come back instead of being silently
|
||||
// skipped. Reply success from the NATS round-trip deletes the queue row;
|
||||
@@ -345,11 +460,41 @@ func (d *DistributedBackendManager) InstallBackend(ctx context.Context, op *gall
|
||||
targetNodeIDs = map[string]bool{op.TargetNodeID: true}
|
||||
}
|
||||
|
||||
result, err := d.enqueueAndDrainBackendOp(ctx, OpBackendInstall, backendName, galleriesJSON, targetNodeIDs, func(node BackendNode) error {
|
||||
result, err := d.enqueueAndDrainBackendOp(ctx, op.ID, OpBackendInstall, backendName, galleriesJSON, targetNodeIDs, func(node BackendNode) error {
|
||||
// onProgress fans each BackendInstallProgressEvent into two
|
||||
// observers: the legacy single-bar progressCb (kept so callers
|
||||
// that only consume the aggregate view keep working) and the
|
||||
// per-node sink (so OpStatus.Nodes gets a "downloading" tick
|
||||
// per file/percentage with node attribution). Defined inside the
|
||||
// loop so each node captures its own node.Name into the closure.
|
||||
onProgress := func(ev messaging.BackendInstallProgressEvent) {
|
||||
if progressCb != nil {
|
||||
progressCb(ev.FileName, ev.Current, ev.Total, ev.Percentage)
|
||||
}
|
||||
if d.progressSink != nil && op.ID != "" {
|
||||
d.progressSink.UpdateNodeProgress(op.ID, ev.NodeID, galleryop.NodeProgress{
|
||||
NodeID: ev.NodeID,
|
||||
NodeName: node.Name,
|
||||
Status: galleryop.NodeStatusDownloading,
|
||||
FileName: ev.FileName,
|
||||
Current: ev.Current,
|
||||
Total: ev.Total,
|
||||
Percentage: ev.Percentage,
|
||||
Phase: ev.Phase,
|
||||
})
|
||||
}
|
||||
}
|
||||
// nil-callback shortcut: when there is nothing to deliver to,
|
||||
// hand the adapter a nil onProgress so it skips the per-op NATS
|
||||
// subscription. Matches the pre-Phase-4 bridgeProgressCb semantics.
|
||||
var onProgressArg func(messaging.BackendInstallProgressEvent)
|
||||
if progressCb != nil || d.progressSink != nil {
|
||||
onProgressArg = onProgress
|
||||
}
|
||||
// Admin-driven backend install: not tied to a specific replica slot.
|
||||
// Pass replica 0 - the worker's processKey is "backend#0" when no
|
||||
// modelID is supplied, matching pre-PR4 behavior.
|
||||
reply, err := d.adapter.InstallBackend(node.ID, backendName, "", string(galleriesJSON), op.ExternalURI, op.ExternalName, op.ExternalAlias, 0)
|
||||
reply, err := d.adapter.InstallBackend(node.ID, backendName, "", string(galleriesJSON), op.ExternalURI, op.ExternalName, op.ExternalAlias, 0, op.ID, onProgressArg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -361,7 +506,19 @@ func (d *DistributedBackendManager) InstallBackend(ctx context.Context, op *gall
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return result.Err()
|
||||
if hardErr := result.Err(); hardErr != nil {
|
||||
return hardErr
|
||||
}
|
||||
// No hard failures, but if at least one node reported running_on_worker,
|
||||
// surface a wrapped ErrWorkerStillInstalling so galleryop can render a
|
||||
// yellow in-progress state instead of green success. The reconciler
|
||||
// will confirm the actual outcome on its next pass via backend.list.
|
||||
for _, n := range result.Nodes {
|
||||
if n.Status == galleryop.NodeStatusRunningOnWorker {
|
||||
return fmt.Errorf("%w: %s", galleryop.ErrWorkerStillInstalling, summarizeRunningOnWorker(result.Nodes))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpgradeBackend uses a separate NATS subject (backend.upgrade) so the slow
|
||||
@@ -392,7 +549,11 @@ func (d *DistributedBackendManager) UpgradeBackend(ctx context.Context, name str
|
||||
targetNodeIDs[n.NodeID] = true
|
||||
}
|
||||
|
||||
result, err := d.enqueueAndDrainBackendOp(ctx, OpBackendUpgrade, name, galleriesJSON, targetNodeIDs, func(node BackendNode) error {
|
||||
// Empty opID: the caller (galleryop) doesn't thread an op ID into
|
||||
// UpgradeBackend today, so we can't tag per-node sink writes with the
|
||||
// right OpStatus key. Until the upgrade path takes a ManagementOp the
|
||||
// way InstallBackend does, the sink stays no-op here.
|
||||
result, err := d.enqueueAndDrainBackendOp(ctx, "", OpBackendUpgrade, name, galleriesJSON, targetNodeIDs, func(node BackendNode) error {
|
||||
reply, err := d.adapter.UpgradeBackend(node.ID, name, string(galleriesJSON), "", "", "", 0)
|
||||
if err != nil {
|
||||
// Rolling-update fallback: an older worker doesn't know
|
||||
@@ -417,7 +578,18 @@ func (d *DistributedBackendManager) UpgradeBackend(ctx context.Context, name str
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return result.Err()
|
||||
if hardErr := result.Err(); hardErr != nil {
|
||||
return hardErr
|
||||
}
|
||||
// Same in-progress surfacing as InstallBackend: a long-running worker
|
||||
// upgrade that timed out the NATS round-trip must not be reported as
|
||||
// green success.
|
||||
for _, n := range result.Nodes {
|
||||
if n.Status == galleryop.NodeStatusRunningOnWorker {
|
||||
return fmt.Errorf("%w: %s", galleryop.ErrWorkerStillInstalling, summarizeRunningOnWorker(result.Nodes))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsDistributed reports that installs from this manager fan out across the
|
||||
@@ -443,3 +615,16 @@ func (d *DistributedBackendManager) CheckUpgrades(ctx context.Context) (map[stri
|
||||
// it used to come from the empty frontend filesystem.
|
||||
return gallery.CheckUpgradesAgainst(ctx, d.backendGalleries, d.systemState, installed)
|
||||
}
|
||||
|
||||
// summarizeRunningOnWorker builds a short human-readable summary of which
|
||||
// nodes are still installing in the background, for inclusion in the
|
||||
// wrapped ErrWorkerStillInstalling error.
|
||||
func summarizeRunningOnWorker(nodes []NodeOpStatus) string {
|
||||
var names []string
|
||||
for _, n := range nodes {
|
||||
if n.Status == galleryop.NodeStatusRunningOnWorker {
|
||||
names = append(names, n.NodeName)
|
||||
}
|
||||
}
|
||||
return strings.Join(names, ", ")
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package nodes
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -12,6 +13,7 @@ import (
|
||||
. "github.com/onsi/gomega"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
@@ -22,11 +24,35 @@ import (
|
||||
// (or error). Used so each fan-out request can simulate a different worker
|
||||
// outcome without spinning up real NATS.
|
||||
type scriptedMessagingClient struct {
|
||||
mu sync.Mutex
|
||||
replies map[string][]byte
|
||||
errs map[string]error
|
||||
calls []requestCall
|
||||
matchedReplies map[string][]matchedReply
|
||||
mu sync.Mutex
|
||||
replies map[string][]byte
|
||||
errs map[string]error
|
||||
calls []requestCall
|
||||
matchedReplies map[string][]matchedReply
|
||||
publishes []progressPublishCall
|
||||
scheduledProgressPublishes []scheduledProgressPublish
|
||||
subscribes []string
|
||||
}
|
||||
|
||||
// progressPublishCall records a single Publish invocation. The progress
|
||||
// publisher tests assert on the sequence of BackendInstallProgressEvent
|
||||
// values written to a per-op subject, so we capture both subject and the
|
||||
// decoded event. Named to avoid clashing with the simpler `publishCall`
|
||||
// already defined in unloader_test.go (which stores raw JSON bytes for
|
||||
// non-progress assertions).
|
||||
type progressPublishCall struct {
|
||||
Subject string
|
||||
Event messaging.BackendInstallProgressEvent
|
||||
}
|
||||
|
||||
// scheduledProgressPublish queues a batch of BackendInstallProgressEvent
|
||||
// values to be delivered the next time Subscribe is called with the matching
|
||||
// subject. This lets master-side tests assert that the adapter installs its
|
||||
// handler BEFORE publishing the install request, by scripting events to be
|
||||
// delivered as soon as the subscription appears.
|
||||
type scheduledProgressPublish struct {
|
||||
subject string
|
||||
events []messaging.BackendInstallProgressEvent
|
||||
}
|
||||
|
||||
// matchedReply lets a test script a canned reply that only fires when the
|
||||
@@ -98,10 +124,10 @@ func (s *scriptedMessagingClient) scriptReplyMatching(subject string, pred func(
|
||||
})
|
||||
}
|
||||
|
||||
func (s *scriptedMessagingClient) Request(subject string, data []byte, _ time.Duration) ([]byte, error) {
|
||||
func (s *scriptedMessagingClient) Request(subject string, data []byte, timeout time.Duration) ([]byte, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.calls = append(s.calls, requestCall{Subject: subject, Data: data})
|
||||
s.calls = append(s.calls, requestCall{Subject: subject, Data: data, Timeout: timeout})
|
||||
|
||||
// Predicate-matched replies take precedence over flat scriptReply.
|
||||
if matchers, ok := s.matchedReplies[subject]; ok {
|
||||
@@ -135,8 +161,88 @@ func (s *scriptedMessagingClient) Request(subject string, data []byte, _ time.Du
|
||||
return nil, &fakeNoRespondersErr{}
|
||||
}
|
||||
|
||||
func (s *scriptedMessagingClient) Publish(_ string, _ any) error { return nil }
|
||||
func (s *scriptedMessagingClient) Subscribe(_ string, _ func([]byte)) (messaging.Subscription, error) {
|
||||
// Publish records each call so progress-publisher tests can assert on the
|
||||
// stream of events written to a subject. The real messaging.Client JSON
|
||||
// encodes the payload before sending, but our publisher hands a typed
|
||||
// struct directly, so we handle both shapes.
|
||||
func (s *scriptedMessagingClient) Publish(subject string, data any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
switch ev := data.(type) {
|
||||
case messaging.BackendInstallProgressEvent:
|
||||
s.publishes = append(s.publishes, progressPublishCall{Subject: subject, Event: ev})
|
||||
case []byte:
|
||||
var e messaging.BackendInstallProgressEvent
|
||||
_ = json.Unmarshal(ev, &e)
|
||||
s.publishes = append(s.publishes, progressPublishCall{Subject: subject, Event: e})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// publishCalls returns every BackendInstallProgressEvent that was published
|
||||
// to `subject`, in order. Lets tests assert on debounce behavior without
|
||||
// depending on internal Publish timing.
|
||||
func (s *scriptedMessagingClient) publishCalls(subject string) []messaging.BackendInstallProgressEvent {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
out := make([]messaging.BackendInstallProgressEvent, 0)
|
||||
for _, c := range s.publishes {
|
||||
if c.Subject != subject {
|
||||
continue
|
||||
}
|
||||
out = append(out, c.Event)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// scheduleProgressPublish queues a set of BackendInstallProgressEvent values
|
||||
// to be delivered on the next Subscribe call matching the per-op progress
|
||||
// subject. A short delay before delivery gives the subscriber time to install
|
||||
// its message handler before the events arrive.
|
||||
func (s *scriptedMessagingClient) scheduleProgressPublish(nodeID, opID string, events []messaging.BackendInstallProgressEvent) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.scheduledProgressPublishes = append(s.scheduledProgressPublishes, scheduledProgressPublish{
|
||||
subject: messaging.SubjectNodeBackendInstallProgress(nodeID, opID),
|
||||
events: events,
|
||||
})
|
||||
}
|
||||
|
||||
// subscribeCalls returns the subjects on which Subscribe was invoked.
|
||||
// Used to confirm the master skipped subscription when onProgress was nil.
|
||||
func (s *scriptedMessagingClient) subscribeCalls() []string {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
out := make([]string, len(s.subscribes))
|
||||
copy(out, s.subscribes)
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *scriptedMessagingClient) Subscribe(subject string, handler func([]byte)) (messaging.Subscription, error) {
|
||||
s.mu.Lock()
|
||||
s.subscribes = append(s.subscribes, subject)
|
||||
matched := []scheduledProgressPublish{}
|
||||
remaining := s.scheduledProgressPublishes[:0]
|
||||
for _, sp := range s.scheduledProgressPublishes {
|
||||
if sp.subject == subject {
|
||||
matched = append(matched, sp)
|
||||
} else {
|
||||
remaining = append(remaining, sp)
|
||||
}
|
||||
}
|
||||
s.scheduledProgressPublishes = remaining
|
||||
s.mu.Unlock()
|
||||
|
||||
go func() {
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
for _, sp := range matched {
|
||||
for _, ev := range sp.events {
|
||||
raw, _ := json.Marshal(ev)
|
||||
handler(raw)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return &fakeSubscription{}, nil
|
||||
}
|
||||
func (s *scriptedMessagingClient) QueueSubscribe(_ string, _ string, _ func([]byte)) (messaging.Subscription, error) {
|
||||
@@ -151,8 +257,43 @@ func (s *scriptedMessagingClient) SubscribeReply(_ string, _ func([]byte, func([
|
||||
func (s *scriptedMessagingClient) IsConnected() bool { return true }
|
||||
func (s *scriptedMessagingClient) Close() {}
|
||||
|
||||
// recordingNodeCall captures a single UpdateNodeProgress invocation so
|
||||
// per-node OpStatus tests can assert on the sequence of writes the
|
||||
// DistributedBackendManager fans out into the sink.
|
||||
type recordingNodeCall struct {
|
||||
OpID string
|
||||
NodeID string
|
||||
Progress galleryop.NodeProgress
|
||||
}
|
||||
|
||||
// recordingProgressSink is a test-only nodeProgressSink that just records
|
||||
// every call. Used by the per-node OpStatus specs below to assert the
|
||||
// manager wrote the expected terminal and downloading entries.
|
||||
type recordingProgressSink struct {
|
||||
mu sync.Mutex
|
||||
calls []recordingNodeCall
|
||||
}
|
||||
|
||||
func (r *recordingProgressSink) UpdateNodeProgress(opID, nodeID string, np galleryop.NodeProgress) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.calls = append(r.calls, recordingNodeCall{OpID: opID, NodeID: nodeID, Progress: np})
|
||||
}
|
||||
|
||||
func (r *recordingProgressSink) callsFor(opID, nodeID string) []galleryop.NodeProgress {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
out := []galleryop.NodeProgress{}
|
||||
for _, c := range r.calls {
|
||||
if c.OpID == opID && c.NodeID == nodeID {
|
||||
out = append(out, c.Progress)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// fakeNoRespondersErr is the unscripted-subject default. It matches
|
||||
// nats.ErrNoResponders by string only — used when a test forgets to script
|
||||
// nats.ErrNoResponders by string only - used when a test forgets to script
|
||||
// a node so the failure is loud but doesn't tickle errors.Is(...) sentinel
|
||||
// paths the test wasn't deliberately exercising. Tests that DO want the
|
||||
// real sentinel (e.g. to drive the manager's NoResponders fallback) call
|
||||
@@ -204,7 +345,7 @@ var _ = Describe("DistributedBackendManager", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
mc = newScriptedMessagingClient()
|
||||
adapter = NewRemoteUnloaderAdapter(nil, mc)
|
||||
adapter = NewRemoteUnloaderAdapter(nil, mc, 3*time.Minute, 15*time.Minute)
|
||||
mgr = &DistributedBackendManager{
|
||||
local: stubLocalBackendManager{},
|
||||
adapter: adapter,
|
||||
@@ -352,6 +493,263 @@ var _ = Describe("DistributedBackendManager", func() {
|
||||
Expect(mc.calls).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Context("when InstallBackend times out on a worker", func() {
|
||||
It("returns galleryop.ErrWorkerStillInstalling and keeps the queue row with NextRetryAt pushed out", func() {
|
||||
n := registerHealthyBackend("slow", "10.0.0.1:50051")
|
||||
|
||||
// Script a NATS timeout on the install subject. The adapter
|
||||
// wraps this into galleryop.ErrWorkerStillInstalling, which
|
||||
// the manager should treat as a soft failure.
|
||||
mc.scriptErr(messaging.SubjectNodeBackendInstall(n.ID), nats.ErrTimeout)
|
||||
|
||||
err := mgr.InstallBackend(ctx, op("vllm"), nil)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(errors.Is(err, galleryop.ErrWorkerStillInstalling)).To(BeTrue(),
|
||||
"expected wrapped ErrWorkerStillInstalling, got %v", err)
|
||||
|
||||
rows, err := registry.ListPendingBackendOps(ctx)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rows).To(HaveLen(1))
|
||||
Expect(rows[0].Backend).To(Equal("vllm"))
|
||||
// The adapter is configured with a 3m install timeout in this
|
||||
// suite (NewRemoteUnloaderAdapter above). NextRetryAt should
|
||||
// be ~now+3m; a > now+2m bound is safe-but-tight enough to
|
||||
// catch the buggy short default (30s exponential backoff).
|
||||
Expect(rows[0].NextRetryAt).To(BeTemporally(">", time.Now().Add(2*time.Minute)),
|
||||
"NextRetryAt should be pushed to ~now+installTimeout, not the short default")
|
||||
})
|
||||
})
|
||||
|
||||
Context("end-to-end: timeout then successful reconcile via backend.list", func() {
|
||||
It("surfaces the install in ListBackends after the worker finishes", func() {
|
||||
// Use the same node-registration helper the Task 5 test uses
|
||||
// so the test fixture is identical to the prior context.
|
||||
node := registerHealthyBackend("jetson", "10.0.0.2:50051")
|
||||
|
||||
// First install attempt: NATS times out. The adapter wraps
|
||||
// this as galleryop.ErrWorkerStillInstalling and the manager
|
||||
// keeps the pending_backend_ops row alive with NextRetryAt
|
||||
// pushed out (asserted in the previous context).
|
||||
mc.scriptErr(messaging.SubjectNodeBackendInstall(node.ID), nats.ErrTimeout)
|
||||
|
||||
err := mgr.InstallBackend(ctx, op("vllm"), nil)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(errors.Is(err, galleryop.ErrWorkerStillInstalling)).To(BeTrue(),
|
||||
"expected wrapped ErrWorkerStillInstalling, got %v", err)
|
||||
|
||||
rows, listErr := registry.ListPendingBackendOps(ctx)
|
||||
Expect(listErr).ToNot(HaveOccurred())
|
||||
Expect(rows).To(HaveLen(1))
|
||||
|
||||
// The worker finished installing in the background. Script
|
||||
// backend.list on the same scriptedMessagingClient so the
|
||||
// manager's ListBackends fan-out reports the backend.
|
||||
mc.scriptReply(messaging.SubjectNodeBackendList(node.ID), messaging.BackendListReply{
|
||||
Backends: []messaging.NodeBackendInfo{{Name: "vllm"}},
|
||||
})
|
||||
|
||||
backends, listErr := mgr.ListBackends()
|
||||
Expect(listErr).ToNot(HaveOccurred())
|
||||
Expect(backends).To(HaveKey("vllm"))
|
||||
Expect(backends["vllm"].Nodes).To(HaveLen(1))
|
||||
Expect(backends["vllm"].Nodes[0].NodeID).To(Equal(node.ID))
|
||||
|
||||
// Phase 1b shipped: ListBackends proactively clears install rows
|
||||
// whose intent is now satisfied by backend.list confirmation. The
|
||||
// operator UI clears immediately instead of waiting for the next
|
||||
// reconciler tick after NextRetryAt.
|
||||
rowsAfter, _ := registry.ListPendingBackendOps(ctx)
|
||||
Expect(rowsAfter).To(BeEmpty(),
|
||||
"install row should clear once backend.list confirms presence on the target node")
|
||||
})
|
||||
})
|
||||
|
||||
Context("ListBackends clears confirmed install rows", func() {
|
||||
It("deletes the pending_backend_ops install row when the backend is reported installed on its target node", func() {
|
||||
node := registerHealthyBackend("worker-a", "10.0.0.5:50051")
|
||||
|
||||
// Pre-stage: simulate an admin install that timed out at the NATS
|
||||
// round-trip, leaving an install row in the queue.
|
||||
mc.scriptErr(messaging.SubjectNodeBackendInstall(node.ID), nats.ErrTimeout)
|
||||
err := mgr.InstallBackend(ctx, op("vllm"), nil)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(errors.Is(err, galleryop.ErrWorkerStillInstalling)).To(BeTrue())
|
||||
|
||||
rows, _ := registry.ListPendingBackendOps(ctx)
|
||||
Expect(rows).To(HaveLen(1))
|
||||
|
||||
// Worker finishes installing in the background. backend.list now
|
||||
// confirms presence; ListBackends should proactively clear the row.
|
||||
mc.scriptReply(messaging.SubjectNodeBackendList(node.ID), messaging.BackendListReply{
|
||||
Backends: []messaging.NodeBackendInfo{{Name: "vllm"}},
|
||||
})
|
||||
|
||||
backends, listErr := mgr.ListBackends()
|
||||
Expect(listErr).ToNot(HaveOccurred())
|
||||
Expect(backends).To(HaveKey("vllm"))
|
||||
|
||||
rowsAfter, _ := registry.ListPendingBackendOps(ctx)
|
||||
Expect(rowsAfter).To(BeEmpty(),
|
||||
"ListBackends should clear install rows whose intent is now satisfied by backend.list")
|
||||
})
|
||||
|
||||
It("does NOT clear an upgrade row even if the backend is reported installed", func() {
|
||||
node := registerHealthyBackend("worker-b", "10.0.0.6:50051")
|
||||
|
||||
Expect(registry.UpsertPendingBackendOp(ctx, node.ID, "vllm", OpBackendUpgrade, []byte("[]"))).To(Succeed())
|
||||
|
||||
mc.scriptReply(messaging.SubjectNodeBackendList(node.ID), messaging.BackendListReply{
|
||||
Backends: []messaging.NodeBackendInfo{{Name: "vllm"}},
|
||||
})
|
||||
|
||||
_, listErr := mgr.ListBackends()
|
||||
Expect(listErr).ToNot(HaveOccurred())
|
||||
|
||||
rowsAfter, _ := registry.ListPendingBackendOps(ctx)
|
||||
Expect(rowsAfter).To(HaveLen(1), "upgrade rows must not be cleared by backend.list presence")
|
||||
})
|
||||
})
|
||||
|
||||
Context("InstallBackend streams progress events to the caller's progressCb", func() {
|
||||
It("invokes progressCb once per worker-published progress event", func() {
|
||||
node := registerHealthyBackend("worker-prog", "10.0.0.7:50051")
|
||||
|
||||
mc.scriptReply(messaging.SubjectNodeBackendInstall(node.ID), messaging.BackendInstallReply{Success: true, Address: "10.0.0.7:50051"})
|
||||
mc.scheduleProgressPublish(node.ID, "op-prog-1", []messaging.BackendInstallProgressEvent{
|
||||
{OpID: "op-prog-1", NodeID: node.ID, Backend: "vllm", FileName: "vllm.tar", Current: "100 MB", Total: "1 GB", Percentage: 10},
|
||||
{OpID: "op-prog-1", NodeID: node.ID, Backend: "vllm", FileName: "vllm.tar", Current: "1 GB", Total: "1 GB", Percentage: 100},
|
||||
})
|
||||
|
||||
type tick struct {
|
||||
FileName, Current, Total string
|
||||
Percentage float64
|
||||
}
|
||||
var (
|
||||
pcCalls []tick
|
||||
mu sync.Mutex
|
||||
)
|
||||
progressCb := func(file, current, total string, pct float64) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
pcCalls = append(pcCalls, tick{file, current, total, pct})
|
||||
}
|
||||
|
||||
opVal := op("vllm")
|
||||
opVal.ID = "op-prog-1"
|
||||
Expect(mgr.InstallBackend(ctx, opVal, progressCb)).To(Succeed())
|
||||
|
||||
Eventually(func() int {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return len(pcCalls)
|
||||
}, "1s").Should(Equal(2))
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
// The adapter dispatches each progress event to its own goroutine
|
||||
// (see unloader.go: `go onProgress(ev)`) so two events emitted back
|
||||
// to back can land at the bridge in either order. Assert the set of
|
||||
// percentages observed contains both ticks, rather than depending
|
||||
// on goroutine scheduling for ordering.
|
||||
pcts := []float64{pcCalls[0].Percentage, pcCalls[1].Percentage}
|
||||
Expect(pcts).To(ConsistOf(10.0, 100.0))
|
||||
})
|
||||
})
|
||||
|
||||
Context("InstallBackend tolerates silent (pre-Phase-2) workers", func() {
|
||||
It("completes successfully even when no progress events are ever published", func() {
|
||||
node := registerHealthyBackend("worker-silent", "10.0.0.8:50051")
|
||||
mc.scriptReply(messaging.SubjectNodeBackendInstall(node.ID), messaging.BackendInstallReply{Success: true, Address: "10.0.0.8:50051"})
|
||||
// NO scheduleProgressPublish call - silent worker.
|
||||
|
||||
var ticks int
|
||||
var mu sync.Mutex
|
||||
progressCb := func(file, current, total string, pct float64) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
ticks++
|
||||
}
|
||||
|
||||
opVal := op("vllm")
|
||||
opVal.ID = "op-silent-1"
|
||||
Expect(mgr.InstallBackend(ctx, opVal, progressCb)).To(Succeed())
|
||||
|
||||
Consistently(func() int {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return ticks
|
||||
}, "200ms").Should(Equal(0))
|
||||
})
|
||||
})
|
||||
|
||||
Context("populates per-node OpStatus entries", func() {
|
||||
var sink *recordingProgressSink
|
||||
|
||||
BeforeEach(func() {
|
||||
// Reconstruct mgr with the recording sink so the new code
|
||||
// path (per-node OpStatus writes) is exercised. The default
|
||||
// mgr in the outer BeforeEach has progressSink=nil so the
|
||||
// pre-existing specs keep verifying the no-sink behavior.
|
||||
sink = &recordingProgressSink{}
|
||||
appCfg := &config.ApplicationConfig{}
|
||||
mgr = NewDistributedBackendManager(appCfg, nil, adapter, registry, sink)
|
||||
// stubLocalBackendManager mirrors the production behaviour
|
||||
// where the frontend node rarely has the backend installed
|
||||
// locally - the NATS fan-out is what these specs verify.
|
||||
mgr.local = stubLocalBackendManager{}
|
||||
})
|
||||
|
||||
It("emits a success entry for each healthy node visited", func() {
|
||||
node := registerHealthyBackend("worker-ok", "10.0.0.9:50051")
|
||||
mc.scriptReply(messaging.SubjectNodeBackendInstall(node.ID),
|
||||
messaging.BackendInstallReply{Success: true, Address: "10.0.0.9:50051"})
|
||||
|
||||
opVal := op("vllm")
|
||||
opVal.ID = "op-node-success"
|
||||
Expect(mgr.InstallBackend(ctx, opVal, nil)).To(Succeed())
|
||||
|
||||
calls := sink.callsFor("op-node-success", node.ID)
|
||||
Expect(calls).ToNot(BeEmpty())
|
||||
Expect(calls[len(calls)-1].Status).To(Equal(galleryop.NodeStatusSuccess))
|
||||
Expect(calls[len(calls)-1].NodeName).To(Equal("worker-ok"))
|
||||
})
|
||||
|
||||
It("emits a running_on_worker entry when NATS times out", func() {
|
||||
node := registerHealthyBackend("worker-slow", "10.0.0.10:50051")
|
||||
mc.scriptErr(messaging.SubjectNodeBackendInstall(node.ID), nats.ErrTimeout)
|
||||
|
||||
opVal := op("vllm")
|
||||
opVal.ID = "op-node-slow"
|
||||
// Soft failure: returns wrapped ErrWorkerStillInstalling.
|
||||
_ = mgr.InstallBackend(ctx, opVal, nil)
|
||||
|
||||
calls := sink.callsFor("op-node-slow", node.ID)
|
||||
Expect(calls).ToNot(BeEmpty())
|
||||
Expect(calls[len(calls)-1].Status).To(Equal(galleryop.NodeStatusRunningOnWorker))
|
||||
})
|
||||
|
||||
It("emits downloading entries from progress events", func() {
|
||||
node := registerHealthyBackend("worker-dl", "10.0.0.11:50051")
|
||||
mc.scriptReply(messaging.SubjectNodeBackendInstall(node.ID),
|
||||
messaging.BackendInstallReply{Success: true})
|
||||
mc.scheduleProgressPublish(node.ID, "op-node-dl", []messaging.BackendInstallProgressEvent{
|
||||
{OpID: "op-node-dl", NodeID: node.ID, Backend: "vllm", FileName: "vllm.tar", Current: "1 GB", Total: "1 GB", Percentage: 100, Phase: messaging.PhaseDownloading},
|
||||
})
|
||||
|
||||
opVal := op("vllm")
|
||||
opVal.ID = "op-node-dl"
|
||||
Expect(mgr.InstallBackend(ctx, opVal, nil)).To(Succeed())
|
||||
|
||||
Eventually(func() bool {
|
||||
for _, np := range sink.callsFor("op-node-dl", node.ID) {
|
||||
if np.Status == galleryop.NodeStatusDownloading && np.Percentage == 100.0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}, "1s").Should(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("UpgradeBackend", func() {
|
||||
|
||||
@@ -68,6 +68,11 @@ func (a *ModelRouterAdapter) Route(ctx context.Context, backend, modelID, modelN
|
||||
// by SmartRouter. Use NewModelWithClient so the wrapper is preserved when
|
||||
// the ModelLoader returns this model on subsequent requests.
|
||||
m := model.NewModelWithClient(modelID, result.Node.Address, result.Client)
|
||||
// Stash the picked node ID so HTTP handlers can surface it via the
|
||||
// optional X-LocalAI-Node response header. Best-effort: the in-process
|
||||
// store keeps only the latest routing decision per modelID; see the
|
||||
// nodeID field comment on Model.
|
||||
m.SetNodeID(result.Node.ID)
|
||||
|
||||
xlog.Info("Model routed to remote node", "model", modelName, "node", result.Node.Name, "address", result.Node.Address)
|
||||
return m, nil
|
||||
|
||||
94
core/services/nodes/probe_cache.go
Normal file
94
core/services/nodes/probe_cache.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
// probeCache memoizes recent successful gRPC HealthCheck results for
|
||||
// (nodeID, addr) tuples so SmartRouter.probeHealth doesn't pay a round-trip
|
||||
// on every inference request.
|
||||
//
|
||||
// Why this exists: with per-request routing (see pkg/model/loader.go), every
|
||||
// inference call goes through SmartRouter.Route, which probes the backend
|
||||
// before returning a client. Many gRPC backends (notably llama.cpp's server)
|
||||
// serialize HealthCheck against active Predict on a shared goroutine, so a
|
||||
// burst of new requests can stall behind a single long-running stream —
|
||||
// exactly the "queue stalling" symptom observed in distributed clusters.
|
||||
//
|
||||
// The background HealthMonitor (perModelHealthCheck) is still the cluster-wide
|
||||
// source of truth that reaps actually-dead backends within ~45s; this cache
|
||||
// only saves the per-request hot path from re-asking when nothing has changed.
|
||||
//
|
||||
// TTL matches healthCheckTTL in pkg/model/model.go so the single-process
|
||||
// IsRecentlyHealthy path and this distributed-mode path share the same
|
||||
// staleness budget.
|
||||
type probeCache struct {
|
||||
ttl time.Duration
|
||||
mu sync.Mutex
|
||||
seen map[string]time.Time // key → last successful probe
|
||||
flight singleflight.Group // coalesces concurrent probes for the same key
|
||||
}
|
||||
|
||||
// newProbeCache returns a probeCache with the given TTL. Zero TTL disables
|
||||
// caching: every call to DoOrCached invokes the probe.
|
||||
func newProbeCache(ttl time.Duration) *probeCache {
|
||||
return &probeCache{
|
||||
ttl: ttl,
|
||||
seen: make(map[string]time.Time),
|
||||
}
|
||||
}
|
||||
|
||||
// IsFresh reports whether key was successfully probed within TTL.
|
||||
func (c *probeCache) IsFresh(key string) bool {
|
||||
if c.ttl <= 0 {
|
||||
return false
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
last, ok := c.seen[key]
|
||||
return ok && time.Since(last) < c.ttl
|
||||
}
|
||||
|
||||
// markFresh records key as successfully probed at the current time.
|
||||
func (c *probeCache) markFresh(key string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.seen[key] = time.Now()
|
||||
}
|
||||
|
||||
// Invalidate drops any cached freshness for key. Used after a probe failure
|
||||
// (or any other signal that the backend may not be alive) so the next call
|
||||
// will re-probe instead of trusting stale state.
|
||||
func (c *probeCache) Invalidate(key string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
delete(c.seen, key)
|
||||
}
|
||||
|
||||
// DoOrCached returns true if key is fresh; otherwise it runs probe (coalescing
|
||||
// concurrent callers via singleflight) and caches a successful result. Failed
|
||||
// probes invalidate the cache, so a transient miss doesn't pin every
|
||||
// subsequent request to a re-probe.
|
||||
func (c *probeCache) DoOrCached(key string, probe func() bool) bool {
|
||||
if c.IsFresh(key) {
|
||||
return true
|
||||
}
|
||||
v, _, _ := c.flight.Do(key, func() (any, error) {
|
||||
// Double-check after potentially waiting: another caller in this
|
||||
// flight may have just populated the cache.
|
||||
if c.IsFresh(key) {
|
||||
return true, nil
|
||||
}
|
||||
ok := probe()
|
||||
if ok {
|
||||
c.markFresh(key)
|
||||
} else {
|
||||
c.Invalidate(key)
|
||||
}
|
||||
return ok, nil
|
||||
})
|
||||
return v.(bool)
|
||||
}
|
||||
145
core/services/nodes/probe_cache_test.go
Normal file
145
core/services/nodes/probe_cache_test.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("probeCache", func() {
|
||||
It("invokes the probe on a cold cache and caches success", func() {
|
||||
c := newProbeCache(time.Minute)
|
||||
var calls int32
|
||||
probe := func() bool {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
return true
|
||||
}
|
||||
|
||||
Expect(c.DoOrCached("k", probe)).To(BeTrue())
|
||||
Expect(c.DoOrCached("k", probe)).To(BeTrue())
|
||||
Expect(c.DoOrCached("k", probe)).To(BeTrue())
|
||||
|
||||
// Cached: probe ran once.
|
||||
Expect(atomic.LoadInt32(&calls)).To(Equal(int32(1)))
|
||||
})
|
||||
|
||||
It("re-probes after the TTL expires", func() {
|
||||
// 1 ms TTL means the second call is virtually guaranteed to see an
|
||||
// expired entry without flaking on scheduler jitter.
|
||||
c := newProbeCache(time.Millisecond)
|
||||
var calls int32
|
||||
probe := func() bool {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
return true
|
||||
}
|
||||
|
||||
Expect(c.DoOrCached("k", probe)).To(BeTrue())
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
Expect(c.DoOrCached("k", probe)).To(BeTrue())
|
||||
|
||||
Expect(atomic.LoadInt32(&calls)).To(Equal(int32(2)))
|
||||
})
|
||||
|
||||
It("does not cache failed probes — next call re-probes", func() {
|
||||
c := newProbeCache(time.Minute)
|
||||
var calls int32
|
||||
var result atomic.Bool
|
||||
probe := func() bool {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
return result.Load()
|
||||
}
|
||||
|
||||
// First probe fails — must NOT be cached.
|
||||
result.Store(false)
|
||||
Expect(c.DoOrCached("k", probe)).To(BeFalse())
|
||||
Expect(c.IsFresh("k")).To(BeFalse())
|
||||
|
||||
// Recover: second probe succeeds and is cached.
|
||||
result.Store(true)
|
||||
Expect(c.DoOrCached("k", probe)).To(BeTrue())
|
||||
Expect(c.IsFresh("k")).To(BeTrue())
|
||||
|
||||
// Third call short-circuits on the fresh entry.
|
||||
Expect(c.DoOrCached("k", probe)).To(BeTrue())
|
||||
Expect(atomic.LoadInt32(&calls)).To(Equal(int32(2)))
|
||||
})
|
||||
|
||||
It("coalesces concurrent probes via singleflight", func() {
|
||||
// Models the "6 chat completions arrive simultaneously for a
|
||||
// not-yet-cached backend" scenario. Without singleflight every caller
|
||||
// would dial the backend, defeating the purpose of the cache.
|
||||
c := newProbeCache(time.Minute)
|
||||
var calls int32
|
||||
start := make(chan struct{})
|
||||
probe := func() bool {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
// Stall briefly so the test reliably has all goroutines parked
|
||||
// inside flight.Do at the same time.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
return true
|
||||
}
|
||||
|
||||
const N = 8
|
||||
var wg sync.WaitGroup
|
||||
results := make([]bool, N)
|
||||
for i := 0; i < N; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
<-start
|
||||
results[i] = c.DoOrCached("k", probe)
|
||||
}(i)
|
||||
}
|
||||
|
||||
close(start)
|
||||
wg.Wait()
|
||||
|
||||
Expect(atomic.LoadInt32(&calls)).To(Equal(int32(1)),
|
||||
"singleflight must collapse %d concurrent probes into one", N)
|
||||
for i, got := range results {
|
||||
Expect(got).To(BeTrue(), "goroutine %d saw a different result", i)
|
||||
}
|
||||
})
|
||||
|
||||
It("treats different keys independently", func() {
|
||||
c := newProbeCache(time.Minute)
|
||||
var aCalls, bCalls int32
|
||||
Expect(c.DoOrCached("a", func() bool { atomic.AddInt32(&aCalls, 1); return true })).To(BeTrue())
|
||||
Expect(c.DoOrCached("b", func() bool { atomic.AddInt32(&bCalls, 1); return true })).To(BeTrue())
|
||||
Expect(c.DoOrCached("a", func() bool { atomic.AddInt32(&aCalls, 1); return true })).To(BeTrue())
|
||||
|
||||
Expect(atomic.LoadInt32(&aCalls)).To(Equal(int32(1)))
|
||||
Expect(atomic.LoadInt32(&bCalls)).To(Equal(int32(1)))
|
||||
})
|
||||
|
||||
It("disables caching when TTL is zero", func() {
|
||||
c := newProbeCache(0)
|
||||
var calls int32
|
||||
probe := func() bool {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
return true
|
||||
}
|
||||
|
||||
Expect(c.DoOrCached("k", probe)).To(BeTrue())
|
||||
Expect(c.DoOrCached("k", probe)).To(BeTrue())
|
||||
Expect(c.DoOrCached("k", probe)).To(BeTrue())
|
||||
|
||||
Expect(atomic.LoadInt32(&calls)).To(Equal(int32(3)))
|
||||
})
|
||||
|
||||
It("Invalidate forces the next call to re-probe", func() {
|
||||
c := newProbeCache(time.Hour)
|
||||
var calls int32
|
||||
probe := func() bool {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
return true
|
||||
}
|
||||
Expect(c.DoOrCached("k", probe)).To(BeTrue())
|
||||
c.Invalidate("k")
|
||||
Expect(c.DoOrCached("k", probe)).To(BeTrue())
|
||||
Expect(atomic.LoadInt32(&calls)).To(Equal(int32(2)))
|
||||
})
|
||||
})
|
||||
@@ -68,9 +68,9 @@ type ModelScheduler interface {
|
||||
|
||||
// ReplicaReconcilerOptions holds configuration for creating a ReplicaReconciler.
|
||||
type ReplicaReconcilerOptions struct {
|
||||
Registry *NodeRegistry
|
||||
Registry *NodeRegistry
|
||||
Scheduler ModelScheduler
|
||||
Unloader NodeCommandSender
|
||||
Unloader NodeCommandSender
|
||||
// Adapter is the NATS sender used to retry pending backend ops. When nil,
|
||||
// the state-reconciler pending-drain pass is a no-op (single-node mode).
|
||||
Adapter *RemoteUnloaderAdapter
|
||||
@@ -78,7 +78,7 @@ type ReplicaReconcilerOptions struct {
|
||||
// addresses. Matches the worker's token so HealthCheck auth succeeds.
|
||||
RegistrationToken string
|
||||
// Prober overrides the default gRPC health probe (used by tests).
|
||||
Prober ModelProber
|
||||
Prober ModelProber
|
||||
DB *gorm.DB
|
||||
Interval time.Duration // default 30s
|
||||
ScaleDownDelay time.Duration // default 5m
|
||||
@@ -191,7 +191,7 @@ func (rc *ReplicaReconciler) drainPendingBackendOps(ctx context.Context) {
|
||||
// Pending-op drain for admin install — not a per-replica load.
|
||||
// Replica 0 is the conventional admin slot. Install is idempotent:
|
||||
// the worker short-circuits if the backend is already running.
|
||||
reply, err := rc.adapter.InstallBackend(op.NodeID, op.Backend, "", string(op.Galleries), "", "", "", 0)
|
||||
reply, err := rc.adapter.InstallBackend(op.NodeID, op.Backend, "", string(op.Galleries), "", "", "", 0, "", nil)
|
||||
if err != nil {
|
||||
applyErr = err
|
||||
} else if !reply.Success {
|
||||
|
||||
@@ -17,24 +17,24 @@ import (
|
||||
// Workers are generic — they don't have a fixed backend type.
|
||||
// The SmartRouter dynamically installs backends via NATS backend.install events.
|
||||
type BackendNode struct {
|
||||
ID string `gorm:"primaryKey;size:36" json:"id"`
|
||||
Name string `gorm:"uniqueIndex;size:255" json:"name"`
|
||||
NodeType string `gorm:"size:32;default:backend" json:"node_type"` // backend, agent
|
||||
Address string `gorm:"size:255" json:"address"` // host:port for gRPC
|
||||
HTTPAddress string `gorm:"size:255" json:"http_address"` // host:port for HTTP file transfer
|
||||
Status string `gorm:"size:32;default:registering" json:"status"` // registering, healthy, unhealthy, draining, pending
|
||||
TokenHash string `gorm:"size:64" json:"-"` // SHA-256 of registration token
|
||||
TotalVRAM uint64 `gorm:"column:total_vram" json:"total_vram"` // Total GPU VRAM in bytes
|
||||
AvailableVRAM uint64 `gorm:"column:available_vram" json:"available_vram"` // Available GPU VRAM in bytes
|
||||
ID string `gorm:"primaryKey;size:36" json:"id"`
|
||||
Name string `gorm:"uniqueIndex;size:255" json:"name"`
|
||||
NodeType string `gorm:"size:32;default:backend" json:"node_type"` // backend, agent
|
||||
Address string `gorm:"size:255" json:"address"` // host:port for gRPC
|
||||
HTTPAddress string `gorm:"size:255" json:"http_address"` // host:port for HTTP file transfer
|
||||
Status string `gorm:"size:32;default:registering" json:"status"` // registering, healthy, unhealthy, draining, pending
|
||||
TokenHash string `gorm:"size:64" json:"-"` // SHA-256 of registration token
|
||||
TotalVRAM uint64 `gorm:"column:total_vram" json:"total_vram"` // Total GPU VRAM in bytes
|
||||
AvailableVRAM uint64 `gorm:"column:available_vram" json:"available_vram"` // Available GPU VRAM in bytes
|
||||
// ReservedVRAM is a soft, in-tick reservation deducted by the scheduler when
|
||||
// it picks this node to load a model. Workers reset it back to 0 on each
|
||||
// heartbeat (the worker is the source of truth for actual free VRAM); the
|
||||
// reservation is only here to keep two scheduling decisions within the
|
||||
// same heartbeat window from over-committing the same node.
|
||||
ReservedVRAM uint64 `gorm:"column:reserved_vram;default:0" json:"reserved_vram"`
|
||||
TotalRAM uint64 `gorm:"column:total_ram" json:"total_ram"` // Total system RAM in bytes (fallback when no GPU)
|
||||
AvailableRAM uint64 `gorm:"column:available_ram" json:"available_ram"` // Available system RAM in bytes
|
||||
GPUVendor string `gorm:"column:gpu_vendor;size:32" json:"gpu_vendor"` // nvidia, amd, intel, vulkan, unknown
|
||||
ReservedVRAM uint64 `gorm:"column:reserved_vram;default:0" json:"reserved_vram"`
|
||||
TotalRAM uint64 `gorm:"column:total_ram" json:"total_ram"` // Total system RAM in bytes (fallback when no GPU)
|
||||
AvailableRAM uint64 `gorm:"column:available_ram" json:"available_ram"` // Available system RAM in bytes
|
||||
GPUVendor string `gorm:"column:gpu_vendor;size:32" json:"gpu_vendor"` // nvidia, amd, intel, vulkan, unknown
|
||||
// MaxReplicasPerModel caps how many replicas of any one model can run on
|
||||
// this node concurrently. Default 1 preserves the historical "one
|
||||
// (node, model)" assumption; set higher (via worker --max-replicas-per-model)
|
||||
@@ -44,12 +44,12 @@ type BackendNode struct {
|
||||
// admin override. When true, the worker's CLI value is ignored on
|
||||
// re-registration so the override survives worker restarts. Cleared
|
||||
// by an explicit "reset to worker default" action.
|
||||
MaxReplicasPerModelManuallySet bool `gorm:"column:max_replicas_per_model_manually_set;default:false" json:"max_replicas_per_model_manually_set"`
|
||||
APIKeyID string `gorm:"size:36" json:"-"` // auto-provisioned API key ID (for cleanup)
|
||||
AuthUserID string `gorm:"size:36" json:"-"` // auto-provisioned user ID (for cleanup)
|
||||
LastHeartbeat time.Time `gorm:"column:last_heartbeat" json:"last_heartbeat"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
MaxReplicasPerModelManuallySet bool `gorm:"column:max_replicas_per_model_manually_set;default:false" json:"max_replicas_per_model_manually_set"`
|
||||
APIKeyID string `gorm:"size:36" json:"-"` // auto-provisioned API key ID (for cleanup)
|
||||
AuthUserID string `gorm:"size:36" json:"-"` // auto-provisioned user ID (for cleanup)
|
||||
LastHeartbeat time.Time `gorm:"column:last_heartbeat" json:"last_heartbeat"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -79,17 +79,17 @@ const (
|
||||
// gRPC Address (each replica is a separate worker process on its own port),
|
||||
// and its own InFlight counter.
|
||||
type NodeModel struct {
|
||||
ID string `gorm:"primaryKey;size:36" json:"id"`
|
||||
NodeID string `gorm:"index;size:36" json:"node_id"`
|
||||
ModelName string `gorm:"index;size:255" json:"model_name"`
|
||||
ReplicaIndex int `gorm:"column:replica_index;default:0;index" json:"replica_index"`
|
||||
Address string `gorm:"size:255" json:"address"` // gRPC address for this replica's backend process
|
||||
State string `gorm:"size:32;default:idle" json:"state"` // loading, loaded, unloading, idle
|
||||
InFlight int `json:"in_flight"` // number of active requests on this replica
|
||||
LastUsed time.Time `json:"last_used"`
|
||||
LoadingBy string `gorm:"size:36" json:"loading_by,omitempty"` // frontend ID that triggered loading
|
||||
BackendType string `gorm:"size:128" json:"backend_type,omitempty"` // e.g. "llama-cpp"; used by reconciler to replicate loads
|
||||
ModelOptsBlob []byte `gorm:"type:bytea" json:"-"` // serialized pb.ModelOptions for replica scale-ups
|
||||
ID string `gorm:"primaryKey;size:36" json:"id"`
|
||||
NodeID string `gorm:"index;size:36" json:"node_id"`
|
||||
ModelName string `gorm:"index;size:255" json:"model_name"`
|
||||
ReplicaIndex int `gorm:"column:replica_index;default:0;index" json:"replica_index"`
|
||||
Address string `gorm:"size:255" json:"address"` // gRPC address for this replica's backend process
|
||||
State string `gorm:"size:32;default:idle" json:"state"` // loading, loaded, unloading, idle
|
||||
InFlight int `json:"in_flight"` // number of active requests on this replica
|
||||
LastUsed time.Time `json:"last_used"`
|
||||
LoadingBy string `gorm:"size:36" json:"loading_by,omitempty"` // frontend ID that triggered loading
|
||||
BackendType string `gorm:"size:128" json:"backend_type,omitempty"` // e.g. "llama-cpp"; used by reconciler to replicate loads
|
||||
ModelOptsBlob []byte `gorm:"type:bytea" json:"-"` // serialized pb.ModelOptions for replica scale-ups
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
@@ -668,10 +668,21 @@ func (r *NodeRegistry) FindNodesWithModel(ctx context.Context, modelName string)
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
// FindAndLockNodeWithModel atomically finds the least-loaded node with the given
|
||||
// model loaded and increments its in-flight counter within a single transaction.
|
||||
// The SELECT FOR UPDATE row lock prevents concurrent eviction from removing the
|
||||
// NodeModel row between the find and increment operations.
|
||||
// FindAndLockNodeWithModel atomically finds the best loaded replica of the
|
||||
// given model and increments its in-flight counter within a single
|
||||
// transaction. The SELECT FOR UPDATE row lock prevents concurrent eviction
|
||||
// from removing the NodeModel row between the find and increment operations,
|
||||
// and serializes contending routers so concurrent picks distribute across
|
||||
// replicas instead of all landing on the same row.
|
||||
//
|
||||
// **Policy:** the SQL ORDER BY below MUST mirror PickBestReplica
|
||||
// (replicapicker.go). PickBestReplica is the canonical Go implementation of
|
||||
// the same rule — the per-frontend rotating-replica cache (TODO, see
|
||||
// pkg/model/loader.go) will eventually use it against in-memory snapshots so
|
||||
// hot inference requests don't pay this DB round-trip. If you change the
|
||||
// ordering here, change both sides; the TestFindAndLockNodeWithModelMirror
|
||||
// spec ("agrees with PickBestReplica on a seeded dataset") fails fast if they
|
||||
// drift.
|
||||
//
|
||||
// When candidateNodeIDs is non-empty, only nodes in that set are considered.
|
||||
// Pass nil (or empty) to consider any node. This lets callers pre-filter by
|
||||
@@ -683,16 +694,16 @@ func (r *NodeRegistry) FindAndLockNodeWithModel(ctx context.Context, modelName s
|
||||
var node BackendNode
|
||||
|
||||
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
// Order by in_flight ASC (least busy replica), then by last_used ASC
|
||||
// (round-robin between equally-loaded replicas — oldest used wins, and
|
||||
// every successful pick refreshes last_used below, so the "oldest" naturally
|
||||
// rotates through the candidate set). available_vram DESC is the final
|
||||
// tiebreaker for cold starts where last_used is identical.
|
||||
// Mirror of PickBestReplica's policy (see replicapicker.go):
|
||||
// 1. in_flight ASC — least busy replica.
|
||||
// 2. last_used ASC — round-robin between equally-loaded replicas.
|
||||
// Every successful pick refreshes last_used below, so the
|
||||
// "oldest" tier naturally rotates through the candidate set.
|
||||
// Without this tier, in_flight ties collapsed to "fattest GPU
|
||||
// wins every time" and one node took nearly all the load.
|
||||
// 3. available_vram DESC — final tiebreaker for cold starts where
|
||||
// last_used is identical across replicas.
|
||||
//
|
||||
// Without the last_used tier, a tie on in_flight (the common case at low
|
||||
// to moderate concurrency where requests don't overlap) collapses to
|
||||
// "biggest GPU wins every time" and one node ends up taking nearly all
|
||||
// the load while replicas on other nodes sit idle.
|
||||
// Filter on backend_nodes.status = healthy in the inner JOIN itself,
|
||||
// not only in the later node-fetch step. The previous version picked
|
||||
// a (node_id, replica) pair purely on node_models state, then bailed
|
||||
@@ -1287,7 +1298,7 @@ func (r *NodeRegistry) UpdateMaxReplicasPerModel(ctx context.Context, nodeID str
|
||||
res := r.db.WithContext(ctx).Model(&BackendNode{}).
|
||||
Where("id = ?", nodeID).
|
||||
Updates(map[string]any{
|
||||
ColMaxReplicasPerModel: n,
|
||||
ColMaxReplicasPerModel: n,
|
||||
"max_replicas_per_model_manually_set": true,
|
||||
})
|
||||
if res.Error != nil {
|
||||
@@ -1460,7 +1471,7 @@ func (r *NodeRegistry) UpsertPendingBackendOp(ctx context.Context, nodeID, backe
|
||||
NextRetryAt: time.Now(),
|
||||
}
|
||||
return r.db.WithContext(ctx).Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "node_id"}, {Name: "backend"}, {Name: "op"}},
|
||||
Columns: []clause.Column{{Name: "node_id"}, {Name: "backend"}, {Name: "op"}},
|
||||
DoUpdates: clause.AssignmentColumns([]string{"galleries", "next_retry_at"}),
|
||||
}).Create(&row).Error
|
||||
}
|
||||
@@ -1515,6 +1526,27 @@ func (r *NodeRegistry) RecordPendingBackendOpFailure(ctx context.Context, id uin
|
||||
})
|
||||
}
|
||||
|
||||
// RecordPendingBackendOpInFlight is the "soft failure" cousin of
|
||||
// RecordPendingBackendOpFailure. Used when a NATS install round-trip timed
|
||||
// out but the worker is still installing in the background. Stores the
|
||||
// message in LastError and pushes NextRetryAt out by `retryDelay` (typically
|
||||
// the install timeout) so the reconciler does not immediately re-fire
|
||||
// another install while the worker is still busy.
|
||||
//
|
||||
// Attempts is intentionally NOT incremented: an in-flight timeout is not a
|
||||
// failed attempt, it is a still-in-progress one. Incrementing it would let a
|
||||
// genuinely-progressing slow install (e.g. 30 GB CUDA image on Wi-Fi) trip
|
||||
// the maxPendingBackendOpAttempts cap in the reconciler and dead-letter the
|
||||
// row while the worker is still legitimately working.
|
||||
func (r *NodeRegistry) RecordPendingBackendOpInFlight(ctx context.Context, id uint, lastError string, retryDelay time.Duration) error {
|
||||
return r.db.WithContext(ctx).Model(&PendingBackendOp{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]any{
|
||||
"last_error": lastError,
|
||||
"next_retry_at": time.Now().Add(retryDelay),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// backoffForAttempt is exponential from 30s doubling up to a 15m cap. The
|
||||
// reconciler tick is 30s so anything shorter would just re-fire immediately.
|
||||
func backoffForAttempt(attempts int) time.Duration {
|
||||
|
||||
@@ -3,6 +3,7 @@ package nodes
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
@@ -357,6 +358,79 @@ var _ = Describe("NodeRegistry", func() {
|
||||
_, _, err := registry.FindAndLockNodeWithModel(context.Background(), "no-match-model", []string{emptyIncluded.ID})
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("agrees with PickBestReplica on a seeded dataset (policy mirror)", func() {
|
||||
// Guard against drift between the SQL ORDER BY in
|
||||
// FindAndLockNodeWithModel and the canonical Go implementation in
|
||||
// PickBestReplica. The two layers will eventually diverge in
|
||||
// caller (DB-backed atomic pick vs in-memory snapshot pick for the
|
||||
// per-frontend rotating cache), but the policy itself must stay
|
||||
// the single source of truth. If this test fails, update *both*
|
||||
// sides — never just one.
|
||||
//
|
||||
// Scenario exercises all three tiers:
|
||||
// - "loser-busy" has the most VRAM but in_flight=2 — loses tier 1.
|
||||
// - "loser-recent" ties at in_flight=0 but its last_used is the
|
||||
// newest of the in_flight=0 group — loses tier 2.
|
||||
// - "winner-mid" and "winner-fat" both tie at in_flight=0 and
|
||||
// share the oldest last_used — tier 3 decides: fattest wins.
|
||||
loserBusy := makeNode("mirror-loser-busy", "10.0.0.70:50051", 32_000_000_000)
|
||||
loserRecent := makeNode("mirror-loser-recent", "10.0.0.71:50051", 8_000_000_000)
|
||||
winnerMid := makeNode("mirror-winner-mid", "10.0.0.72:50051", 16_000_000_000)
|
||||
winnerFat := makeNode("mirror-winner-fat", "10.0.0.73:50051", 24_000_000_000)
|
||||
for _, n := range []*BackendNode{loserBusy, loserRecent, winnerMid, winnerFat} {
|
||||
Expect(registry.Register(context.Background(), n, true)).To(Succeed())
|
||||
Expect(registry.SetNodeModel(context.Background(), n.ID, "mirror-model", 0, "loaded", "", 0)).To(Succeed())
|
||||
}
|
||||
|
||||
// Force in_flight=2 on the "busy" node so tier 1 disqualifies it.
|
||||
Expect(registry.IncrementInFlight(context.Background(), loserBusy.ID, "mirror-model", 0)).To(Succeed())
|
||||
Expect(registry.IncrementInFlight(context.Background(), loserBusy.ID, "mirror-model", 0)).To(Succeed())
|
||||
|
||||
// Slam last_used to known values so the test is deterministic
|
||||
// regardless of clock resolution between the helpers above.
|
||||
base := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
set := func(id string, t time.Time) {
|
||||
Expect(db.Model(&NodeModel{}).
|
||||
Where("node_id = ? AND model_name = ?", id, "mirror-model").
|
||||
Update("last_used", t).Error).To(Succeed())
|
||||
}
|
||||
set(loserBusy.ID, base) // newest doesn't matter — already disqualified by tier 1
|
||||
set(loserRecent.ID, base.Add(time.Hour))
|
||||
set(winnerMid.ID, base)
|
||||
set(winnerFat.ID, base)
|
||||
|
||||
// Pull the same dataset both pickers will operate on. The Go
|
||||
// picker is a faithful representation of the policy; the SQL is
|
||||
// the production path.
|
||||
var rows []NodeModel
|
||||
Expect(db.Where("model_name = ? AND state = ?", "mirror-model", "loaded").
|
||||
Find(&rows).Error).To(Succeed())
|
||||
candidates := make([]ReplicaCandidate, 0, len(rows))
|
||||
for _, nm := range rows {
|
||||
var bn BackendNode
|
||||
Expect(db.First(&bn, "id = ? AND status = ?", nm.NodeID, StatusHealthy).Error).To(Succeed())
|
||||
candidates = append(candidates, ReplicaCandidate{
|
||||
NodeID: nm.NodeID,
|
||||
Address: bn.Address,
|
||||
ReplicaIndex: nm.ReplicaIndex,
|
||||
InFlight: nm.InFlight,
|
||||
LastUsed: nm.LastUsed,
|
||||
AvailableVRAM: bn.AvailableVRAM,
|
||||
})
|
||||
}
|
||||
goPick := PickBestReplica(candidates)
|
||||
Expect(goPick).ToNot(BeNil())
|
||||
|
||||
sqlNode, _, err := registry.FindAndLockNodeWithModel(context.Background(), "mirror-model", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(sqlNode.ID).To(Equal(goPick.NodeID),
|
||||
"SQL ORDER BY picked %s; PickBestReplica picked %s — policy has drifted",
|
||||
sqlNode.ID, goPick.NodeID)
|
||||
// Sanity check: the policy says winner-fat wins on tier 3.
|
||||
Expect(goPick.NodeID).To(Equal(winnerFat.ID))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("MarkHealthy and MarkUnhealthy round-trip", func() {
|
||||
|
||||
69
core/services/nodes/replicapicker.go
Normal file
69
core/services/nodes/replicapicker.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package nodes
|
||||
|
||||
import "time"
|
||||
|
||||
// ReplicaCandidate is the minimum view of a loaded model replica needed to
|
||||
// apply the routing policy. It is intentionally decoupled from the gorm models
|
||||
// (BackendNode, NodeModel) so the same picker can run against fresh DB rows
|
||||
// (SmartRouter.Route → FindAndLockNodeWithModel) and against an in-memory
|
||||
// snapshot (the per-frontend rotating cache flagged in pkg/model — see TODO
|
||||
// below).
|
||||
type ReplicaCandidate struct {
|
||||
NodeID string
|
||||
Address string
|
||||
ReplicaIndex int
|
||||
InFlight int
|
||||
LastUsed time.Time
|
||||
AvailableVRAM uint64
|
||||
}
|
||||
|
||||
// PickBestReplica is the single source of truth for which loaded replica of a
|
||||
// model serves the next request.
|
||||
//
|
||||
// Policy (ordered tiers, first non-tie wins):
|
||||
// 1. Least in-flight wins — primary load-balancing signal.
|
||||
// 2. Oldest last_used wins — round-robin between equally-loaded replicas.
|
||||
// Every successful pick refreshes last_used (in FindAndLockNodeWithModel's
|
||||
// transaction and in TouchNodeModel on cache hits), so the "oldest" tier
|
||||
// naturally rotates through the candidate set without a separate cursor.
|
||||
// 3. Largest available_vram wins — cold-start tiebreaker for replicas that
|
||||
// have never been picked (identical last_used).
|
||||
//
|
||||
// Two callers must agree on this policy:
|
||||
//
|
||||
// - SmartRouter.Route, via the SQL ORDER BY in FindAndLockNodeWithModel
|
||||
// (registry.go). That query MUST mirror this function — TestPickerSQLMirror
|
||||
// asserts both sides agree on a representative dataset.
|
||||
//
|
||||
// - The per-frontend rotating-replica cache (NOT YET IMPLEMENTED — see
|
||||
// pkg/model/loader.go and pkg/model/initializers.go for the integration
|
||||
// point). When that cache lands, it will call PickBestReplica against an
|
||||
// in-memory snapshot using locally-tracked in-flight counters and skip the
|
||||
// per-request DB round-trip.
|
||||
//
|
||||
// Returns nil when the candidate list is empty. Does not allocate.
|
||||
func PickBestReplica(candidates []ReplicaCandidate) *ReplicaCandidate {
|
||||
if len(candidates) == 0 {
|
||||
return nil
|
||||
}
|
||||
best := &candidates[0]
|
||||
for i := 1; i < len(candidates); i++ {
|
||||
c := &candidates[i]
|
||||
if betterReplica(c, best) {
|
||||
best = c
|
||||
}
|
||||
}
|
||||
return best
|
||||
}
|
||||
|
||||
// betterReplica reports whether candidate a is preferred over candidate b
|
||||
// under the policy documented on PickBestReplica.
|
||||
func betterReplica(a, b *ReplicaCandidate) bool {
|
||||
if a.InFlight != b.InFlight {
|
||||
return a.InFlight < b.InFlight
|
||||
}
|
||||
if !a.LastUsed.Equal(b.LastUsed) {
|
||||
return a.LastUsed.Before(b.LastUsed)
|
||||
}
|
||||
return a.AvailableVRAM > b.AvailableVRAM
|
||||
}
|
||||
81
core/services/nodes/replicapicker_test.go
Normal file
81
core/services/nodes/replicapicker_test.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("PickBestReplica", func() {
|
||||
// Use a single reference time so every test that wants identical
|
||||
// last_used can share it without relying on time.Now() interleavings.
|
||||
ref := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
It("returns nil for an empty candidate list", func() {
|
||||
Expect(PickBestReplica(nil)).To(BeNil())
|
||||
Expect(PickBestReplica([]ReplicaCandidate{})).To(BeNil())
|
||||
})
|
||||
|
||||
It("returns the only candidate when there is just one", func() {
|
||||
only := ReplicaCandidate{NodeID: "only", InFlight: 99, LastUsed: ref, AvailableVRAM: 1}
|
||||
pick := PickBestReplica([]ReplicaCandidate{only})
|
||||
Expect(pick).ToNot(BeNil())
|
||||
Expect(pick.NodeID).To(Equal("only"))
|
||||
})
|
||||
|
||||
It("prefers the replica with the lowest in_flight", func() {
|
||||
// Without the in-flight tier, the larger-VRAM node would win.
|
||||
cs := []ReplicaCandidate{
|
||||
{NodeID: "busy-big", InFlight: 3, LastUsed: ref, AvailableVRAM: 24_000_000_000},
|
||||
{NodeID: "idle-small", InFlight: 0, LastUsed: ref, AvailableVRAM: 8_000_000_000},
|
||||
{NodeID: "mid", InFlight: 1, LastUsed: ref, AvailableVRAM: 16_000_000_000},
|
||||
}
|
||||
Expect(PickBestReplica(cs).NodeID).To(Equal("idle-small"))
|
||||
})
|
||||
|
||||
It("uses oldest last_used as the tiebreaker when in_flight ties", func() {
|
||||
// All three tied on in_flight=0. Without last_used, available_vram
|
||||
// would pin every pick to the fattest node — the exact bug
|
||||
// fix(distributed): round-robin replicas of the same model addressed.
|
||||
cs := []ReplicaCandidate{
|
||||
{NodeID: "fat-recent", InFlight: 0, LastUsed: ref.Add(2 * time.Second), AvailableVRAM: 24_000_000_000},
|
||||
{NodeID: "small-oldest", InFlight: 0, LastUsed: ref, AvailableVRAM: 8_000_000_000},
|
||||
{NodeID: "mid-middle", InFlight: 0, LastUsed: ref.Add(1 * time.Second), AvailableVRAM: 16_000_000_000},
|
||||
}
|
||||
Expect(PickBestReplica(cs).NodeID).To(Equal("small-oldest"))
|
||||
})
|
||||
|
||||
It("uses largest available_vram as the final tiebreaker", func() {
|
||||
// in_flight tied AND last_used tied — pick the largest GPU.
|
||||
cs := []ReplicaCandidate{
|
||||
{NodeID: "small", InFlight: 0, LastUsed: ref, AvailableVRAM: 8_000_000_000},
|
||||
{NodeID: "fat", InFlight: 0, LastUsed: ref, AvailableVRAM: 24_000_000_000},
|
||||
{NodeID: "mid", InFlight: 0, LastUsed: ref, AvailableVRAM: 16_000_000_000},
|
||||
}
|
||||
Expect(PickBestReplica(cs).NodeID).To(Equal("fat"))
|
||||
})
|
||||
|
||||
It("respects tier precedence: in_flight beats last_used beats available_vram", func() {
|
||||
// "fat-busy-oldest" wins on neither of the first two tiers; the
|
||||
// "small-idle-recent" replica is busy=0 and should beat it despite
|
||||
// being newer and smaller.
|
||||
cs := []ReplicaCandidate{
|
||||
{NodeID: "fat-busy-oldest", InFlight: 5, LastUsed: ref, AvailableVRAM: 80_000_000_000},
|
||||
{NodeID: "small-idle-recent", InFlight: 0, LastUsed: ref.Add(time.Hour), AvailableVRAM: 4_000_000_000},
|
||||
}
|
||||
Expect(PickBestReplica(cs).NodeID).To(Equal("small-idle-recent"))
|
||||
})
|
||||
|
||||
It("is stable: returns the first candidate when every field ties", func() {
|
||||
// betterReplica returns false on a full tie, so the leading element
|
||||
// remains best. Callers shouldn't depend on this for correctness,
|
||||
// but pinning the behavior here catches accidental reorderings.
|
||||
cs := []ReplicaCandidate{
|
||||
{NodeID: "first", InFlight: 0, LastUsed: ref, AvailableVRAM: 8_000_000_000},
|
||||
{NodeID: "second", InFlight: 0, LastUsed: ref, AvailableVRAM: 8_000_000_000},
|
||||
{NodeID: "third", InFlight: 0, LastUsed: ref, AvailableVRAM: 8_000_000_000},
|
||||
}
|
||||
Expect(PickBestReplica(cs).NodeID).To(Equal("first"))
|
||||
})
|
||||
})
|
||||
@@ -61,8 +61,19 @@ type SmartRouter struct {
|
||||
// completions for one not-yet-loaded model produce ONE round-trip, not
|
||||
// six. Avoids amplifying head-of-line blocking on the worker side.
|
||||
installFlight singleflight.Group
|
||||
// probeCache memoizes recent successful gRPC HealthCheck results so
|
||||
// per-request routing doesn't stall behind a busy backend's serialized
|
||||
// HealthCheck/Predict. See probe_cache.go for the rationale.
|
||||
probeCache *probeCache
|
||||
}
|
||||
|
||||
// probeCacheTTL is how long a successful gRPC HealthCheck on a backend is
|
||||
// trusted before the next request re-probes. Matches healthCheckTTL in
|
||||
// pkg/model/model.go so the single-process and distributed paths share a
|
||||
// staleness budget. The background HealthMonitor still reaps dead backends
|
||||
// independently within ~45s (see perModelMissThreshold).
|
||||
const probeCacheTTL = 30 * time.Second
|
||||
|
||||
// NewSmartRouter creates a new SmartRouter backed by the given ModelRouter.
|
||||
// All optional dependencies are passed via SmartRouterOptions to avoid post-creation races.
|
||||
func NewSmartRouter(registry ModelRouter, opts SmartRouterOptions) *SmartRouter {
|
||||
@@ -79,6 +90,7 @@ func NewSmartRouter(registry ModelRouter, opts SmartRouterOptions) *SmartRouter
|
||||
db: opts.DB,
|
||||
stagingTracker: NewStagingTracker(),
|
||||
conflictResolver: opts.ConflictResolver,
|
||||
probeCache: newProbeCache(probeCacheTTL),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -688,7 +700,7 @@ func (r *SmartRouter) installBackendOnNode(ctx context.Context, node *BackendNod
|
||||
|
||||
key := fmt.Sprintf("%s|%s|%s|%d", node.ID, backendType, modelID, replicaIndex)
|
||||
v, err, _ := r.installFlight.Do(key, func() (any, error) {
|
||||
reply, err := r.unloader.InstallBackend(node.ID, backendType, modelID, r.galleriesJSON, "", "", "", replicaIndex)
|
||||
reply, err := r.unloader.InstallBackend(node.ID, backendType, modelID, r.galleriesJSON, "", "", "", replicaIndex, "", nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -961,14 +973,26 @@ func (r *SmartRouter) stageGenericOptions(ctx context.Context, node *BackendNode
|
||||
}
|
||||
|
||||
// probeHealth checks whether a backend process on the given node/addr is alive
|
||||
// via a gRPC health check with a 2-second timeout. The client is closed after the check.
|
||||
// via a gRPC health check with a 2-second timeout. The client is closed after
|
||||
// the check.
|
||||
//
|
||||
// The result is memoized in r.probeCache for probeCacheTTL. With per-request
|
||||
// routing every inference call lands here, and unbounded re-probing can stall
|
||||
// behind a busy backend that serializes HealthCheck against active Predict.
|
||||
// Concurrent probes for the same (node, addr) coalesce via singleflight so a
|
||||
// burst of N requests for a cold cache costs at most one round-trip, not N.
|
||||
// Failed probes invalidate the cache so the staleness recovery path
|
||||
// (DecrementInFlight + RemoveNodeModel) still triggers on the next request.
|
||||
func (r *SmartRouter) probeHealth(ctx context.Context, node *BackendNode, addr string) bool {
|
||||
client := r.buildClientForAddr(node, addr, false)
|
||||
defer closeClient(client)
|
||||
checkCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
||||
defer cancel()
|
||||
ok, _ := client.HealthCheck(checkCtx)
|
||||
return ok
|
||||
key := node.ID + "|" + addr
|
||||
return r.probeCache.DoOrCached(key, func() bool {
|
||||
client := r.buildClientForAddr(node, addr, false)
|
||||
defer closeClient(client)
|
||||
checkCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
||||
defer cancel()
|
||||
ok, _ := client.HealthCheck(checkCtx)
|
||||
return ok
|
||||
})
|
||||
}
|
||||
|
||||
// closeClient closes a gRPC backend client if it implements io.Closer.
|
||||
|
||||
@@ -330,7 +330,7 @@ type upgradeCall struct {
|
||||
replica int
|
||||
}
|
||||
|
||||
func (f *fakeUnloader) InstallBackend(nodeID, backend, modelID, _, _, _, _ string, replica int) (*messaging.BackendInstallReply, error) {
|
||||
func (f *fakeUnloader) InstallBackend(nodeID, backend, modelID, _, _, _, _ string, replica int, _ string, _ func(messaging.BackendInstallProgressEvent)) (*messaging.BackendInstallReply, error) {
|
||||
// installHook intentionally runs OUTSIDE the mutex: the hook may block
|
||||
// on a channel and we don't want to serialize concurrent callers,
|
||||
// which would defeat the singleflight-overlap test.
|
||||
|
||||
@@ -2,9 +2,15 @@ package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/nats-io/nats.go"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
@@ -28,7 +34,7 @@ type backendStopRequest struct {
|
||||
// nats.ErrNoResponders for old workers that don't subscribe to the new
|
||||
// backend.upgrade subject.
|
||||
type NodeCommandSender interface {
|
||||
InstallBackend(nodeID, backendType, modelID, galleriesJSON, uri, name, alias string, replicaIndex int) (*messaging.BackendInstallReply, error)
|
||||
InstallBackend(nodeID, backendType, modelID, galleriesJSON, uri, name, alias string, replicaIndex int, opID string, onProgress func(messaging.BackendInstallProgressEvent)) (*messaging.BackendInstallReply, error)
|
||||
UpgradeBackend(nodeID, backendType, galleriesJSON, uri, name, alias string, replicaIndex int) (*messaging.BackendUpgradeReply, error)
|
||||
DeleteBackend(nodeID, backendName string) (*messaging.BackendDeleteReply, error)
|
||||
ListBackends(nodeID string) (*messaging.BackendListReply, error)
|
||||
@@ -43,18 +49,33 @@ type NodeCommandSender interface {
|
||||
// This mirrors the local ModelLoader's startProcess()/deleteProcess() but
|
||||
// over NATS for remote nodes.
|
||||
type RemoteUnloaderAdapter struct {
|
||||
registry ModelLocator
|
||||
nats messaging.MessagingClient
|
||||
registry ModelLocator
|
||||
nats messaging.MessagingClient
|
||||
installTimeout time.Duration
|
||||
upgradeTimeout time.Duration
|
||||
}
|
||||
|
||||
// NewRemoteUnloaderAdapter creates a new adapter.
|
||||
func NewRemoteUnloaderAdapter(registry ModelLocator, nats messaging.MessagingClient) *RemoteUnloaderAdapter {
|
||||
// NewRemoteUnloaderAdapter creates a new adapter. installTimeout and
|
||||
// upgradeTimeout govern the NATS request-reply deadlines for backend.install
|
||||
// and backend.upgrade respectively. Use
|
||||
// DistributedConfig.BackendInstallTimeoutOrDefault() /
|
||||
// BackendUpgradeTimeoutOrDefault() at construction.
|
||||
func NewRemoteUnloaderAdapter(registry ModelLocator, nats messaging.MessagingClient, installTimeout, upgradeTimeout time.Duration) *RemoteUnloaderAdapter {
|
||||
return &RemoteUnloaderAdapter{
|
||||
registry: registry,
|
||||
nats: nats,
|
||||
registry: registry,
|
||||
nats: nats,
|
||||
installTimeout: installTimeout,
|
||||
upgradeTimeout: upgradeTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// InstallTimeout returns the configured backend.install round-trip timeout.
|
||||
// Used by DistributedBackendManager to push NextRetryAt out by this duration
|
||||
// when a worker times out replying but is still installing in the background.
|
||||
func (a *RemoteUnloaderAdapter) InstallTimeout() time.Duration {
|
||||
return a.installTimeout
|
||||
}
|
||||
|
||||
// UnloadRemoteModel finds the node(s) hosting the given model and tells them
|
||||
// to stop their backend process via NATS backend.stop event.
|
||||
// The worker process handles: Free() → kill process.
|
||||
@@ -87,18 +108,59 @@ func (a *RemoteUnloaderAdapter) UnloadRemoteModel(modelName string) error {
|
||||
// is on disk, the worker just spawns a process; only a missing binary
|
||||
// triggers a full gallery pull.
|
||||
//
|
||||
// Timeout: 3 minutes. Most calls return in under 2 seconds (process already
|
||||
// running). The 3-minute ceiling covers the cold-binary spawn-after-download
|
||||
// case while still failing fast enough to surface real worker hangs.
|
||||
// Timeout: configured via DistributedConfig.BackendInstallTimeoutOrDefault
|
||||
// (default 15m). Most calls return in under 2 seconds (process already
|
||||
// running). The 15-minute ceiling covers the cold-binary spawn-after-download
|
||||
// case on slow links (Jetson Wi-Fi, multi-GB CUDA images) while still
|
||||
// failing fast enough to surface real worker hangs.
|
||||
//
|
||||
// For force-reinstall (admin-driven Upgrade), use UpgradeBackend instead —
|
||||
// For force-reinstall (admin-driven Upgrade), use UpgradeBackend instead -
|
||||
// it lives on a different NATS subject so it cannot head-of-line-block
|
||||
// routine load traffic on the same worker.
|
||||
func (a *RemoteUnloaderAdapter) InstallBackend(nodeID, backendType, modelID, galleriesJSON, uri, name, alias string, replicaIndex int) (*messaging.BackendInstallReply, error) {
|
||||
func (a *RemoteUnloaderAdapter) InstallBackend(
|
||||
nodeID, backendType, modelID, galleriesJSON, uri, name, alias string,
|
||||
replicaIndex int,
|
||||
opID string,
|
||||
onProgress func(messaging.BackendInstallProgressEvent),
|
||||
) (*messaging.BackendInstallReply, error) {
|
||||
subject := messaging.SubjectNodeBackendInstall(nodeID)
|
||||
xlog.Info("Sending NATS backend.install", "nodeID", nodeID, "backend", backendType, "modelID", modelID, "replica", replicaIndex)
|
||||
xlog.Info("Sending NATS backend.install", "nodeID", nodeID, "backend", backendType, "modelID", modelID, "replica", replicaIndex, "opID", opID)
|
||||
|
||||
return messaging.RequestJSON[messaging.BackendInstallRequest, messaging.BackendInstallReply](a.nats, subject, messaging.BackendInstallRequest{
|
||||
// Subscribe to the per-op progress subject BEFORE publishing the install
|
||||
// request so we don't miss early events. When onProgress is nil OR opID
|
||||
// is empty (the reconciler-driven retry path), skip subscription entirely:
|
||||
// silent installs cost nothing extra.
|
||||
var sub messaging.Subscription
|
||||
if onProgress != nil && opID != "" {
|
||||
progressSubject := messaging.SubjectNodeBackendInstallProgress(nodeID, opID)
|
||||
s, subErr := a.nats.Subscribe(progressSubject, func(raw []byte) {
|
||||
var ev messaging.BackendInstallProgressEvent
|
||||
if err := json.Unmarshal(raw, &ev); err != nil {
|
||||
xlog.Debug("malformed install progress event", "subject", progressSubject, "error", err)
|
||||
return
|
||||
}
|
||||
// Goroutine guard: a slow onProgress callback must not stall
|
||||
// the NATS reader thread.
|
||||
//
|
||||
// NOTE: events spawn one goroutine each, so ordering at the
|
||||
// consumer is best-effort. In practice the worker debounces to
|
||||
// ~250ms which is far larger than goroutine scheduling jitter,
|
||||
// so reordering is rare. The worker's final Flush() event is
|
||||
// intended to win as the terminal tick. A future hardening pass
|
||||
// could add a Seq uint64 field to BackendInstallProgressEvent
|
||||
// and drop stale-by-seq at the bridge if reordering becomes a
|
||||
// real UX issue.
|
||||
go onProgress(ev)
|
||||
})
|
||||
if subErr != nil {
|
||||
xlog.Warn("Failed to subscribe to install progress subject; proceeding without progress streaming",
|
||||
"subject", progressSubject, "error", subErr)
|
||||
} else {
|
||||
sub = s
|
||||
}
|
||||
}
|
||||
|
||||
reply, err := messaging.RequestJSON[messaging.BackendInstallRequest, messaging.BackendInstallReply](a.nats, subject, messaging.BackendInstallRequest{
|
||||
Backend: backendType,
|
||||
ModelID: modelID,
|
||||
BackendGalleries: galleriesJSON,
|
||||
@@ -106,29 +168,46 @@ func (a *RemoteUnloaderAdapter) InstallBackend(nodeID, backendType, modelID, gal
|
||||
Name: name,
|
||||
Alias: alias,
|
||||
ReplicaIndex: int32(replicaIndex),
|
||||
}, 3*time.Minute)
|
||||
OpID: opID,
|
||||
}, a.installTimeout)
|
||||
|
||||
if sub != nil {
|
||||
_ = sub.Unsubscribe()
|
||||
}
|
||||
|
||||
if err != nil && isNATSTimeout(err) {
|
||||
return nil, fmt.Errorf("%w (subject=%s nodeID=%s backend=%s): %v",
|
||||
galleryop.ErrWorkerStillInstalling, subject, nodeID, backendType, err)
|
||||
}
|
||||
return reply, err
|
||||
}
|
||||
|
||||
// UpgradeBackend sends a backend.upgrade request-reply to a worker node.
|
||||
// The worker stops every live process for this backend, force-reinstalls
|
||||
// from the gallery (overwriting the on-disk artifact), and replies. The
|
||||
// next routine InstallBackend call spawns a fresh process with the new
|
||||
// binary — upgrade itself does not start a process.
|
||||
// binary - upgrade itself does not start a process.
|
||||
//
|
||||
// Timeout: 15 minutes. Real-world worst case observed: 8–10 minutes for
|
||||
// large CUDA-l4t backend images on Jetson over WiFi.
|
||||
// Timeout: configured via DistributedConfig.BackendUpgradeTimeoutOrDefault
|
||||
// (default 15m). Real-world worst case observed: 8-10 minutes for large
|
||||
// CUDA-l4t backend images on Jetson over WiFi.
|
||||
func (a *RemoteUnloaderAdapter) UpgradeBackend(nodeID, backendType, galleriesJSON, uri, name, alias string, replicaIndex int) (*messaging.BackendUpgradeReply, error) {
|
||||
subject := messaging.SubjectNodeBackendUpgrade(nodeID)
|
||||
xlog.Info("Sending NATS backend.upgrade", "nodeID", nodeID, "backend", backendType, "replica", replicaIndex)
|
||||
|
||||
return messaging.RequestJSON[messaging.BackendUpgradeRequest, messaging.BackendUpgradeReply](a.nats, subject, messaging.BackendUpgradeRequest{
|
||||
reply, err := messaging.RequestJSON[messaging.BackendUpgradeRequest, messaging.BackendUpgradeReply](a.nats, subject, messaging.BackendUpgradeRequest{
|
||||
Backend: backendType,
|
||||
BackendGalleries: galleriesJSON,
|
||||
URI: uri,
|
||||
Name: name,
|
||||
Alias: alias,
|
||||
ReplicaIndex: int32(replicaIndex),
|
||||
}, 15*time.Minute)
|
||||
}, a.upgradeTimeout)
|
||||
if err != nil && isNATSTimeout(err) {
|
||||
return nil, fmt.Errorf("%w (subject=%s nodeID=%s backend=%s): %v",
|
||||
galleryop.ErrWorkerStillInstalling, subject, nodeID, backendType, err)
|
||||
}
|
||||
return reply, err
|
||||
}
|
||||
|
||||
// installWithForceFallback is the rolling-update fallback used by
|
||||
@@ -141,7 +220,7 @@ func (a *RemoteUnloaderAdapter) installWithForceFallback(nodeID, backendType, ga
|
||||
subject := messaging.SubjectNodeBackendInstall(nodeID)
|
||||
xlog.Warn("Falling back to legacy backend.install Force=true (old worker)", "nodeID", nodeID, "backend", backendType)
|
||||
|
||||
return messaging.RequestJSON[messaging.BackendInstallRequest, messaging.BackendInstallReply](a.nats, subject, messaging.BackendInstallRequest{
|
||||
reply, err := messaging.RequestJSON[messaging.BackendInstallRequest, messaging.BackendInstallReply](a.nats, subject, messaging.BackendInstallRequest{
|
||||
Backend: backendType,
|
||||
BackendGalleries: galleriesJSON,
|
||||
URI: uri,
|
||||
@@ -149,7 +228,12 @@ func (a *RemoteUnloaderAdapter) installWithForceFallback(nodeID, backendType, ga
|
||||
Alias: alias,
|
||||
ReplicaIndex: int32(replicaIndex),
|
||||
Force: true,
|
||||
}, 15*time.Minute)
|
||||
}, a.upgradeTimeout)
|
||||
if err != nil && isNATSTimeout(err) {
|
||||
return nil, fmt.Errorf("%w (subject=%s nodeID=%s backend=%s): %v",
|
||||
galleryop.ErrWorkerStillInstalling, subject, nodeID, backendType, err)
|
||||
}
|
||||
return reply, err
|
||||
}
|
||||
|
||||
// ListBackends queries a worker node for its installed backends via NATS request-reply.
|
||||
@@ -228,3 +312,14 @@ func (a *RemoteUnloaderAdapter) StopNode(nodeID string) error {
|
||||
subject := messaging.SubjectNodeStop(nodeID)
|
||||
return a.nats.Publish(subject, nil)
|
||||
}
|
||||
|
||||
// isNATSTimeout returns true if err looks like a NATS request-reply timeout.
|
||||
// nats.ErrTimeout is the canonical sentinel; context.DeadlineExceeded can
|
||||
// also surface depending on the client's path; we accept both, plus a
|
||||
// string-match fallback for clients that return a bare error.
|
||||
func isNATSTimeout(err error) bool {
|
||||
if errors.Is(err, nats.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return true
|
||||
}
|
||||
return err != nil && strings.Contains(err.Error(), "nats: timeout")
|
||||
}
|
||||
|
||||
@@ -3,13 +3,16 @@ package nodes
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/nats-io/nats.go"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
)
|
||||
|
||||
@@ -60,6 +63,7 @@ type publishCall struct {
|
||||
type requestCall struct {
|
||||
Subject string
|
||||
Data []byte
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
func (f *fakeMessagingClient) Publish(subject string, data any) error {
|
||||
@@ -93,10 +97,10 @@ func (f *fakeMessagingClient) SubscribeReply(_ string, _ func(data []byte, reply
|
||||
return &fakeSubscription{}, nil
|
||||
}
|
||||
|
||||
func (f *fakeMessagingClient) Request(subject string, data []byte, _ time.Duration) ([]byte, error) {
|
||||
func (f *fakeMessagingClient) Request(subject string, data []byte, timeout time.Duration) ([]byte, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.requestCalls = append(f.requestCalls, requestCall{Subject: subject, Data: data})
|
||||
f.requestCalls = append(f.requestCalls, requestCall{Subject: subject, Data: data, Timeout: timeout})
|
||||
return f.requestReply, f.requestErr
|
||||
}
|
||||
|
||||
@@ -119,7 +123,7 @@ var _ = Describe("RemoteUnloaderAdapter", func() {
|
||||
BeforeEach(func() {
|
||||
locator = &fakeModelLocator{}
|
||||
mc = &fakeMessagingClient{}
|
||||
adapter = NewRemoteUnloaderAdapter(locator, mc)
|
||||
adapter = NewRemoteUnloaderAdapter(locator, mc, 3*time.Minute, 15*time.Minute)
|
||||
})
|
||||
|
||||
Describe("UnloadRemoteModel", func() {
|
||||
@@ -154,7 +158,7 @@ var _ = Describe("RemoteUnloaderAdapter", func() {
|
||||
}
|
||||
// Use a messaging client that fails the first Publish call only.
|
||||
failOnce := &failOnceMessagingClient{inner: mc, failOn: 0}
|
||||
adapter = NewRemoteUnloaderAdapter(locator, failOnce)
|
||||
adapter = NewRemoteUnloaderAdapter(locator, failOnce, 3*time.Minute, 15*time.Minute)
|
||||
|
||||
Expect(adapter.UnloadRemoteModel("llama")).To(Succeed())
|
||||
|
||||
@@ -259,3 +263,96 @@ func (f *failOnceMessagingClient) Request(subject string, data []byte, timeout t
|
||||
|
||||
func (f *failOnceMessagingClient) IsConnected() bool { return true }
|
||||
func (f *failOnceMessagingClient) Close() {}
|
||||
|
||||
var _ = Describe("RemoteUnloaderAdapter timeout configuration", func() {
|
||||
It("passes the configured install timeout to the messaging client", func() {
|
||||
mc := newScriptedMessagingClient()
|
||||
mc.scriptReply(messaging.SubjectNodeBackendInstall("n1"), messaging.BackendInstallReply{Success: true, Address: "127.0.0.1:0"})
|
||||
adapter := NewRemoteUnloaderAdapter(nil, mc, 7*time.Minute, 11*time.Minute)
|
||||
|
||||
_, err := adapter.InstallBackend("n1", "llama-cpp", "", "[]", "", "", "", 0, "", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(mc.calls).To(HaveLen(1))
|
||||
Expect(mc.calls[0].Timeout).To(Equal(7 * time.Minute))
|
||||
})
|
||||
|
||||
It("passes the configured upgrade timeout to the messaging client", func() {
|
||||
mc := newScriptedMessagingClient()
|
||||
mc.scriptReply(messaging.SubjectNodeBackendUpgrade("n1"), messaging.BackendUpgradeReply{Success: true})
|
||||
adapter := NewRemoteUnloaderAdapter(nil, mc, 7*time.Minute, 11*time.Minute)
|
||||
|
||||
_, err := adapter.UpgradeBackend("n1", "llama-cpp", "[]", "", "", "", 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(mc.calls).To(HaveLen(1))
|
||||
Expect(mc.calls[0].Timeout).To(Equal(11 * time.Minute))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("RemoteUnloaderAdapter NATS timeout handling", func() {
|
||||
It("wraps nats.ErrTimeout from InstallBackend in galleryop.ErrWorkerStillInstalling", func() {
|
||||
mc := newScriptedMessagingClient()
|
||||
mc.scriptErr(messaging.SubjectNodeBackendInstall("n1"), nats.ErrTimeout)
|
||||
adapter := NewRemoteUnloaderAdapter(nil, mc, 100*time.Millisecond, 1*time.Second)
|
||||
|
||||
_, err := adapter.InstallBackend("n1", "vllm", "", "[]", "", "", "", 0, "", nil)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(errors.Is(err, galleryop.ErrWorkerStillInstalling)).To(BeTrue(),
|
||||
"expected wrapped ErrWorkerStillInstalling, got %v", err)
|
||||
})
|
||||
|
||||
It("does NOT wrap non-timeout errors", func() {
|
||||
mc := newScriptedMessagingClient()
|
||||
mc.scriptErr(messaging.SubjectNodeBackendInstall("n1"), nats.ErrNoResponders)
|
||||
adapter := NewRemoteUnloaderAdapter(nil, mc, 100*time.Millisecond, 1*time.Second)
|
||||
|
||||
_, err := adapter.InstallBackend("n1", "vllm", "", "[]", "", "", "", 0, "", nil)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(errors.Is(err, galleryop.ErrWorkerStillInstalling)).To(BeFalse())
|
||||
Expect(errors.Is(err, nats.ErrNoResponders)).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("RemoteUnloaderAdapter install progress streaming", func() {
|
||||
It("forwards BackendInstallProgressEvent values into the onProgress callback when the worker publishes them", func() {
|
||||
mc := newScriptedMessagingClient()
|
||||
mc.scriptReply(messaging.SubjectNodeBackendInstall("n1"), messaging.BackendInstallReply{Success: true, Address: "127.0.0.1:0"})
|
||||
mc.scheduleProgressPublish("n1", "op-abc", []messaging.BackendInstallProgressEvent{
|
||||
{OpID: "op-abc", NodeID: "n1", Backend: "vllm", FileName: "vllm.tar.zst", Current: "100 MB", Total: "1 GB", Percentage: 10},
|
||||
{OpID: "op-abc", NodeID: "n1", Backend: "vllm", FileName: "vllm.tar.zst", Current: "500 MB", Total: "1 GB", Percentage: 50},
|
||||
})
|
||||
|
||||
adapter := NewRemoteUnloaderAdapter(nil, mc, 1*time.Second, 1*time.Second)
|
||||
var (
|
||||
received []messaging.BackendInstallProgressEvent
|
||||
mu sync.Mutex
|
||||
)
|
||||
onProgress := func(ev messaging.BackendInstallProgressEvent) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
received = append(received, ev)
|
||||
}
|
||||
|
||||
_, err := adapter.InstallBackend("n1", "vllm", "", "[]", "", "", "", 0, "op-abc", onProgress)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Eventually(func() int {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return len(received)
|
||||
}, "1s").Should(Equal(2))
|
||||
})
|
||||
|
||||
It("does NOT subscribe when onProgress is nil (reconciler retry path)", func() {
|
||||
mc := newScriptedMessagingClient()
|
||||
mc.scriptReply(messaging.SubjectNodeBackendInstall("n1"), messaging.BackendInstallReply{Success: true})
|
||||
|
||||
adapter := NewRemoteUnloaderAdapter(nil, mc, 1*time.Second, 1*time.Second)
|
||||
_, err := adapter.InstallBackend("n1", "vllm", "", "[]", "", "", "", 0, "", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(mc.subscribeCalls()).To(BeEmpty(),
|
||||
"reconciler-driven retries must not subscribe to the per-op progress subject")
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
@@ -15,7 +17,7 @@ var _ = Describe("RemoteUnloaderAdapter.UpgradeBackend", func() {
|
||||
mc.scriptReply(messaging.SubjectNodeBackendUpgrade(nodeID),
|
||||
messaging.BackendUpgradeReply{Success: true})
|
||||
|
||||
adapter := NewRemoteUnloaderAdapter(nil, mc)
|
||||
adapter := NewRemoteUnloaderAdapter(nil, mc, 3*time.Minute, 15*time.Minute)
|
||||
reply, err := adapter.UpgradeBackend(nodeID, "llama-cpp", `[{"name":"x"}]`, "", "", "", 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(reply.Success).To(BeTrue())
|
||||
@@ -24,7 +26,7 @@ var _ = Describe("RemoteUnloaderAdapter.UpgradeBackend", func() {
|
||||
It("returns the underlying error when the subject has no responders", func() {
|
||||
mc := newScriptedMessagingClient() // unscripted subject => fakeNoRespondersErr by harness convention
|
||||
|
||||
adapter := NewRemoteUnloaderAdapter(nil, mc)
|
||||
adapter := NewRemoteUnloaderAdapter(nil, mc, 3*time.Minute, 15*time.Minute)
|
||||
_, err := adapter.UpgradeBackend("missing-node", "llama-cpp", "", "", "", "", 0)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
@@ -7,14 +7,22 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// installProgressDebounce is the leading-edge window the worker uses when
|
||||
// streaming download progress to the master. 250ms caps wire chatter at
|
||||
// ~4 events/sec per in-flight install while still surfacing every
|
||||
// meaningful percentage jump.
|
||||
const installProgressDebounce = 250 * time.Millisecond
|
||||
|
||||
// buildProcessKey is the supervisor's stable identifier for a backend gRPC
|
||||
// process. It includes the replica index so the same model can run multiple
|
||||
// processes on a worker simultaneously without colliding on the same map slot
|
||||
@@ -100,6 +108,20 @@ func (s *backendSupervisor) installBackend(req messaging.BackendInstallRequest,
|
||||
}
|
||||
}
|
||||
|
||||
// When the master tagged this install with an OpID, stream the
|
||||
// gallery download progress back to it on the per-op NATS subject.
|
||||
// Old masters that omit OpID stay on the silent path so they keep
|
||||
// working without changes. The publisher releases its mutex before
|
||||
// every Publish so a slow link never stalls the download loop, and
|
||||
// the deferred Flush guarantees a terminal-percentage event reaches
|
||||
// the master even when the install errors out.
|
||||
var downloadCb func(file, current, total string, percentage float64)
|
||||
if req.OpID != "" && s.nats != nil {
|
||||
publisher := nodes.NewDebouncedInstallProgressPublisher(s.nats, s.nodeID, req.OpID, req.Backend, installProgressDebounce)
|
||||
downloadCb = publisher.OnDownload
|
||||
defer publisher.Flush()
|
||||
}
|
||||
|
||||
// On upgrade, run the gallery install path even if the binary already
|
||||
// exists on disk: findBackend would otherwise short-circuit and we'd
|
||||
// restart the same stale binary. The force flag passed to
|
||||
@@ -112,14 +134,14 @@ func (s *backendSupervisor) installBackend(req messaging.BackendInstallRequest,
|
||||
if req.URI != "" {
|
||||
xlog.Info("Installing backend from external URI", "backend", req.Backend, "uri", req.URI, "force", force)
|
||||
if err := galleryop.InstallExternalBackend(
|
||||
context.Background(), galleries, s.systemState, s.ml, nil, req.URI, req.Name, req.Alias, s.cfg.RequireBackendIntegrity,
|
||||
context.Background(), galleries, s.systemState, s.ml, downloadCb, req.URI, req.Name, req.Alias, s.cfg.RequireBackendIntegrity,
|
||||
); err != nil {
|
||||
return "", fmt.Errorf("installing backend from gallery: %w", err)
|
||||
}
|
||||
} else {
|
||||
xlog.Info("Installing backend from gallery", "backend", req.Backend, "force", force)
|
||||
if err := gallery.InstallBackendFromGallery(
|
||||
context.Background(), galleries, s.systemState, s.ml, req.Backend, nil, force, s.cfg.RequireBackendIntegrity,
|
||||
context.Background(), galleries, s.systemState, s.ml, req.Backend, downloadCb, force, s.cfg.RequireBackendIntegrity,
|
||||
); err != nil {
|
||||
return "", fmt.Errorf("installing backend from gallery: %w", err)
|
||||
}
|
||||
|
||||
@@ -16,8 +16,12 @@ const MaxSnippetSeconds = 30
|
||||
|
||||
// AudioSnippet captures the first MaxSnippetSeconds of a WAV file and computes
|
||||
// quality metrics. The result is a map suitable for merging into a BackendTrace
|
||||
// Data field.
|
||||
func AudioSnippet(wavPath string) map[string]any {
|
||||
// Data field. maxBytes caps the embedded base64 waveform so a single TTS or
|
||||
// transcription trace cannot blow past the backend-trace body cap (~1.3 MiB
|
||||
// of base64 per 30s of 16 kHz mono int16 PCM otherwise); when the encoded
|
||||
// waveform would exceed the cap the audio_wav_base64 field is dropped and
|
||||
// the rest of the metrics are returned. maxBytes <= 0 disables the cap.
|
||||
func AudioSnippet(wavPath string, maxBytes int) map[string]any {
|
||||
raw, err := os.ReadFile(wavPath)
|
||||
if err != nil {
|
||||
xlog.Warn("audio snippet: read failed", "path", wavPath, "error", err)
|
||||
@@ -34,12 +38,14 @@ func AudioSnippet(wavPath string) map[string]any {
|
||||
sampleRate = 16000
|
||||
}
|
||||
|
||||
return AudioSnippetFromPCM(pcm, sampleRate, len(pcm))
|
||||
return AudioSnippetFromPCM(pcm, sampleRate, len(pcm), maxBytes)
|
||||
}
|
||||
|
||||
// AudioSnippetFromPCM builds an audio snippet from raw PCM bytes (int16 LE mono).
|
||||
// totalPCMBytes is the full audio size before truncation (used to compute total duration).
|
||||
func AudioSnippetFromPCM(pcm []byte, sampleRate int, totalPCMBytes int) map[string]any {
|
||||
// totalPCMBytes is the full audio size before truncation (used to compute
|
||||
// total duration). maxBytes caps the embedded base64 waveform as described
|
||||
// on AudioSnippet.
|
||||
func AudioSnippetFromPCM(pcm []byte, sampleRate, totalPCMBytes, maxBytes int) map[string]any {
|
||||
if len(pcm) == 0 || len(pcm)%2 != 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -89,8 +95,7 @@ func AudioSnippetFromPCM(pcm []byte, sampleRate int, totalPCMBytes int) map[stri
|
||||
}
|
||||
buf.Write(snippetPCM)
|
||||
|
||||
return map[string]any{
|
||||
"audio_wav_base64": base64.StdEncoding.EncodeToString(buf.Bytes()),
|
||||
out := map[string]any{
|
||||
"audio_duration_s": math.Round(durationS*100) / 100,
|
||||
"audio_snippet_s": math.Round(snippetDuration*100) / 100,
|
||||
"audio_sample_rate": sampleRate,
|
||||
@@ -99,4 +104,15 @@ func AudioSnippetFromPCM(pcm []byte, sampleRate int, totalPCMBytes int) map[stri
|
||||
"audio_peak_dbfs": math.Round(peakDBFS*10) / 10,
|
||||
"audio_dc_offset": math.Round(dcOffset*10000) / 10000,
|
||||
}
|
||||
// Skip the embedded waveform when it would dominate the trace payload.
|
||||
// Truncating mid-base64 produces an undecodable string, so the right
|
||||
// move is to drop the field and let the UI render just the metrics.
|
||||
encodedSize := base64.StdEncoding.EncodedLen(buf.Len())
|
||||
if maxBytes <= 0 || encodedSize <= maxBytes {
|
||||
out["audio_wav_base64"] = base64.StdEncoding.EncodeToString(buf.Bytes())
|
||||
} else {
|
||||
xlog.Debug("audio snippet: dropping audio_wav_base64", "encoded_bytes", encodedSize, "max_bytes", maxBytes)
|
||||
out["audio_wav_base64_dropped_bytes"] = encodedSize
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
49
core/trace/audio_snippet_test.go
Normal file
49
core/trace/audio_snippet_test.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package trace_test
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
)
|
||||
|
||||
// One second of mono 16-bit PCM at 16 kHz: 32 KiB raw. After the 44-byte
|
||||
// WAV header and base64 encoding the snippet runs ~42 KiB, which is well
|
||||
// over the small caps used here and matches the smallest realistic TTS
|
||||
// output size.
|
||||
const (
|
||||
snippetSampleRate = 16000
|
||||
snippetSeconds = 1
|
||||
)
|
||||
|
||||
func makePCM(seconds, sampleRate int) []byte {
|
||||
return make([]byte, seconds*sampleRate*2) // int16 mono
|
||||
}
|
||||
|
||||
var _ = Describe("AudioSnippetFromPCM byte cap", func() {
|
||||
pcm := makePCM(snippetSeconds, snippetSampleRate)
|
||||
totalPCM := len(pcm)
|
||||
|
||||
It("omits audio_wav_base64 when the encoded snippet would exceed the cap, keeping the metrics", func() {
|
||||
out := trace.AudioSnippetFromPCM(pcm, snippetSampleRate, totalPCM, 1024)
|
||||
|
||||
Expect(out).ToNot(BeNil(), "metrics must still be returned even when the waveform is dropped")
|
||||
Expect(out).ToNot(HaveKey("audio_wav_base64"), "oversized base64 must be dropped so the UI does not try to render invalid audio data")
|
||||
Expect(out).To(HaveKey("audio_duration_s"))
|
||||
Expect(out).To(HaveKey("audio_sample_rate"))
|
||||
Expect(out).To(HaveKey("audio_rms_dbfs"))
|
||||
})
|
||||
|
||||
It("includes audio_wav_base64 when the snippet fits under the cap", func() {
|
||||
out := trace.AudioSnippetFromPCM(pcm, snippetSampleRate, totalPCM, 1024*1024)
|
||||
|
||||
Expect(out).To(HaveKey("audio_wav_base64"))
|
||||
Expect(out["audio_wav_base64"]).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
It("includes audio_wav_base64 when the cap is disabled (0)", func() {
|
||||
out := trace.AudioSnippetFromPCM(pcm, snippetSampleRate, totalPCM, 0)
|
||||
|
||||
Expect(out).To(HaveKey("audio_wav_base64"))
|
||||
})
|
||||
})
|
||||
@@ -2,6 +2,7 @@ package trace
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -49,13 +50,25 @@ var backendMu sync.Mutex
|
||||
var backendLogChan = make(chan *BackendTrace, 100)
|
||||
var backendInitOnce sync.Once
|
||||
|
||||
func InitBackendTracingIfEnabled(maxItems int) {
|
||||
// backendMaxBodyBytes caps each captured string value in a BackendTrace.Data
|
||||
// field to keep the /api/backend-traces JSON small enough for the admin UI to
|
||||
// load on every 5s auto-refresh. Mirrors the API-trace body cap added in
|
||||
// commit 61bf34ea: without it a chatty LLM workload (full message history per
|
||||
// trace) or any TTS run (~1.3 MiB of audio_wav_base64 per trace) blows the
|
||||
// payload past tens of MiB and locks the Traces page in a loading state.
|
||||
//
|
||||
// 0 disables the cap. Set on the first InitBackendTracingIfEnabled call only,
|
||||
// matching the sync.Once-guarded maxItems semantics.
|
||||
var backendMaxBodyBytes int
|
||||
|
||||
func InitBackendTracingIfEnabled(maxItems, maxBodyBytes int) {
|
||||
backendInitOnce.Do(func() {
|
||||
if maxItems <= 0 {
|
||||
maxItems = 100
|
||||
}
|
||||
backendMu.Lock()
|
||||
backendTraceBuffer = circularbuffer.New[*BackendTrace](maxItems)
|
||||
backendMaxBodyBytes = maxBodyBytes
|
||||
backendMu.Unlock()
|
||||
|
||||
go func() {
|
||||
@@ -71,6 +84,9 @@ func InitBackendTracingIfEnabled(maxItems int) {
|
||||
}
|
||||
|
||||
func RecordBackendTrace(t BackendTrace) {
|
||||
if t.Data != nil && backendMaxBodyBytes > 0 {
|
||||
t.Data = capDataStrings(t.Data, backendMaxBodyBytes)
|
||||
}
|
||||
select {
|
||||
case backendLogChan <- &t:
|
||||
default:
|
||||
@@ -78,6 +94,35 @@ func RecordBackendTrace(t BackendTrace) {
|
||||
}
|
||||
}
|
||||
|
||||
// capDataStrings walks a trace Data map and replaces any string value (at any
|
||||
// depth) that exceeds maxBytes with a fixed-size marker that names the
|
||||
// original byte count. The replacement is intentionally short and not valid
|
||||
// base64/JSON: the goal is to flag "this was dropped" cheaply, not to keep a
|
||||
// partial value that the UI might try to render. Non-string scalars and
|
||||
// non-map containers pass through untouched so structural fields like
|
||||
// total_deltas or audio_sample_rate remain useful.
|
||||
func capDataStrings(data map[string]any, maxBytes int) map[string]any {
|
||||
out := make(map[string]any, len(data))
|
||||
for k, v := range data {
|
||||
out[k] = capValue(v, maxBytes)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func capValue(v any, maxBytes int) any {
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
if len(val) > maxBytes {
|
||||
return fmt.Sprintf("<truncated: %d bytes>", len(val))
|
||||
}
|
||||
return val
|
||||
case map[string]any:
|
||||
return capDataStrings(val, maxBytes)
|
||||
default:
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
func GetBackendTraces() []BackendTrace {
|
||||
backendMu.Lock()
|
||||
if backendTraceBuffer == nil {
|
||||
@@ -136,3 +181,24 @@ func TruncateString(s string, maxLen int) string {
|
||||
}
|
||||
return s[:maxLen] + "..."
|
||||
}
|
||||
|
||||
// TruncateToBytes caps a string at exactly maxBytes, preserving the leading
|
||||
// content and appending a marker so the UI knows the value was clipped.
|
||||
// Unlike TruncateString it guarantees output <= maxBytes, which matters for
|
||||
// fields that feed back into the trace pipeline: capDataStrings in
|
||||
// RecordBackendTrace re-checks size and would otherwise replace a producer's
|
||||
// head-preserving truncation with the bare marker, losing the prefix.
|
||||
//
|
||||
// maxBytes <= 0 disables the cap, matching backendMaxBodyBytes semantics.
|
||||
func TruncateToBytes(s string, maxBytes int) string {
|
||||
if maxBytes <= 0 || len(s) <= maxBytes {
|
||||
return s
|
||||
}
|
||||
suffix := fmt.Sprintf("...[truncated, %d bytes]", len(s))
|
||||
if len(suffix) >= maxBytes {
|
||||
// Pathologically small caps can't fit the marker; fall back to a
|
||||
// hard cut so the contract (output <= maxBytes) still holds.
|
||||
return s[:maxBytes]
|
||||
}
|
||||
return s[:maxBytes-len(suffix)] + suffix
|
||||
}
|
||||
|
||||
160
core/trace/backend_trace_cap_test.go
Normal file
160
core/trace/backend_trace_cap_test.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package trace_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
)
|
||||
|
||||
// The /api/backend-traces endpoint ships up to TracingMaxItems entries to the
|
||||
// admin Traces UI on every 5s auto-refresh. Without a cap on the per-trace
|
||||
// Data field, a chatty agent-pool workload (LLM traces carry the full
|
||||
// `messages` array, TTS traces carry ~1.3 MiB of audio_wav_base64) makes the
|
||||
// response tens of MiB. The UI then stays in "loading" forever because the
|
||||
// download + parse runs longer than the refresh interval: the same symptom
|
||||
// the API-trace fix (commit 61bf34ea) addressed on the other side.
|
||||
//
|
||||
// These specs pin the generic safety net (Option A) so any future producer
|
||||
// that stuffs a large string into Data is automatically bounded.
|
||||
|
||||
const (
|
||||
smallCap = 1024
|
||||
smallCapStep = 16
|
||||
)
|
||||
|
||||
var _ = Describe("RecordBackendTrace Data capping", func() {
|
||||
BeforeEach(func() {
|
||||
// Init is sync.Once so the first test wins; subsequent tests just
|
||||
// clear the buffer. The cap value below has to match the first call.
|
||||
trace.InitBackendTracingIfEnabled(64, smallCap)
|
||||
trace.ClearBackendTraces()
|
||||
})
|
||||
|
||||
It("replaces oversized top-level string values with a truncation marker", func() {
|
||||
oversized := strings.Repeat("x", smallCap*4)
|
||||
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: time.Now(),
|
||||
Type: trace.BackendTraceLLM,
|
||||
ModelName: "m",
|
||||
Data: map[string]any{
|
||||
"messages": oversized,
|
||||
"small": "fits",
|
||||
},
|
||||
})
|
||||
|
||||
Eventually(trace.GetBackendTraces).Should(HaveLen(1))
|
||||
got := trace.GetBackendTraces()[0]
|
||||
|
||||
Expect(got.Data["small"]).To(Equal("fits"), "fields under the cap must pass through untouched")
|
||||
|
||||
// The marker is the contract the UI reads to show truncation; the
|
||||
// concrete shape can evolve but it must be a short fixed-size string
|
||||
// that encodes the original byte count so users know what was dropped.
|
||||
msg, ok := got.Data["messages"].(string)
|
||||
Expect(ok).To(BeTrue(), "string fields stay strings after capping")
|
||||
Expect(len(msg)).To(BeNumerically("<", smallCap), "capped value must fit under the configured cap")
|
||||
Expect(msg).To(ContainSubstring("truncated"))
|
||||
Expect(msg).To(ContainSubstring("4096"), "marker should reference the original byte count for diagnostics")
|
||||
})
|
||||
|
||||
It("recurses into nested maps so deeply nested oversized strings are also bounded", func() {
|
||||
oversized := strings.Repeat("y", smallCap*2)
|
||||
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: time.Now(),
|
||||
Type: trace.BackendTraceLLM,
|
||||
ModelName: "m",
|
||||
Data: map[string]any{
|
||||
"chat_deltas": map[string]any{
|
||||
"content": oversized,
|
||||
"total_deltas": 5,
|
||||
"tool_call_count": 0,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
Eventually(trace.GetBackendTraces).Should(HaveLen(1))
|
||||
got := trace.GetBackendTraces()[0]
|
||||
|
||||
deltas, ok := got.Data["chat_deltas"].(map[string]any)
|
||||
Expect(ok).To(BeTrue(), "nested map structure must be preserved")
|
||||
Expect(deltas["total_deltas"]).To(Equal(5), "non-string siblings must pass through untouched")
|
||||
|
||||
content, ok := deltas["content"].(string)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(len(content)).To(BeNumerically("<", smallCap), "nested oversized string must still be capped")
|
||||
Expect(content).To(ContainSubstring("truncated"))
|
||||
})
|
||||
|
||||
It("leaves values within the cap untouched", func() {
|
||||
smallVal := strings.Repeat("z", smallCap-smallCapStep)
|
||||
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: time.Now(),
|
||||
Type: trace.BackendTraceEmbedding,
|
||||
ModelName: "m",
|
||||
Data: map[string]any{
|
||||
"input_text": smallVal,
|
||||
},
|
||||
})
|
||||
|
||||
Eventually(trace.GetBackendTraces).Should(HaveLen(1))
|
||||
got := trace.GetBackendTraces()[0]
|
||||
|
||||
Expect(got.Data["input_text"]).To(Equal(smallVal))
|
||||
})
|
||||
|
||||
It("does not re-truncate values that producers already capped with TruncateToBytes", func() {
|
||||
// Producers (LLM messages/response, etc.) prefer head-preserving
|
||||
// truncation so users can still read the start of the conversation.
|
||||
// TruncateToBytes guarantees output <= cap, so the generic safety
|
||||
// net below must leave it alone, otherwise the kept prefix gets
|
||||
// thrown away and replaced with the marker.
|
||||
preTruncated := trace.TruncateToBytes(strings.Repeat("a", smallCap*4), smallCap)
|
||||
Expect(len(preTruncated)).To(BeNumerically("<=", smallCap))
|
||||
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: time.Now(),
|
||||
Type: trace.BackendTraceLLM,
|
||||
ModelName: "m",
|
||||
Data: map[string]any{
|
||||
"messages": preTruncated,
|
||||
},
|
||||
})
|
||||
|
||||
Eventually(trace.GetBackendTraces).Should(HaveLen(1))
|
||||
got := trace.GetBackendTraces()[0]
|
||||
Expect(got.Data["messages"]).To(Equal(preTruncated))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("TruncateToBytes", func() {
|
||||
It("returns the input unchanged when it fits", func() {
|
||||
Expect(trace.TruncateToBytes("hello", 1024)).To(Equal("hello"))
|
||||
})
|
||||
|
||||
It("treats maxBytes <= 0 as unlimited", func() {
|
||||
Expect(trace.TruncateToBytes("hello", 0)).To(Equal("hello"))
|
||||
Expect(trace.TruncateToBytes("hello", -1)).To(Equal("hello"))
|
||||
})
|
||||
|
||||
It("caps oversized input to at most maxBytes and preserves the head", func() {
|
||||
in := strings.Repeat("a", 5000)
|
||||
out := trace.TruncateToBytes(in, 100)
|
||||
Expect(len(out)).To(BeNumerically("<=", 100), "output must never exceed the cap so the generic Record-time safety net doesn't fire")
|
||||
Expect(out).To(HavePrefix("a"), "should keep the leading content readable")
|
||||
Expect(out).To(ContainSubstring("truncated"), "should mark the value as truncated for the UI")
|
||||
})
|
||||
|
||||
It("falls back to plain truncation when the cap is smaller than the suffix", func() {
|
||||
in := strings.Repeat("a", 100)
|
||||
out := trace.TruncateToBytes(in, 4)
|
||||
Expect(len(out)).To(Equal(4))
|
||||
Expect(out).To(Equal("aaaa"))
|
||||
})
|
||||
})
|
||||
13
core/trace/trace_suite_test.go
Normal file
13
core/trace/trace_suite_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package trace_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestTrace(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Trace test suite")
|
||||
}
|
||||
@@ -86,6 +86,9 @@ The frontend is a standard LocalAI instance with distributed mode enabled. These
|
||||
| `--auto-approve-nodes` | `LOCALAI_AUTO_APPROVE_NODES` | `false` | Auto-approve new worker nodes (skip admin approval) |
|
||||
| `--auth` | `LOCALAI_AUTH` | `false` | **Must be `true`** for distributed mode |
|
||||
| `--auth-database-url` | `LOCALAI_AUTH_DATABASE_URL` | *(required)* | PostgreSQL connection URL |
|
||||
| `--backend-install-timeout` | `LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT` | `15m` | How long the frontend waits for a worker to acknowledge a backend install before considering the request stalled. Raise it when workers pull large backend images over slow links. If a worker takes longer than this, the operation shows as "still installing in background" in the admin UI and clears once the worker finishes. |
|
||||
| `--backend-upgrade-timeout` | `LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT` | `15m` | Same as the install timeout, applied to backend upgrades (force-reinstall). |
|
||||
| `--expose-node-header` | `LOCALAI_EXPOSE_NODE_HEADER` | `false` | When enabled, inference responses on the OpenAI-compatible endpoints (chat completions, completions, embeddings) as well as the Anthropic Messages (`/v1/messages`) and Ollama (`/api/chat`, `/api/generate`, `/api/embed`) shims carry an `X-LocalAI-Node` header with the ID of the worker node that served the request. Useful for debugging, observability and load-balancer attribution. Off by default: the node ID reveals internal cluster topology and should not be exposed on a public endpoint. Best-effort: under heavy concurrency for the same model across multiple replicas, the header may reflect a recent routing decision rather than this exact request's. Acceptable for observability and debugging. |
|
||||
|
||||
### Optional: S3 Object Storage
|
||||
|
||||
@@ -103,6 +106,31 @@ When S3 is not configured, model files are transferred directly from the fronten
|
||||
|
||||
For high-throughput or very large model files, S3 can be more efficient since it avoids streaming through the frontend.
|
||||
|
||||
### Watching Backend Installs
|
||||
|
||||
While a worker downloads a backend, the admin **Operations Bar** at the top
|
||||
of the UI shows real-time progress: current file, downloaded/total bytes,
|
||||
and percentage. This works the same as single-node mode.
|
||||
|
||||
When an install targets more than one worker, an **N nodes** chevron
|
||||
appears on the operation row. Click it to expand a per-node breakdown,
|
||||
with one row per worker showing:
|
||||
|
||||
- A status pill: **Queued** (gray), **Downloading** (blue), **Worker busy**
|
||||
(yellow), **Done** (green), or **Failed** (red).
|
||||
- The file currently being downloaded with current/total bytes and percentage.
|
||||
- A thin per-node progress bar.
|
||||
- Any error returned by the worker.
|
||||
|
||||
The yellow **Worker busy** pill means the worker took longer than
|
||||
`--backend-install-timeout` to acknowledge but is most likely still
|
||||
working in the background. The admin UI clears it as soon as the worker
|
||||
finishes; no action is required from the operator.
|
||||
|
||||
If a worker is running an older LocalAI release that does not report
|
||||
progress, its row in the breakdown will still show terminal status
|
||||
(queued / done / failed / worker busy) but no per-file progress.
|
||||
|
||||
## Worker Configuration
|
||||
|
||||
Workers are started with the `worker` subcommand. Each worker is generic — it doesn't need a backend type at startup:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user