mirror of
https://github.com/mudler/LocalAI.git
synced 2026-07-04 21:37:02 -04:00
Compare commits
1 Commits
feat/model
...
dependabot
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ce57bf8171 |
11
Dockerfile
11
Dockerfile
@@ -171,17 +171,6 @@ RUN if [ "${BUILD_TYPE}" = "hipblas" ]; then \
|
||||
ln -s /opt/rocm-**/lib/llvm/lib/libomp.so /usr/lib/libomp.so \
|
||||
; fi
|
||||
|
||||
# ROCm's bundled libdrm_amdgpu is built with a hardcoded fallback lookup path
|
||||
# for the ASIC ID table (/opt/amdgpu/share/libdrm/amdgpu.ids), which only exists
|
||||
# if AMD's full amdgpu graphics/DKMS stack is installed. This compute-only image
|
||||
# doesn't have it, so hipblas/rocBLAS log "No such file or directory" on every
|
||||
# model load and can fail to identify the GPU. Point it at the equivalent file
|
||||
# Ubuntu's libdrm-common package already ships.
|
||||
RUN if [ "${BUILD_TYPE}" = "hipblas" ] && [ -f /usr/share/libdrm/amdgpu.ids ] && [ ! -e /opt/amdgpu/share/libdrm/amdgpu.ids ]; then \
|
||||
mkdir -p /opt/amdgpu/share/libdrm && \
|
||||
ln -s /usr/share/libdrm/amdgpu.ids /opt/amdgpu/share/libdrm/amdgpu.ids \
|
||||
; fi
|
||||
|
||||
RUN expr "${BUILD_TYPE}" = intel && echo "intel" > /run/localai/capability || echo "not intel"
|
||||
|
||||
# Cuda
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
IK_LLAMA_VERSION?=bbc7de475178dd0535c16ad85f204a2529806c9d
|
||||
IK_LLAMA_VERSION?=29431b31c89e79c10f8736e8f2742485ba1713d6
|
||||
LLAMA_REPO?=https://github.com/ikawrakow/ik_llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -101,13 +101,4 @@ if(LLAMA_GRPC_BUILD_TESTS)
|
||||
target_link_libraries(message_content_test PRIVATE ${_LLAMA_COMMON_TARGET})
|
||||
target_compile_features(message_content_test PRIVATE cxx_std_17)
|
||||
add_test(NAME message_content_test COMMAND message_content_test)
|
||||
|
||||
# Parent-death watcher test (parent_watch.h) — standard library only, but
|
||||
# needs a threading runtime for std::thread.
|
||||
find_package(Threads REQUIRED)
|
||||
add_executable(parent_watch_test parent_watch_test.cpp parent_watch.h)
|
||||
target_include_directories(parent_watch_test PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_link_libraries(parent_watch_test PRIVATE Threads::Threads)
|
||||
target_compile_features(parent_watch_test PRIVATE cxx_std_17)
|
||||
add_test(NAME parent_watch_test COMMAND parent_watch_test)
|
||||
endif()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
LLAMA_VERSION?=d4cff114c0084f1fbc9b4c62717eca8fb2ae494a
|
||||
LLAMA_VERSION?=6f4f53f2b7da54fcdbbecaaa734337c337ad6176
|
||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -75,8 +75,6 @@
|
||||
#include <windows.h>
|
||||
#endif
|
||||
|
||||
#include "parent_watch.h" // best-effort parent-death backstop (see header)
|
||||
|
||||
|
||||
using grpc::Server;
|
||||
using grpc::ServerBuilder;
|
||||
@@ -3444,10 +3442,6 @@ int main(int argc, char** argv) {
|
||||
}
|
||||
}
|
||||
|
||||
// Best-effort backstop: self-terminate if the LocalAI process that spawned
|
||||
// us dies without cleaning us up (see parent_watch.h).
|
||||
llama_grpc::start_parent_death_watcher();
|
||||
|
||||
server_context ctx_server;
|
||||
BackendServiceImpl service(ctx_server);
|
||||
|
||||
|
||||
@@ -1,179 +0,0 @@
|
||||
// Parent-death watcher (best-effort backstop) for the llama.cpp gRPC backend.
|
||||
//
|
||||
// LocalAI spawns this backend as a child process and, on a clean shutdown,
|
||||
// tears it down itself (SIGTERM -> grace -> SIGKILL). That graceful path only
|
||||
// runs when LocalAI receives a catchable signal and lives long enough to run
|
||||
// its handlers. If LocalAI is SIGKILLed (e.g. a supervising process's grace
|
||||
// period elapses first), that teardown never runs and this backend would be
|
||||
// reparented to init and linger, holding VRAM and its listen port.
|
||||
//
|
||||
// The watcher here is a best-effort backstop for exactly that case: it does
|
||||
// NOT replace the graceful teardown, it only covers the "parent vanished
|
||||
// without cleaning up" path. It detects reparenting: when the process that
|
||||
// spawned this backend dies, the kernel reparents us to the nearest sub-reaper
|
||||
// or to init (PID 1), so getppid() stops matching the value captured at
|
||||
// startup. This getppid() approach is portable across Linux/macOS (unlike the
|
||||
// Linux-only PR_SET_PDEATHSIG), which is why it is used here, mirroring the Go
|
||||
// backends' pkg/grpc/parentwatch.go. It is disabled on Windows, which has no
|
||||
// equivalent orphan-reparenting semantics.
|
||||
//
|
||||
// This header is intentionally dependency-free (C++ standard library only) so
|
||||
// it can be exercised by a standalone unit test (parent_watch_test.cpp) without
|
||||
// building the full llama.cpp + gRPC backend.
|
||||
#ifndef LLAMA_GRPC_PARENT_WATCH_H
|
||||
#define LLAMA_GRPC_PARENT_WATCH_H
|
||||
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <chrono>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
|
||||
#if !defined(_WIN32)
|
||||
#include <unistd.h> // getppid(2), _exit(2)
|
||||
#endif
|
||||
|
||||
namespace llama_grpc {
|
||||
|
||||
// Env var names are shared verbatim with the Go and Python backends for
|
||||
// consistency across languages.
|
||||
inline const char *kEnvParentWatch() { return "LOCALAI_BACKEND_PARENT_WATCH"; }
|
||||
inline const char *kEnvParentWatchInterval() { return "LOCALAI_BACKEND_PARENT_WATCH_INTERVAL"; }
|
||||
|
||||
// Default poll interval in milliseconds. Matches the Go side's 2 * time.Second.
|
||||
inline long parent_watch_default_interval_ms() { return 2000; }
|
||||
|
||||
namespace detail {
|
||||
inline std::string trim_lower(const std::string &in, bool lower) {
|
||||
size_t a = in.find_first_not_of(" \t\r\n");
|
||||
size_t b = in.find_last_not_of(" \t\r\n");
|
||||
if (a == std::string::npos) {
|
||||
return "";
|
||||
}
|
||||
std::string s = in.substr(a, b - a + 1);
|
||||
if (lower) {
|
||||
std::transform(s.begin(), s.end(), s.begin(),
|
||||
[](unsigned char c) { return std::tolower(c); });
|
||||
}
|
||||
return s;
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
// parent_watch_enabled reports whether the watcher should run. Enabled by
|
||||
// default; a falsey value ("false"/"0"/"no"/"off", case-insensitive) disables
|
||||
// it, matching the Go implementation's exact semantics.
|
||||
inline bool parent_watch_enabled() {
|
||||
#if defined(_WIN32)
|
||||
return false;
|
||||
#else
|
||||
const char *v = std::getenv(kEnvParentWatch());
|
||||
if (v == nullptr || v[0] == '\0') {
|
||||
return true;
|
||||
}
|
||||
const std::string s = detail::trim_lower(v, true);
|
||||
return !(s == "false" || s == "0" || s == "no" || s == "off");
|
||||
#endif
|
||||
}
|
||||
|
||||
// parent_watch_interval_ms returns the poll interval in milliseconds. Accepts
|
||||
// Go-style duration strings ("500ms", "2s", "1m") for cross-language parity, or
|
||||
// a bare number interpreted as seconds. Defaults to
|
||||
// parent_watch_default_interval_ms().
|
||||
inline long parent_watch_interval_ms() {
|
||||
const long def = parent_watch_default_interval_ms();
|
||||
const char *v = std::getenv(kEnvParentWatchInterval());
|
||||
if (v == nullptr || v[0] == '\0') {
|
||||
return def;
|
||||
}
|
||||
const std::string s = detail::trim_lower(v, false);
|
||||
if (s.empty()) {
|
||||
return def;
|
||||
}
|
||||
size_t i = 0;
|
||||
while (i < s.size() && (std::isdigit((unsigned char)s[i]) || s[i] == '.')) {
|
||||
i++;
|
||||
}
|
||||
if (i == 0) {
|
||||
return def;
|
||||
}
|
||||
double num = 0.0;
|
||||
try {
|
||||
num = std::stod(s.substr(0, i));
|
||||
} catch (...) {
|
||||
return def;
|
||||
}
|
||||
const std::string unit = s.substr(i);
|
||||
long ms;
|
||||
if (unit == "ms") {
|
||||
ms = (long)num;
|
||||
} else if (unit == "s" || unit.empty()) {
|
||||
ms = (long)(num * 1000.0);
|
||||
} else if (unit == "m") {
|
||||
ms = (long)(num * 60000.0);
|
||||
} else {
|
||||
return def; // unrecognized unit
|
||||
}
|
||||
return ms > 0 ? ms : def;
|
||||
}
|
||||
|
||||
#if !defined(_WIN32)
|
||||
// parent_died reports whether this process has been reparented away from the
|
||||
// parent it had when the watcher started. Reparenting is the standard POSIX
|
||||
// signal that the original parent (here, the LocalAI process that spawned this
|
||||
// backend) has exited: the orphan is handed to the nearest sub-reaper or to
|
||||
// init (PID 1), so getppid() no longer matches the value captured at startup.
|
||||
inline bool parent_died(pid_t orig_ppid) {
|
||||
const pid_t ppid = getppid();
|
||||
return ppid != orig_ppid || ppid == 1;
|
||||
}
|
||||
|
||||
// watch_parent_death polls until parent_died reports the original parent is
|
||||
// gone, then invokes on_death. It blocks, so run it on its own thread.
|
||||
inline void watch_parent_death(pid_t orig_ppid, long interval_ms,
|
||||
const std::function<void()> &on_death) {
|
||||
for (;;) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(interval_ms));
|
||||
if (parent_died(orig_ppid)) {
|
||||
on_death();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// start_parent_death_watcher installs the best-effort safety net described in
|
||||
// the file header on the calling backend process. It is a no-op when disabled,
|
||||
// on Windows, or when the process is already orphaned at startup
|
||||
// (getppid() <= 1). This is a backstop alongside — never a replacement for —
|
||||
// LocalAI's graceful teardown.
|
||||
inline void start_parent_death_watcher() {
|
||||
#if !defined(_WIN32)
|
||||
if (!parent_watch_enabled()) {
|
||||
return;
|
||||
}
|
||||
const pid_t orig_ppid = getppid();
|
||||
// A parent of 1 (or less) at startup means we were already orphaned (or
|
||||
// launched directly under init) — there is no original parent to watch for.
|
||||
if (orig_ppid <= 1) {
|
||||
return;
|
||||
}
|
||||
const long interval_ms = parent_watch_interval_ms();
|
||||
std::thread([orig_ppid, interval_ms]() {
|
||||
watch_parent_death(orig_ppid, interval_ms, [orig_ppid]() {
|
||||
fprintf(stderr,
|
||||
"backend parent process (pid %d) exited without stopping "
|
||||
"this backend; self-terminating to avoid orphaning\n",
|
||||
(int)orig_ppid);
|
||||
fflush(stderr);
|
||||
_exit(1);
|
||||
});
|
||||
}).detach();
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace llama_grpc
|
||||
|
||||
#endif // LLAMA_GRPC_PARENT_WATCH_H
|
||||
@@ -1,197 +0,0 @@
|
||||
// Unit tests for the parent-death watcher (parent_watch.h).
|
||||
//
|
||||
// Build & run standalone (C++ standard library only, no nlohmann/json needed):
|
||||
// g++ -std=c++17 -pthread parent_watch_test.cpp -o t && ./t
|
||||
//
|
||||
// The core test (TestDetectsReparent) builds a genuine two-level process tree
|
||||
// (test -> middle -> grandchild), lets the middle process die, and asserts the
|
||||
// grandchild's watch_parent_death detects the reparenting and self-terminates —
|
||||
// mirroring the Go test in pkg/grpc/parentwatch_test.go, but with fork(2).
|
||||
//
|
||||
// On Windows this file compiles to a no-op success (the watcher is unsupported
|
||||
// there), matching parent_watch.h's platform gating.
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
|
||||
#include "parent_watch.h"
|
||||
|
||||
static int failures = 0;
|
||||
|
||||
static void check(bool ok, const std::string &name) {
|
||||
if (!ok) {
|
||||
failures++;
|
||||
fprintf(stderr, "FAIL: %s\n", name.c_str());
|
||||
} else {
|
||||
fprintf(stderr, "ok: %s\n", name.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
// Env-parsing tests are platform-independent and always run.
|
||||
static void test_env_parsing() {
|
||||
using namespace llama_grpc;
|
||||
|
||||
// Interval: default when unset.
|
||||
unsetenv("LOCALAI_BACKEND_PARENT_WATCH_INTERVAL");
|
||||
check(parent_watch_interval_ms() == 2000, "interval default 2000ms");
|
||||
|
||||
setenv("LOCALAI_BACKEND_PARENT_WATCH_INTERVAL", "500ms", 1);
|
||||
check(parent_watch_interval_ms() == 500, "interval 500ms");
|
||||
|
||||
setenv("LOCALAI_BACKEND_PARENT_WATCH_INTERVAL", "2s", 1);
|
||||
check(parent_watch_interval_ms() == 2000, "interval 2s");
|
||||
|
||||
setenv("LOCALAI_BACKEND_PARENT_WATCH_INTERVAL", "1m", 1);
|
||||
check(parent_watch_interval_ms() == 60000, "interval 1m");
|
||||
|
||||
setenv("LOCALAI_BACKEND_PARENT_WATCH_INTERVAL", "3", 1); // bare number -> seconds
|
||||
check(parent_watch_interval_ms() == 3000, "interval bare 3 -> 3000ms");
|
||||
|
||||
setenv("LOCALAI_BACKEND_PARENT_WATCH_INTERVAL", "garbage", 1);
|
||||
check(parent_watch_interval_ms() == 2000, "interval garbage -> default");
|
||||
unsetenv("LOCALAI_BACKEND_PARENT_WATCH_INTERVAL");
|
||||
|
||||
#if !defined(_WIN32)
|
||||
// Enabled semantics (POSIX only; always false on Windows).
|
||||
unsetenv("LOCALAI_BACKEND_PARENT_WATCH");
|
||||
check(parent_watch_enabled(), "enabled by default");
|
||||
|
||||
for (const char *falsey : {"false", "0", "no", "off", "OFF", " False "}) {
|
||||
setenv("LOCALAI_BACKEND_PARENT_WATCH", falsey, 1);
|
||||
check(!parent_watch_enabled(), std::string("disabled by '") + falsey + "'");
|
||||
}
|
||||
setenv("LOCALAI_BACKEND_PARENT_WATCH", "true", 1);
|
||||
check(parent_watch_enabled(), "enabled by 'true'");
|
||||
setenv("LOCALAI_BACKEND_PARENT_WATCH", "1", 1);
|
||||
check(parent_watch_enabled(), "enabled by '1'");
|
||||
unsetenv("LOCALAI_BACKEND_PARENT_WATCH");
|
||||
#endif
|
||||
}
|
||||
|
||||
#if !defined(_WIN32)
|
||||
|
||||
#include <atomic>
|
||||
#include <ctime>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/wait.h>
|
||||
#include <unistd.h>
|
||||
|
||||
static bool file_exists(const std::string &p) {
|
||||
struct stat st;
|
||||
return ::stat(p.c_str(), &st) == 0;
|
||||
}
|
||||
|
||||
static bool wait_for_file(const std::string &p, int timeout_ms) {
|
||||
int waited = 0;
|
||||
while (waited < timeout_ms) {
|
||||
if (file_exists(p)) {
|
||||
return true;
|
||||
}
|
||||
usleep(20 * 1000);
|
||||
waited += 20;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static void write_file(const std::string &p, const std::string &content) {
|
||||
FILE *f = fopen(p.c_str(), "w");
|
||||
if (f) {
|
||||
fwrite(content.data(), 1, content.size(), f);
|
||||
fclose(f);
|
||||
}
|
||||
}
|
||||
|
||||
// Builds test -> middle -> grandchild via fork(2). The grandchild arms the REAL
|
||||
// watch_parent_death against middle; middle exits, orphaning the grandchild;
|
||||
// the watcher must detect the reparenting and self-terminate.
|
||||
static void test_detects_reparent() {
|
||||
char tmpl[] = "/tmp/parentwatch_test_XXXXXX";
|
||||
char *dir = mkdtemp(tmpl);
|
||||
if (dir == nullptr) {
|
||||
check(false, "mkdtemp");
|
||||
return;
|
||||
}
|
||||
const std::string ready_file = std::string(dir) + "/ready";
|
||||
const std::string exited_file = std::string(dir) + "/exited";
|
||||
|
||||
pid_t middle = fork();
|
||||
if (middle < 0) {
|
||||
check(false, "fork middle");
|
||||
return;
|
||||
}
|
||||
|
||||
if (middle == 0) {
|
||||
// ---- middle process ----
|
||||
pid_t grandchild = fork();
|
||||
if (grandchild < 0) {
|
||||
_exit(4);
|
||||
}
|
||||
if (grandchild == 0) {
|
||||
// ---- grandchild process ----
|
||||
pid_t orig_ppid = getppid(); // == middle
|
||||
std::thread([&]() {
|
||||
llama_grpc::watch_parent_death(orig_ppid, 50 /*ms*/, [&]() {
|
||||
write_file(exited_file, "1");
|
||||
_exit(7);
|
||||
});
|
||||
}).detach();
|
||||
|
||||
// Safety valve: never linger if something goes wrong.
|
||||
std::thread([]() {
|
||||
usleep(30 * 1000 * 1000);
|
||||
_exit(2);
|
||||
}).detach();
|
||||
|
||||
// Signal readiness only after the watcher captured orig_ppid.
|
||||
write_file(ready_file, std::to_string(getpid()));
|
||||
for (;;) {
|
||||
pause();
|
||||
}
|
||||
}
|
||||
// middle: wait until grandchild is ready, then exit to orphan it.
|
||||
if (!wait_for_file(ready_file, 10000)) {
|
||||
_exit(5);
|
||||
}
|
||||
_exit(0);
|
||||
}
|
||||
|
||||
// ---- test (top) process ----
|
||||
int status = 0;
|
||||
waitpid(middle, &status, 0); // reap middle only; grandchild is orphaned
|
||||
|
||||
check(file_exists(ready_file), "grandchild signaled readiness");
|
||||
|
||||
bool detected = wait_for_file(exited_file, 10000);
|
||||
check(detected, "watcher detected parent death and self-terminated");
|
||||
|
||||
// Best-effort cleanup: kill the grandchild if it somehow survived.
|
||||
if (file_exists(ready_file)) {
|
||||
FILE *f = fopen(ready_file.c_str(), "r");
|
||||
if (f) {
|
||||
int pid = 0;
|
||||
if (fscanf(f, "%d", &pid) == 1 && pid > 1) {
|
||||
kill(pid, SIGKILL);
|
||||
}
|
||||
fclose(f);
|
||||
}
|
||||
}
|
||||
unlink(ready_file.c_str());
|
||||
unlink(exited_file.c_str());
|
||||
rmdir(dir);
|
||||
}
|
||||
|
||||
#endif // !_WIN32
|
||||
|
||||
int main() {
|
||||
test_env_parsing();
|
||||
#if !defined(_WIN32)
|
||||
test_detects_reparent();
|
||||
#endif
|
||||
if (failures == 0) {
|
||||
fprintf(stderr, "\nAll parent_watch tests passed.\n");
|
||||
return 0;
|
||||
}
|
||||
fprintf(stderr, "\n%d parent_watch test(s) failed.\n", failures);
|
||||
return 1;
|
||||
}
|
||||
@@ -22,10 +22,6 @@ cp -r grpc-server.cpp llama.cpp/tools/grpc-server/
|
||||
# unit test (compiled only when -DLLAMA_GRPC_BUILD_TESTS=ON).
|
||||
cp -r message_content.h llama.cpp/tools/grpc-server/
|
||||
cp -r message_content_test.cpp llama.cpp/tools/grpc-server/
|
||||
# Parent-death watcher (included by grpc-server.cpp) and its standalone unit
|
||||
# test (run via backend/cpp/run-unit-tests.sh; also buildable under ctest).
|
||||
cp -r parent_watch.h llama.cpp/tools/grpc-server/
|
||||
cp -r parent_watch_test.cpp llama.cpp/tools/grpc-server/
|
||||
cp -rfv llama.cpp/vendor/nlohmann/json.hpp llama.cpp/tools/grpc-server/
|
||||
cp -rfv llama.cpp/vendor/cpp-httplib/httplib.h llama.cpp/tools/grpc-server/
|
||||
|
||||
|
||||
@@ -36,12 +36,6 @@ else
|
||||
if [ -d "$CURDIR/lib/rocblas/library" ]; then
|
||||
export ROCBLAS_TENSILE_LIBPATH="$CURDIR"/lib/rocblas/library
|
||||
fi
|
||||
# Same for hipBLASLt (rocblaslt): the bundled libhipblaslt.so resolves its
|
||||
# TensileLibrary_lazy_gfx*.dat kernel data relative to itself, so point it at
|
||||
# the bundled data or it falls back to slow generic kernels (issue #10660).
|
||||
if [ -d "$CURDIR/lib/hipblaslt/library" ]; then
|
||||
export HIPBLASLT_TENSILE_LIBPATH="$CURDIR"/lib/hipblaslt/library
|
||||
fi
|
||||
fi
|
||||
|
||||
# If there is a lib/ld.so, use it
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
# Local development: point at a working checkout instead of cloning, e.g.
|
||||
# make PRIVACY_FILTER_SRC=$HOME/c/privacy-filter.cpp grpc-server
|
||||
|
||||
PRIVACY_FILTER_VERSION?=735a6c28607ee82afc3a670383f41b55266a3b9a
|
||||
PRIVACY_FILTER_VERSION?=595f59630c69d361b5196f2aba2c71c873d0c13c
|
||||
PRIVACY_FILTER_REPO?=https://github.com/localai-org/privacy-filter.cpp
|
||||
PRIVACY_FILTER_SRC?=
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ for test_src in "${tests[@]}"; do
|
||||
name="$(basename "$test_src" .cpp)"
|
||||
bin="$(mktemp -d)/$name"
|
||||
echo "==> $test_src"
|
||||
if ! "$CXX" -std=c++17 -Wall -Wextra -pthread \
|
||||
if ! "$CXX" -std=c++17 -Wall -Wextra \
|
||||
-I"$JSON_INC" -I"$(dirname "$test_src")" \
|
||||
"$test_src" -o "$bin"; then
|
||||
echo "COMPILE FAILED: $test_src" >&2
|
||||
|
||||
@@ -34,12 +34,6 @@ else
|
||||
if [ -d "$CURDIR/lib/rocblas/library" ]; then
|
||||
export ROCBLAS_TENSILE_LIBPATH="$CURDIR"/lib/rocblas/library
|
||||
fi
|
||||
# Same for hipBLASLt (rocblaslt): the bundled libhipblaslt.so resolves its
|
||||
# TensileLibrary_lazy_gfx*.dat kernel data relative to itself, so point it at
|
||||
# the bundled data or it falls back to slow generic kernels (issue #10660).
|
||||
if [ -d "$CURDIR/lib/hipblaslt/library" ]; then
|
||||
export HIPBLASLT_TENSILE_LIBPATH="$CURDIR"/lib/hipblaslt/library
|
||||
fi
|
||||
fi
|
||||
|
||||
# If there is a lib/ld.so, use it
|
||||
|
||||
@@ -25,7 +25,7 @@ target_include_directories(goacestepcpp PRIVATE ${ACESTEP_DIR}/src ${ACESTEP_DIR
|
||||
target_include_directories(goacestepcpp SYSTEM PRIVATE ${ACESTEP_DIR}/ggml/include)
|
||||
|
||||
# Link GPU backends if available (mirrors link_ggml_backends macro)
|
||||
foreach(backend blas cuda hip metal vulkan)
|
||||
foreach(backend blas cuda metal vulkan)
|
||||
if(TARGET ggml-${backend})
|
||||
target_link_libraries(goacestepcpp PRIVATE ggml-${backend})
|
||||
string(TOUPPER ${backend} BACKEND_UPPER)
|
||||
|
||||
@@ -24,14 +24,7 @@ else ifeq ($(BUILD_TYPE),openblas)
|
||||
else ifeq ($(BUILD_TYPE),clblas)
|
||||
CMAKE_ARGS+=-DGGML_CLBLAST=ON -DCLBlast_DIR=/some/path
|
||||
else ifeq ($(BUILD_TYPE),hipblas)
|
||||
# This ggml only understands GGML_HIP (GGML_HIPBLAS was removed upstream),
|
||||
# so passing GGML_HIPBLAS silently produced a CPU-only build (see #10666).
|
||||
ROCM_HOME ?= /opt/rocm
|
||||
ROCM_PATH ?= /opt/rocm
|
||||
export CXX=$(ROCM_HOME)/llvm/bin/clang++
|
||||
export CC=$(ROCM_HOME)/llvm/bin/clang
|
||||
AMDGPU_TARGETS ?= gfx908,gfx90a,gfx942,gfx950,gfx1030,gfx1100,gfx1101,gfx1102,gfx1151,gfx1200,gfx1201
|
||||
CMAKE_ARGS+=-DGGML_HIP=ON -DAMDGPU_TARGETS=$(AMDGPU_TARGETS)
|
||||
CMAKE_ARGS+=-DGGML_HIPBLAS=ON
|
||||
else ifeq ($(BUILD_TYPE),vulkan)
|
||||
CMAKE_ARGS+=-DGGML_VULKAN=ON
|
||||
else ifeq ($(OS),Darwin)
|
||||
|
||||
@@ -142,12 +142,19 @@ func buildAnthropicRequest(opts *pb.PredictOptions, cfg *proxyConfig, stream boo
|
||||
if req.MaxTokens <= 0 {
|
||||
req.MaxTokens = anthropicDefaultMaxTokens
|
||||
}
|
||||
// Do not forward temperature/top_p. Newer Anthropic reasoning models reject
|
||||
// requests that carry temperature ("`temperature` is deprecated for this
|
||||
// model"), and the OpenAI-compatible clients typically send only the
|
||||
// server-side DEFAULT sampling values rather than user intent — dropping
|
||||
// them loses nothing and lets the upstream apply its own defaults.
|
||||
_ = opts
|
||||
// Newer Anthropic models 400 when both temperature and top_p are
|
||||
// set ("`temperature` and `top_p` cannot both be specified for
|
||||
// this model. Please use only one.") even though their docs only
|
||||
// "recommend" picking one. The OpenAI-compatible chat UI almost
|
||||
// always sends both with default values, so prefer temperature
|
||||
// and drop top_p when both are present.
|
||||
if t := opts.GetTemperature(); t != 0 {
|
||||
v := float64(t)
|
||||
req.Temperature = &v
|
||||
} else if t := opts.GetTopP(); t != 0 {
|
||||
v := float64(t)
|
||||
req.TopP = &v
|
||||
}
|
||||
|
||||
req.Tools = convertOpenAITools(opts.GetTools())
|
||||
req.ToolChoice = convertOpenAIToolChoice(opts.GetToolChoice())
|
||||
|
||||
@@ -3,6 +3,7 @@ package main
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
@@ -74,16 +75,15 @@ func TestPredict_Anthropic_BasicMessages(t *testing.T) {
|
||||
g.Expect(captured.Messages).To(HaveLen(1))
|
||||
g.Expect(captured.Messages[0].Role).To(Equal("user"))
|
||||
g.Expect(captured.MaxTokens).To(Equal(int32(32)))
|
||||
// Newer Anthropic reasoning models reject requests carrying temperature
|
||||
// ("`temperature` is deprecated for this model"); clients typically send
|
||||
// only default sampling values, so the translator forwards neither.
|
||||
g.Expect(captured.Temperature).To(BeNil())
|
||||
g.Expect(captured.Temperature).NotTo(BeNil())
|
||||
g.Expect(*captured.Temperature).To(Equal(0.5))
|
||||
// Anthropic 400s when both temperature and top_p are set; the
|
||||
// translator must prefer temperature and drop top_p.
|
||||
g.Expect(captured.TopP).To(BeNil())
|
||||
g.Expect(captured.Stream).To(BeFalse())
|
||||
}
|
||||
|
||||
// Sampling parameters are not forwarded at all — the upstream applies its
|
||||
// own defaults (newest models reject explicit temperature/top_p).
|
||||
// When only top_p is set, it should be forwarded.
|
||||
func TestPredict_Anthropic_TopPOnly(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
@@ -99,7 +99,11 @@ func TestPredict_Anthropic_TopPOnly(t *testing.T) {
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.Temperature).To(BeNil())
|
||||
g.Expect(captured.TopP).To(BeNil())
|
||||
// PredictOptions.TopP is float32 on the wire; the translator widens
|
||||
// to float64 so 0.9 round-trips as 0.8999999761581421… — compare
|
||||
// with a small tolerance rather than exact equality.
|
||||
g.Expect(captured.TopP).NotTo(BeNil())
|
||||
g.Expect(math.Abs(*captured.TopP - 0.9)).To(BeNumerically("<=", 1e-6))
|
||||
}
|
||||
|
||||
func TestPredict_Anthropic_DefaultsMaxTokens(t *testing.T) {
|
||||
|
||||
@@ -30,7 +30,7 @@ type openAIRequest struct {
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
MaxTokens *int32 `json:"max_completion_tokens,omitempty"` // newer OpenAI models reject max_tokens ("use max_completion_tokens instead")
|
||||
MaxTokens *int32 `json:"max_tokens,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||
@@ -107,10 +107,14 @@ func buildOpenAIRequest(opts *pb.PredictOptions, cfg *proxyConfig, stream bool)
|
||||
Tools: parseRawJSON(opts.GetTools()),
|
||||
ToolChoice: parseRawJSON(opts.GetToolChoice()),
|
||||
}
|
||||
// Do not forward temperature/top_p. Newer OpenAI reasoning models reject
|
||||
// temperature as deprecated, and clients typically send only default
|
||||
// sampling values rather than user intent — let the upstream apply its
|
||||
// own defaults.
|
||||
if t := opts.GetTemperature(); t != 0 {
|
||||
v := float64(t)
|
||||
req.Temperature = &v
|
||||
}
|
||||
if t := opts.GetTopP(); t != 0 {
|
||||
v := float64(t)
|
||||
req.TopP = &v
|
||||
}
|
||||
if n := opts.GetTokens(); n > 0 {
|
||||
req.MaxTokens = &n
|
||||
}
|
||||
|
||||
@@ -74,9 +74,8 @@ func TestPredict_OpenAI_BasicChat(t *testing.T) {
|
||||
g.Expect(captured.Messages).To(HaveLen(2))
|
||||
g.Expect(captured.Messages[0].Role).To(Equal("system"))
|
||||
g.Expect(captured.Messages[1].Role).To(Equal("user"))
|
||||
// Sampling parameters are not forwarded (newest models reject explicit
|
||||
// temperature); token limit is serialized as max_completion_tokens.
|
||||
g.Expect(captured.Temperature).To(BeNil())
|
||||
g.Expect(captured.Temperature).NotTo(BeNil())
|
||||
g.Expect(*captured.Temperature).To(Equal(0.5))
|
||||
g.Expect(captured.MaxTokens).NotTo(BeNil())
|
||||
g.Expect(*captured.MaxTokens).To(Equal(int32(32)))
|
||||
g.Expect(captured.Stream).To(BeFalse())
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# CrispASR version (release tag)
|
||||
CRISPASR_REPO?=https://github.com/CrispStrobe/CrispASR
|
||||
CRISPASR_VERSION?=f35185b876fc482fcb2053a81a2697936ed5fcc0
|
||||
CRISPASR_VERSION?=3b93758f9725d400eca82976f895e4cec3f31260
|
||||
SO_TARGET?=libgocrispasr.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -30,7 +30,7 @@ target_include_directories(gomnivoicecpp PRIVATE ${OMNIVOICE_DIR}/src)
|
||||
target_include_directories(gomnivoicecpp SYSTEM PRIVATE ${OMNIVOICE_DIR}/ggml/include)
|
||||
|
||||
# Link GPU backends if the upstream ggml created them.
|
||||
foreach(backend blas cuda hip metal vulkan sycl)
|
||||
foreach(backend blas cuda metal vulkan sycl)
|
||||
if(TARGET ggml-${backend})
|
||||
target_link_libraries(gomnivoicecpp PRIVATE ggml-${backend})
|
||||
if(backend STREQUAL "cuda")
|
||||
|
||||
@@ -24,14 +24,7 @@ else ifeq ($(BUILD_TYPE),openblas)
|
||||
else ifeq ($(BUILD_TYPE),clblas)
|
||||
CMAKE_ARGS+=-DGGML_CLBLAST=ON -DCLBlast_DIR=/some/path
|
||||
else ifeq ($(BUILD_TYPE),hipblas)
|
||||
# This ggml only understands GGML_HIP (GGML_HIPBLAS was removed upstream),
|
||||
# so passing GGML_HIPBLAS silently produced a CPU-only build (see #10666).
|
||||
ROCM_HOME ?= /opt/rocm
|
||||
ROCM_PATH ?= /opt/rocm
|
||||
export CXX=$(ROCM_HOME)/llvm/bin/clang++
|
||||
export CC=$(ROCM_HOME)/llvm/bin/clang
|
||||
AMDGPU_TARGETS ?= gfx908,gfx90a,gfx942,gfx950,gfx1030,gfx1100,gfx1101,gfx1102,gfx1151,gfx1200,gfx1201
|
||||
CMAKE_ARGS+=-DGGML_HIP=ON -DAMDGPU_TARGETS=$(AMDGPU_TARGETS)
|
||||
CMAKE_ARGS+=-DGGML_HIPBLAS=ON
|
||||
else ifeq ($(BUILD_TYPE),vulkan)
|
||||
CMAKE_ARGS+=-DGGML_VULKAN=ON
|
||||
else ifeq ($(OS),Darwin)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# parakeet-cpp backend Makefile.
|
||||
#
|
||||
# Upstream pin lives below as PARAKEET_VERSION?=e8acc6172a94e20a952cf1843decace5d771a94b
|
||||
# Upstream pin lives below as PARAKEET_VERSION?=f469a57270a1cc4554acb15febf60e56619673b9
|
||||
# (.github/bump_deps.sh) can find and update it - matches the
|
||||
# whisper.cpp / ds4 / vibevoice-cpp convention.
|
||||
#
|
||||
@@ -15,7 +15,7 @@
|
||||
# That's what the L0 smoke test uses. The default target below does the
|
||||
# proper clone-at-pin + cmake build so CI doesn't need a side-checkout.
|
||||
|
||||
PARAKEET_VERSION?=e8acc6172a94e20a952cf1843decace5d771a94b
|
||||
PARAKEET_VERSION?=f469a57270a1cc4554acb15febf60e56619673b9
|
||||
PARAKEET_REPO?=https://github.com/mudler/parakeet.cpp
|
||||
|
||||
GOCMD?=go
|
||||
|
||||
@@ -30,7 +30,7 @@ target_include_directories(goqwen3ttscpp PRIVATE ${QWENTTS_DIR}/src)
|
||||
target_include_directories(goqwen3ttscpp SYSTEM PRIVATE ${QWENTTS_DIR}/ggml/include)
|
||||
|
||||
# Link GPU backends if the upstream ggml created them.
|
||||
foreach(backend blas cuda hip metal vulkan sycl)
|
||||
foreach(backend blas cuda metal vulkan sycl)
|
||||
if(TARGET ggml-${backend})
|
||||
target_link_libraries(goqwen3ttscpp PRIVATE ggml-${backend})
|
||||
if(backend STREQUAL "cuda")
|
||||
|
||||
@@ -24,14 +24,7 @@ else ifeq ($(BUILD_TYPE),openblas)
|
||||
else ifeq ($(BUILD_TYPE),clblas)
|
||||
CMAKE_ARGS+=-DGGML_CLBLAST=ON -DCLBlast_DIR=/some/path
|
||||
else ifeq ($(BUILD_TYPE),hipblas)
|
||||
# This ggml only understands GGML_HIP (GGML_HIPBLAS was removed upstream),
|
||||
# so passing GGML_HIPBLAS silently produced a CPU-only build (see #10666).
|
||||
ROCM_HOME ?= /opt/rocm
|
||||
ROCM_PATH ?= /opt/rocm
|
||||
export CXX=$(ROCM_HOME)/llvm/bin/clang++
|
||||
export CC=$(ROCM_HOME)/llvm/bin/clang
|
||||
AMDGPU_TARGETS ?= gfx908,gfx90a,gfx942,gfx950,gfx1030,gfx1100,gfx1101,gfx1102,gfx1151,gfx1200,gfx1201
|
||||
CMAKE_ARGS+=-DGGML_HIP=ON -DAMDGPU_TARGETS=$(AMDGPU_TARGETS)
|
||||
CMAKE_ARGS+=-DGGML_HIPBLAS=ON
|
||||
else ifeq ($(BUILD_TYPE),vulkan)
|
||||
CMAKE_ARGS+=-DGGML_VULKAN=ON
|
||||
else ifeq ($(OS),Darwin)
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# stablediffusion.cpp (ggml)
|
||||
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
||||
STABLEDIFFUSION_GGML_VERSION?=2574f5936571645f784b77623e1f09bad97d948a
|
||||
STABLEDIFFUSION_GGML_VERSION?=484baa41e5e006c52dcd4addc38c830b9489745f
|
||||
|
||||
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ target_include_directories(govibevoicecpp SYSTEM PRIVATE ${VIBEVOICE_DIR}/third_
|
||||
# Link GPU backends if available — vibevoice's own CMake already links
|
||||
# these to the libvibevoice STATIC library, but we re-link them on the
|
||||
# MODULE so resolved symbols include all backend kernels.
|
||||
foreach(backend blas cuda hip metal vulkan)
|
||||
foreach(backend blas cuda metal vulkan)
|
||||
if(TARGET ggml-${backend})
|
||||
target_link_libraries(govibevoicecpp PRIVATE ggml-${backend})
|
||||
string(TOUPPER ${backend} BACKEND_UPPER)
|
||||
|
||||
@@ -29,14 +29,7 @@ else ifeq ($(BUILD_TYPE),openblas)
|
||||
else ifeq ($(BUILD_TYPE),clblas)
|
||||
CMAKE_ARGS+=-DGGML_CLBLAST=ON -DCLBlast_DIR=/some/path
|
||||
else ifeq ($(BUILD_TYPE),hipblas)
|
||||
# This ggml only understands GGML_HIP (GGML_HIPBLAS was removed upstream),
|
||||
# so passing GGML_HIPBLAS silently produced a CPU-only build (see #10666).
|
||||
ROCM_HOME ?= /opt/rocm
|
||||
ROCM_PATH ?= /opt/rocm
|
||||
export CXX=$(ROCM_HOME)/llvm/bin/clang++
|
||||
export CC=$(ROCM_HOME)/llvm/bin/clang
|
||||
AMDGPU_TARGETS ?= gfx908,gfx90a,gfx942,gfx950,gfx1030,gfx1100,gfx1101,gfx1102,gfx1151,gfx1200,gfx1201
|
||||
CMAKE_ARGS+=-DGGML_HIP=ON -DAMDGPU_TARGETS=$(AMDGPU_TARGETS)
|
||||
CMAKE_ARGS+=-DGGML_HIPBLAS=ON -DVIBEVOICE_GGML_HIPBLAS=ON
|
||||
else ifeq ($(BUILD_TYPE),vulkan)
|
||||
CMAKE_ARGS+=-DGGML_VULKAN=ON -DVIBEVOICE_GGML_VULKAN=ON
|
||||
else ifeq ($(OS),Darwin)
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# whisper.cpp version
|
||||
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
||||
WHISPER_CPP_VERSION?=6fc7c33b4c3a2cec83e4b65abd5e96a890480375
|
||||
WHISPER_CPP_VERSION?=0874de3e8e8e48361dba85c7fe6d176f008bf158
|
||||
SO_TARGET?=libgowhisper.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -11,8 +11,6 @@ import os
|
||||
|
||||
import grpc
|
||||
|
||||
from parent_watch import start_parent_death_watcher
|
||||
|
||||
|
||||
class _AbortHandler(grpc.RpcMethodHandler):
|
||||
"""A method handler that immediately aborts with UNAUTHENTICATED."""
|
||||
@@ -72,13 +70,6 @@ def get_auth_interceptors(*, aio: bool = False):
|
||||
|
||||
Returns an empty list when LOCALAI_GRPC_AUTH_TOKEN is not set.
|
||||
"""
|
||||
# Arm the best-effort parent-death backstop here: this is the single helper
|
||||
# every LocalAI Python backend invokes exactly once while building its gRPC
|
||||
# server (mirroring how the Go watcher arms in pkg/grpc's shared serve path).
|
||||
# start_parent_death_watcher() is idempotent and a no-op when disabled or on
|
||||
# unsupported platforms — see parent_watch.py.
|
||||
start_parent_death_watcher()
|
||||
|
||||
token = os.environ.get("LOCALAI_GRPC_AUTH_TOKEN", "")
|
||||
if not token:
|
||||
return []
|
||||
|
||||
@@ -20,15 +20,7 @@ def split_reasoning(text, think_start, think_end):
|
||||
Returns ``(reasoning_content, remaining_text)``. When ``think_start`` is
|
||||
empty or not found, returns ``("", text)`` unchanged.
|
||||
"""
|
||||
if not think_start or not text:
|
||||
return "", text
|
||||
if think_start not in text:
|
||||
# Models like Qwen3.5 open assistant turns already INSIDE thinking, so
|
||||
# the generated text carries only the closing tag. Everything before it
|
||||
# is reasoning that would otherwise leak into the content.
|
||||
if think_end and think_end in text:
|
||||
head, _, tail = text.partition(think_end)
|
||||
return head.strip(), tail.strip()
|
||||
if not think_start or not text or think_start not in text:
|
||||
return "", text
|
||||
pattern = re.compile(
|
||||
re.escape(think_start) + r"(.*?)" + re.escape(think_end or ""),
|
||||
|
||||
@@ -1,75 +0,0 @@
|
||||
"""Unit tests for the mlx/mlx-vlm shared helpers (mlx_utils.py).
|
||||
|
||||
Run standalone (Python standard library only, no backend venv needed):
|
||||
python3 -m unittest mlx_utils_test
|
||||
|
||||
These mirror the server-less helper tests in backend/python/mlx/test.py
|
||||
(TestSharedHelpers), but live here so they run on any platform: the mlx
|
||||
test module imports grpc/backend_pb2 at import time and needs the MLX venv,
|
||||
whereas mlx_utils only needs the standard library.
|
||||
"""
|
||||
|
||||
import types
|
||||
import unittest
|
||||
|
||||
from mlx_utils import parse_tool_calls, split_reasoning
|
||||
|
||||
|
||||
class TestSplitReasoning(unittest.TestCase):
|
||||
def test_both_tags(self):
|
||||
r, c = split_reasoning(
|
||||
"<think>step 1\nstep 2</think>The answer is 42.", "<think>", "</think>"
|
||||
)
|
||||
self.assertEqual(r, "step 1\nstep 2")
|
||||
self.assertEqual(c, "The answer is 42.")
|
||||
|
||||
def test_implicit_opener_only_closing_tag(self):
|
||||
# Qwen3.5 opens the assistant turn already inside thinking, so the
|
||||
# output carries only the closing tag; everything before it is reasoning.
|
||||
r, c = split_reasoning(
|
||||
"The user is asking about the weather.\n</think>\n\nThe weather in Rome is sunny.",
|
||||
"<think>",
|
||||
"</think>",
|
||||
)
|
||||
self.assertEqual(r, "The user is asking about the weather.")
|
||||
self.assertEqual(c, "The weather in Rome is sunny.")
|
||||
|
||||
def test_no_tags_at_all(self):
|
||||
r, c = split_reasoning("just text", "<think>", "</think>")
|
||||
self.assertEqual(r, "")
|
||||
self.assertEqual(c, "just text")
|
||||
|
||||
def test_empty_think_end_and_no_opener_match(self):
|
||||
# No think_end to anchor on, and the opener is absent → return unchanged.
|
||||
r, c = split_reasoning("no opener here", "<think>", "")
|
||||
self.assertEqual(r, "")
|
||||
self.assertEqual(c, "no opener here")
|
||||
|
||||
def test_empty_text(self):
|
||||
r, c = split_reasoning("", "<think>", "</think>")
|
||||
self.assertEqual(r, "")
|
||||
self.assertEqual(c, "")
|
||||
|
||||
|
||||
class TestParseToolCalls(unittest.TestCase):
|
||||
def test_with_shim(self):
|
||||
tm = types.SimpleNamespace(
|
||||
tool_call_start="<tool_call>",
|
||||
tool_call_end="</tool_call>",
|
||||
parse_tool_call=lambda body, tools: {
|
||||
"name": "get_weather",
|
||||
"arguments": {"location": body.strip()},
|
||||
},
|
||||
)
|
||||
calls, remaining = parse_tool_calls(
|
||||
"Sure: <tool_call>Paris</tool_call>", tm, tools=None
|
||||
)
|
||||
self.assertEqual(len(calls), 1)
|
||||
self.assertEqual(calls[0]["name"], "get_weather")
|
||||
self.assertEqual(calls[0]["arguments"], '{"location": "Paris"}')
|
||||
self.assertEqual(calls[0]["index"], 0)
|
||||
self.assertNotIn("<tool_call>", remaining)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,149 +0,0 @@
|
||||
"""Parent-death watcher (best-effort backstop) for LocalAI Python backends.
|
||||
|
||||
LocalAI spawns each backend as a child process and, on a clean shutdown, tears
|
||||
it down itself (SIGTERM -> grace -> SIGKILL). That graceful path only runs when
|
||||
LocalAI receives a catchable signal and lives long enough to run its handlers.
|
||||
If LocalAI is SIGKILLed (e.g. a supervising process's grace period elapses
|
||||
first), that teardown never runs and this backend would be reparented to init
|
||||
and linger, holding GPU/VRAM and its listen port.
|
||||
|
||||
The watcher here is a best-effort backstop for exactly that case: it does NOT
|
||||
replace the graceful teardown, it only covers the "parent vanished without
|
||||
cleaning up" path. It detects reparenting: when the process that spawned this
|
||||
backend dies, the kernel reparents us to the nearest sub-reaper or to init
|
||||
(PID 1), so os.getppid() stops matching the value captured at startup. This
|
||||
getppid() approach is portable across Linux/macOS (unlike the Linux-only
|
||||
PR_SET_PDEATHSIG), which is why it is used here, mirroring the Go backends'
|
||||
pkg/grpc/parentwatch.go and the C++ backends' parent_watch.h. It is disabled on
|
||||
Windows, which has no equivalent orphan-reparenting semantics.
|
||||
|
||||
Env vars (shared verbatim across the Go, C++ and Python backends):
|
||||
LOCALAI_BACKEND_PARENT_WATCH enabled by default; a falsey value
|
||||
("false"/"0"/"no"/"off", case-insensitive)
|
||||
disables it.
|
||||
LOCALAI_BACKEND_PARENT_WATCH_INTERVAL poll interval as a Go-style duration
|
||||
string ("500ms", "2s", "1m") or a bare
|
||||
number of seconds. Defaults to 2s.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
|
||||
ENV_PARENT_WATCH = "LOCALAI_BACKEND_PARENT_WATCH"
|
||||
ENV_PARENT_WATCH_INTERVAL = "LOCALAI_BACKEND_PARENT_WATCH_INTERVAL"
|
||||
|
||||
_DEFAULT_INTERVAL_SECONDS = 2.0
|
||||
|
||||
# Guard so repeated calls (e.g. get_auth_interceptors invoked more than once)
|
||||
# only ever arm a single watcher thread per process.
|
||||
_started = False
|
||||
_started_lock = threading.Lock()
|
||||
|
||||
|
||||
def _enabled():
|
||||
"""Report whether the watcher should run in this process."""
|
||||
# Windows does not reparent orphans to a well-known init PID, so the
|
||||
# getppid() heuristic used here doesn't apply there.
|
||||
if os.name == "nt" or sys.platform.startswith("win"):
|
||||
return False
|
||||
val = os.environ.get(ENV_PARENT_WATCH, "").strip().lower()
|
||||
if val in ("false", "0", "no", "off"):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _interval_seconds():
|
||||
"""Return the configured poll interval in seconds, or the default.
|
||||
|
||||
Accepts Go-style duration strings ("500ms", "2s", "1m") for cross-language
|
||||
parity, or a bare number interpreted as seconds.
|
||||
"""
|
||||
raw = os.environ.get(ENV_PARENT_WATCH_INTERVAL, "").strip()
|
||||
if not raw:
|
||||
return _DEFAULT_INTERVAL_SECONDS
|
||||
# Split numeric prefix from unit suffix.
|
||||
i = 0
|
||||
while i < len(raw) and (raw[i].isdigit() or raw[i] == "." or (i == 0 and raw[i] in "+-")):
|
||||
i += 1
|
||||
if i == 0:
|
||||
return _DEFAULT_INTERVAL_SECONDS
|
||||
try:
|
||||
num = float(raw[:i])
|
||||
except ValueError:
|
||||
return _DEFAULT_INTERVAL_SECONDS
|
||||
unit = raw[i:].lower()
|
||||
if unit == "ms":
|
||||
seconds = num / 1000.0
|
||||
elif unit in ("s", ""):
|
||||
seconds = num
|
||||
elif unit == "m":
|
||||
seconds = num * 60.0
|
||||
else:
|
||||
return _DEFAULT_INTERVAL_SECONDS
|
||||
return seconds if seconds > 0 else _DEFAULT_INTERVAL_SECONDS
|
||||
|
||||
|
||||
def _parent_died(orig_ppid):
|
||||
"""Report whether this process has been reparented away from orig_ppid.
|
||||
|
||||
Reparenting is the standard POSIX signal that the original parent (here, the
|
||||
LocalAI process that spawned this backend) has exited: the orphan is handed
|
||||
to the nearest sub-reaper or to init (PID 1), so os.getppid() no longer
|
||||
matches the value captured at startup.
|
||||
"""
|
||||
ppid = os.getppid()
|
||||
return ppid != orig_ppid or ppid == 1
|
||||
|
||||
|
||||
def _watch(orig_ppid, interval, on_death):
|
||||
"""Poll until _parent_died reports the original parent is gone, then call
|
||||
on_death. Blocks, so run it on its own (daemon) thread."""
|
||||
import time
|
||||
|
||||
while True:
|
||||
time.sleep(interval)
|
||||
if _parent_died(orig_ppid):
|
||||
on_death()
|
||||
return
|
||||
|
||||
|
||||
def start_parent_death_watcher():
|
||||
"""Install the best-effort safety net described in this module's docstring.
|
||||
|
||||
No-op when disabled, on Windows, when already orphaned at startup
|
||||
(os.getppid() <= 1), or if already started. This is a backstop alongside —
|
||||
never a replacement for — LocalAI's graceful teardown.
|
||||
"""
|
||||
global _started
|
||||
if not _enabled():
|
||||
return
|
||||
with _started_lock:
|
||||
if _started:
|
||||
return
|
||||
orig_ppid = os.getppid()
|
||||
# A parent of 1 (or less) at startup means we were already orphaned (or
|
||||
# launched directly under init) — there is no original parent to watch.
|
||||
if orig_ppid <= 1:
|
||||
return
|
||||
interval = _interval_seconds()
|
||||
|
||||
def on_death():
|
||||
print(
|
||||
"backend parent process (pid {}) exited without stopping this "
|
||||
"backend; self-terminating to avoid orphaning".format(orig_ppid),
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
# Immediate, non-cleanup exit: this is a shutdown safety net and the
|
||||
# normal graceful path is already gone.
|
||||
os._exit(1)
|
||||
|
||||
thread = threading.Thread(
|
||||
target=_watch,
|
||||
args=(orig_ppid, interval, on_death),
|
||||
name="parent-death-watcher",
|
||||
daemon=True,
|
||||
)
|
||||
thread.start()
|
||||
_started = True
|
||||
@@ -1,150 +0,0 @@
|
||||
"""Unit tests for the parent-death watcher (parent_watch.py).
|
||||
|
||||
Run standalone (Python standard library only, no backend venv needed):
|
||||
python3 -m unittest parent_watch_test
|
||||
|
||||
The core test (test_detects_reparent) builds a genuine two-level process tree
|
||||
(test -> middle -> grandchild) with os.fork, lets the middle process die, and
|
||||
asserts the grandchild's parent_watch._watch detects the reparenting and
|
||||
self-terminates — mirroring the Go test in pkg/grpc/parentwatch_test.go and the
|
||||
C++ test in backend/cpp/llama-cpp/parent_watch_test.cpp.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import parent_watch
|
||||
|
||||
|
||||
class TestParentWatchEnvParsing(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self._saved = {
|
||||
k: os.environ.get(k)
|
||||
for k in (parent_watch.ENV_PARENT_WATCH, parent_watch.ENV_PARENT_WATCH_INTERVAL)
|
||||
}
|
||||
for k in self._saved:
|
||||
os.environ.pop(k, None)
|
||||
|
||||
def tearDown(self):
|
||||
for k, v in self._saved.items():
|
||||
if v is None:
|
||||
os.environ.pop(k, None)
|
||||
else:
|
||||
os.environ[k] = v
|
||||
|
||||
def test_interval_default(self):
|
||||
self.assertEqual(parent_watch._interval_seconds(), 2.0)
|
||||
|
||||
def test_interval_units(self):
|
||||
cases = {"500ms": 0.5, "2s": 2.0, "1m": 60.0, "3": 3.0, "0.5s": 0.5}
|
||||
for raw, expected in cases.items():
|
||||
os.environ[parent_watch.ENV_PARENT_WATCH_INTERVAL] = raw
|
||||
self.assertAlmostEqual(parent_watch._interval_seconds(), expected, msg=raw)
|
||||
|
||||
def test_interval_garbage_falls_back(self):
|
||||
os.environ[parent_watch.ENV_PARENT_WATCH_INTERVAL] = "garbage"
|
||||
self.assertEqual(parent_watch._interval_seconds(), 2.0)
|
||||
|
||||
@unittest.skipIf(os.name == "nt" or sys.platform.startswith("win"), "POSIX only")
|
||||
def test_enabled_default(self):
|
||||
self.assertTrue(parent_watch._enabled())
|
||||
|
||||
@unittest.skipIf(os.name == "nt" or sys.platform.startswith("win"), "POSIX only")
|
||||
def test_disabled_by_falsey(self):
|
||||
for val in ("false", "0", "no", "off", "OFF", " False "):
|
||||
os.environ[parent_watch.ENV_PARENT_WATCH] = val
|
||||
self.assertFalse(parent_watch._enabled(), msg=val)
|
||||
|
||||
@unittest.skipIf(os.name == "nt" or sys.platform.startswith("win"), "POSIX only")
|
||||
def test_enabled_by_truthy(self):
|
||||
for val in ("true", "1", "yes", "on"):
|
||||
os.environ[parent_watch.ENV_PARENT_WATCH] = val
|
||||
self.assertTrue(parent_watch._enabled(), msg=val)
|
||||
|
||||
|
||||
@unittest.skipIf(os.name == "nt" or sys.platform.startswith("win"), "fork/reparent is POSIX only")
|
||||
class TestParentWatchReparent(unittest.TestCase):
|
||||
def _wait_for_file(self, path, timeout=10.0):
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
if os.path.exists(path):
|
||||
return True
|
||||
time.sleep(0.02)
|
||||
return False
|
||||
|
||||
def test_detects_reparent(self):
|
||||
tmpdir = tempfile.mkdtemp(prefix="parentwatch_test_")
|
||||
ready_file = os.path.join(tmpdir, "ready")
|
||||
exited_file = os.path.join(tmpdir, "exited")
|
||||
|
||||
middle = os.fork()
|
||||
if middle == 0:
|
||||
# ---- middle process ----
|
||||
grandchild = os.fork()
|
||||
if grandchild == 0:
|
||||
# ---- grandchild process: arm the REAL watcher against middle ----
|
||||
orig_ppid = os.getppid()
|
||||
|
||||
def on_death():
|
||||
with open(exited_file, "w") as f:
|
||||
f.write("1")
|
||||
os._exit(7)
|
||||
|
||||
threading.Thread(
|
||||
target=parent_watch._watch,
|
||||
args=(orig_ppid, 0.05, on_death),
|
||||
daemon=True,
|
||||
).start()
|
||||
|
||||
# Safety valve: never linger if something goes wrong.
|
||||
def bail():
|
||||
time.sleep(30)
|
||||
os._exit(2)
|
||||
|
||||
threading.Thread(target=bail, daemon=True).start()
|
||||
|
||||
# Signal readiness only after the watcher captured orig_ppid.
|
||||
with open(ready_file, "w") as f:
|
||||
f.write(str(os.getpid()))
|
||||
while True:
|
||||
time.sleep(1)
|
||||
else:
|
||||
# middle: wait until grandchild is ready, then exit to orphan it.
|
||||
if not self._wait_for_file(ready_file):
|
||||
os._exit(5)
|
||||
os._exit(0)
|
||||
|
||||
# ---- test (top) process ----
|
||||
os.waitpid(middle, 0) # reap middle only; grandchild is orphaned
|
||||
|
||||
self.assertTrue(os.path.exists(ready_file), "grandchild never signaled readiness")
|
||||
self.assertTrue(
|
||||
self._wait_for_file(exited_file),
|
||||
"watcher did not detect parent death within timeout",
|
||||
)
|
||||
|
||||
# Best-effort cleanup: kill the grandchild if it somehow survived.
|
||||
try:
|
||||
with open(ready_file) as f:
|
||||
pid = int(f.read().strip())
|
||||
if pid > 1:
|
||||
os.kill(pid, 9)
|
||||
except (OSError, ValueError):
|
||||
pass
|
||||
for p in (ready_file, exited_file):
|
||||
try:
|
||||
os.remove(p)
|
||||
except OSError:
|
||||
pass
|
||||
try:
|
||||
os.rmdir(tmpdir)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -58,18 +58,7 @@ def messages_to_dicts(proto_messages):
|
||||
d["reasoning_content"] = msg.reasoning_content
|
||||
if msg.tool_calls:
|
||||
try:
|
||||
tool_calls = json.loads(msg.tool_calls)
|
||||
# Chat templates (e.g. Qwen) iterate function.arguments as a
|
||||
# mapping, but the OpenAI wire format carries it as a JSON
|
||||
# string — decode it back so the template's .items() works.
|
||||
for tc in tool_calls:
|
||||
fn = tc.get("function") if isinstance(tc, dict) else None
|
||||
if isinstance(fn, dict) and isinstance(fn.get("arguments"), str):
|
||||
try:
|
||||
fn["arguments"] = json.loads(fn["arguments"])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
d["tool_calls"] = tool_calls
|
||||
d["tool_calls"] = json.loads(msg.tool_calls)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
result.append(d)
|
||||
|
||||
@@ -1,122 +0,0 @@
|
||||
"""Unit tests for the shared python backend helpers (python_utils.py).
|
||||
|
||||
Run standalone (Python standard library only, no backend venv needed):
|
||||
python3 -m unittest python_utils_test
|
||||
|
||||
These mirror the server-less helper tests in backend/python/mlx/test.py
|
||||
(TestSharedHelpers), but live here so they run on any platform: the mlx
|
||||
test module imports grpc/backend_pb2 at import time and needs the MLX venv,
|
||||
whereas python_utils has no third-party dependency. Proto Message objects
|
||||
are faked with types.SimpleNamespace (real proto fields default to "").
|
||||
"""
|
||||
|
||||
import json
|
||||
import types
|
||||
import unittest
|
||||
|
||||
from python_utils import messages_to_dicts, parse_options
|
||||
|
||||
|
||||
def _msg(**fields):
|
||||
"""Fake a proto Message: every unset field is the empty string, as protobuf."""
|
||||
defaults = {
|
||||
"role": "",
|
||||
"content": "",
|
||||
"name": "",
|
||||
"tool_call_id": "",
|
||||
"reasoning_content": "",
|
||||
"tool_calls": "",
|
||||
}
|
||||
defaults.update(fields)
|
||||
return types.SimpleNamespace(**defaults)
|
||||
|
||||
|
||||
class TestParseOptions(unittest.TestCase):
|
||||
def test_type_inference(self):
|
||||
opts = parse_options(
|
||||
["temperature:0.7", "max_tokens:128", "trust:true", "name:hello", "no_colon_skipped"]
|
||||
)
|
||||
self.assertEqual(opts["temperature"], 0.7)
|
||||
self.assertEqual(opts["max_tokens"], 128)
|
||||
self.assertIs(opts["trust"], True)
|
||||
self.assertEqual(opts["name"], "hello")
|
||||
self.assertNotIn("no_colon_skipped", opts)
|
||||
|
||||
|
||||
class TestMessagesToDicts(unittest.TestCase):
|
||||
def test_basic_fields(self):
|
||||
out = messages_to_dicts(
|
||||
[
|
||||
_msg(role="user", content="hi"),
|
||||
_msg(role="tool", content="42", tool_call_id="call_1", name="f"),
|
||||
]
|
||||
)
|
||||
self.assertEqual(out[0], {"role": "user", "content": "hi"})
|
||||
self.assertEqual(out[1]["tool_call_id"], "call_1")
|
||||
self.assertEqual(out[1]["name"], "f")
|
||||
|
||||
def test_tool_call_arguments_string_decoded_to_mapping(self):
|
||||
# OpenAI wire format ships function.arguments as a JSON *string*; chat
|
||||
# templates iterate it as a mapping, so it must come back as a dict.
|
||||
out = messages_to_dicts(
|
||||
[
|
||||
_msg(
|
||||
role="assistant",
|
||||
tool_calls=json.dumps(
|
||||
[
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "Rome"}',
|
||||
},
|
||||
}
|
||||
]
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
args = out[0]["tool_calls"][0]["function"]["arguments"]
|
||||
self.assertEqual(args, {"location": "Rome"})
|
||||
self.assertEqual(dict(args.items()), {"location": "Rome"})
|
||||
|
||||
def test_tool_call_arguments_already_mapping_is_idempotent(self):
|
||||
out = messages_to_dicts(
|
||||
[
|
||||
_msg(
|
||||
role="assistant",
|
||||
tool_calls=json.dumps(
|
||||
[{"function": {"name": "f", "arguments": {"a": 1}}}]
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.assertEqual(out[0]["tool_calls"][0]["function"]["arguments"], {"a": 1})
|
||||
|
||||
def test_tool_call_arguments_invalid_json_left_as_string(self):
|
||||
out = messages_to_dicts(
|
||||
[
|
||||
_msg(
|
||||
role="assistant",
|
||||
tool_calls=json.dumps(
|
||||
[{"function": {"name": "f", "arguments": "not-json"}}]
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.assertEqual(out[0]["tool_calls"][0]["function"]["arguments"], "not-json")
|
||||
|
||||
def test_tool_call_without_function_key(self):
|
||||
out = messages_to_dicts(
|
||||
[_msg(role="assistant", tool_calls=json.dumps([{"id": "call_1"}]))]
|
||||
)
|
||||
self.assertEqual(out[0]["tool_calls"], [{"id": "call_1"}])
|
||||
|
||||
def test_tool_calls_invalid_json_dropped(self):
|
||||
out = messages_to_dicts([_msg(role="assistant", tool_calls="{not json")])
|
||||
self.assertNotIn("tool_calls", out[0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -748,12 +748,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
# When (A) native streaming ran cleanly, per-delta yields above already
|
||||
# delivered everything — do NOT extract again on the full text or we'd
|
||||
# duplicate content/tool_calls into the final chunk.
|
||||
# NOTE: `native_streaming` is a capability flag ("streaming parser is
|
||||
# available"), not a state flag ("streaming actually ran"). For
|
||||
# non-streaming requests it is still True but the per-delta loop was
|
||||
# never entered, so we MUST still run extract_tool_calls here. Hence
|
||||
# the explicit `streaming and …` guard on both branches.
|
||||
if has_tool_parser and not (streaming and native_streaming and not native_streaming_error):
|
||||
if has_tool_parser and not (native_streaming and not native_streaming_error):
|
||||
try:
|
||||
tp = tp_instance
|
||||
if tp is None:
|
||||
@@ -775,7 +770,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
))
|
||||
except Exception as e:
|
||||
print(f"Tool parser error: {e}", file=sys.stderr)
|
||||
elif streaming and native_streaming and not native_streaming_error:
|
||||
elif native_streaming and not native_streaming_error:
|
||||
# Per-delta path already emitted content + tool_calls; the final
|
||||
# chat_delta should carry only metadata (token counts, logprobs).
|
||||
content = ""
|
||||
|
||||
@@ -35,21 +35,6 @@ if [ "x${BUILD_PROFILE}" == "xcpu" ]; then
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --index-strategy=unsafe-best-match"
|
||||
fi
|
||||
|
||||
# AMD ROCm: vLLM ships prebuilt ROCm wheels, but on a DEDICATED index
|
||||
# (https://wheels.vllm.ai/rocm/), NOT PyPI, and ONLY for CPython 3.12. On any
|
||||
# other Python the installer silently falls back to the CUDA-only PyPI wheel,
|
||||
# which is unusable on an AMD GPU (import fails, so the backend never finds the
|
||||
# vllm module). Force Python 3.12 before the venv is created (matches the
|
||||
# intel/l4t13 cp312 bump); the hipblas branch below pulls vllm from the ROCm
|
||||
# wheel index. unsafe-best-match lets uv consult that index and PyPI together.
|
||||
# https://docs.vllm.ai/en/latest/getting_started/installation/gpu.html?device=rocm
|
||||
if [ "x${BUILD_TYPE}" == "xhipblas" ]; then
|
||||
PYTHON_VERSION="3.12"
|
||||
PYTHON_PATCH="12"
|
||||
PY_STANDALONE_TAG="20251120"
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --index-strategy=unsafe-best-match"
|
||||
fi
|
||||
|
||||
# cublas13 pulls the vLLM wheel from a per-tag cu130 index (PyPI's vllm wheel
|
||||
# is built against CUDA 12 and won't load on cu130). uv's default per-package
|
||||
# first-match strategy would still pick the PyPI wheel, so allow it to consult
|
||||
@@ -119,7 +104,7 @@ if [ "$(uname -s)" = "Darwin" ]; then
|
||||
# can rewrite it. Darwin therefore follows vllm-metal and can lag the Linux
|
||||
# vllm pin (requirements-cublas13-after.txt, bumped independently against
|
||||
# vllm/vllm) until vllm-metal supports a newer vLLM.
|
||||
VLLM_METAL_VERSION="v0.3.0.dev20260701212152"
|
||||
VLLM_METAL_VERSION="v0.3.0.dev20260628073537"
|
||||
|
||||
# The coupled vLLM source version is whatever this vllm-metal release builds
|
||||
# against -- it declares it in its own installer as `vllm_v=`. Derive it from
|
||||
@@ -209,22 +194,6 @@ elif [ "x${BUILD_TYPE}" == "xintel" ]; then
|
||||
export CMAKE_PREFIX_PATH="$(python -c 'import site; print(site.getsitepackages()[0])'):${CMAKE_PREFIX_PATH:-}"
|
||||
VLLM_TARGET_DEVICE=xpu uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} --no-deps .
|
||||
popd
|
||||
# AMD ROCm: install vllm from its dedicated ROCm wheel index instead of the
|
||||
# CUDA-only PyPI wheel. installRequirements brings the base ROCm
|
||||
# torch/transformers (requirements-hipblas.txt), then we pull vllm (plus the
|
||||
# matching ROCm torch, via --upgrade) from wheels.vllm.ai/rocm. This is the
|
||||
# method upstream prescribes for AMD; the Python-3.12 pin is set above.
|
||||
# There is intentionally no requirements-hipblas-after.txt: a bare `vllm`
|
||||
# there would resolve to the CUDA wheel, and installRequirements never loads
|
||||
# a ${BUILD_TYPE}-after file for hipblas anyway (BUILD_TYPE == BUILD_PROFILE).
|
||||
# https://docs.vllm.ai/en/latest/getting_started/installation/gpu.html?device=rocm
|
||||
elif [ "x${BUILD_TYPE}" == "xhipblas" ]; then
|
||||
installRequirements
|
||||
|
||||
# --upgrade reconciles the base ROCm torch to whatever the vllm ROCm wheel
|
||||
# pins; --extra-index-url adds the ROCm wheel repository on top of PyPI.
|
||||
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} \
|
||||
--extra-index-url https://wheels.vllm.ai/rocm/ --upgrade vllm
|
||||
# FROM_SOURCE=true on a CPU build skips the prebuilt vllm wheel in
|
||||
# requirements-cpu-after.txt and compiles vllm locally against the host's
|
||||
# actual CPU. Not used by default because it takes ~30-40 minutes, but
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
accelerate
|
||||
torch==2.9.1+cpu
|
||||
torch==2.12.0+cpu
|
||||
torchvision
|
||||
torchaudio
|
||||
transformers
|
||||
|
||||
1
backend/python/vllm/requirements-hipblas-after.txt
Normal file
1
backend/python/vllm/requirements-hipblas-after.txt
Normal file
@@ -0,0 +1 @@
|
||||
vllm
|
||||
@@ -356,12 +356,6 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
|
||||
PrefixConfig: prefixCfg,
|
||||
Pressure: pressure,
|
||||
SharedModels: cfg.Distributed.SharedModels,
|
||||
// Cap how long a cold load may hold the per-model advisory lock: the
|
||||
// configured backend.install deadline plus a margin for file staging and
|
||||
// the remote LoadModel. Derived from the install timeout so raising it
|
||||
// (for slow links pulling multi-GB images) widens the ceiling too,
|
||||
// instead of letting the static default cut a legitimately slow load.
|
||||
ModelLoadCeiling: cfg.Distributed.BackendInstallTimeoutOrDefault() + 10*time.Minute,
|
||||
})
|
||||
|
||||
// Wire staging-progress broadcasting so file-staging shows up on every
|
||||
|
||||
@@ -369,7 +369,7 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
}
|
||||
|
||||
for _, backend := range options.ExternalBackends {
|
||||
if err := galleryop.InstallExternalBackend(options.Context, options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", "", false, options.RequireBackendIntegrity); err != nil {
|
||||
if err := galleryop.InstallExternalBackend(options.Context, options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", "", options.RequireBackendIntegrity); err != nil {
|
||||
xlog.Error("error installing external backend", "error", err)
|
||||
}
|
||||
}
|
||||
@@ -473,13 +473,20 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
|
||||
if options.LoadToMemory != nil && !options.SingleBackend {
|
||||
for _, m := range options.LoadToMemory {
|
||||
xlog.Debug("Auto loading model into memory from file", "model", m)
|
||||
// Same path as POST /backend/load: a realtime pipeline model expands
|
||||
// to its sub-models, and load failures are recorded as model_load
|
||||
// traces.
|
||||
if _, err := backend.PreloadModelByName(options.Context, application.ModelConfigLoader(), application.ModelLoader(), options, m); err != nil {
|
||||
cfg, err := application.ModelConfigLoader().LoadModelConfigFileByNameDefaultOptions(m, options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
xlog.Debug("Auto loading model into memory from file", "model", m, "file", cfg.Model)
|
||||
|
||||
o := backend.ModelOptions(*cfg, options)
|
||||
|
||||
var backendErr error
|
||||
_, backendErr = application.ModelLoader().Load(o...)
|
||||
if backendErr != nil {
|
||||
return nil, backendErr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -52,22 +52,6 @@ func ModelLoadTraceObserver(appConfig *config.ApplicationConfig) func(model.Back
|
||||
}
|
||||
}
|
||||
|
||||
// PreloadModel warms a model into memory without running any inference, so the
|
||||
// first real request doesn't pay the backend's cold-start load cost. It uses
|
||||
// the same ModelOptions + ml.Load path the modality functions use, so a
|
||||
// subsequent inference call hits the loader cache instead of reloading. Load
|
||||
// failures are recorded and returned; callers that warm models opportunistically
|
||||
// (e.g. realtime session warm-up) typically log and continue, since the lazy
|
||||
// path will retry on first use.
|
||||
func PreloadModel(ctx context.Context, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) error {
|
||||
opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx))
|
||||
if _, err := ml.Load(opts...); err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// recordModelLoadFailure records a backend trace when model loading fails.
|
||||
func recordModelLoadFailure(appConfig *config.ApplicationConfig, modelName, backend string, err error, data map[string]any) {
|
||||
if !appConfig.EnableTracing {
|
||||
|
||||
@@ -1,122 +0,0 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// PreloadModelByName loads the named model into memory so the first request
|
||||
// that uses it pays no cold-start load cost — the inverse of shutting a model
|
||||
// down. If the model is a realtime pipeline (its config declares a `pipeline:`
|
||||
// block), each configured sub-model (VAD, transcription, LLM, TTS,
|
||||
// sound_detection, voice_recognition) is loaded concurrently instead of the
|
||||
// pipeline stub, which has no backend of its own. It returns the model names
|
||||
// actually loaded and a joined error naming each sub-model that failed (nil on
|
||||
// full success); a partial pipeline load reports both the loaded names and the
|
||||
// failures so the caller can surface exactly what is and isn't resident.
|
||||
// Compaction's summary_model is deliberately left cold: it is only invoked off
|
||||
// the response path, so it can stay lazy.
|
||||
func PreloadModelByName(ctx context.Context, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, name string) ([]string, error) {
|
||||
cfg, err := cl.LoadModelConfigFileByNameDefaultOptions(name, appConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stages, err := pipelineStages(cl, &cfg.Pipeline, ml.ModelPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(stages) == 0 {
|
||||
// Not a pipeline: load the model's own backend directly.
|
||||
if err := PreloadModel(ctx, ml, *cfg, appConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []string{cfg.Name}, nil
|
||||
}
|
||||
return PreloadStages(ctx, ml, appConfig, stages)
|
||||
}
|
||||
|
||||
// PreloadStage names one pipeline sub-model to preload and the resolved config
|
||||
// to load it from (nil = stage absent, skipped). Role labels the pipeline slot
|
||||
// in errors and logs.
|
||||
type PreloadStage struct {
|
||||
Role string
|
||||
Cfg *config.ModelConfig
|
||||
}
|
||||
|
||||
// loadStage is PreloadModel behind a seam so PreloadStages can be unit-tested
|
||||
// without spawning real backends.
|
||||
var loadStage = PreloadModel
|
||||
|
||||
// pipelineStages resolves each populated pipeline stage to its concrete model
|
||||
// config, following a single alias hop — the same resolution the realtime
|
||||
// pipeline itself uses. A stage that fails to resolve is a misconfiguration,
|
||||
// so it fails fast rather than being deferred to load. A pipeline with no
|
||||
// stages set returns nil, which callers treat as "not a pipeline".
|
||||
func pipelineStages(cl *config.ModelConfigLoader, p *config.Pipeline, modelPath string) ([]PreloadStage, error) {
|
||||
voiceRec := ""
|
||||
if p.VoiceRecognition != nil {
|
||||
voiceRec = p.VoiceRecognition.Model
|
||||
}
|
||||
var stages []PreloadStage
|
||||
for _, s := range []struct{ role, name string }{
|
||||
{"vad", p.VAD},
|
||||
{"transcription", p.Transcription},
|
||||
{"llm", p.LLM},
|
||||
{"tts", p.TTS},
|
||||
{"sound_detection", p.SoundDetection},
|
||||
{"voice_recognition", voiceRec},
|
||||
} {
|
||||
if s.name == "" {
|
||||
continue
|
||||
}
|
||||
cfg, err := cl.LoadResolvedModelConfig(s.name, modelPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s (%s): %w", s.role, s.name, err)
|
||||
}
|
||||
stages = append(stages, PreloadStage{Role: s.role, Cfg: cfg})
|
||||
}
|
||||
return stages, nil
|
||||
}
|
||||
|
||||
// PreloadStages loads every present stage at once and waits for all of them, so
|
||||
// a pipeline warms in the time of its slowest stage rather than the sum. Absent
|
||||
// (nil-config) stages are skipped. A failed stage does not cancel the others —
|
||||
// they all run to completion so the joined error names every broken stage at
|
||||
// once, alongside the names that did load.
|
||||
func PreloadStages(ctx context.Context, ml *model.ModelLoader, appConfig *config.ApplicationConfig, stages []PreloadStage) ([]string, error) {
|
||||
var (
|
||||
wg sync.WaitGroup
|
||||
mu sync.Mutex
|
||||
loaded []string
|
||||
errs []error
|
||||
)
|
||||
for _, s := range stages {
|
||||
if s.Cfg == nil {
|
||||
continue
|
||||
}
|
||||
wg.Add(1)
|
||||
go func(s PreloadStage) {
|
||||
defer wg.Done()
|
||||
if err := loadStage(ctx, ml, *s.Cfg, appConfig); err != nil {
|
||||
xlog.Warn("preload: failed to load pipeline sub-model", "stage", s.Role, "model", s.Cfg.Name, "error", err)
|
||||
mu.Lock()
|
||||
errs = append(errs, fmt.Errorf("%s (%s): %w", s.Role, s.Cfg.Name, err))
|
||||
mu.Unlock()
|
||||
return
|
||||
}
|
||||
xlog.Debug("preload: loaded pipeline sub-model", "stage", s.Role, "model", s.Cfg.Name)
|
||||
mu.Lock()
|
||||
loaded = append(loaded, s.Cfg.Name)
|
||||
mu.Unlock()
|
||||
}(s)
|
||||
}
|
||||
wg.Wait()
|
||||
return loaded, errors.Join(errs...)
|
||||
}
|
||||
@@ -1,146 +0,0 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("pipelineStages", func() {
|
||||
seed := func(dir string, names ...string) *config.ModelConfigLoader {
|
||||
for _, n := range names {
|
||||
yaml := "name: " + n + "\nbackend: fake-backend\n"
|
||||
Expect(os.WriteFile(filepath.Join(dir, n+".yaml"), []byte(yaml), 0o644)).To(Succeed())
|
||||
}
|
||||
cl := config.NewModelConfigLoader(dir)
|
||||
Expect(cl.LoadModelConfigsFromPath(dir)).To(Succeed())
|
||||
return cl
|
||||
}
|
||||
|
||||
It("resolves only the populated stages, in load order", func() {
|
||||
dir := GinkgoT().TempDir()
|
||||
cl := seed(dir, "vad-m", "stt-m", "llm-m", "tts-m")
|
||||
|
||||
stages, err := pipelineStages(cl, &config.Pipeline{
|
||||
VAD: "vad-m",
|
||||
Transcription: "stt-m",
|
||||
LLM: "llm-m",
|
||||
TTS: "tts-m",
|
||||
}, dir)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
roles := make([]string, len(stages))
|
||||
names := make([]string, len(stages))
|
||||
for i, s := range stages {
|
||||
roles[i] = s.Role
|
||||
names[i] = s.Cfg.Name
|
||||
}
|
||||
Expect(roles).To(Equal([]string{"vad", "transcription", "llm", "tts"}))
|
||||
Expect(names).To(Equal([]string{"vad-m", "stt-m", "llm-m", "tts-m"}))
|
||||
})
|
||||
|
||||
It("skips unset stages and includes sound_detection and voice_recognition when set", func() {
|
||||
dir := GinkgoT().TempDir()
|
||||
cl := seed(dir, "stt-m", "ced", "spk")
|
||||
|
||||
stages, err := pipelineStages(cl, &config.Pipeline{
|
||||
Transcription: "stt-m",
|
||||
SoundDetection: "ced",
|
||||
VoiceRecognition: &config.PipelineVoiceRecognition{Model: "spk"},
|
||||
}, dir)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
roles := make([]string, len(stages))
|
||||
for i, s := range stages {
|
||||
roles[i] = s.Role
|
||||
}
|
||||
Expect(roles).To(ConsistOf("transcription", "sound_detection", "voice_recognition"))
|
||||
})
|
||||
|
||||
It("returns nil for a pipeline with no stages (not a pipeline)", func() {
|
||||
dir := GinkgoT().TempDir()
|
||||
cl := seed(dir)
|
||||
|
||||
stages, err := pipelineStages(cl, &config.Pipeline{}, dir)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(stages).To(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("PreloadStages", func() {
|
||||
var (
|
||||
mu sync.Mutex
|
||||
seen []string
|
||||
)
|
||||
|
||||
// stubLoader swaps the loadStage seam for a recorder so no real backends
|
||||
// are spawned; errFor injects per-model failures.
|
||||
stubLoader := func(errFor map[string]error) {
|
||||
loadStage = func(_ context.Context, _ *model.ModelLoader, cfg config.ModelConfig, _ *config.ApplicationConfig) error {
|
||||
mu.Lock()
|
||||
seen = append(seen, cfg.Name)
|
||||
mu.Unlock()
|
||||
return errFor[cfg.Name]
|
||||
}
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
seen = nil
|
||||
})
|
||||
AfterEach(func() {
|
||||
loadStage = PreloadModel
|
||||
})
|
||||
|
||||
mkStage := func(role, name string) PreloadStage {
|
||||
return PreloadStage{Role: role, Cfg: &config.ModelConfig{Name: name}}
|
||||
}
|
||||
|
||||
It("loads every present stage, skips absent (nil-config) ones, and returns the loaded names", func() {
|
||||
stubLoader(nil)
|
||||
|
||||
loaded, err := PreloadStages(context.Background(), nil, nil, []PreloadStage{
|
||||
mkStage("vad", "vad-m"),
|
||||
{Role: "transcription"}, // absent stage
|
||||
mkStage("llm", "llm-m"),
|
||||
})
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(loaded).To(ConsistOf("vad-m", "llm-m"))
|
||||
// Barrier: every stage has run by the time PreloadStages returns, so
|
||||
// reading seen without the lock here is safe.
|
||||
Expect(seen).To(ConsistOf("vad-m", "llm-m"))
|
||||
})
|
||||
|
||||
It("reports a joined error naming each failed stage while still loading the rest", func() {
|
||||
stubLoader(map[string]error{
|
||||
"vad-m": errors.New("vad boom"),
|
||||
"tts-m": errors.New("tts boom"),
|
||||
})
|
||||
|
||||
loaded, err := PreloadStages(context.Background(), nil, nil, []PreloadStage{
|
||||
mkStage("vad", "vad-m"),
|
||||
mkStage("llm", "llm-m"),
|
||||
mkStage("tts", "tts-m"),
|
||||
})
|
||||
|
||||
// Every stage ran (a failure does not cancel the others)...
|
||||
Expect(seen).To(ConsistOf("vad-m", "llm-m", "tts-m"))
|
||||
// ...the stage that loaded fine is reported as loaded...
|
||||
Expect(loaded).To(ConsistOf("llm-m"))
|
||||
// ...and the joined error names every broken stage and its cause.
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("vad (vad-m)"))
|
||||
Expect(err.Error()).To(ContainSubstring("vad boom"))
|
||||
Expect(err.Error()).To(ContainSubstring("tts (tts-m)"))
|
||||
Expect(err.Error()).To(ContainSubstring("tts boom"))
|
||||
Expect(err.Error()).ToNot(ContainSubstring("llm"))
|
||||
})
|
||||
})
|
||||
@@ -127,7 +127,7 @@ func (bi *BackendsInstall) Run(ctx *cliContext.Context) error {
|
||||
}
|
||||
|
||||
modelLoader := model.NewModelLoader(systemState)
|
||||
err = galleryop.InstallExternalBackend(context.Background(), galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias, false, bi.RequireBackendIntegrity)
|
||||
err = galleryop.InstallExternalBackend(context.Background(), galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias, bi.RequireBackendIntegrity)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -67,6 +67,16 @@ func guessGGUFFromFile(cfg *ModelConfig, f *gguf.GGUFFile, defaultCtx int) {
|
||||
ApplyMTPDefaults(cfg, n)
|
||||
}
|
||||
|
||||
// Sliding-window-attention models (Gemma 2/3, Cohere2, Llama 4, ...) ship
|
||||
// with a reduced SWA KV cache by default, which cannot reuse a prompt
|
||||
// prefix across requests and so defeats the cross-request prefix cache
|
||||
// (cache_reuse) we enable in serving_defaults.go. Enable the full SWA cache
|
||||
// for these models so the prefix survives; skipped for dense models and
|
||||
// when the user already pinned an SWA cache option.
|
||||
if w, ok := HasSlidingWindowAttention(f); ok {
|
||||
ApplySWAFullDefault(cfg, w)
|
||||
}
|
||||
|
||||
// Thinking support detection is done after model load via DetectThinkingSupportFromBackend
|
||||
|
||||
// template estimations
|
||||
|
||||
@@ -599,13 +599,6 @@ func DefaultRegistry() map[string]FieldMetaOverride {
|
||||
Component: "toggle",
|
||||
Order: 89,
|
||||
},
|
||||
"pipeline.disable_warmup": {
|
||||
Section: "pipeline",
|
||||
Label: "Disable Warmup",
|
||||
Description: "Turn off eager pre-loading of the pipeline's sub-models at realtime session start. By default LocalAI loads every configured sub-model backend (VAD, transcription, LLM, TTS, sound detection, voice recognition) before the session starts and blocks until they are ready, so the first turn pays no cold-start cost and a model that fails to load is reported at session start instead of mid-call. Enable this to restore the lazy 'load on first use' behavior — session start no longer waits on loading and load errors surface on the first turn instead. Useful to keep idle sessions from holding model memory they may never use.",
|
||||
Component: "toggle",
|
||||
Order: 90,
|
||||
},
|
||||
|
||||
// --- Functions ---
|
||||
"function.grammar.parallel_calls": {
|
||||
|
||||
@@ -1,197 +0,0 @@
|
||||
package config
|
||||
|
||||
// This file is the single source of truth for deriving a model's user-facing
|
||||
// capabilities and input/output modalities from its ModelConfig. Both the
|
||||
// OpenAI-compatible /v1/models/capabilities endpoint and the Ollama-compatible
|
||||
// /api/tags|/api/show surface consume these, so the vocabulary stays consistent
|
||||
// across clients. Keep the detection heuristics here rather than duplicating
|
||||
// them per endpoint.
|
||||
|
||||
// VisionSupported reports whether the model can accept image inputs.
|
||||
//
|
||||
// We deliberately avoid HasUsecases(FLAG_VISION): GuessUsecases has no
|
||||
// FLAG_VISION branch and reports true for any chat model, so it would paint
|
||||
// vision onto text-only models. Instead we look for explicit signals: the
|
||||
// declared KnownUsecases bit, a multimodal projector, or a template/backend
|
||||
// multimodal marker.
|
||||
func (c *ModelConfig) VisionSupported() bool {
|
||||
if c.KnownUsecases != nil && (*c.KnownUsecases&FLAG_VISION) == FLAG_VISION {
|
||||
return true
|
||||
}
|
||||
if c.MMProj != "" {
|
||||
return true
|
||||
}
|
||||
if c.TemplateConfig.Multimodal != "" {
|
||||
return true
|
||||
}
|
||||
if c.MediaMarker != "" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ToolSupported reports whether the model is wired up for tool / function
|
||||
// calling. We look for any of the explicit knobs LocalAI uses to drive
|
||||
// function-call extraction (regex match, response regex, grammar triggers, XML
|
||||
// format) or the auto-detected tool-format markers the llama.cpp backend
|
||||
// populates during model load.
|
||||
func (c *ModelConfig) ToolSupported() bool {
|
||||
fc := c.FunctionsConfig
|
||||
if fc.ToolFormatMarkers != nil && fc.ToolFormatMarkers.FormatType != "" {
|
||||
return true
|
||||
}
|
||||
if len(fc.JSONRegexMatch) > 0 || len(fc.ResponseRegex) > 0 {
|
||||
return true
|
||||
}
|
||||
if fc.XMLFormatPreset != "" || fc.XMLFormat != nil {
|
||||
return true
|
||||
}
|
||||
if len(fc.GrammarConfig.GrammarTriggers) > 0 || fc.GrammarConfig.SchemaType != "" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ThinkingSupported reports whether the model has reasoning / thinking enabled.
|
||||
// LocalAI sets DisableReasoning=false (or leaves thinking markers configured)
|
||||
// when the backend probe reports that the model supports thinking.
|
||||
func (c *ModelConfig) ThinkingSupported() bool {
|
||||
rc := c.ReasoningConfig
|
||||
if rc.DisableReasoning != nil && !*rc.DisableReasoning {
|
||||
return true
|
||||
}
|
||||
if len(rc.ThinkingStartTokens) > 0 || len(rc.TagPairs) > 0 {
|
||||
// Explicit thinking markers imply support unless explicitly disabled.
|
||||
return rc.DisableReasoning == nil || !*rc.DisableReasoning
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// AudioInputSupported reports whether a chat/generation model accepts audio as
|
||||
// input (e.g. vLLM omni models). The signal is the vLLM per-prompt audio limit;
|
||||
// there is no FLAG_* for "chat model that hears audio", which is exactly why a
|
||||
// plain usecase list can't express it. Transcription models are handled
|
||||
// separately in InputModalities via FLAG_TRANSCRIPT.
|
||||
func (c *ModelConfig) AudioInputSupported() bool {
|
||||
return c.LimitMMPerPrompt.LimitAudioPerPrompt > 0
|
||||
}
|
||||
|
||||
// VideoInputSupported reports whether a chat/generation model accepts video as
|
||||
// input. The signal is the vLLM per-prompt video limit. Note this is distinct
|
||||
// from FLAG_VIDEO, which denotes video *generation* (diffusers) — an output
|
||||
// modality, not an input one.
|
||||
func (c *ModelConfig) VideoInputSupported() bool {
|
||||
return c.LimitMMPerPrompt.LimitVideoPerPrompt > 0
|
||||
}
|
||||
|
||||
// Capabilities returns the ordered list of capability strings the model
|
||||
// supports, using the canonical usecase vocabulary (chat, vision, transcript,
|
||||
// tts, embeddings, image, video, ...) plus the modifier capabilities "tools"
|
||||
// and "thinking". Vision is resolved via VisionSupported (not HasUsecases) to
|
||||
// avoid the guess-heuristic false positive.
|
||||
func (c *ModelConfig) Capabilities() []string {
|
||||
chat := c.HasUsecases(FLAG_CHAT)
|
||||
completion := c.HasUsecases(FLAG_COMPLETION)
|
||||
|
||||
var caps []string
|
||||
add := func(cond bool, name string) {
|
||||
if cond {
|
||||
caps = append(caps, name)
|
||||
}
|
||||
}
|
||||
|
||||
add(chat, UsecaseChat)
|
||||
add(completion, UsecaseCompletion)
|
||||
add(c.HasUsecases(FLAG_EDIT), UsecaseEdit)
|
||||
add(c.HasUsecases(FLAG_EMBEDDINGS), UsecaseEmbeddings)
|
||||
add(c.HasUsecases(FLAG_RERANK), UsecaseRerank)
|
||||
// Vision is only meaningful as an image-understanding modifier on a chat/
|
||||
// completion model. Gating on (chat||completion) matches the Ollama surface
|
||||
// and avoids a false positive when config defaults hydrate a MediaMarker on
|
||||
// a non-chat model (e.g. a pure ASR/TTS backend).
|
||||
add((chat || completion) && c.VisionSupported(), UsecaseVision)
|
||||
// tools/thinking are modifiers on the chat/completion surface.
|
||||
add((chat || completion) && c.ToolSupported(), "tools")
|
||||
add((chat || completion) && c.ThinkingSupported(), "thinking")
|
||||
add(c.HasUsecases(FLAG_TRANSCRIPT), UsecaseTranscript)
|
||||
add(c.HasUsecases(FLAG_TTS), UsecaseTTS)
|
||||
add(c.HasUsecases(FLAG_SOUND_GENERATION), UsecaseSoundGeneration)
|
||||
add(c.HasUsecases(FLAG_IMAGE), UsecaseImage)
|
||||
add(c.HasUsecases(FLAG_VIDEO), UsecaseVideo)
|
||||
add(c.HasUsecases(FLAG_VAD), UsecaseVAD)
|
||||
add(c.HasUsecases(FLAG_DETECTION), UsecaseDetection)
|
||||
add(c.HasUsecases(FLAG_DEPTH), UsecaseDepth)
|
||||
add(c.HasUsecases(FLAG_AUDIO_TRANSFORM), UsecaseAudioTransform)
|
||||
add(c.HasUsecases(FLAG_DIARIZATION), UsecaseDiarization)
|
||||
add(c.HasUsecases(FLAG_SOUND_CLASSIFICATION), UsecaseSoundClassification)
|
||||
add(c.HasUsecases(FLAG_REALTIME_AUDIO), UsecaseRealtimeAudio)
|
||||
add(c.HasUsecases(FLAG_FACE_RECOGNITION), UsecaseFaceRecognition)
|
||||
add(c.HasUsecases(FLAG_SPEAKER_RECOGNITION), UsecaseSpeakerRecognition)
|
||||
return caps
|
||||
}
|
||||
|
||||
// InputModalities returns the set of modalities (text, image, audio, video) the
|
||||
// model accepts as input, ordered text→image→audio→video. This is what an
|
||||
// attachment router consults to decide whether an image/audio/video file can be
|
||||
// handed to the active model directly.
|
||||
func (c *ModelConfig) InputModalities() []string {
|
||||
imageGen := c.HasUsecases(FLAG_IMAGE)
|
||||
videoGen := c.HasUsecases(FLAG_VIDEO)
|
||||
chatish := c.HasUsecases(FLAG_CHAT) || c.HasUsecases(FLAG_COMPLETION)
|
||||
|
||||
textIn := chatish || c.HasUsecases(FLAG_EDIT) ||
|
||||
c.HasUsecases(FLAG_EMBEDDINGS) || c.HasUsecases(FLAG_RERANK) || c.HasUsecases(FLAG_TOKENIZE) ||
|
||||
c.HasUsecases(FLAG_TTS) || c.HasUsecases(FLAG_SOUND_GENERATION) || imageGen || videoGen
|
||||
|
||||
// Image input via a chat model requires vision (gated on chat, like the
|
||||
// Ollama surface); detection/depth/face models consume images directly.
|
||||
imageIn := (chatish && c.VisionSupported()) || c.LimitMMPerPrompt.LimitImagePerPrompt > 0 ||
|
||||
c.HasUsecases(FLAG_DETECTION) || c.HasUsecases(FLAG_DEPTH) || c.HasUsecases(FLAG_FACE_RECOGNITION)
|
||||
|
||||
audioIn := c.AudioInputSupported() || c.HasUsecases(FLAG_TRANSCRIPT) || c.HasUsecases(FLAG_AUDIO_TRANSFORM) ||
|
||||
c.HasUsecases(FLAG_REALTIME_AUDIO) || c.HasUsecases(FLAG_VAD) || c.HasUsecases(FLAG_DIARIZATION) ||
|
||||
c.HasUsecases(FLAG_SOUND_CLASSIFICATION) || c.HasUsecases(FLAG_SPEAKER_RECOGNITION)
|
||||
|
||||
videoIn := c.VideoInputSupported()
|
||||
|
||||
var mods []string
|
||||
if textIn {
|
||||
mods = append(mods, "text")
|
||||
}
|
||||
if imageIn {
|
||||
mods = append(mods, "image")
|
||||
}
|
||||
if audioIn {
|
||||
mods = append(mods, "audio")
|
||||
}
|
||||
if videoIn {
|
||||
mods = append(mods, "video")
|
||||
}
|
||||
return mods
|
||||
}
|
||||
|
||||
// OutputModalities returns the set of modalities (text, image, audio, video)
|
||||
// the model produces, ordered text→image→audio→video.
|
||||
func (c *ModelConfig) OutputModalities() []string {
|
||||
textOut := c.HasUsecases(FLAG_CHAT) || c.HasUsecases(FLAG_COMPLETION) || c.HasUsecases(FLAG_EDIT) ||
|
||||
c.HasUsecases(FLAG_TRANSCRIPT)
|
||||
imageOut := c.HasUsecases(FLAG_IMAGE)
|
||||
audioOut := c.HasUsecases(FLAG_TTS) || c.HasUsecases(FLAG_SOUND_GENERATION) ||
|
||||
c.HasUsecases(FLAG_AUDIO_TRANSFORM) || c.HasUsecases(FLAG_REALTIME_AUDIO)
|
||||
videoOut := c.HasUsecases(FLAG_VIDEO)
|
||||
|
||||
var mods []string
|
||||
if textOut {
|
||||
mods = append(mods, "text")
|
||||
}
|
||||
if imageOut {
|
||||
mods = append(mods, "image")
|
||||
}
|
||||
if audioOut {
|
||||
mods = append(mods, "audio")
|
||||
}
|
||||
if videoOut {
|
||||
mods = append(mods, "video")
|
||||
}
|
||||
return mods
|
||||
}
|
||||
@@ -1,103 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func usecaseBits(flags ModelConfigUsecase) *ModelConfigUsecase {
|
||||
return &flags
|
||||
}
|
||||
|
||||
var _ = Describe("Model capabilities derivation", func() {
|
||||
Describe("VisionSupported", func() {
|
||||
It("is false for a plain text chat model", func() {
|
||||
cfg := &ModelConfig{KnownUsecases: usecaseBits(FLAG_CHAT), Backend: "llama.cpp"}
|
||||
Expect(cfg.VisionSupported()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("is true when the FLAG_VISION bit is declared", func() {
|
||||
cfg := &ModelConfig{KnownUsecases: usecaseBits(FLAG_CHAT | FLAG_VISION), Backend: "llama.cpp"}
|
||||
Expect(cfg.VisionSupported()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("is true when an mmproj projector is set", func() {
|
||||
cfg := &ModelConfig{KnownUsecases: usecaseBits(FLAG_CHAT), Backend: "llama.cpp"}
|
||||
cfg.MMProj = "mmproj.gguf" // promoted field from the embedded options struct
|
||||
Expect(cfg.VisionSupported()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("does not fall for the GuessUsecases FLAG_VISION false positive", func() {
|
||||
// A chat model with a chat template would make HasUsecases(FLAG_VISION)
|
||||
// return true via the guess heuristic; VisionSupported must not.
|
||||
cfg := &ModelConfig{Backend: "llama.cpp"}
|
||||
cfg.TemplateConfig.Chat = "{{.Input}}"
|
||||
Expect(cfg.VisionSupported()).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("AudioInputSupported / VideoInputSupported", func() {
|
||||
It("detects vLLM omni audio input via limit_mm_per_prompt", func() {
|
||||
cfg := &ModelConfig{KnownUsecases: usecaseBits(FLAG_CHAT), Backend: "vllm"}
|
||||
cfg.LimitMMPerPrompt.LimitAudioPerPrompt = 1
|
||||
Expect(cfg.AudioInputSupported()).To(BeTrue())
|
||||
Expect(cfg.VideoInputSupported()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("detects vLLM omni video input via limit_mm_per_prompt", func() {
|
||||
cfg := &ModelConfig{KnownUsecases: usecaseBits(FLAG_CHAT), Backend: "vllm"}
|
||||
cfg.LimitMMPerPrompt.LimitVideoPerPrompt = 2
|
||||
Expect(cfg.VideoInputSupported()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Capabilities + modalities", func() {
|
||||
It("a text-only chat model exposes chat and text-only modalities", func() {
|
||||
cfg := &ModelConfig{KnownUsecases: usecaseBits(FLAG_CHAT), Backend: "llama.cpp"}
|
||||
Expect(cfg.Capabilities()).To(ContainElement(UsecaseChat))
|
||||
Expect(cfg.Capabilities()).NotTo(ContainElement(UsecaseVision))
|
||||
Expect(cfg.Capabilities()).NotTo(ContainElement(UsecaseTranscript))
|
||||
Expect(cfg.InputModalities()).To(Equal([]string{"text"}))
|
||||
Expect(cfg.OutputModalities()).To(Equal([]string{"text"}))
|
||||
})
|
||||
|
||||
It("a vision chat model accepts text+image input", func() {
|
||||
cfg := &ModelConfig{KnownUsecases: usecaseBits(FLAG_CHAT | FLAG_VISION), Backend: "llama.cpp"}
|
||||
Expect(cfg.Capabilities()).To(ContainElements(UsecaseChat, UsecaseVision))
|
||||
Expect(cfg.InputModalities()).To(Equal([]string{"text", "image"}))
|
||||
Expect(cfg.OutputModalities()).To(Equal([]string{"text"}))
|
||||
})
|
||||
|
||||
It("an omni chat model accepts text+audio input without an audio capability flag", func() {
|
||||
cfg := &ModelConfig{KnownUsecases: usecaseBits(FLAG_CHAT), Backend: "vllm"}
|
||||
cfg.LimitMMPerPrompt.LimitAudioPerPrompt = 1
|
||||
// audio-in is a modality, not a usecase string — this is exactly the
|
||||
// case a plain capability list cannot express.
|
||||
Expect(cfg.Capabilities()).To(ContainElement(UsecaseChat))
|
||||
Expect(cfg.InputModalities()).To(Equal([]string{"text", "audio"}))
|
||||
})
|
||||
|
||||
It("a transcription model reads audio and writes text", func() {
|
||||
cfg := &ModelConfig{KnownUsecases: usecaseBits(FLAG_TRANSCRIPT), Backend: "parakeet-cpp"}
|
||||
Expect(cfg.Capabilities()).To(Equal([]string{UsecaseTranscript}))
|
||||
Expect(cfg.InputModalities()).To(Equal([]string{"audio"}))
|
||||
Expect(cfg.OutputModalities()).To(Equal([]string{"text"}))
|
||||
})
|
||||
|
||||
It("an image-generation model reads text and writes an image", func() {
|
||||
// stablediffusion-ggml is image-only; plain "stablediffusion" is also
|
||||
// in GuessUsecases' video-backend list, so it would report video too.
|
||||
cfg := &ModelConfig{KnownUsecases: usecaseBits(FLAG_IMAGE), Backend: "stablediffusion-ggml"}
|
||||
Expect(cfg.Capabilities()).To(Equal([]string{UsecaseImage}))
|
||||
Expect(cfg.InputModalities()).To(Equal([]string{"text"}))
|
||||
Expect(cfg.OutputModalities()).To(Equal([]string{"image"}))
|
||||
})
|
||||
|
||||
It("a TTS model reads text and writes audio", func() {
|
||||
cfg := &ModelConfig{KnownUsecases: usecaseBits(FLAG_TTS), Backend: "piper"}
|
||||
Expect(cfg.Capabilities()).To(ContainElement(UsecaseTTS))
|
||||
Expect(cfg.InputModalities()).To(Equal([]string{"text"}))
|
||||
Expect(cfg.OutputModalities()).To(Equal([]string{"audio"}))
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -656,18 +656,6 @@ type Pipeline struct {
|
||||
// to benefit. A client session.update still overrides type and eagerness
|
||||
// per session; retranscribe is server-side only. Unset keeps server_vad.
|
||||
TurnDetection PipelineTurnDetection `yaml:"turn_detection,omitempty" json:"turn_detection,omitempty"`
|
||||
|
||||
// DisableWarmup turns off eager pre-loading of the pipeline's sub-models at
|
||||
// realtime session start. By default (false) LocalAI loads every configured
|
||||
// sub-model backend (VAD, transcription, LLM, TTS, sound detection, voice
|
||||
// recognition) into memory (concurrently) before the
|
||||
// session is announced and blocks until they are ready, so the first turn
|
||||
// pays no cold-start cost and a model that fails to load surfaces as an error
|
||||
// at session start rather than mid-call. Set true to restore the lazy "load
|
||||
// on first use" behavior — session start no longer blocks on loading and
|
||||
// load errors surface on first use instead (e.g. to keep idle sessions from
|
||||
// holding model memory they may never use).
|
||||
DisableWarmup bool `yaml:"disable_warmup,omitempty" json:"disable_warmup,omitempty"`
|
||||
}
|
||||
|
||||
// PipelineCompaction configures summarize-then-drop for a realtime pipeline.
|
||||
|
||||
@@ -155,25 +155,6 @@ func (bcl *ModelConfigLoader) LoadModelConfigFileByNameDefaultOptions(modelName
|
||||
ModelPath(appConfig.SystemState.Model.ModelsPath))
|
||||
}
|
||||
|
||||
// LoadResolvedModelConfig loads a model config by name and follows a single
|
||||
// alias hop, so a caller that references an alias (e.g. a pipeline with
|
||||
// `llm: default`) gets the alias target's full config (Backend, Model, ...)
|
||||
// rather than the alias stub with an empty Backend. Without this the alias
|
||||
// survives unresolved into model loading and fails downstream — notably in
|
||||
// distributed mode with "backend name is empty". Mirrors the top-level alias
|
||||
// resolution in core/http/middleware/request.go.
|
||||
func (bcl *ModelConfigLoader) LoadResolvedModelConfig(modelName, modelPath string) (*ModelConfig, error) {
|
||||
cfg, err := bcl.LoadModelConfigFileByName(modelName, modelPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resolved, _, err := bcl.ResolveAlias(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resolved, nil
|
||||
}
|
||||
|
||||
// This format is currently only used when reading a single file at startup, passed in via ApplicationConfig.ConfigFile
|
||||
func (bcl *ModelConfigLoader) LoadMultipleModelConfigsSingleFile(file string, opts ...ConfigLoaderOption) error {
|
||||
bcl.Lock()
|
||||
|
||||
56
core/config/swa.go
Normal file
56
core/config/swa.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
gguf "github.com/gpustack/gguf-parser-go"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// swaCacheOptionNames lists the backend option keys that control the
|
||||
// sliding-window-attention KV cache. If the user pinned any of these we leave
|
||||
// the SWA cache alone instead of forcing swa_full.
|
||||
var swaCacheOptionNames = []string{"swa_full", "n_swa"}
|
||||
|
||||
// HasSlidingWindowAttention reports whether the parsed GGUF describes a
|
||||
// sliding-window-attention (SWA) model — Gemma 2/3, Cohere2, Llama 4 and the
|
||||
// like. The gguf-parser library normalizes the per-architecture
|
||||
// `<arch>.attention.sliding_window` metadata key into
|
||||
// GGUFArchitecture.AttentionSlidingWindow, applying the same family-specific
|
||||
// rules llama.cpp uses (e.g. Phi-3 carries the key but does not actually run
|
||||
// SWA, and is normalized to 0). A non-zero window means the model interleaves
|
||||
// SWA layers, so the returned size is also the diagnostic value we log.
|
||||
func HasSlidingWindowAttention(f *gguf.GGUFFile) (uint64, bool) {
|
||||
if f == nil {
|
||||
return 0, false
|
||||
}
|
||||
w := f.Architecture().AttentionSlidingWindow
|
||||
return w, w > 0
|
||||
}
|
||||
|
||||
// ApplySWAFullDefault enables the full-size SWA KV cache (swa_full:true) for a
|
||||
// sliding-window model, unless the user already pinned an SWA cache option.
|
||||
//
|
||||
// Why: llama.cpp defaults to a reduced SWA KV cache sized to the sliding window
|
||||
// (memory-light), but that reduced cache cannot preserve a prompt prefix across
|
||||
// requests. So for SWA models the cross-request prefix cache we enable in
|
||||
// serving_defaults.go (cache_reuse) is silently defeated — every turn
|
||||
// reprocesses the entire prompt. Setting swa_full:true makes llama.cpp keep the
|
||||
// full KV cache so the shared prefix is actually reused.
|
||||
//
|
||||
// The tradeoff is memory: the full SWA cache scales with context_size, so this
|
||||
// is gated to models that are genuinely SWA (never applied to dense models,
|
||||
// where it would only waste memory) and never overrides an explicit user
|
||||
// choice. `slidingWindow` is the value read from the GGUF and is used only for
|
||||
// the diagnostic log line.
|
||||
func ApplySWAFullDefault(cfg *ModelConfig, slidingWindow uint64) {
|
||||
if cfg == nil || slidingWindow == 0 {
|
||||
return
|
||||
}
|
||||
if backendOptionSet(cfg.Options, swaCacheOptionNames...) {
|
||||
xlog.Debug("[swa] sliding-window model but an SWA cache option is already set; leaving user choice intact",
|
||||
"name", cfg.Name, "sliding_window", slidingWindow)
|
||||
return
|
||||
}
|
||||
cfg.Options = append(cfg.Options, "swa_full:true")
|
||||
xlog.Debug("[swa] enabling swa_full for sliding-window model so the cross-request prompt-prefix cache survives (reduced SWA cache cannot reuse a prefix across requests)",
|
||||
"name", cfg.Name, "sliding_window", slidingWindow)
|
||||
}
|
||||
120
core/config/swa_test.go
Normal file
120
core/config/swa_test.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package config_test
|
||||
|
||||
import (
|
||||
. "github.com/mudler/LocalAI/core/config"
|
||||
|
||||
gguf "github.com/gpustack/gguf-parser-go"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// ggufWithSlidingWindow fabricates a minimal in-memory GGUF carrying the given
|
||||
// `general.architecture` and `<arch>.attention.sliding_window` so the SWA
|
||||
// detection can be exercised without a real model file. A window of 0 omits the
|
||||
// key, modelling a dense (non-SWA) model.
|
||||
func ggufWithSlidingWindow(arch string, window uint32) *gguf.GGUFFile {
|
||||
kvs := gguf.GGUFMetadataKVs{
|
||||
{
|
||||
Key: "general.architecture",
|
||||
ValueType: gguf.GGUFMetadataValueTypeString,
|
||||
Value: arch,
|
||||
},
|
||||
}
|
||||
if window > 0 {
|
||||
kvs = append(kvs, gguf.GGUFMetadataKV{
|
||||
Key: arch + ".attention.sliding_window",
|
||||
ValueType: gguf.GGUFMetadataValueTypeUint32,
|
||||
Value: window,
|
||||
})
|
||||
}
|
||||
return &gguf.GGUFFile{
|
||||
Header: gguf.GGUFHeader{MetadataKV: kvs},
|
||||
}
|
||||
}
|
||||
|
||||
var _ = Describe("SWA full-cache auto-default", func() {
|
||||
Context("HasSlidingWindowAttention", func() {
|
||||
It("returns false on a nil GGUF file", func() {
|
||||
w, ok := HasSlidingWindowAttention(nil)
|
||||
Expect(ok).To(BeFalse())
|
||||
Expect(w).To(BeZero())
|
||||
})
|
||||
|
||||
It("detects a sliding-window model (Gemma 3 style)", func() {
|
||||
w, ok := HasSlidingWindowAttention(ggufWithSlidingWindow("gemma3", 1024))
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(w).To(Equal(uint64(1024)))
|
||||
})
|
||||
|
||||
It("detects Gemma 2 even without an explicit key (family default window)", func() {
|
||||
// gguf-parser applies llama.cpp's family rules: gemma2 defaults the
|
||||
// sliding window to 4096 when the metadata key is absent.
|
||||
w, ok := HasSlidingWindowAttention(ggufWithSlidingWindow("gemma2", 0))
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(w).To(Equal(uint64(4096)))
|
||||
})
|
||||
|
||||
It("reports a dense model as non-SWA", func() {
|
||||
w, ok := HasSlidingWindowAttention(ggufWithSlidingWindow("llama", 0))
|
||||
Expect(ok).To(BeFalse())
|
||||
Expect(w).To(BeZero())
|
||||
})
|
||||
|
||||
It("treats Phi-3 as non-SWA even when the key is present", func() {
|
||||
// Phi-3 carries attention.sliding_window but does not actually run
|
||||
// SWA; gguf-parser normalizes it to 0 to match llama.cpp.
|
||||
w, ok := HasSlidingWindowAttention(ggufWithSlidingWindow("phi3", 2048))
|
||||
Expect(ok).To(BeFalse())
|
||||
Expect(w).To(BeZero())
|
||||
})
|
||||
})
|
||||
|
||||
Context("ApplySWAFullDefault", func() {
|
||||
It("enables swa_full for a sliding-window model when unset", func() {
|
||||
cfg := &ModelConfig{Name: "gemma3"}
|
||||
ApplySWAFullDefault(cfg, 1024)
|
||||
Expect(cfg.Options).To(ContainElement("swa_full:true"))
|
||||
})
|
||||
|
||||
It("is a no-op for a dense model (window 0)", func() {
|
||||
cfg := &ModelConfig{Name: "llama"}
|
||||
ApplySWAFullDefault(cfg, 0)
|
||||
Expect(cfg.Options).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("preserves an explicit swa_full:false", func() {
|
||||
cfg := &ModelConfig{Name: "gemma3", Options: []string{"swa_full:false"}}
|
||||
ApplySWAFullDefault(cfg, 1024)
|
||||
Expect(cfg.Options).To(Equal([]string{"swa_full:false"}))
|
||||
})
|
||||
|
||||
It("preserves an explicit swa_full:true without duplicating it", func() {
|
||||
cfg := &ModelConfig{Name: "gemma3", Options: []string{"swa_full:true"}}
|
||||
ApplySWAFullDefault(cfg, 1024)
|
||||
Expect(cfg.Options).To(Equal([]string{"swa_full:true"}))
|
||||
})
|
||||
|
||||
It("respects the n_swa alias", func() {
|
||||
cfg := &ModelConfig{Name: "gemma3", Options: []string{"n_swa:512"}}
|
||||
ApplySWAFullDefault(cfg, 1024)
|
||||
Expect(cfg.Options).To(Equal([]string{"n_swa:512"}))
|
||||
})
|
||||
|
||||
It("preserves unrelated options already on the config", func() {
|
||||
cfg := &ModelConfig{
|
||||
Name: "gemma3",
|
||||
Options: []string{"use_jinja:true", "cache_reuse:256"},
|
||||
}
|
||||
ApplySWAFullDefault(cfg, 1024)
|
||||
Expect(cfg.Options).To(Equal([]string{
|
||||
"use_jinja:true",
|
||||
"cache_reuse:256",
|
||||
"swa_full:true",
|
||||
}))
|
||||
})
|
||||
|
||||
It("tolerates a nil config", func() {
|
||||
Expect(func() { ApplySWAFullDefault(nil, 1024) }).ToNot(Panic())
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -15,35 +15,14 @@ import (
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/mudler/LocalAI/pkg/xsync"
|
||||
"github.com/mudler/xlog"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// validateGalleryConfigURL guards the gallery config fetch against SSRF. A
|
||||
// gallery config URL can be attacker-controlled (e.g. POST /models/apply with
|
||||
// an empty id fetches it directly), so a plain http(s) URL must not be allowed
|
||||
// to reach private, loopback, link-local or cloud-metadata addresses. Other
|
||||
// schemes (huggingface://, github:, oci://, ollama://, file://) resolve to
|
||||
// fixed public services or local files and are not a network-SSRF vector, so
|
||||
// they are left untouched.
|
||||
// See https://github.com/mudler/LocalAI/issues/10665
|
||||
func validateGalleryConfigURL(rawURL string) error {
|
||||
lower := strings.ToLower(strings.TrimSpace(rawURL))
|
||||
if strings.HasPrefix(lower, "http://") || strings.HasPrefix(lower, "https://") {
|
||||
return utils.ValidateExternalURL(rawURL)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetGalleryConfigFromURL[T any](url string, basePath string) (T, error) {
|
||||
var config T
|
||||
if err := validateGalleryConfigURL(url); err != nil {
|
||||
xlog.Error("refusing to fetch gallery config", "error", err, "url", url)
|
||||
return config, err
|
||||
}
|
||||
uri := downloader.URI(url)
|
||||
err := uri.ReadWithCallback(basePath, func(url string, d []byte) error {
|
||||
return yaml.Unmarshal(d, &config)
|
||||
@@ -57,10 +36,6 @@ func GetGalleryConfigFromURL[T any](url string, basePath string) (T, error) {
|
||||
|
||||
func GetGalleryConfigFromURLWithContext[T any](ctx context.Context, url string, basePath string) (T, error) {
|
||||
var config T
|
||||
if err := validateGalleryConfigURL(url); err != nil {
|
||||
xlog.Error("refusing to fetch gallery config", "error", err, "url", url)
|
||||
return config, err
|
||||
}
|
||||
uri := downloader.URI(url)
|
||||
err := uri.ReadWithAuthorizationAndCallback(ctx, basePath, "", func(url string, d []byte) error {
|
||||
return yaml.Unmarshal(d, &config)
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
package gallery_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
||||
. "github.com/mudler/LocalAI/core/gallery"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
@@ -23,49 +19,4 @@ var _ = Describe("Gallery API tests", func() {
|
||||
Expect(e.Name).To(Equal("gpt4all-j"))
|
||||
})
|
||||
})
|
||||
|
||||
// SSRF guard: a user-supplied gallery config URL (e.g. POST /models/apply
|
||||
// with an empty id) must not be able to reach internal network addresses.
|
||||
// See https://github.com/mudler/LocalAI/issues/10665
|
||||
Context("SSRF protection on config URLs", func() {
|
||||
var server *httptest.Server
|
||||
|
||||
BeforeEach(func() {
|
||||
// A reachable internal server that would happily serve a valid
|
||||
// gallery config. Without the SSRF guard the fetch succeeds; the
|
||||
// guard must block it before the request ever leaves the process.
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("name: internal-ssrf\nfiles: []\n"))
|
||||
}))
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
server.Close()
|
||||
})
|
||||
|
||||
It("blocks fetching a config from a loopback address", func() {
|
||||
_, err := GetGalleryConfigFromURL[ModelConfig](server.URL, "")
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("not allowed"))
|
||||
})
|
||||
|
||||
It("blocks fetching a config from a loopback address (context variant)", func() {
|
||||
_, err := GetGalleryConfigFromURLWithContext[ModelConfig](context.Background(), server.URL, "")
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("not allowed"))
|
||||
})
|
||||
|
||||
It("blocks well-known internal hostnames and metadata endpoints", func() {
|
||||
for _, u := range []string{
|
||||
"http://localhost/secret",
|
||||
"http://10.0.0.1/config.yaml",
|
||||
"http://192.168.1.1/config.yaml",
|
||||
"http://169.254.169.254/latest/meta-data/",
|
||||
} {
|
||||
_, err := GetGalleryConfigFromURL[ModelConfig](u, "")
|
||||
Expect(err).To(HaveOccurred(), "expected %s to be rejected", u)
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -65,10 +65,6 @@ type BackendEndpointService struct {
|
||||
|
||||
type GalleryBackend struct {
|
||||
ID string `json:"id"`
|
||||
// Force reinstalls the backend even when it is already installed and
|
||||
// runnable. Off by default so apply stays idempotent for supervising
|
||||
// apps that ensure their backend on every boot.
|
||||
Force bool `json:"force"`
|
||||
}
|
||||
|
||||
func CreateBackendEndpointService(galleries []config.Gallery, systemState *system.SystemState, backendApplier *galleryop.GalleryService, upgradeChecker UpgradeInfoProvider) BackendEndpointService {
|
||||
@@ -107,9 +103,7 @@ func (mgs *BackendEndpointService) GetAllStatusEndpoint() echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// ApplyBackendEndpoint installs a new backend to a LocalAI instance. The op is
|
||||
// idempotent: an already-installed, runnable backend is left alone unless the
|
||||
// request sets "force": true (explicit reinstall).
|
||||
// ApplyBackendEndpoint installs a new backend to a LocalAI instance
|
||||
// @Summary Install backends to LocalAI.
|
||||
// @Tags backends
|
||||
// @Param request body GalleryBackend true "query params"
|
||||
@@ -143,7 +137,6 @@ func (mgs *BackendEndpointService) ApplyBackendEndpoint(systemState *system.Syst
|
||||
ID: uuid.String(),
|
||||
GalleryElementName: input.ID,
|
||||
Galleries: mgs.galleries,
|
||||
Force: input.Force,
|
||||
}
|
||||
|
||||
return c.JSON(200, schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", middleware.BaseURL(c), uuid.String())})
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
package localai_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
. "github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// POST /backends/apply must be idempotent by default: supervising apps call it
|
||||
// on every boot to ensure a backend exists, and forcing a reinstall there
|
||||
// re-downloads the whole artifact each time. Reinstall stays available behind
|
||||
// the explicit force flag.
|
||||
var _ = Describe("POST /backends/apply force plumbing", func() {
|
||||
var (
|
||||
app *echo.Echo
|
||||
gs *galleryop.GalleryService
|
||||
tmpDir string
|
||||
received chan galleryop.ManagementOp[gallery.GalleryBackend, any]
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
app = echo.New()
|
||||
|
||||
var err error
|
||||
tmpDir, err = os.MkdirTemp("", "backends-apply-test-*")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
systemState, err := system.GetSystemState(system.WithBackendPath(tmpDir))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
appConfig := &config.ApplicationConfig{SystemState: systemState}
|
||||
|
||||
// The service is deliberately not started: the test reads the op off
|
||||
// the (unbuffered) channel itself.
|
||||
gs = galleryop.NewGalleryService(appConfig, model.NewModelLoader(systemState))
|
||||
svc := CreateBackendEndpointService(nil, systemState, gs, nil)
|
||||
app.POST("/backends/apply", svc.ApplyBackendEndpoint(systemState))
|
||||
|
||||
received = make(chan galleryop.ManagementOp[gallery.GalleryBackend, any], 1)
|
||||
go func() {
|
||||
op := <-gs.BackendGalleryChannel
|
||||
received <- op
|
||||
}()
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
Expect(os.RemoveAll(tmpDir)).To(Succeed())
|
||||
})
|
||||
|
||||
apply := func(body string) *httptest.ResponseRecorder {
|
||||
req := httptest.NewRequest(http.MethodPost, "/backends/apply", strings.NewReader(body))
|
||||
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
return rec
|
||||
}
|
||||
|
||||
It("enqueues a non-forced op by default", func() {
|
||||
rec := apply(`{"id":"llama-cpp"}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var op galleryop.ManagementOp[gallery.GalleryBackend, any]
|
||||
Eventually(received).Should(Receive(&op))
|
||||
Expect(op.GalleryElementName).To(Equal("llama-cpp"))
|
||||
Expect(op.Force).To(BeFalse())
|
||||
})
|
||||
|
||||
It("enqueues a forced op when the request sets force", func() {
|
||||
rec := apply(`{"id":"llama-cpp","force":true}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var op galleryop.ManagementOp[gallery.GalleryBackend, any]
|
||||
Eventually(received).Should(Receive(&op))
|
||||
Expect(op.GalleryElementName).To(Equal("llama-cpp"))
|
||||
Expect(op.Force).To(BeTrue())
|
||||
})
|
||||
})
|
||||
@@ -1,54 +0,0 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// LoadModelEndpoint pre-loads a model into memory by name — the inverse of
|
||||
// /backend/shutdown. For a realtime pipeline model every configured sub-model
|
||||
// (VAD, transcription, LLM, TTS, sound_detection, voice_recognition) is loaded; for a regular
|
||||
// model its own backend is loaded. The call blocks until loading finishes so
|
||||
// clients can drive warm-up explicitly and learn up front whether a model
|
||||
// fails to load.
|
||||
// @Summary Pre-load a model into memory
|
||||
// @Description Loads the named model (or, for a realtime pipeline, all of its sub-models) into memory so subsequent requests pay no cold-start cost. The inverse of /backend/shutdown.
|
||||
// @Tags monitoring
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body schema.ModelLoadRequest true "Model to load"
|
||||
// @Success 200 {object} schema.ModelLoadResponse "Model loaded"
|
||||
// @Failure 400 {object} schema.ModelLoadResponse "Missing model name"
|
||||
// @Failure 500 {object} schema.ModelLoadResponse "Load failed (Loaded lists any sub-models that did load)"
|
||||
// @Router /backend/load [post]
|
||||
func LoadModelEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input := new(schema.ModelLoadRequest)
|
||||
if err := c.Bind(input); err != nil {
|
||||
return err
|
||||
}
|
||||
if input.Model == "" {
|
||||
return c.JSON(http.StatusBadRequest, schema.ModelLoadResponse{Message: "model is required"})
|
||||
}
|
||||
|
||||
loaded, err := backend.PreloadModelByName(c.Request().Context(), cl, ml, appConfig, input.Model)
|
||||
if err != nil {
|
||||
xlog.Error("failed to pre-load model", "model", input.Model, "loaded", loaded, "error", err)
|
||||
return c.JSON(http.StatusInternalServerError, schema.ModelLoadResponse{
|
||||
Loaded: loaded,
|
||||
Message: "failed to load model: " + err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, schema.ModelLoadResponse{
|
||||
Loaded: loaded,
|
||||
Message: "model loaded",
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,102 +0,0 @@
|
||||
package localai_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
. "github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("LoadModelEndpoint (/backend/load)", func() {
|
||||
var (
|
||||
app *echo.Echo
|
||||
tempDir string
|
||||
configLoader *config.ModelConfigLoader
|
||||
modelLoader *model.ModelLoader
|
||||
appConfig *config.ApplicationConfig
|
||||
)
|
||||
|
||||
post := func(body string) *httptest.ResponseRecorder {
|
||||
req := httptest.NewRequest(http.MethodPost, "/backend/load", bytes.NewBufferString(body))
|
||||
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
return rec
|
||||
}
|
||||
|
||||
decode := func(rec *httptest.ResponseRecorder) schema.ModelLoadResponse {
|
||||
var resp schema.ModelLoadResponse
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
return resp
|
||||
}
|
||||
|
||||
writeConfig := func(name, contents string) {
|
||||
Expect(os.WriteFile(filepath.Join(tempDir, name+".yaml"), []byte(contents), 0o600)).To(Succeed())
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
tempDir, err = os.MkdirTemp("", "backend-load-test-*")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
systemState, err := system.GetSystemState(system.WithModelPath(tempDir))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
appConfig = config.NewApplicationConfig(config.WithSystemState(systemState))
|
||||
configLoader = config.NewModelConfigLoader(tempDir)
|
||||
modelLoader = model.NewModelLoader(systemState) // no backends installed
|
||||
|
||||
app = echo.New()
|
||||
app.POST("/backend/load", LoadModelEndpoint(configLoader, modelLoader, appConfig))
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
_ = os.RemoveAll(tempDir)
|
||||
})
|
||||
|
||||
It("rejects a request with no model name", func() {
|
||||
rec := post(`{}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
Expect(decode(rec).Message).To(ContainSubstring("model is required"))
|
||||
})
|
||||
|
||||
It("reports a load failure for a regular model with nothing loaded", func() {
|
||||
writeConfig("solo", "name: solo\n")
|
||||
|
||||
rec := post(`{"model":"solo"}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusInternalServerError))
|
||||
|
||||
resp := decode(rec)
|
||||
Expect(resp.Loaded).To(BeEmpty())
|
||||
Expect(resp.Message).To(ContainSubstring("failed to load model"))
|
||||
})
|
||||
|
||||
It("expands a pipeline model and reports each sub-model that failed to load", func() {
|
||||
writeConfig("voicebot", "name: voicebot\npipeline:\n vad: vad-m\n transcription: stt-m\n llm: llm-m\n tts: tts-m\n")
|
||||
writeConfig("vad-m", "name: vad-m\n")
|
||||
writeConfig("stt-m", "name: stt-m\n")
|
||||
writeConfig("llm-m", "name: llm-m\n")
|
||||
writeConfig("tts-m", "name: tts-m\n")
|
||||
|
||||
rec := post(`{"model":"voicebot"}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusInternalServerError))
|
||||
|
||||
resp := decode(rec)
|
||||
Expect(resp.Message).To(ContainSubstring("failed to load model"))
|
||||
// The pipeline stub itself is never loaded; its sub-models are what the
|
||||
// endpoint tries, so the error names them rather than "voicebot".
|
||||
Expect(resp.Message).To(ContainSubstring("vad-m"))
|
||||
Expect(resp.Message).ToNot(ContainSubstring("voicebot"))
|
||||
})
|
||||
})
|
||||
@@ -51,9 +51,6 @@ func (stubClient) EditModelConfig(_ context.Context, _ string, _ map[string]any)
|
||||
return nil
|
||||
}
|
||||
func (stubClient) ReloadModels(_ context.Context) error { return nil }
|
||||
func (stubClient) LoadModel(_ context.Context, model string) ([]string, error) {
|
||||
return []string{model}, nil
|
||||
}
|
||||
func (stubClient) SetAlias(_ context.Context, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -49,23 +49,62 @@ func modelCapabilities(cfg *config.ModelConfig) []string {
|
||||
return caps
|
||||
}
|
||||
|
||||
// hasVisionSupport reports whether the model can accept image inputs.
|
||||
// The detection heuristic is the canonical config.ModelConfig.VisionSupported —
|
||||
// kept as a thin wrapper here so the Ollama capability mapping reads cleanly.
|
||||
// hasVisionSupport reports whether the model can accept image inputs. We avoid
|
||||
// cfg.HasUsecases(FLAG_VISION) because GuessUsecases has no FLAG_VISION case
|
||||
// and returns true for any chat model — see core/config/model_config.go. Instead
|
||||
// we look for explicit signals: KnownUsecases bit, multimodal projector, or
|
||||
// template/backend-reported multimodal markers.
|
||||
func hasVisionSupport(cfg *config.ModelConfig) bool {
|
||||
return cfg.VisionSupported()
|
||||
if cfg.KnownUsecases != nil && (*cfg.KnownUsecases&config.FLAG_VISION) == config.FLAG_VISION {
|
||||
return true
|
||||
}
|
||||
if cfg.MMProj != "" {
|
||||
return true
|
||||
}
|
||||
if cfg.TemplateConfig.Multimodal != "" {
|
||||
return true
|
||||
}
|
||||
if cfg.MediaMarker != "" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// hasToolSupport reports whether the model is wired up for tool / function
|
||||
// calling. Delegates to the canonical config.ModelConfig.ToolSupported.
|
||||
// hasToolSupport reports whether the model is wired up for tool / function calling.
|
||||
// We look for any of the explicit configuration knobs LocalAI uses to drive
|
||||
// function-call extraction (regex match, response regex, grammar triggers, XML
|
||||
// format) or for the auto-detected tool-format markers populated by the
|
||||
// llama.cpp backend during model load.
|
||||
func hasToolSupport(cfg *config.ModelConfig) bool {
|
||||
return cfg.ToolSupported()
|
||||
fc := cfg.FunctionsConfig
|
||||
if fc.ToolFormatMarkers != nil && fc.ToolFormatMarkers.FormatType != "" {
|
||||
return true
|
||||
}
|
||||
if len(fc.JSONRegexMatch) > 0 || len(fc.ResponseRegex) > 0 {
|
||||
return true
|
||||
}
|
||||
if fc.XMLFormatPreset != "" || fc.XMLFormat != nil {
|
||||
return true
|
||||
}
|
||||
if len(fc.GrammarConfig.GrammarTriggers) > 0 || fc.GrammarConfig.SchemaType != "" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// hasThinkingSupport reports whether the model has reasoning / thinking enabled.
|
||||
// Delegates to the canonical config.ModelConfig.ThinkingSupported.
|
||||
// LocalAI sets DisableReasoning=false (or leaves thinking markers configured)
|
||||
// when the backend probe reports that the model supports thinking.
|
||||
func hasThinkingSupport(cfg *config.ModelConfig) bool {
|
||||
return cfg.ThinkingSupported()
|
||||
rc := cfg.ReasoningConfig
|
||||
if rc.DisableReasoning != nil && !*rc.DisableReasoning {
|
||||
return true
|
||||
}
|
||||
if len(rc.ThinkingStartTokens) > 0 || len(rc.TagPairs) > 0 {
|
||||
// Explicit thinking markers imply support unless explicitly disabled.
|
||||
return rc.DisableReasoning == nil || !*rc.DisableReasoning
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// quantRegex matches GGUF-style quantization suffixes (Q4_K_M, Q8_0, IQ3_XS, F16, ...).
|
||||
|
||||
@@ -21,11 +21,48 @@ func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, ap
|
||||
authDB = db[0]
|
||||
}
|
||||
return func(c echo.Context) error {
|
||||
modelNames, err := listVisibleModelNames(c, bcl, ml, authDB)
|
||||
// If blank, no filter is applied.
|
||||
filter := c.QueryParam("filter")
|
||||
|
||||
// By default, exclude any loose files that are already referenced by a configuration file.
|
||||
var policy galleryop.LooseFilePolicy
|
||||
excludeConfigured := c.QueryParam("excludeConfigured")
|
||||
if excludeConfigured == "" || excludeConfigured == "true" {
|
||||
policy = galleryop.SKIP_IF_CONFIGURED
|
||||
} else {
|
||||
policy = galleryop.ALWAYS_INCLUDE // This replicates current behavior. TODO: give more options to the user?
|
||||
}
|
||||
|
||||
filterFn, err := config.BuildNameFilterFn(filter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
modelNames, err := galleryop.ListModels(bcl, ml, filterFn, policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Filter models by user's allowlist if auth is enabled
|
||||
if authDB != nil {
|
||||
if user := auth.GetUser(c); user != nil && user.Role != auth.RoleAdmin {
|
||||
perm, err := auth.GetCachedUserPermissions(c, authDB, user.ID)
|
||||
if err == nil && perm.AllowedModels.Enabled {
|
||||
allowed := map[string]bool{}
|
||||
for _, m := range perm.AllowedModels.Models {
|
||||
allowed[m] = true
|
||||
}
|
||||
filtered := make([]string, 0, len(modelNames))
|
||||
for _, m := range modelNames {
|
||||
if allowed[m] {
|
||||
filtered = append(filtered, m)
|
||||
}
|
||||
}
|
||||
modelNames = filtered
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Map from a slice of names to a slice of OpenAIModel response objects
|
||||
dataModels := []schema.OpenAIModel{}
|
||||
for _, m := range modelNames {
|
||||
@@ -38,53 +75,3 @@ func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, ap
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// listVisibleModelNames resolves the model names visible to the caller, applying
|
||||
// the same query filters (filter, excludeConfigured) and per-user allowlist as
|
||||
// the OpenAI models listing. Shared by ListModelsEndpoint and
|
||||
// ListModelCapabilitiesEndpoint so both stay consistent.
|
||||
func listVisibleModelNames(c echo.Context, bcl *config.ModelConfigLoader, ml *model.ModelLoader, authDB *gorm.DB) ([]string, error) {
|
||||
// If blank, no filter is applied.
|
||||
filter := c.QueryParam("filter")
|
||||
|
||||
// By default, exclude any loose files that are already referenced by a configuration file.
|
||||
var policy galleryop.LooseFilePolicy
|
||||
excludeConfigured := c.QueryParam("excludeConfigured")
|
||||
if excludeConfigured == "" || excludeConfigured == "true" {
|
||||
policy = galleryop.SKIP_IF_CONFIGURED
|
||||
} else {
|
||||
policy = galleryop.ALWAYS_INCLUDE // This replicates current behavior. TODO: give more options to the user?
|
||||
}
|
||||
|
||||
filterFn, err := config.BuildNameFilterFn(filter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
modelNames, err := galleryop.ListModels(bcl, ml, filterFn, policy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Filter models by user's allowlist if auth is enabled
|
||||
if authDB != nil {
|
||||
if user := auth.GetUser(c); user != nil && user.Role != auth.RoleAdmin {
|
||||
perm, err := auth.GetCachedUserPermissions(c, authDB, user.ID)
|
||||
if err == nil && perm.AllowedModels.Enabled {
|
||||
allowed := map[string]bool{}
|
||||
for _, m := range perm.AllowedModels.Models {
|
||||
allowed[m] = true
|
||||
}
|
||||
filtered := make([]string, 0, len(modelNames))
|
||||
for _, m := range modelNames {
|
||||
if allowed[m] {
|
||||
filtered = append(filtered, m)
|
||||
}
|
||||
}
|
||||
modelNames = filtered
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return modelNames, nil
|
||||
}
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// ListModelCapabilitiesEndpoint is a LocalAI-specific extension of the OpenAI
|
||||
// models listing. It returns the same set of models as /v1/models but enriches
|
||||
// each entry with the capabilities and input/output modalities the model
|
||||
// supports, so clients can decide whether an image/audio/video attachment can be
|
||||
// handed to a given model directly (or must be converted/transcribed first).
|
||||
//
|
||||
// It is purely additive: clients that don't know about it keep using /v1/models
|
||||
// and see no change.
|
||||
// @Summary List available models enriched with capabilities and input/output modalities.
|
||||
// @Tags models
|
||||
// @Success 200 {object} schema.ModelCapabilitiesResponse "Response"
|
||||
// @Router /v1/models/capabilities [get]
|
||||
func ListModelCapabilitiesEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, db ...*gorm.DB) echo.HandlerFunc {
|
||||
var authDB *gorm.DB
|
||||
if len(db) > 0 {
|
||||
authDB = db[0]
|
||||
}
|
||||
return func(c echo.Context) error {
|
||||
modelNames, err := listVisibleModelNames(c, bcl, ml, authDB)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dataModels := []schema.ModelCapabilities{}
|
||||
for _, m := range modelNames {
|
||||
entry := schema.ModelCapabilities{ID: m, Object: "model"}
|
||||
if cfg, ok := bcl.GetModelConfig(m); ok {
|
||||
entry.Capabilities = cfg.Capabilities()
|
||||
entry.InputModalities = cfg.InputModalities()
|
||||
entry.OutputModalities = cfg.OutputModalities()
|
||||
}
|
||||
dataModels = append(dataModels, entry)
|
||||
}
|
||||
|
||||
return c.JSON(200, schema.ModelCapabilitiesResponse{
|
||||
Object: "list",
|
||||
Data: dataModels,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,119 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("ListModelCapabilitiesEndpoint", func() {
|
||||
var (
|
||||
e *echo.Echo
|
||||
tmpDir string
|
||||
bcl *config.ModelConfigLoader
|
||||
ml *model.ModelLoader
|
||||
appConf *config.ApplicationConfig
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
e = echo.New()
|
||||
tmpDir, err = os.MkdirTemp("", "models-caps-test-*")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
st, err := system.GetSystemState(system.WithModelPath(tmpDir))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
ml = model.NewModelLoader(st)
|
||||
bcl = config.NewModelConfigLoader(tmpDir)
|
||||
appConf = config.NewApplicationConfig()
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
_ = os.RemoveAll(tmpDir)
|
||||
})
|
||||
|
||||
writeConfig := func(name, yaml string) {
|
||||
path := filepath.Join(tmpDir, name+".yaml")
|
||||
Expect(os.WriteFile(path, []byte(yaml), 0o644)).To(Succeed())
|
||||
Expect(bcl.ReadModelConfig(path)).To(Succeed())
|
||||
}
|
||||
|
||||
// call exercises the endpoint with auth disabled (no auth DB), which is the
|
||||
// standard deployment path. The per-user allowlist branch is shared verbatim
|
||||
// with ListModelsEndpoint (listVisibleModelNames) and covered there.
|
||||
call := func() schema.ModelCapabilitiesResponse {
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/models/capabilities", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := ListModelCapabilitiesEndpoint(bcl, ml, appConf)
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var resp schema.ModelCapabilitiesResponse
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
return resp
|
||||
}
|
||||
|
||||
entryFor := func(resp schema.ModelCapabilitiesResponse, id string) *schema.ModelCapabilities {
|
||||
for i := range resp.Data {
|
||||
if resp.Data[i].ID == id {
|
||||
return &resp.Data[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
It("returns the list envelope even with no models", func() {
|
||||
resp := call()
|
||||
Expect(resp.Object).To(Equal("list"))
|
||||
})
|
||||
|
||||
It("enriches a vision chat model with capabilities and image input modality", func() {
|
||||
writeConfig("vlm", `
|
||||
name: vlm
|
||||
backend: llama-cpp
|
||||
known_usecases:
|
||||
- FLAG_CHAT
|
||||
- FLAG_VISION
|
||||
template:
|
||||
chat: "{{ .Input }}"
|
||||
parameters:
|
||||
model: qwen2.5-vl-Q4_K_M.gguf
|
||||
`)
|
||||
entry := entryFor(call(), "vlm")
|
||||
Expect(entry).NotTo(BeNil())
|
||||
Expect(entry.Object).To(Equal("model"))
|
||||
Expect(entry.Capabilities).To(ContainElements("chat", "vision"))
|
||||
Expect(entry.InputModalities).To(ContainElements("text", "image"))
|
||||
Expect(entry.OutputModalities).To(ContainElement("text"))
|
||||
})
|
||||
|
||||
It("marks a parakeet model as an audio-in/text-out transcription model", func() {
|
||||
writeConfig("parakeet", `
|
||||
name: parakeet
|
||||
backend: parakeet-cpp
|
||||
known_usecases:
|
||||
- FLAG_TRANSCRIPT
|
||||
parameters:
|
||||
model: parakeet-tdt-0.6b
|
||||
`)
|
||||
entry := entryFor(call(), "parakeet")
|
||||
Expect(entry).NotTo(BeNil())
|
||||
Expect(entry.Capabilities).To(ContainElement("transcript"))
|
||||
Expect(entry.InputModalities).To(Equal([]string{"audio"}))
|
||||
Expect(entry.OutputModalities).To(Equal([]string{"text"}))
|
||||
Expect(entry.Capabilities).NotTo(ContainElement("chat"))
|
||||
})
|
||||
})
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
@@ -267,12 +266,6 @@ type Model interface {
|
||||
// grpcerrors.IsLiveTranscriptionUnsupported.
|
||||
TranscribeLive(ctx context.Context, language string, onEvent func(backend.LiveTranscriptionEvent)) (backend.LiveTranscriptionSession, error)
|
||||
PredictConfig() *config.ModelConfig
|
||||
// Warmup eagerly loads the pipeline's sub-model backends into memory so the
|
||||
// first realtime turn doesn't pay each backend's cold-start load cost. Loads
|
||||
// run concurrently; Warmup blocks until they all finish and returns a joined
|
||||
// error naming every stage that failed to load (nil if all succeeded), so a
|
||||
// caller can surface model-load failures at session start instead of mid-call.
|
||||
Warmup(ctx context.Context) error
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
@@ -590,8 +583,18 @@ func runRealtimeSession(application *application.Application, t Transport, model
|
||||
}
|
||||
session.ModelInterface = m
|
||||
|
||||
// The voice gate is built before the warm-up below so its
|
||||
// speaker-recognition model can warm alongside the pipeline stages.
|
||||
if session.SummaryModel != "" {
|
||||
summaryModelName := session.SummaryModel
|
||||
sid := sessionID
|
||||
session.summarizerFactory = func() (Model, error) {
|
||||
summaryCfg, lerr := application.ModelConfigLoader().LoadModelConfigFileByNameDefaultOptions(summaryModelName, application.ApplicationConfig())
|
||||
if lerr != nil {
|
||||
return nil, fmt.Errorf("load summary model config %q: %w", summaryModelName, lerr)
|
||||
}
|
||||
return newModel(&summaryCfg.Pipeline, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), evaluator, buildRealtimeRoutingContext(application, sid))
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.Pipeline.VoiceGateEnabled() {
|
||||
gate, gerr := newVoiceGate(
|
||||
*cfg.Pipeline.VoiceRecognition,
|
||||
@@ -609,47 +612,6 @@ func runRealtimeSession(application *application.Application, t Transport, model
|
||||
xlog.Info("realtime voice recognition gate enabled", "mode", gate.cfg.Mode, "when", gate.cfg.When)
|
||||
}
|
||||
|
||||
// Warm the pipeline's sub-model backends before announcing the session.
|
||||
// Loads run concurrently but we block here until they all finish, so a model
|
||||
// that fails to load (missing weights, bad backend, OOM) surfaces as an error
|
||||
// at session start rather than stalling — or failing — mid-call on the first
|
||||
// turn (VAD on the first audio chunk, STT at end-of-speech, LLM on the first
|
||||
// reply, TTS on the first spoken output). On success the backends are already
|
||||
// resident, so the first turn pays no cold-start cost. Opt out per pipeline
|
||||
// with `pipeline.disable_warmup: true` to restore lazy load-on-first-use
|
||||
// (errors then surface on first use instead of at session start).
|
||||
if !cfg.Pipeline.DisableWarmup {
|
||||
warmErr := make(chan error, 1)
|
||||
go func() { warmErr <- m.Warmup(context.Background()) }()
|
||||
// The voice-gate model warms concurrently with the pipeline stages: an
|
||||
// enforced gate blocks each utterance on speaker resolution, so its
|
||||
// cold-start would otherwise land on the first turn too. (Compaction's
|
||||
// summary_model stays lazy — it only runs off the response path.)
|
||||
var gateErr error
|
||||
if session.voiceGate != nil {
|
||||
_, gateErr = backend.PreloadStages(context.Background(), application.ModelLoader(), application.ApplicationConfig(), []backend.PreloadStage{
|
||||
{Role: "voice_recognition", Cfg: session.voiceGate.recCfg},
|
||||
})
|
||||
}
|
||||
if err := errors.Join(<-warmErr, gateErr); err != nil {
|
||||
xlog.Error("realtime warmup failed", "model", model, "error", err)
|
||||
sendError(t, "model_load_error", "Failed to load pipeline models: "+err.Error(), "", "")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if session.SummaryModel != "" {
|
||||
summaryModelName := session.SummaryModel
|
||||
sid := sessionID
|
||||
session.summarizerFactory = func() (Model, error) {
|
||||
summaryCfg, lerr := application.ModelConfigLoader().LoadModelConfigFileByNameDefaultOptions(summaryModelName, application.ApplicationConfig())
|
||||
if lerr != nil {
|
||||
return nil, fmt.Errorf("load summary model config %q: %w", summaryModelName, lerr)
|
||||
}
|
||||
return newModel(&summaryCfg.Pipeline, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), evaluator, buildRealtimeRoutingContext(application, sid))
|
||||
}
|
||||
}
|
||||
|
||||
// Store the session and notify the transport (for WebRTC audio track handling)
|
||||
sessionLock.Lock()
|
||||
sessions[sessionID] = session
|
||||
@@ -1163,21 +1125,6 @@ func updateSession(session *Session, update *types.SessionUnion, cl *config.Mode
|
||||
return err
|
||||
}
|
||||
session.ModelInterface = m
|
||||
// A session.update that swaps the model/voice rebuilds the pipeline, so
|
||||
// warm the new backends too (unless opted out) — otherwise the next turn
|
||||
// pays the cold-start load the original session warm-up already avoided.
|
||||
// Unlike session start this stays non-blocking: updateSession runs under
|
||||
// the global sessionLock, so blocking on a multi-second load here would
|
||||
// stall every other session. Load errors are logged (and still surface on
|
||||
// first use); per-stage failures are already warned inside
|
||||
// backend.PreloadStages.
|
||||
if !session.ModelConfig.Pipeline.DisableWarmup {
|
||||
go func() {
|
||||
if err := m.Warmup(context.Background()); err != nil {
|
||||
xlog.Error("realtime warmup failed after session.update", "error", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
if rt.Audio != nil && rt.Audio.Input != nil && rt.Audio.Input.TurnDetectionSet {
|
||||
|
||||
@@ -174,8 +174,6 @@ func (m *fakeModel) TranscribeLive(_ context.Context, _ string, onEvent func(bac
|
||||
|
||||
func (m *fakeModel) PredictConfig() *config.ModelConfig { return m.cfg }
|
||||
|
||||
func (m *fakeModel) Warmup(ctx context.Context) error { return nil }
|
||||
|
||||
// fakeLiveSession records what semantic_vad fed and closed; closeEvents are
|
||||
// replayed through onEvent during Close, mimicking the backend's finalize
|
||||
// flush (trailing delta + Final) landing before Close returns.
|
||||
|
||||
@@ -110,15 +110,6 @@ func (m *transcriptOnlyModel) PredictConfig() *config.ModelConfig {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *transcriptOnlyModel) Warmup(ctx context.Context) error {
|
||||
_, err := backend.PreloadStages(ctx, m.modelLoader, m.appConfig, []backend.PreloadStage{
|
||||
{Role: "vad", Cfg: m.VADConfig},
|
||||
{Role: "transcription", Cfg: m.TranscriptionConfig},
|
||||
{Role: "sound_detection", Cfg: m.SoundDetectionConfig},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *wrappedModel) VAD(ctx context.Context, request *schema.VADRequest) (*schema.VADResponse, error) {
|
||||
return backend.VAD(request, ctx, m.modelLoader, m.appConfig, *m.VADConfig)
|
||||
}
|
||||
@@ -369,17 +360,6 @@ func (m *wrappedModel) PredictConfig() *config.ModelConfig {
|
||||
return m.LLMConfig
|
||||
}
|
||||
|
||||
func (m *wrappedModel) Warmup(ctx context.Context) error {
|
||||
_, err := backend.PreloadStages(ctx, m.modelLoader, m.appConfig, []backend.PreloadStage{
|
||||
{Role: "vad", Cfg: m.VADConfig},
|
||||
{Role: "transcription", Cfg: m.TranscriptionConfig},
|
||||
{Role: "llm", Cfg: m.LLMConfig},
|
||||
{Role: "tts", Cfg: m.TTSConfig},
|
||||
{Role: "sound_detection", Cfg: m.SoundDetectionConfig},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// wavStreamHeaderBytes is the size of the WAV header that backend.ModelTTSStream
|
||||
// emits as its first audio callback; the sample rate lives at byte offset 24.
|
||||
const wavStreamHeaderBytes = 44
|
||||
@@ -460,7 +440,7 @@ func loadSoundDetectionConfig(pipeline *config.Pipeline, cl *config.ModelConfigL
|
||||
if pipeline.SoundDetection == "" {
|
||||
return nil, nil
|
||||
}
|
||||
cfg, err := cl.LoadResolvedModelConfig(pipeline.SoundDetection, ml.ModelPath)
|
||||
cfg, err := loadPipelineSubModel(cl, pipeline.SoundDetection, ml.ModelPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load sound detection config: %w", err)
|
||||
}
|
||||
@@ -471,7 +451,7 @@ func loadSoundDetectionConfig(pipeline *config.Pipeline, cl *config.ModelConfigL
|
||||
}
|
||||
|
||||
func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (Model, *config.ModelConfig, error) {
|
||||
cfgVAD, err := cl.LoadResolvedModelConfig(pipeline.VAD, ml.ModelPath)
|
||||
cfgVAD, err := loadPipelineSubModel(cl, pipeline.VAD, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
return nil, nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||
@@ -481,7 +461,7 @@ func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.ModelConfig
|
||||
return nil, nil, fmt.Errorf("failed to validate config: %w", err)
|
||||
}
|
||||
|
||||
cfgSST, err := cl.LoadResolvedModelConfig(pipeline.Transcription, ml.ModelPath)
|
||||
cfgSST, err := loadPipelineSubModel(cl, pipeline.Transcription, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
return nil, nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||
@@ -570,11 +550,30 @@ func buildRealtimeRoutingContext(a *application.Application, sessionID string) *
|
||||
}
|
||||
}
|
||||
|
||||
// loadPipelineSubModel loads a pipeline sub-model config by name and follows a
|
||||
// single alias hop, so a pipeline that references an alias (e.g. `llm: default`)
|
||||
// gets the alias target's full config (Backend, Model, ...) rather than the
|
||||
// alias stub with an empty Backend. Without this the alias survives unresolved
|
||||
// into model loading and fails downstream — notably in distributed mode with
|
||||
// "backend name is empty". Mirrors the top-level alias resolution in
|
||||
// core/http/middleware/request.go.
|
||||
func loadPipelineSubModel(cl *config.ModelConfigLoader, name, modelPath string) (*config.ModelConfig, error) {
|
||||
cfg, err := cl.LoadModelConfigFileByName(name, modelPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resolved, _, err := cl.ResolveAlias(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resolved, nil
|
||||
}
|
||||
|
||||
// returns and loads either a wrapped model or a model that support audio-to-audio
|
||||
func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, evaluator *templates.Evaluator, routing *RealtimeRoutingContext) (Model, error) {
|
||||
xlog.Debug("Creating new model pipeline model", "pipeline", pipeline)
|
||||
|
||||
cfgVAD, err := cl.LoadResolvedModelConfig(pipeline.VAD, ml.ModelPath)
|
||||
cfgVAD, err := loadPipelineSubModel(cl, pipeline.VAD, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
return nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||
@@ -585,7 +584,7 @@ func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model
|
||||
}
|
||||
|
||||
// TODO: Do we always need a transcription model? It can be disabled. Note that any-to-any instruction following models don't transcribe as such, so if transcription is required it is a separate process
|
||||
cfgSST, err := cl.LoadResolvedModelConfig(pipeline.Transcription, ml.ModelPath)
|
||||
cfgSST, err := loadPipelineSubModel(cl, pipeline.Transcription, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
return nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||
@@ -617,7 +616,7 @@ func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model
|
||||
xlog.Debug("Loading a wrapped model")
|
||||
|
||||
// Otherwise we want to return a wrapped model, which is a "virtual" model that re-uses other models to perform operations
|
||||
cfgLLM, err := cl.LoadResolvedModelConfig(pipeline.LLM, ml.ModelPath)
|
||||
cfgLLM, err := loadPipelineSubModel(cl, pipeline.LLM, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
return nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||
@@ -632,7 +631,7 @@ func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model
|
||||
applyPipelineReasoning(cfgLLM, *pipeline)
|
||||
applyPipelineThinking(cfgLLM, *pipeline)
|
||||
|
||||
cfgTTS, err := cl.LoadResolvedModelConfig(pipeline.TTS, ml.ModelPath)
|
||||
cfgTTS, err := loadPipelineSubModel(cl, pipeline.TTS, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
return nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package config_test
|
||||
package openai
|
||||
|
||||
import (
|
||||
"os"
|
||||
@@ -10,14 +10,14 @@ import (
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
)
|
||||
|
||||
// LoadResolvedModelConfig must resolve a model that references an alias
|
||||
// (e.g. a pipeline with `llm: default`) one hop to the alias target's full
|
||||
// config — so the effective backend is the target's backend, not the empty
|
||||
// backend of the alias stub. This mirrors the top-level alias resolution done
|
||||
// in core/http/middleware/request.go, which the realtime pipeline previously
|
||||
// loadPipelineSubModel must resolve a pipeline sub-model that references an
|
||||
// alias (e.g. `llm: default`) one hop to the alias target's full config — so
|
||||
// the effective backend is the target's backend, not the empty backend of the
|
||||
// alias stub. This mirrors the top-level alias resolution done in
|
||||
// core/http/middleware/request.go, which the realtime pipeline previously
|
||||
// skipped (failing in distributed mode with "backend name is empty").
|
||||
var _ = Describe("LoadResolvedModelConfig", func() {
|
||||
It("resolves an alias one hop to the target's config", func() {
|
||||
var _ = Describe("loadPipelineSubModel", func() {
|
||||
It("resolves a sub-model alias one hop to the target's config", func() {
|
||||
tmpDir := GinkgoT().TempDir()
|
||||
|
||||
// A real model config with a concrete backend.
|
||||
@@ -38,13 +38,13 @@ alias: real-llm
|
||||
Expect(cl.LoadModelConfigsFromPath(tmpDir)).To(Succeed())
|
||||
|
||||
// Resolving the alias must follow the hop to the target's full config.
|
||||
resolved, err := cl.LoadResolvedModelConfig("default", tmpDir)
|
||||
resolved, err := loadPipelineSubModel(cl, "default", tmpDir)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(resolved.IsAlias()).To(BeFalse())
|
||||
Expect(resolved.Backend).To(Equal("llama-cpp"))
|
||||
|
||||
// A non-alias name must load unchanged.
|
||||
direct, err := cl.LoadResolvedModelConfig("real-llm", tmpDir)
|
||||
direct, err := loadPipelineSubModel(cl, "real-llm", tmpDir)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(direct.Backend).To(Equal("llama-cpp"))
|
||||
Expect(direct.Name).To(Equal("real-llm"))
|
||||
@@ -21,7 +21,6 @@ type namedEmbedding struct {
|
||||
// drive the realtime pipeline.
|
||||
type voiceGate struct {
|
||||
cfg config.PipelineVoiceRecognition // normalized
|
||||
recCfg *config.ModelConfig // resolved speaker-recognition model, for warm-up
|
||||
registry voicerecognition.Registry // identify mode (nil otherwise)
|
||||
refEmbeds []namedEmbedding // verify mode, pre-embedded refs
|
||||
refAudios []config.VoiceReference // verify + anti-spoofing: ref paths
|
||||
@@ -73,9 +72,7 @@ func newVoiceGate(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Resolved like every other pipeline sub-model (one alias hop), so an
|
||||
// aliased voice_recognition model gets its target's backend.
|
||||
recCfg, err := cl.LoadResolvedModelConfig(cfg.Model, ml.ModelPath)
|
||||
recCfg, err := cl.LoadModelConfigFileByName(cfg.Model, ml.ModelPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("voice_recognition: failed to load model %q: %w", cfg.Model, err)
|
||||
}
|
||||
@@ -85,7 +82,6 @@ func newVoiceGate(
|
||||
|
||||
g := &voiceGate{
|
||||
cfg: cfg,
|
||||
recCfg: recCfg,
|
||||
registry: registry,
|
||||
embedFn: func(ctx context.Context, wavPath string) ([]float32, error) {
|
||||
res, err := backend.VoiceEmbed(ctx, wavPath, ml, appConfig, *recCfg)
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
)
|
||||
|
||||
// Warmup delegates to backend.PreloadStages (its concurrency, nil-skipping and
|
||||
// error-joining semantics are pinned in core/backend). These specs pin the
|
||||
// wiring instead: each realtime model type must warm exactly its configured
|
||||
// stages under the right pipeline-role labels. No backends are installed, so
|
||||
// every attempted stage fails to load — the joined error is the proof of which
|
||||
// stages were attempted and how they were labeled.
|
||||
var _ = Describe("realtime model Warmup wiring", func() {
|
||||
newLoader := func() (*model.ModelLoader, *config.ApplicationConfig) {
|
||||
systemState, err := system.GetSystemState(system.WithModelPath(GinkgoT().TempDir()))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
appConfig := config.NewApplicationConfig(config.WithSystemState(systemState))
|
||||
return model.NewModelLoader(systemState), appConfig
|
||||
}
|
||||
|
||||
It("wrappedModel warms every configured stage under its pipeline role", func() {
|
||||
ml, appConfig := newLoader()
|
||||
m := &wrappedModel{
|
||||
VADConfig: &config.ModelConfig{Name: "vad-m"},
|
||||
TranscriptionConfig: &config.ModelConfig{Name: "stt-m"},
|
||||
LLMConfig: &config.ModelConfig{Name: "llm-m"},
|
||||
TTSConfig: &config.ModelConfig{Name: "tts-m"},
|
||||
SoundDetectionConfig: &config.ModelConfig{Name: "ced-m"},
|
||||
modelLoader: ml,
|
||||
appConfig: appConfig,
|
||||
}
|
||||
|
||||
err := m.Warmup(context.Background())
|
||||
Expect(err).To(HaveOccurred())
|
||||
for _, stage := range []string{"vad (vad-m)", "transcription (stt-m)", "llm (llm-m)", "tts (tts-m)", "sound_detection (ced-m)"} {
|
||||
Expect(err.Error()).To(ContainSubstring(stage))
|
||||
}
|
||||
})
|
||||
|
||||
It("transcriptOnlyModel warms its stages and skips absent ones", func() {
|
||||
ml, appConfig := newLoader()
|
||||
m := &transcriptOnlyModel{
|
||||
VADConfig: &config.ModelConfig{Name: "vad-m"},
|
||||
TranscriptionConfig: &config.ModelConfig{Name: "stt-m"},
|
||||
// SoundDetectionConfig nil: an absent stage must be skipped, not
|
||||
// fail the warm-up.
|
||||
modelLoader: ml,
|
||||
appConfig: appConfig,
|
||||
}
|
||||
|
||||
err := m.Warmup(context.Background())
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("vad (vad-m)"))
|
||||
Expect(err.Error()).To(ContainSubstring("transcription (stt-m)"))
|
||||
Expect(err.Error()).ToNot(ContainSubstring("sound_detection"))
|
||||
})
|
||||
})
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -30,8 +29,6 @@ const testModel = "Qwen3-VL-2B-Instruct-Q4_K_M"
|
||||
|
||||
var _ = Describe("Open Responses API", func() {
|
||||
var app *echo.Echo
|
||||
var localApp *application.Application
|
||||
var localModelDir string
|
||||
var c context.Context
|
||||
var cancel context.CancelFunc
|
||||
|
||||
@@ -41,47 +38,28 @@ var _ = Describe("Open Responses API", func() {
|
||||
|
||||
Context("API with ephemeral models", func() {
|
||||
BeforeEach(func(sc SpecContext) {
|
||||
// This suite exercises the /v1/responses HTTP/protocol contract
|
||||
// (Content-Type, SSE framing, response envelope, error shapes),
|
||||
// not real inference — so it runs against the same prebuilt
|
||||
// mock-backend the rest of the http suite uses instead of
|
||||
// downloading a real model. Skip cleanly when it isn't built.
|
||||
if mockBackendPath == "" {
|
||||
Skip("mock-backend binary not built; run 'make build-mock-backend'")
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
backendPath := os.Getenv("BACKENDS_PATH")
|
||||
|
||||
c, cancel = context.WithCancel(context.Background())
|
||||
|
||||
// Isolated model dir carrying a single config named after testModel
|
||||
// but served by the mock backend, so the responses endpoint can
|
||||
// resolve and load the model without any real backend build.
|
||||
localModelDir, err = os.MkdirTemp("", "openresponses-models-")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
mockModelYAML := "name: " + testModel + "\n" +
|
||||
"backend: mock-backend\n" +
|
||||
"parameters:\n" +
|
||||
" model: mock-model.bin\n"
|
||||
Expect(os.WriteFile(filepath.Join(localModelDir, testModel+".yaml"), []byte(mockModelYAML), 0644)).To(Succeed())
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendPath(backendDir),
|
||||
system.WithModelPath(localModelDir),
|
||||
system.WithBackendPath(backendPath),
|
||||
system.WithModelPath(modelDir),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
localApp, err = application.New(
|
||||
application, err := application.New(
|
||||
append(commonOpts,
|
||||
config.WithContext(c),
|
||||
config.WithSystemState(systemState),
|
||||
config.WithApiKeys([]string{apiKey}),
|
||||
config.WithModelsURL("https://huggingface.co/unsloth/Qwen3-VL-2B-Instruct-GGUF"),
|
||||
)...)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
localApp.ModelLoader().SetExternalBackend("mock-backend", mockBackendPath)
|
||||
|
||||
app, err = API(localApp)
|
||||
app, err = API(application)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
go func() {
|
||||
@@ -102,24 +80,14 @@ var _ = Describe("Open Responses API", func() {
|
||||
})
|
||||
|
||||
AfterEach(func(sc SpecContext) {
|
||||
// Synchronous app shutdown first — context-cancel cleanup is async
|
||||
// and races test-binary exit, orphaning mock-backend children.
|
||||
if localApp != nil {
|
||||
_ = localApp.Shutdown()
|
||||
localApp = nil
|
||||
}
|
||||
cancel()
|
||||
if app != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
err := app.Shutdown(ctx)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
app = nil
|
||||
}
|
||||
if localModelDir != "" {
|
||||
_ = os.RemoveAll(localModelDir)
|
||||
localModelDir = ""
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
Context("HTTP Protocol Compliance", func() {
|
||||
@@ -1001,16 +969,13 @@ var _ = Describe("Open Responses API", func() {
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(itemID).ToNot(BeEmpty())
|
||||
|
||||
// Now create a new response with item_reference. Per the OpenAI
|
||||
// Responses spec (and this server's parser in
|
||||
// endpoints/openresponses/responses.go) an item_reference carries
|
||||
// the referenced item in the "id" field, not "item_id".
|
||||
// Now create a new response with item_reference
|
||||
reqBody2 := map[string]any{
|
||||
"model": testModel,
|
||||
"input": []any{
|
||||
map[string]any{
|
||||
"type": "item_reference",
|
||||
"id": itemID,
|
||||
"type": "item_reference",
|
||||
"item_id": itemID,
|
||||
},
|
||||
map[string]any{
|
||||
"type": "message",
|
||||
@@ -1040,8 +1005,8 @@ var _ = Describe("Open Responses API", func() {
|
||||
"model": testModel,
|
||||
"input": []any{
|
||||
map[string]any{
|
||||
"type": "item_reference",
|
||||
"id": "nonexistent_item_id",
|
||||
"type": "item_reference",
|
||||
"item_id": "nonexistent_item_id",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,133 +0,0 @@
|
||||
import { test, expect } from './coverage-fixtures.js'
|
||||
|
||||
// Seeds two-message chat into localStorage so we don't need a live model.
|
||||
async function seedChat(page, history) {
|
||||
await page.addInitScript((h) => {
|
||||
const chat = {
|
||||
id: 'seed1', name: 'Seeded Chat', model: 'test-model',
|
||||
history: h, systemPrompt: '', mcpMode: false, mcpServers: [],
|
||||
clientMCPServers: [], temperature: null, topP: null, topK: null,
|
||||
tokenUsage: { prompt: 0, completion: 0, total: 0 },
|
||||
contextSize: null, createdAt: Date.now(), updatedAt: Date.now(),
|
||||
}
|
||||
localStorage.setItem('localai_chats_data', JSON.stringify({
|
||||
chats: [chat], activeChatId: 'seed1', lastSaved: Date.now(),
|
||||
}))
|
||||
}, history)
|
||||
}
|
||||
|
||||
async function mockModels(page) {
|
||||
await page.route('**/api/models/capabilities', (route) => route.fulfill({
|
||||
contentType: 'application/json',
|
||||
body: JSON.stringify({ data: [{ id: 'test-model', capabilities: ['FLAG_CHAT'] }] }),
|
||||
}))
|
||||
await page.route('**/api/operations', (route) => route.fulfill({
|
||||
contentType: 'application/json', body: JSON.stringify({ operations: [] }),
|
||||
}))
|
||||
}
|
||||
|
||||
const TWO_TURNS = [
|
||||
{ role: 'user', content: 'first question' },
|
||||
{ role: 'assistant', content: 'first answer' },
|
||||
{ role: 'user', content: 'second question' },
|
||||
{ role: 'assistant', content: 'second answer' },
|
||||
]
|
||||
|
||||
test('duplicate creates an independent copy and switches to it', async ({ page }) => {
|
||||
await mockModels(page)
|
||||
await seedChat(page, TWO_TURNS)
|
||||
await page.goto('/app/chat')
|
||||
|
||||
// Open the chats menu (Ctrl/Cmd+K) and duplicate the seeded chat.
|
||||
// Wait for the menu trigger to mount so its global keydown listener is armed
|
||||
// before we dispatch the shortcut.
|
||||
await page.getByTitle('Conversations (Ctrl/Cmd+K)').waitFor()
|
||||
await page.keyboard.press('Control+k')
|
||||
await page.getByTitle('Duplicate chat').first().click()
|
||||
|
||||
// A new active chat named "Seeded Chat (fork)" with the same 4 messages.
|
||||
await expect(page.locator('.chat-header-title')).toHaveText('Seeded Chat (fork)')
|
||||
await expect(page.locator('.chat-message-user')).toHaveCount(2)
|
||||
await expect(page.locator('.chat-message-assistant')).toHaveCount(2)
|
||||
})
|
||||
|
||||
async function mockCompletion(page, replyText) {
|
||||
await page.route('**/v1/chat/completions', (route) => {
|
||||
const sse =
|
||||
`data: ${JSON.stringify({ choices: [{ delta: { content: replyText } }] })}\n\n` +
|
||||
`data: ${JSON.stringify({ choices: [{ delta: {}, finish_reason: 'stop' }], usage: { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2 } })}\n\n` +
|
||||
`data: [DONE]\n\n`
|
||||
route.fulfill({ status: 200, contentType: 'text/event-stream', body: sse })
|
||||
})
|
||||
}
|
||||
|
||||
test('retry regenerates the first answer and drops the later turn', async ({ page }) => {
|
||||
await mockModels(page)
|
||||
// Capture the outbound request body so we can assert the model receives the
|
||||
// truncated history (not the stale downstream turns).
|
||||
let sentMessages = null
|
||||
await page.route('**/v1/chat/completions', (route) => {
|
||||
sentMessages = route.request().postDataJSON()?.messages || []
|
||||
const sse =
|
||||
`data: ${JSON.stringify({ choices: [{ delta: { content: 'REGENERATED first answer' } }] })}\n\n` +
|
||||
`data: ${JSON.stringify({ choices: [{ delta: {}, finish_reason: 'stop' }], usage: { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2 } })}\n\n` +
|
||||
`data: [DONE]\n\n`
|
||||
route.fulfill({ status: 200, contentType: 'text/event-stream', body: sse })
|
||||
})
|
||||
await seedChat(page, TWO_TURNS)
|
||||
await page.goto('/app/chat')
|
||||
|
||||
// Hover the FIRST assistant message and click its retry button.
|
||||
const firstAssistant = page.locator('.chat-message-assistant').first()
|
||||
await firstAssistant.hover()
|
||||
await firstAssistant.getByTitle('Regenerate').click()
|
||||
|
||||
// History is truncated to the first user turn, then the new answer streams in;
|
||||
// the second Q/A turn is gone.
|
||||
await expect(page.locator('.chat-message-assistant')).toContainText(['REGENERATED first answer'])
|
||||
await expect(page.locator('.chat-message-user')).toHaveCount(1)
|
||||
await expect(page.locator('.chat-message-assistant')).toHaveCount(1)
|
||||
|
||||
// The OUTBOUND payload must also be truncated: the resent user turn is present,
|
||||
// but the downstream turn and the stale first answer must be gone.
|
||||
const contents = (sentMessages || []).map(m =>
|
||||
typeof m.content === 'string' ? m.content : JSON.stringify(m.content)
|
||||
)
|
||||
expect(contents.join('\n')).toContain('first question')
|
||||
expect(contents.join('\n')).not.toContain('second question')
|
||||
expect(contents.join('\n')).not.toContain('first answer')
|
||||
})
|
||||
|
||||
test('copy chat puts the whole conversation on the clipboard', async ({ page, context }) => {
|
||||
await context.grantPermissions(['clipboard-read', 'clipboard-write'])
|
||||
await mockModels(page)
|
||||
await seedChat(page, TWO_TURNS)
|
||||
await page.goto('/app/chat')
|
||||
|
||||
// Wait for the menu trigger to mount so its global keydown listener is armed
|
||||
// before we dispatch the shortcut (same mount-race guard as the duplicate test).
|
||||
await page.getByTitle('Conversations (Ctrl/Cmd+K)').waitFor()
|
||||
await page.keyboard.press('Control+k')
|
||||
await page.getByTitle('Copy chat').first().click()
|
||||
|
||||
const clip = await page.evaluate(() => navigator.clipboard.readText())
|
||||
expect(clip).toContain('# Seeded Chat')
|
||||
expect(clip).toContain('first answer')
|
||||
expect(clip).toContain('second answer')
|
||||
})
|
||||
|
||||
test('branch from the first answer forks history up to that point', async ({ page }) => {
|
||||
await mockModels(page)
|
||||
await seedChat(page, TWO_TURNS)
|
||||
await page.goto('/app/chat')
|
||||
|
||||
const firstAssistant = page.locator('.chat-message-assistant').first()
|
||||
await firstAssistant.hover()
|
||||
await firstAssistant.getByTitle('Branch from here').click()
|
||||
|
||||
// New active chat "Seeded Chat (fork)" contains only the first Q/A turn.
|
||||
await expect(page.locator('.chat-header-title')).toHaveText('Seeded Chat (fork)')
|
||||
await expect(page.locator('.chat-message-user')).toHaveCount(1)
|
||||
await expect(page.locator('.chat-message-assistant')).toHaveCount(1)
|
||||
await expect(page.locator('.chat-message-assistant')).toContainText(['first answer'])
|
||||
})
|
||||
@@ -72,7 +72,6 @@
|
||||
"actions": {
|
||||
"copy": "Copy",
|
||||
"regenerate": "Regenerate",
|
||||
"branch": "Branch from here",
|
||||
"jumpToLatest": "Jump to latest"
|
||||
},
|
||||
"streaming": {
|
||||
@@ -101,9 +100,7 @@
|
||||
"toasts": {
|
||||
"selectModel": "Please select a model",
|
||||
"copied": "Copied to clipboard",
|
||||
"copyFailed": "Could not copy to clipboard",
|
||||
"chatCopied": "Chat copied to clipboard",
|
||||
"forked": "Created a new chat"
|
||||
"copyFailed": "Could not copy to clipboard"
|
||||
},
|
||||
"menu": {
|
||||
"trigger": "Chats",
|
||||
@@ -113,8 +110,6 @@
|
||||
"noMatch": "No conversations match your search",
|
||||
"noConversations": "No conversations yet",
|
||||
"rename": "Rename",
|
||||
"duplicate": "Duplicate chat",
|
||||
"copyChat": "Copy chat",
|
||||
"exportMarkdown": "Export as Markdown",
|
||||
"deleteChat": "Delete chat",
|
||||
"newChat": "New chat",
|
||||
|
||||
@@ -24,8 +24,6 @@ const ChatsMenu = forwardRef(function ChatsMenu({
|
||||
onDeleteAll,
|
||||
onRename,
|
||||
onExport,
|
||||
onCopyChat,
|
||||
onDuplicate,
|
||||
}, ref) {
|
||||
const { t } = useTranslation('chat')
|
||||
const [open, setOpen] = useState(false)
|
||||
@@ -232,24 +230,6 @@ const ChatsMenu = forwardRef(function ChatsMenu({
|
||||
>
|
||||
<i className="fas fa-pen" />
|
||||
</button>
|
||||
{onDuplicate && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={(e) => { e.stopPropagation(); onDuplicate(chat); setOpen(false) }}
|
||||
title={t('menu.duplicate')}
|
||||
>
|
||||
<i className="fas fa-clone" />
|
||||
</button>
|
||||
)}
|
||||
{(chat.history?.length || 0) > 0 && onCopyChat && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={(e) => { e.stopPropagation(); onCopyChat(chat) }}
|
||||
title={t('menu.copyChat')}
|
||||
>
|
||||
<i className="fas fa-clipboard" />
|
||||
</button>
|
||||
)}
|
||||
{(chat.history?.length || 0) > 0 && onExport && (
|
||||
<button
|
||||
type="button"
|
||||
|
||||
27
core/http/react-ui/src/hooks/useChat.js
vendored
27
core/http/react-ui/src/hooks/useChat.js
vendored
@@ -141,24 +141,6 @@ export function useChat(initialModel = '') {
|
||||
return chat
|
||||
}, [])
|
||||
|
||||
const forkChat = useCallback((chatId, uptoIndex) => {
|
||||
const src = chats.find(c => c.id === chatId)
|
||||
if (!src) return null
|
||||
const end = typeof uptoIndex === 'number' ? uptoIndex : src.history.length
|
||||
const forked = {
|
||||
...src,
|
||||
id: generateId(),
|
||||
name: `${src.name} (fork)`,
|
||||
history: structuredClone(src.history.slice(0, end)),
|
||||
tokenUsage: { prompt: 0, completion: 0, total: 0 },
|
||||
createdAt: Date.now(),
|
||||
updatedAt: Date.now(),
|
||||
}
|
||||
setChats(prev => [forked, ...prev])
|
||||
setActiveChatId(forked.id)
|
||||
return forked
|
||||
}, [chats])
|
||||
|
||||
const switchChat = useCallback((chatId) => {
|
||||
setActiveChatId(chatId)
|
||||
setStreamingContent('')
|
||||
@@ -278,12 +260,8 @@ export function useChat(initialModel = '') {
|
||||
if (chat?.systemPrompt) {
|
||||
messages.push({ role: 'system', content: chat.systemPrompt })
|
||||
}
|
||||
// Filter out thinking/reasoning/tool_call/tool_result messages.
|
||||
// options.baseHistory lets callers (e.g. mid-conversation retry) pass the
|
||||
// intended truncated history synchronously; the closure `chat` still holds
|
||||
// the stale pre-truncation state because setChats only schedules an update.
|
||||
const baseHistory = options.baseHistory || chat?.history || []
|
||||
const historyForApi = baseHistory.filter(m =>
|
||||
// Filter out thinking/reasoning/tool_call/tool_result messages
|
||||
const historyForApi = (chat?.history || []).filter(m =>
|
||||
m.role !== 'thinking' && m.role !== 'reasoning' && m.role !== 'tool_call' && m.role !== 'tool_result'
|
||||
)
|
||||
messages.push(...historyForApi, { role: 'user', content: messageContent })
|
||||
@@ -815,7 +793,6 @@ export function useChat(initialModel = '') {
|
||||
tokensPerSecond,
|
||||
maxTokensPerSecond,
|
||||
addChat,
|
||||
forkChat,
|
||||
switchChat,
|
||||
deleteChat,
|
||||
deleteAllChats,
|
||||
|
||||
@@ -33,7 +33,7 @@ function getLastMessagePreview(chat) {
|
||||
return ''
|
||||
}
|
||||
|
||||
function serializeChatAsMarkdown(chat) {
|
||||
function exportChatAsMarkdown(chat) {
|
||||
let md = `# ${chat.name}\n\n`
|
||||
md += `Model: ${chat.model || 'Unknown'}\n`
|
||||
md += `Date: ${new Date(chat.createdAt).toLocaleString()}\n\n---\n\n`
|
||||
@@ -47,11 +47,7 @@ function serializeChatAsMarkdown(chat) {
|
||||
md += `<details><summary>Thinking</summary>\n\n${msg.content}\n\n</details>\n\n`
|
||||
}
|
||||
}
|
||||
return md
|
||||
}
|
||||
|
||||
function downloadChatAsMarkdown(chat) {
|
||||
const blob = new Blob([serializeChatAsMarkdown(chat)], { type: 'text/markdown' })
|
||||
const blob = new Blob([md], { type: 'text/markdown' })
|
||||
const url = URL.createObjectURL(blob)
|
||||
const a = document.createElement('a')
|
||||
a.href = url
|
||||
@@ -298,7 +294,7 @@ export default function Chat() {
|
||||
const {
|
||||
chats, activeChat, activeChatId, isStreaming, streamingChatId, streamingContent,
|
||||
streamingReasoning, streamingToolCalls, tokensPerSecond, maxTokensPerSecond,
|
||||
addChat, forkChat, switchChat, deleteChat, deleteAllChats, renameChat, updateChatSettings,
|
||||
addChat, switchChat, deleteChat, deleteAllChats, renameChat, updateChatSettings,
|
||||
sendMessage, stopGeneration, clearHistory, getContextUsagePercent, addMessage,
|
||||
} = useChat(urlModel || '')
|
||||
|
||||
@@ -799,27 +795,34 @@ export default function Chat() {
|
||||
await sendMessage(msg, files, mcpOptions)
|
||||
}, [input, files, activeChat, sendMessage, addToast, getToolsForLLM, isClientTool, executeTool, hasAppUI, getAppResource, getToolDefinition])
|
||||
|
||||
const handleRegenerate = useCallback(async (targetIndex) => {
|
||||
const handleRegenerate = useCallback(async () => {
|
||||
if (!activeChat || isStreaming) return
|
||||
const history = activeChat.history
|
||||
const end = typeof targetIndex === 'number' ? targetIndex : history.length
|
||||
// Nearest user message at or before the target answer.
|
||||
let userIdx = -1
|
||||
for (let i = Math.min(end, history.length) - 1; i >= 0; i--) {
|
||||
if (history[i].role === 'user') { userIdx = i; break }
|
||||
let lastUserMsg = null
|
||||
let lastUserFiles = null
|
||||
for (let i = history.length - 1; i >= 0; i--) {
|
||||
if (history[i].role === 'user') {
|
||||
lastUserMsg = typeof history[i].content === 'string' ? history[i].content : history[i].content?.[0]?.text || ''
|
||||
lastUserFiles = history[i].files || []
|
||||
break
|
||||
}
|
||||
}
|
||||
if (userIdx === -1) return
|
||||
const userMsg = typeof history[userIdx].content === 'string'
|
||||
? history[userIdx].content
|
||||
: history[userIdx].content?.[0]?.text || ''
|
||||
const userFiles = history[userIdx].files || []
|
||||
// Drop the user turn and everything after it; sendMessage re-appends it.
|
||||
// Thread the truncated history through explicitly: updateChatSettings only
|
||||
// schedules a state update, so sendMessage's closure would otherwise read
|
||||
// the stale pre-truncation history for the outbound API payload.
|
||||
const baseHistory = history.slice(0, userIdx)
|
||||
updateChatSettings(activeChat.id, { history: baseHistory })
|
||||
await sendMessage(userMsg, userFiles, { baseHistory })
|
||||
if (!lastUserMsg) return
|
||||
|
||||
// Remove everything after and including the last user message
|
||||
const newHistory = []
|
||||
let foundLastUser = false
|
||||
for (let i = history.length - 1; i >= 0; i--) {
|
||||
if (!foundLastUser && history[i].role === 'user') {
|
||||
foundLastUser = true
|
||||
continue
|
||||
}
|
||||
if (foundLastUser) {
|
||||
newHistory.unshift(history[i])
|
||||
}
|
||||
}
|
||||
updateChatSettings(activeChat.id, { history: newHistory })
|
||||
await sendMessage(lastUserMsg, lastUserFiles)
|
||||
}, [activeChat, isStreaming, sendMessage, updateChatSettings])
|
||||
|
||||
const handleKeyDown = (e) => {
|
||||
@@ -849,11 +852,6 @@ export default function Chat() {
|
||||
}
|
||||
}
|
||||
|
||||
const copyChatAsMarkdown = async (chat) => {
|
||||
const ok = await copyToClipboard(serializeChatAsMarkdown(chat))
|
||||
addToast(ok ? t('toasts.chatCopied') : t('toasts.copyFailed'), ok ? 'success' : 'error', ok ? 2000 : 3000)
|
||||
}
|
||||
|
||||
const contextPercent = getContextUsagePercent()
|
||||
|
||||
// Recent chats for the empty state — exclude the current chat and any
|
||||
@@ -894,9 +892,7 @@ export default function Chat() {
|
||||
onDelete={deleteChat}
|
||||
onDeleteAll={promptDeleteAll}
|
||||
onRename={renameChat}
|
||||
onExport={(chat) => downloadChatAsMarkdown(chat)}
|
||||
onCopyChat={(chat) => copyChatAsMarkdown(chat)}
|
||||
onDuplicate={(chat) => { if (forkChat(chat.id)) addToast(t('toasts.forked'), 'success', 2000) }}
|
||||
onExport={(chat) => exportChatAsMarkdown(chat)}
|
||||
/>
|
||||
{activeChat.localaiAssistant && (
|
||||
<span
|
||||
@@ -1188,19 +1184,11 @@ export default function Chat() {
|
||||
<button onClick={() => copyMessage(msg.content)} title={t('actions.copy')}>
|
||||
<i className="fas fa-copy" />
|
||||
</button>
|
||||
{msg.role === 'assistant' && !isStreaming && (
|
||||
<button onClick={() => handleRegenerate(i)} title={t('actions.regenerate')}>
|
||||
{msg.role === 'assistant' && i === activeChat.history.length - 1 && !isStreaming && (
|
||||
<button onClick={handleRegenerate} title={t('actions.regenerate')}>
|
||||
<i className="fas fa-rotate" />
|
||||
</button>
|
||||
)}
|
||||
{msg.role === 'assistant' && !isStreaming && (
|
||||
<button
|
||||
onClick={() => { forkChat(activeChat.id, i + 1); addToast(t('toasts.forked'), 'success', 2000) }}
|
||||
title={t('actions.branch')}
|
||||
>
|
||||
<i className="fas fa-code-branch" />
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -146,7 +146,6 @@ export default function Manage() {
|
||||
const [distributedMode, setDistributedMode] = useState(false)
|
||||
const [togglingModels, setTogglingModels] = useState(new Set())
|
||||
const [pinningModels, setPinningModels] = useState(new Set())
|
||||
const [loadingModels, setLoadingModels] = useState(new Set())
|
||||
// Expanded row state — keyed by `${tab}:${id}` so switching tabs doesn't
|
||||
// collide and a single row is open at a time per tab.
|
||||
const [expandedKey, setExpandedKey] = useState(null)
|
||||
@@ -314,26 +313,6 @@ export default function Manage() {
|
||||
})
|
||||
}
|
||||
|
||||
// Pre-load a model (or all of a realtime pipeline's sub-models) into memory.
|
||||
// The /backend/load call blocks until loading finishes, so the menu item shows
|
||||
// a loading state while in flight and reports the outcome on completion.
|
||||
const handleLoadModel = async (modelName) => {
|
||||
setLoadingModels(prev => new Set(prev).add(modelName))
|
||||
try {
|
||||
await backendControlApi.load({ model: modelName })
|
||||
addToast(`Loaded ${modelName}`, 'success')
|
||||
setTimeout(fetchLoadedModels, 500)
|
||||
} catch (err) {
|
||||
addToast(`Failed to load: ${err.message}`, 'error')
|
||||
} finally {
|
||||
setLoadingModels(prev => {
|
||||
const next = new Set(prev)
|
||||
next.delete(modelName)
|
||||
return next
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const handleDeleteModel = (modelName) => {
|
||||
setConfirmDialog({
|
||||
title: 'Delete Model',
|
||||
@@ -708,11 +687,6 @@ export default function Manage() {
|
||||
label: model.disabled ? 'Enable model' : 'Disable model',
|
||||
onClick: () => handleToggleModel(model.id, model.disabled),
|
||||
disabled: togglingModels.has(model.id) },
|
||||
{ key: 'load', icon: 'fa-bolt',
|
||||
label: loadingModels.has(model.id) ? 'Loading…' : 'Load into memory',
|
||||
onClick: () => handleLoadModel(model.id),
|
||||
hidden: isRunning || !!model.disabled,
|
||||
disabled: loadingModels.has(model.id) },
|
||||
{ key: 'stop', icon: 'fa-stop', label: 'Stop model',
|
||||
onClick: () => handleStopModel(model.id), hidden: !isRunning },
|
||||
{ key: 'pin', icon: 'fa-thumbtack',
|
||||
|
||||
3
core/http/react-ui/src/utils/api.js
vendored
3
core/http/react-ui/src/utils/api.js
vendored
@@ -352,9 +352,6 @@ export const realtimeApi = {
|
||||
// Backend control
|
||||
export const backendControlApi = {
|
||||
shutdown: (body) => postJSON(API_CONFIG.endpoints.backendShutdown, body),
|
||||
// Pre-load a model (or all of a realtime pipeline's sub-models) into memory.
|
||||
// body: { model: "<name>" }. Inverse of shutdown.
|
||||
load: (body) => postJSON(API_CONFIG.endpoints.backendLoad, body),
|
||||
}
|
||||
|
||||
// System info
|
||||
|
||||
1
core/http/react-ui/src/utils/config.js
vendored
1
core/http/react-ui/src/utils/config.js
vendored
@@ -106,7 +106,6 @@ export const API_CONFIG = {
|
||||
video: '/video',
|
||||
backendMonitor: '/backend/monitor',
|
||||
backendShutdown: '/backend/shutdown',
|
||||
backendLoad: '/backend/load',
|
||||
modelsApply: '/models/apply',
|
||||
modelsDelete: (name) => `/models/delete/${name}`,
|
||||
modelsAvailable: '/models/available',
|
||||
|
||||
@@ -207,14 +207,9 @@ func RegisterLocalAIRoutes(router *echo.Echo,
|
||||
backendMonitorService := monitoring.NewBackendMonitorService(ml, cl, appConfig) // Split out for now
|
||||
router.GET("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService), adminMiddleware)
|
||||
router.POST("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService), adminMiddleware)
|
||||
// /backend/load is the inverse of /backend/shutdown: pre-load a model (or all
|
||||
// of a realtime pipeline's sub-models) into memory so clients can drive
|
||||
// warm-up explicitly instead of paying the cold-start cost on first use.
|
||||
router.POST("/backend/load", localai.LoadModelEndpoint(cl, ml, appConfig), adminMiddleware)
|
||||
// The v1/* urls are exactly the same as above - makes local e2e testing easier if they are registered.
|
||||
router.GET("/v1/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService), adminMiddleware)
|
||||
router.POST("/v1/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService), adminMiddleware)
|
||||
router.POST("/v1/backend/load", localai.LoadModelEndpoint(cl, ml, appConfig), adminMiddleware)
|
||||
|
||||
// Traces and backend logs (monitoring)
|
||||
router.GET("/api/traces", localai.GetAPITracesEndpoint(), adminMiddleware)
|
||||
@@ -250,7 +245,6 @@ func RegisterLocalAIRoutes(router *echo.Echo,
|
||||
"metrics": "/metrics",
|
||||
"backend_monitor": "/backend/monitor",
|
||||
"backend_shutdown": "/backend/shutdown",
|
||||
"backend_load": "/backend/load",
|
||||
"system": "/system",
|
||||
"version": "/version",
|
||||
"traces": "/api/traces",
|
||||
@@ -272,27 +266,25 @@ func RegisterLocalAIRoutes(router *echo.Echo,
|
||||
"version": internal.PrintableVersion(),
|
||||
// Flat endpoint list for backwards compatibility
|
||||
"endpoints": map[string]any{
|
||||
"models": "/v1/models",
|
||||
"models_capabilities": "/v1/models/capabilities",
|
||||
"chat_completions": "/v1/chat/completions",
|
||||
"completions": "/v1/completions",
|
||||
"embeddings": "/v1/embeddings",
|
||||
"config_metadata": "/api/models/config-metadata",
|
||||
"config_json": "/api/models/config-json/:name",
|
||||
"config_patch": "/api/models/config-json/:name",
|
||||
"autocomplete": "/api/models/config-metadata/autocomplete/:provider",
|
||||
"vram_estimate": "/api/models/vram-estimate",
|
||||
"tts": "/tts",
|
||||
"transcription": "/v1/audio/transcriptions",
|
||||
"image_generation": "/v1/images/generations",
|
||||
"swagger": "/swagger/index.html",
|
||||
"instructions": "/api/instructions",
|
||||
"models": "/v1/models",
|
||||
"chat_completions": "/v1/chat/completions",
|
||||
"completions": "/v1/completions",
|
||||
"embeddings": "/v1/embeddings",
|
||||
"config_metadata": "/api/models/config-metadata",
|
||||
"config_json": "/api/models/config-json/:name",
|
||||
"config_patch": "/api/models/config-json/:name",
|
||||
"autocomplete": "/api/models/config-metadata/autocomplete/:provider",
|
||||
"vram_estimate": "/api/models/vram-estimate",
|
||||
"tts": "/tts",
|
||||
"transcription": "/v1/audio/transcriptions",
|
||||
"image_generation": "/v1/images/generations",
|
||||
"swagger": "/swagger/index.html",
|
||||
"instructions": "/api/instructions",
|
||||
},
|
||||
// Categorized endpoint groups for structured discovery
|
||||
"endpoint_groups": map[string]any{
|
||||
"openai_compatible": map[string]string{
|
||||
"models": "/v1/models",
|
||||
"models_capabilities": "/v1/models/capabilities",
|
||||
"chat_completions": "/v1/chat/completions",
|
||||
"completions": "/v1/completions",
|
||||
"embeddings": "/v1/embeddings",
|
||||
|
||||
@@ -257,10 +257,4 @@ func RegisterOpenAIRoutes(app *echo.Echo,
|
||||
// List models
|
||||
app.GET("/v1/models", openai.ListModelsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.AuthDB()))
|
||||
app.GET("/models", openai.ListModelsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.AuthDB()))
|
||||
|
||||
// List models enriched with capabilities + input/output modalities
|
||||
// (LocalAI-specific, additive superset of /v1/models).
|
||||
capabilitiesHandler := openai.ListModelCapabilitiesEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.AuthDB())
|
||||
app.GET("/v1/models/capabilities", capabilitiesHandler)
|
||||
app.GET("/models/capabilities", capabilitiesHandler)
|
||||
}
|
||||
|
||||
@@ -1243,9 +1243,6 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model
|
||||
Galleries: appConfig.BackendGalleries,
|
||||
Context: ctx,
|
||||
CancelFunc: cancelFunc,
|
||||
// The React UI's "Reinstall backend" action reuses this route, so
|
||||
// the op must force even when the backend is already installed.
|
||||
Force: true,
|
||||
}
|
||||
// Store cancellation function immediately so queued operations can be cancelled
|
||||
galleryService.StoreCancellation(uid, cancelFunc)
|
||||
|
||||
@@ -11,24 +11,6 @@ type BackendMonitorRequest struct {
|
||||
BasicModelRequest
|
||||
}
|
||||
|
||||
// ModelLoadRequest asks LocalAI to pre-load a model into memory by name, so the
|
||||
// first request that uses it pays no cold-start load cost. For a realtime
|
||||
// pipeline model, every configured sub-model (VAD, transcription, LLM, TTS,
|
||||
// sound_detection, voice_recognition) is loaded instead of the pipeline stub.
|
||||
// It is the inverse of the /backend/shutdown request.
|
||||
type ModelLoadRequest struct {
|
||||
BasicModelRequest
|
||||
}
|
||||
|
||||
// ModelLoadResponse reports the outcome of a /backend/load call.
|
||||
type ModelLoadResponse struct {
|
||||
// Loaded lists the model names actually resident in memory after the call.
|
||||
// For a pipeline model these are its sub-models, not the pipeline name.
|
||||
Loaded []string `json:"loaded"`
|
||||
// Message is a short human-readable status ("model loaded", or an error).
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type TokenMetricsRequest struct {
|
||||
BasicModelRequest
|
||||
}
|
||||
|
||||
@@ -251,27 +251,3 @@ type ModelsDataResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []OpenAIModel `json:"data"`
|
||||
}
|
||||
|
||||
// ModelCapabilities is a strict superset of OpenAIModel that additionally
|
||||
// describes what a model can do and which modalities it accepts/produces. It is
|
||||
// served by the LocalAI-specific /v1/models/capabilities endpoint so clients can
|
||||
// route attachments (image/audio/video) to a model only when it can handle them.
|
||||
type ModelCapabilities struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
// Capabilities are canonical usecase strings (e.g. chat, vision, transcript,
|
||||
// tts, embeddings, image, video) plus the modifiers "tools" and "thinking".
|
||||
Capabilities []string `json:"capabilities"`
|
||||
// InputModalities is the subset of {text,image,audio,video} the model accepts.
|
||||
InputModalities []string `json:"input_modalities"`
|
||||
// OutputModalities is the subset of {text,image,audio,video} the model produces.
|
||||
OutputModalities []string `json:"output_modalities"`
|
||||
}
|
||||
|
||||
// ModelCapabilitiesResponse is the envelope returned by /v1/models/capabilities.
|
||||
// It mirrors ModelsDataResponse so a client can treat it as an enriched
|
||||
// drop-in for /v1/models.
|
||||
type ModelCapabilitiesResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []ModelCapabilities `json:"data"`
|
||||
}
|
||||
|
||||
@@ -6,39 +6,10 @@ import (
|
||||
"hash/fnv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// advisoryLockWaitBackstop bounds, server-side, how long we will wait to
|
||||
// acquire a blocking advisory lock when the caller's context carries no
|
||||
// deadline (e.g. a startup schema migration using context.Background()). It
|
||||
// only exists so such a caller cannot hang forever behind a holder whose
|
||||
// session never releases the lock; it is far longer than any legitimate
|
||||
// guarded section. A var (not const) so tests can shrink it.
|
||||
var advisoryLockWaitBackstop = 30 * time.Minute
|
||||
|
||||
// advisoryLockTimeoutMargin is added to a context's remaining budget when
|
||||
// deriving the server-side lock_timeout, so the Go context's own (cleaner)
|
||||
// cancellation fires first and the server bound is only ever a backstop.
|
||||
const advisoryLockTimeoutMargin = 30 * time.Second
|
||||
|
||||
// advisoryLockWaitBudget returns the server-side lock_timeout to use for a
|
||||
// blocking acquire: the caller context's remaining time plus a margin (so the
|
||||
// Go context still governs), or the backstop when the context has no deadline.
|
||||
// Never returns zero - "wait forever" must not be possible.
|
||||
func advisoryLockWaitBudget(ctx context.Context) time.Duration {
|
||||
if dl, ok := ctx.Deadline(); ok {
|
||||
budget := time.Until(dl) + advisoryLockTimeoutMargin
|
||||
if budget < time.Second {
|
||||
budget = time.Second
|
||||
}
|
||||
return budget
|
||||
}
|
||||
return advisoryLockWaitBackstop
|
||||
}
|
||||
|
||||
// localLocks holds one buffered channel (capacity 1) per lock key, used as an
|
||||
// in-process mutex for non-PostgreSQL dialects (SQLite). A SQLite auth DB is
|
||||
// effectively single-process, so serializing guarded sections within this
|
||||
@@ -159,27 +130,6 @@ func WithLockCtx(ctx context.Context, db *gorm.DB, key int64, fn func() error) e
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Override any deployment-wide lock_timeout on this dedicated connection.
|
||||
// Operators commonly set a short global lock_timeout (on the role or
|
||||
// database) to bound ordinary row-lock waits. Applied to the blocking
|
||||
// pg_advisory_lock below, it aborts the wait with SQLSTATE 55P03 and turns
|
||||
// LocalAI's intentional cross-replica "wait your turn, then re-check"
|
||||
// coordination into a hard error for the caller (e.g. a chat request that
|
||||
// just wanted to reuse a model another replica is loading).
|
||||
//
|
||||
// We do NOT disable it outright (lock_timeout = 0 would wait forever, which
|
||||
// is unsafe for the schema-migration callers that pass context.Background()).
|
||||
// Instead we set a bound derived from the caller's context: its remaining
|
||||
// budget plus a margin so the Go context's cancellation wins with a clean
|
||||
// error, or a finite backstop when the context has no deadline.
|
||||
waitBudget := advisoryLockWaitBudget(ctx)
|
||||
if _, err := conn.ExecContext(ctx,
|
||||
fmt.Sprintf("SET lock_timeout = %d", waitBudget.Milliseconds())); err != nil {
|
||||
return fmt.Errorf("advisorylock: setting lock_timeout: %w", err)
|
||||
}
|
||||
// Restore the session default before this pooled connection is reused.
|
||||
defer func() { _, _ = conn.ExecContext(context.Background(), "RESET lock_timeout") }()
|
||||
|
||||
if _, err := conn.ExecContext(ctx, "SELECT pg_advisory_lock($1)", key); err != nil {
|
||||
return fmt.Errorf("advisorylock: acquiring lock %d: %w", key, err)
|
||||
}
|
||||
|
||||
@@ -158,87 +158,6 @@ var _ = Describe("AdvisoryLock", func() {
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("waits out a short server-side lock_timeout instead of failing with 55P03", func() {
|
||||
const lockKey int64 = 703
|
||||
|
||||
// Reproduce the production deployment that triggered this: a short
|
||||
// global lock_timeout set on the database. Without the fix, a waiter
|
||||
// blocked on pg_advisory_lock() is aborted by the server after this
|
||||
// window and surfaces SQLSTATE 55P03 ("canceling statement due to
|
||||
// lock timeout") to the caller instead of waiting for its turn.
|
||||
Expect(db.Exec("ALTER DATABASE testdb SET lock_timeout = '300ms'").Error).ToNot(HaveOccurred())
|
||||
sqlDB, err := db.DB()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// Drop pooled connections so subsequent ones reconnect and inherit
|
||||
// the new database-level lock_timeout default.
|
||||
sqlDB.SetMaxIdleConns(0)
|
||||
|
||||
holding := make(chan struct{})
|
||||
released := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
herr := WithLockCtx(context.Background(), db, lockKey, func() error {
|
||||
close(holding)
|
||||
// Hold well past the 300ms server lock_timeout.
|
||||
time.Sleep(1 * time.Second)
|
||||
return nil
|
||||
})
|
||||
Expect(herr).ToNot(HaveOccurred())
|
||||
close(released)
|
||||
}()
|
||||
|
||||
<-holding // ensure the holder owns the lock before we contend
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
executed := false
|
||||
start := time.Now()
|
||||
werr := WithLockCtx(ctx, db, lockKey, func() error {
|
||||
executed = true
|
||||
return nil
|
||||
})
|
||||
Expect(werr).ToNot(HaveOccurred(),
|
||||
"waiter should wait out the in-progress hold, not fail with lock_timeout (55P03)")
|
||||
Expect(executed).To(BeTrue())
|
||||
Expect(time.Since(start)).To(BeNumerically(">=", 400*time.Millisecond),
|
||||
"waiter should have actually waited for the holder to release")
|
||||
<-released
|
||||
})
|
||||
|
||||
It("bounds a deadline-less waiter with the backstop instead of waiting forever", func() {
|
||||
const lockKey int64 = 704
|
||||
|
||||
// A caller with no context deadline (e.g. startup schema migration
|
||||
// passing context.Background()) must not hang forever if the holder
|
||||
// never releases. Shrink the backstop so the test is fast.
|
||||
origBackstop := advisoryLockWaitBackstop
|
||||
advisoryLockWaitBackstop = 500 * time.Millisecond
|
||||
DeferCleanup(func() { advisoryLockWaitBackstop = origBackstop })
|
||||
|
||||
holding := make(chan struct{})
|
||||
release := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_ = WithLockCtx(context.Background(), db, lockKey, func() error {
|
||||
close(holding)
|
||||
<-release // hold until the test releases us
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
defer close(release)
|
||||
|
||||
<-holding
|
||||
|
||||
start := time.Now()
|
||||
err := WithLockCtx(context.Background(), db, lockKey, func() error {
|
||||
Fail("waiter should not have acquired the still-held lock")
|
||||
return nil
|
||||
})
|
||||
Expect(err).To(HaveOccurred(), "deadline-less waiter should give up at the backstop, not hang")
|
||||
Expect(time.Since(start)).To(BeNumerically("<", 5*time.Second),
|
||||
"backstop must cap the wait well under the test timeout")
|
||||
})
|
||||
|
||||
It("serializes concurrent WithLockCtx on same key", func() {
|
||||
const lockKey int64 = 702
|
||||
|
||||
|
||||
@@ -1,114 +0,0 @@
|
||||
package galleryop_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// The install op must be idempotent unless Force is set: API clients call
|
||||
// POST /backends/apply on every boot to make sure the backend exists, and an
|
||||
// unconditional force here re-downloads the whole backend artifact each time.
|
||||
// Reinstall is an explicit, opted-in action.
|
||||
var _ = Describe("LocalBackendManager force semantics", func() {
|
||||
var (
|
||||
backendsDir string
|
||||
srcDir string
|
||||
mgr *galleryop.LocalBackendManager
|
||||
systemState *system.SystemState
|
||||
ml *model.ModelLoader
|
||||
)
|
||||
|
||||
const installedRunSh = "#!/bin/sh\necho installed\n"
|
||||
const galleryRunSh = "#!/bin/sh\necho from-gallery\n"
|
||||
|
||||
installedRunShPath := func() string {
|
||||
return filepath.Join(backendsDir, "test-backend", "run.sh")
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
backendsDir, err = os.MkdirTemp("", "force-backends-*")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
srcDir, err = os.MkdirTemp("", "force-src-*")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// The gallery serves test-backend from a plain directory (offline).
|
||||
// The gallery yaml itself must live under the backends path: file://
|
||||
// galleries outside the trusted root are rejected by the downloader.
|
||||
Expect(os.WriteFile(filepath.Join(srcDir, "run.sh"), []byte(galleryRunSh), 0o755)).To(Succeed())
|
||||
entries := []map[string]any{{"name": "test-backend", "uri": srcDir}}
|
||||
data, err := yaml.Marshal(entries)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
galleryYAML := filepath.Join(backendsDir, "gallery.yaml")
|
||||
Expect(os.WriteFile(galleryYAML, data, 0o644)).To(Succeed())
|
||||
|
||||
// test-backend is already installed, with content that differs from
|
||||
// the gallery's so a reinstall is observable.
|
||||
Expect(os.MkdirAll(filepath.Join(backendsDir, "test-backend"), 0o755)).To(Succeed())
|
||||
Expect(os.WriteFile(installedRunShPath(), []byte(installedRunSh), 0o755)).To(Succeed())
|
||||
|
||||
systemState, err = system.GetSystemState(system.WithBackendPath(backendsDir))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
appConfig := &config.ApplicationConfig{
|
||||
SystemState: systemState,
|
||||
BackendGalleries: []config.Gallery{{Name: "test", URL: "file://" + galleryYAML}},
|
||||
}
|
||||
ml = model.NewModelLoader(systemState)
|
||||
mgr = galleryop.NewLocalBackendManager(appConfig, ml)
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
Expect(os.RemoveAll(backendsDir)).To(Succeed())
|
||||
Expect(os.RemoveAll(srcDir)).To(Succeed())
|
||||
})
|
||||
|
||||
It("skips an already-installed backend when Force is not set", func() {
|
||||
op := &galleryop.ManagementOp[gallery.GalleryBackend, any]{
|
||||
ID: "op-1",
|
||||
GalleryElementName: "test-backend",
|
||||
}
|
||||
Expect(mgr.InstallBackend(context.Background(), op, nil)).To(Succeed())
|
||||
|
||||
content, err := os.ReadFile(installedRunShPath())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(string(content)).To(Equal(installedRunSh), "install without Force must not overwrite an installed backend")
|
||||
})
|
||||
|
||||
It("reinstalls an already-installed backend when Force is set", func() {
|
||||
op := &galleryop.ManagementOp[gallery.GalleryBackend, any]{
|
||||
ID: "op-2",
|
||||
GalleryElementName: "test-backend",
|
||||
Force: true,
|
||||
}
|
||||
Expect(mgr.InstallBackend(context.Background(), op, nil)).To(Succeed())
|
||||
|
||||
content, err := os.ReadFile(installedRunShPath())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(string(content)).To(Equal(galleryRunSh), "install with Force must overwrite the installed backend")
|
||||
})
|
||||
|
||||
// The LOCALAI_EXTERNAL_BACKENDS boot loop goes through
|
||||
// InstallExternalBackend's gallery-name path on EVERY startup; it must not
|
||||
// force, or each boot re-downloads every listed backend.
|
||||
It("skips an already-installed backend on the non-forced external gallery-name path", func() {
|
||||
err := galleryop.InstallExternalBackend(context.Background(),
|
||||
[]config.Gallery{{Name: "test", URL: "file://" + filepath.Join(backendsDir, "gallery.yaml")}},
|
||||
systemState, ml, nil, "test-backend", "", "", false, false)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
content, err := os.ReadFile(installedRunShPath())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(string(content)).To(Equal(installedRunSh), "non-forced external install must not overwrite an installed backend")
|
||||
})
|
||||
})
|
||||
@@ -144,12 +144,7 @@ func (g *GalleryService) backendHandler(op *ManagementOp[gallery.GalleryBackend,
|
||||
// InstallExternalBackend installs a backend from an external source (OCI image, URL, or path).
|
||||
// This method contains the logic to detect the input type and call the appropriate installation function.
|
||||
// It can be used by both CLI and Web UI for installing backends from external sources.
|
||||
//
|
||||
// force applies only to the gallery-name fallback: a URI install (dir/OCI/file)
|
||||
// always writes, but a bare gallery name is an "ensure installed" — the
|
||||
// LOCALAI_EXTERNAL_BACKENDS boot loop runs it on every start and must not
|
||||
// re-download an installed, runnable backend.
|
||||
func InstallExternalBackend(ctx context.Context, galleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, downloadStatus func(string, string, string, float64), backend, name, alias string, force, requireIntegrity bool) error {
|
||||
func InstallExternalBackend(ctx context.Context, galleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, downloadStatus func(string, string, string, float64), backend, name, alias string, requireIntegrity bool) error {
|
||||
uri := downloader.URI(backend)
|
||||
switch {
|
||||
case uri.LooksLikeDir():
|
||||
@@ -207,7 +202,7 @@ func InstallExternalBackend(ctx context.Context, galleries []config.Gallery, sys
|
||||
if name != "" || alias != "" {
|
||||
return fmt.Errorf("specifying a name or alias is not supported for gallery backends")
|
||||
}
|
||||
err := gallery.InstallBackendFromGallery(ctx, galleries, systemState, modelLoader, backend, downloadStatus, force, requireIntegrity)
|
||||
err := gallery.InstallBackendFromGallery(ctx, galleries, systemState, modelLoader, backend, downloadStatus, true, requireIntegrity)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error installing backend %s: %w", backend, err)
|
||||
}
|
||||
|
||||
@@ -70,7 +70,6 @@ var _ = Describe("InstallExternalBackend", func() {
|
||||
"test-backend", // gallery name
|
||||
"custom-name", // name should not be allowed
|
||||
"",
|
||||
false, // force
|
||||
false,
|
||||
)
|
||||
Expect(err).To(HaveOccurred())
|
||||
@@ -87,7 +86,6 @@ var _ = Describe("InstallExternalBackend", func() {
|
||||
"non-existent-backend",
|
||||
"",
|
||||
"",
|
||||
false, // force
|
||||
false,
|
||||
)
|
||||
Expect(err).To(HaveOccurred())
|
||||
@@ -105,7 +103,6 @@ var _ = Describe("InstallExternalBackend", func() {
|
||||
"oci://quay.io/mudler/tests:localai-backend-test",
|
||||
"", // name is required for OCI images
|
||||
"",
|
||||
false, // force
|
||||
false,
|
||||
)
|
||||
Expect(err).To(HaveOccurred())
|
||||
@@ -139,7 +136,6 @@ var _ = Describe("InstallExternalBackend", func() {
|
||||
testBackendPath,
|
||||
"", // name should be inferred as "source-backend"
|
||||
"",
|
||||
false, // force
|
||||
false,
|
||||
)
|
||||
// The function should at least attempt to install with the inferred name
|
||||
@@ -159,7 +155,6 @@ var _ = Describe("InstallExternalBackend", func() {
|
||||
testBackendPath,
|
||||
"custom-backend-name",
|
||||
"",
|
||||
false, // force
|
||||
false,
|
||||
)
|
||||
// The function should use the provided name
|
||||
@@ -178,7 +173,6 @@ var _ = Describe("InstallExternalBackend", func() {
|
||||
testBackendPath,
|
||||
"custom-backend-name",
|
||||
"custom-alias",
|
||||
false, // force
|
||||
false,
|
||||
)
|
||||
// The function should accept alias for directory paths
|
||||
|
||||
@@ -110,13 +110,10 @@ func (b *LocalBackendManager) CheckUpgrades(ctx context.Context) (map[string]gal
|
||||
func (b *LocalBackendManager) InstallBackend(ctx context.Context, op *ManagementOp[gallery.GalleryBackend, any], progressCb ProgressCallback) error {
|
||||
if op.ExternalURI != "" {
|
||||
return InstallExternalBackend(ctx, b.backendGalleries, b.systemState, b.modelLoader,
|
||||
progressCb, op.ExternalURI, op.ExternalName, op.ExternalAlias, op.Force, b.requireBackendIntegrity)
|
||||
progressCb, op.ExternalURI, op.ExternalName, op.ExternalAlias, b.requireBackendIntegrity)
|
||||
}
|
||||
// op.Force distinguishes an explicit reinstall from an idempotent
|
||||
// "make sure it's installed" op; the latter must not re-download an
|
||||
// already-runnable backend (supervisors apply on every boot).
|
||||
return gallery.InstallBackendFromGallery(ctx, b.backendGalleries, b.systemState,
|
||||
b.modelLoader, op.GalleryElementName, progressCb, op.Force, b.requireBackendIntegrity)
|
||||
b.modelLoader, op.GalleryElementName, progressCb, true, b.requireBackendIntegrity)
|
||||
}
|
||||
|
||||
func (b *LocalBackendManager) IsDistributed() bool { return false }
|
||||
|
||||
@@ -45,13 +45,6 @@ type ManagementOp[T any, E any] struct {
|
||||
|
||||
// Upgrade is true if this is an upgrade operation (not a fresh install)
|
||||
Upgrade bool
|
||||
|
||||
// Force reinstalls a backend even when it is already installed and
|
||||
// runnable. Without it a backend install op is idempotent — API clients
|
||||
// that ensure a backend exists on every boot must not trigger a full
|
||||
// artifact re-download each time. The UI's explicit "Reinstall backend"
|
||||
// action sets it.
|
||||
Force bool
|
||||
}
|
||||
|
||||
type OpStatus struct {
|
||||
|
||||
@@ -68,13 +68,6 @@ type SmartRouterOptions struct {
|
||||
// the absolute model paths untouched so the worker loads them directly from
|
||||
// the shared volume (#10556). See config.DistributedConfig.SharedModels.
|
||||
SharedModels bool
|
||||
// ModelLoadCeiling is the hard upper bound on how long a single cold-load
|
||||
// attempt (node selection -> backend install -> file staging -> LoadModel)
|
||||
// may run while holding the per-model advisory lock. It backstops every
|
||||
// sub-step's own timeout so a wedged worker can never pin the lock - and
|
||||
// every other replica's request for that model - indefinitely. Zero selects
|
||||
// defaultModelLoadCeiling.
|
||||
ModelLoadCeiling time.Duration
|
||||
}
|
||||
|
||||
// SmartRouter routes inference requests to the best available backend node.
|
||||
@@ -108,18 +101,8 @@ type SmartRouter struct {
|
||||
// sharedModels skips file staging when all nodes mount the same models
|
||||
// directory at the same path (see SmartRouterOptions.SharedModels).
|
||||
sharedModels bool
|
||||
// modelLoadCeiling bounds how long a cold load may hold the per-model
|
||||
// advisory lock (see SmartRouterOptions.ModelLoadCeiling).
|
||||
modelLoadCeiling time.Duration
|
||||
}
|
||||
|
||||
// defaultModelLoadCeiling is the fallback hold ceiling for a cold model load.
|
||||
// It must comfortably exceed the slowest legitimate load - a multi-GB backend
|
||||
// install (DefaultBackendInstallTimeout, 15m) plus staging and the remote
|
||||
// LoadModel (5m) - so it never cuts a real load short; it only ever fires when
|
||||
// a step is genuinely wedged (e.g. a worker that died mid-install).
|
||||
const defaultModelLoadCeiling = 25 * time.Minute
|
||||
|
||||
// probeCacheTTL is how long a successful gRPC HealthCheck on a backend is
|
||||
// trusted before the next request re-probes. Matches healthCheckTTL in
|
||||
// pkg/model/model.go so the single-process and distributed paths share a
|
||||
@@ -134,10 +117,6 @@ func NewSmartRouter(registry ModelRouter, opts SmartRouterOptions) *SmartRouter
|
||||
if factory == nil {
|
||||
factory = &tokenClientFactory{token: opts.AuthToken}
|
||||
}
|
||||
ceiling := opts.ModelLoadCeiling
|
||||
if ceiling <= 0 {
|
||||
ceiling = defaultModelLoadCeiling
|
||||
}
|
||||
return &SmartRouter{
|
||||
registry: registry,
|
||||
unloader: opts.Unloader,
|
||||
@@ -152,7 +131,6 @@ func NewSmartRouter(registry ModelRouter, opts SmartRouterOptions) *SmartRouter
|
||||
prefixConfig: opts.PrefixConfig,
|
||||
pressure: opts.Pressure,
|
||||
sharedModels: opts.SharedModels,
|
||||
modelLoadCeiling: ceiling,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -405,19 +383,11 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType
|
||||
// the request context. If staging were bound to it, the multi-GB upload
|
||||
// aborts with "context canceled" mid-transfer and large models can never
|
||||
// finish staging (the model-load outage). WithoutCancel keeps the request's
|
||||
// values (prefix chain, etc.) but drops its cancellation/deadline.
|
||||
//
|
||||
// Detaching from the caller is necessary, but it must not be unbounded: the
|
||||
// load runs while holding the per-model advisory lock, and a worker that
|
||||
// dies mid-install (its backend.install never replies) would otherwise pin
|
||||
// that lock (and every other replica's request for the same model) until
|
||||
// the NATS install deadline alone expires. Re-impose a single hard ceiling
|
||||
// over the whole sequence so the lock is always released in bounded time,
|
||||
// even if a sub-step wedges. Each long step still has its own (tighter)
|
||||
// bound; this only backstops them. The per-model advisory lock below
|
||||
// de-dupes concurrent loaders across replicas.
|
||||
loadCtx, cancelLoad := context.WithTimeout(context.WithoutCancel(ctx), r.modelLoadCeiling)
|
||||
defer cancelLoad()
|
||||
// values (prefix chain, etc.) but drops its cancellation/deadline. Each
|
||||
// long step still has its own bound (the file stager's resume budget,
|
||||
// LoadModel's 5m timeout), and the per-model advisory lock below de-dupes
|
||||
// concurrent loaders across replicas.
|
||||
loadCtx := context.WithoutCancel(ctx)
|
||||
loadModel := func(ctx context.Context) (*RouteResult, error) {
|
||||
// Re-check after acquiring lock — another request may have loaded it
|
||||
node, nm, err := r.registry.FindAndLockNodeWithModel(ctx, trackingKey, candidateNodeIDs, pref)
|
||||
@@ -946,14 +916,7 @@ func (r *SmartRouter) installBackendOnNode(ctx context.Context, node *BackendNod
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s|%s|%s|%d", node.ID, backendType, modelID, replicaIndex)
|
||||
// DoChan rather than Do so this wait honors ctx cancellation. InstallBackend
|
||||
// blocks for its full NATS deadline (15m by default) when a worker accepts
|
||||
// the request but never replies (e.g. it died mid-install). Without ctx
|
||||
// awareness the caller (holding the per-model advisory lock) would sit there
|
||||
// the whole time; here a cancelled ctx (typically the model-load ceiling)
|
||||
// frees the caller promptly. The shared install keeps running in the
|
||||
// background and still coalesces other callers via singleflight.
|
||||
resCh := r.installFlight.DoChan(key, func() (any, error) {
|
||||
v, err, _ := r.installFlight.Do(key, func() (any, error) {
|
||||
reply, err := r.unloader.InstallBackend(node.ID, backendType, modelID, r.galleriesJSON, "", "", "", replicaIndex, "", nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -968,15 +931,10 @@ func (r *SmartRouter) installBackendOnNode(ctx context.Context, node *BackendNod
|
||||
}
|
||||
return addr, nil
|
||||
})
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", ctx.Err()
|
||||
case res := <-resCh:
|
||||
if res.Err != nil {
|
||||
return "", res.Err
|
||||
}
|
||||
return res.Val.(string), nil
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return v.(string), nil
|
||||
}
|
||||
|
||||
func (r *SmartRouter) buildClientForAddr(node *BackendNode, addr string, parallel bool) grpc.Backend {
|
||||
|
||||
@@ -493,44 +493,6 @@ var _ = Describe("SmartRouter", func() {
|
||||
Expect(result.Node.ID).To(Equal("n3"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("worker wedges mid-install (dead node holding the lock)", func() {
|
||||
It("aborts the load at the ModelLoadCeiling instead of blocking forever", func() {
|
||||
// Simulate the production incident: the chosen worker accepts the
|
||||
// backend.install but never replies (it died), so InstallBackend
|
||||
// would otherwise block for its full NATS deadline (15m by
|
||||
// default) while pinning the per-model advisory lock. Route must
|
||||
// give up at the ceiling so the lock is released promptly.
|
||||
reg.findAndLockErr = errors.New("not found")
|
||||
reg.findIdleNode = &BackendNode{ID: "n4", Name: "dead-node", Address: "10.0.0.4:50051"}
|
||||
|
||||
block := make(chan struct{})
|
||||
defer close(block) // let the background install goroutine drain at test end
|
||||
unloader.installHook = func() { <-block }
|
||||
|
||||
router := NewSmartRouter(reg, SmartRouterOptions{
|
||||
Unloader: unloader,
|
||||
ClientFactory: factory,
|
||||
ModelLoadCeiling: 200 * time.Millisecond,
|
||||
})
|
||||
|
||||
done := make(chan error, 1)
|
||||
start := time.Now()
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := router.Route(context.Background(), "wedged-model",
|
||||
"models/wedged.gguf", "llama-cpp",
|
||||
&pb.ModelOptions{Model: "models/wedged.gguf"}, false)
|
||||
done <- err
|
||||
}()
|
||||
|
||||
var routeErr error
|
||||
Eventually(done, 5*time.Second).Should(Receive(&routeErr),
|
||||
"Route must not block on a wedged install past the ceiling")
|
||||
Expect(routeErr).To(HaveOccurred())
|
||||
Expect(time.Since(start)).To(BeNumerically("<", 5*time.Second))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("scheduleNewModel (mock-based, via Route)", func() {
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
package pii
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
)
|
||||
|
||||
// Prometheus counter for PII events. The EventStore ring buffer is
|
||||
// capacity-bound and meant for recent-audit browsing; operators also want
|
||||
// a monotonic, scrape-friendly signal ("how many detections/blocks per
|
||||
// hour, did the filter stop firing after a deploy"). Record() is the
|
||||
// single choke point every producer already goes through (request
|
||||
// middleware, response scrubbing, MITM proxy connects/intercepts), so one
|
||||
// counter here covers all paths without touching the producers.
|
||||
//
|
||||
// Initialised lazily on first Record so the package works no matter when
|
||||
// (or whether) the Prometheus-backed global MeterProvider is installed —
|
||||
// same pattern as core/services/routing/billing.
|
||||
var (
|
||||
metricsOnce sync.Once
|
||||
eventsCounter metric.Int64Counter
|
||||
)
|
||||
|
||||
func recordEventMetric(e PIIEvent) {
|
||||
metricsOnce.Do(func() {
|
||||
meter := otel.Meter("github.com/mudler/LocalAI")
|
||||
c, err := meter.Int64Counter(
|
||||
"localai_pii_events_total",
|
||||
metric.WithDescription("PII/audit events recorded, labeled by kind, origin, action and direction"),
|
||||
)
|
||||
if err == nil {
|
||||
eventsCounter = c
|
||||
}
|
||||
})
|
||||
if eventsCounter == nil {
|
||||
return
|
||||
}
|
||||
eventsCounter.Add(context.Background(), 1, metric.WithAttributes(
|
||||
attribute.String("kind", string(e.Kind)),
|
||||
attribute.String("origin", string(e.Origin)),
|
||||
attribute.String("action", string(e.Action)),
|
||||
attribute.String("direction", string(e.Direction)),
|
||||
))
|
||||
}
|
||||
@@ -58,7 +58,6 @@ type memoryEventStore struct {
|
||||
}
|
||||
|
||||
func (s *memoryEventStore) Record(_ context.Context, e PIIEvent) error {
|
||||
recordEventMetric(e)
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.ring[s.cursor] = e
|
||||
|
||||
@@ -134,7 +134,7 @@ func (s *backendSupervisor) installBackend(req messaging.BackendInstallRequest,
|
||||
if req.URI != "" {
|
||||
xlog.Info("Installing backend from external URI", "backend", req.Backend, "uri", req.URI, "force", force)
|
||||
if err := galleryop.InstallExternalBackend(
|
||||
context.Background(), galleries, s.systemState, s.ml, downloadCb, req.URI, req.Name, req.Alias, force, s.cfg.RequireBackendIntegrity,
|
||||
context.Background(), galleries, s.systemState, s.ml, downloadCb, req.URI, req.Name, req.Alias, s.cfg.RequireBackendIntegrity,
|
||||
); err != nil {
|
||||
return "", fmt.Errorf("installing backend from gallery: %w", err)
|
||||
}
|
||||
@@ -201,7 +201,7 @@ func (s *backendSupervisor) upgradeBackend(req messaging.BackendUpgradeRequest)
|
||||
if req.URI != "" {
|
||||
xlog.Info("Upgrading backend from external URI", "backend", req.Backend, "uri", req.URI)
|
||||
if err := galleryop.InstallExternalBackend(
|
||||
context.Background(), galleries, s.systemState, s.ml, downloadCb, req.URI, req.Name, req.Alias, true, s.cfg.RequireBackendIntegrity,
|
||||
context.Background(), galleries, s.systemState, s.ml, downloadCb, req.URI, req.Name, req.Alias, s.cfg.RequireBackendIntegrity,
|
||||
); err != nil {
|
||||
return fmt.Errorf("upgrading backend from external URI: %w", err)
|
||||
}
|
||||
|
||||
@@ -14,16 +14,6 @@ import (
|
||||
// MaxSnippetSeconds is the maximum number of seconds of audio captured per trace.
|
||||
const MaxSnippetSeconds = 30
|
||||
|
||||
// silenceFloorDBFS is the dBFS value reported for digital silence (RMS or peak
|
||||
// of zero). The true level is -∞ dBFS; reporting a finite floor keeps the
|
||||
// metric present and meaningful in the Traces UI (a scrubbed nil would read as
|
||||
// "missing" rather than "silent"). -120 dBFS sits well below 16-bit PCM's
|
||||
// ~-90 dBFS least-significant-bit floor, so it reads unambiguously as
|
||||
// "effectively silent". JSON-marshal safety for any non-finite float that does
|
||||
// reach a trace is owned centrally by RecordBackendTrace's sanitizer — this
|
||||
// floor is about presentation, not transport.
|
||||
const silenceFloorDBFS = -120.0
|
||||
|
||||
// AudioSnippet captures the first MaxSnippetSeconds of a WAV file and computes
|
||||
// quality metrics. The result is a map suitable for merging into a BackendTrace
|
||||
// Data field. maxBytes caps the embedded base64 waveform so a single TTS or
|
||||
@@ -73,7 +63,7 @@ func AudioSnippetFromPCM(pcm []byte, sampleRate, totalPCMBytes, maxBytes int) ma
|
||||
snippetDuration := float64(len(samples)) / float64(sampleRate)
|
||||
|
||||
rms := sound.CalculateRMS16(samples)
|
||||
rmsDBFS := silenceFloorDBFS
|
||||
rmsDBFS := -math.Inf(1)
|
||||
if rms > 0 {
|
||||
rmsDBFS = 20 * math.Log10(rms/32768.0)
|
||||
}
|
||||
@@ -88,7 +78,7 @@ func AudioSnippetFromPCM(pcm []byte, sampleRate, totalPCMBytes, maxBytes int) ma
|
||||
}
|
||||
dcSum += int64(s)
|
||||
}
|
||||
peakDBFS := silenceFloorDBFS
|
||||
peakDBFS := -math.Inf(1)
|
||||
if peak > 0 {
|
||||
peakDBFS = 20 * math.Log10(float64(peak)/32768.0)
|
||||
}
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
package trace_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
@@ -50,32 +47,3 @@ var _ = Describe("AudioSnippetFromPCM byte cap", func() {
|
||||
Expect(out).To(HaveKey("audio_wav_base64"))
|
||||
})
|
||||
})
|
||||
|
||||
// Silent audio (RMS/peak of zero) has a true level of -∞ dBFS, but emitting
|
||||
// -Inf made the whole /api/backend-traces response fail to JSON-marshal and
|
||||
// blanked the Traces UI. The metrics must instead be finite and serializable.
|
||||
var _ = Describe("AudioSnippetFromPCM silent audio dBFS", func() {
|
||||
pcm := makePCM(snippetSeconds, snippetSampleRate) // all zeros == digital silence
|
||||
totalPCM := len(pcm)
|
||||
|
||||
It("reports finite dBFS for silence instead of -Inf", func() {
|
||||
out := trace.AudioSnippetFromPCM(pcm, snippetSampleRate, totalPCM, 0)
|
||||
|
||||
rms, ok := out["audio_rms_dbfs"].(float64)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(math.IsInf(rms, 0)).To(BeFalse(), "silent RMS must not be ±Inf")
|
||||
Expect(math.IsNaN(rms)).To(BeFalse())
|
||||
|
||||
peak, ok := out["audio_peak_dbfs"].(float64)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(math.IsInf(peak, 0)).To(BeFalse(), "silent peak must not be ±Inf")
|
||||
Expect(math.IsNaN(peak)).To(BeFalse())
|
||||
})
|
||||
|
||||
It("produces a snippet that round-trips through encoding/json", func() {
|
||||
out := trace.AudioSnippetFromPCM(pcm, snippetSampleRate, totalPCM, 0)
|
||||
|
||||
_, err := json.Marshal(out)
|
||||
Expect(err).ToNot(HaveOccurred(), "silent-audio metrics must be JSON-marshalable")
|
||||
})
|
||||
})
|
||||
|
||||
@@ -3,8 +3,6 @@ package trace
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"maps"
|
||||
"math"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -118,13 +116,8 @@ func RecordBackendTrace(t BackendTrace) {
|
||||
backendMu.Lock()
|
||||
maxBody := backendMaxBodyBytes
|
||||
backendMu.Unlock()
|
||||
// Always walk Data, even with no body cap configured: besides capping
|
||||
// oversized strings (maxBody > 0), the walk replaces non-finite floats
|
||||
// (Inf/NaN) that encoding/json cannot marshal. A single such value — e.g. a
|
||||
// -Inf dBFS audio metric from a silent clip — would otherwise fail the whole
|
||||
// /api/backend-traces response and blank the Traces UI.
|
||||
if t.Data != nil {
|
||||
t.Data = sanitizeData(t.Data, maxBody)
|
||||
if t.Data != nil && maxBody > 0 {
|
||||
t.Data = capDataStrings(t.Data, maxBody)
|
||||
}
|
||||
select {
|
||||
case backendLogChan <- &t:
|
||||
@@ -133,90 +126,32 @@ func RecordBackendTrace(t BackendTrace) {
|
||||
}
|
||||
}
|
||||
|
||||
// sanitizeData walks a trace Data map (recursing into nested maps and slices)
|
||||
// and makes every value safe for the /api/backend-traces JSON response:
|
||||
//
|
||||
// - When maxBytes > 0, any string longer than maxBytes is replaced with a
|
||||
// fixed-size marker that names the original byte count. The replacement is
|
||||
// intentionally short and not valid base64/JSON: it flags "this was dropped"
|
||||
// cheaply rather than keeping a partial value the UI might try to render.
|
||||
// - Non-finite floats (Inf/NaN) are replaced with nil regardless of maxBytes,
|
||||
// because encoding/json refuses to marshal them and one bad value would fail
|
||||
// the entire response.
|
||||
//
|
||||
// Other scalars (ints, bools, finite floats) pass through untouched so
|
||||
// structural fields like total_deltas or audio_sample_rate remain useful.
|
||||
//
|
||||
// The walk is copy-on-write: it runs on every RecordBackendTrace call, and in
|
||||
// the common case nothing needs rewriting, so containers are only re-allocated
|
||||
// on the paths that actually changed and untouched values keep their original
|
||||
// interface boxes instead of paying a per-value re-boxing allocation.
|
||||
func sanitizeData(data map[string]any, maxBytes int) map[string]any {
|
||||
out, _ := sanitizeMap(data, maxBytes)
|
||||
// capDataStrings walks a trace Data map and replaces any string value (at any
|
||||
// depth) that exceeds maxBytes with a fixed-size marker that names the
|
||||
// original byte count. The replacement is intentionally short and not valid
|
||||
// base64/JSON: the goal is to flag "this was dropped" cheaply, not to keep a
|
||||
// partial value that the UI might try to render. Non-string scalars and
|
||||
// non-map containers pass through untouched so structural fields like
|
||||
// total_deltas or audio_sample_rate remain useful.
|
||||
func capDataStrings(data map[string]any, maxBytes int) map[string]any {
|
||||
out := make(map[string]any, len(data))
|
||||
for k, v := range data {
|
||||
out[k] = capValue(v, maxBytes)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func sanitizeMap(m map[string]any, maxBytes int) (map[string]any, bool) {
|
||||
var out map[string]any
|
||||
for k, v := range m {
|
||||
nv, changed := sanitizeValue(v, maxBytes)
|
||||
if changed && out == nil {
|
||||
// First change: fork the map. Entries already visited were
|
||||
// unchanged, so a full copy then overwriting as we go is exact.
|
||||
out = make(map[string]any, len(m))
|
||||
maps.Copy(out, m)
|
||||
}
|
||||
if out != nil {
|
||||
out[k] = nv
|
||||
}
|
||||
}
|
||||
if out == nil {
|
||||
return m, false
|
||||
}
|
||||
return out, true
|
||||
}
|
||||
|
||||
func sanitizeSlice(s []any, maxBytes int) ([]any, bool) {
|
||||
var out []any
|
||||
for i, v := range s {
|
||||
nv, changed := sanitizeValue(v, maxBytes)
|
||||
if changed && out == nil {
|
||||
out = make([]any, len(s))
|
||||
copy(out, s)
|
||||
}
|
||||
if out != nil {
|
||||
out[i] = nv
|
||||
}
|
||||
}
|
||||
if out == nil {
|
||||
return s, false
|
||||
}
|
||||
return out, true
|
||||
}
|
||||
|
||||
func sanitizeValue(v any, maxBytes int) (any, bool) {
|
||||
func capValue(v any, maxBytes int) any {
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
if maxBytes > 0 && len(val) > maxBytes {
|
||||
return fmt.Sprintf("<truncated: %d bytes>", len(val)), true
|
||||
if len(val) > maxBytes {
|
||||
return fmt.Sprintf("<truncated: %d bytes>", len(val))
|
||||
}
|
||||
return v, false
|
||||
case float64:
|
||||
if math.IsInf(val, 0) || math.IsNaN(val) {
|
||||
return nil, true
|
||||
}
|
||||
return v, false
|
||||
case float32:
|
||||
if f := float64(val); math.IsInf(f, 0) || math.IsNaN(f) {
|
||||
return nil, true
|
||||
}
|
||||
return v, false
|
||||
return val
|
||||
case map[string]any:
|
||||
return sanitizeMap(val, maxBytes)
|
||||
case []any:
|
||||
return sanitizeSlice(val, maxBytes)
|
||||
return capDataStrings(val, maxBytes)
|
||||
default:
|
||||
return v, false
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,80 +0,0 @@
|
||||
package trace_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
)
|
||||
|
||||
// encoding/json cannot marshal ±Inf or NaN. The /api/backend-traces endpoint
|
||||
// serializes the whole buffer with one json call, so a single non-finite float
|
||||
// in any trace's Data map (e.g. a -Inf dBFS audio metric from a silent clip)
|
||||
// would fail the entire response and blank the Traces UI. RecordBackendTrace
|
||||
// must scrub those values regardless of whether a body cap is configured.
|
||||
var _ = Describe("RecordBackendTrace non-finite float sanitization", func() {
|
||||
BeforeEach(func() {
|
||||
// maxBodyBytes 0 == no body cap: float sanitization must still run.
|
||||
trace.InitBackendTracingIfEnabled(64, 0)
|
||||
trace.ClearBackendTraces()
|
||||
})
|
||||
|
||||
It("replaces ±Inf and NaN with nil so the response stays JSON-marshalable", func() {
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: time.Now(),
|
||||
Type: trace.BackendTraceTranscription,
|
||||
ModelName: "m",
|
||||
Data: map[string]any{
|
||||
"audio_rms_dbfs": math.Inf(-1),
|
||||
"audio_peak_dbfs": math.Inf(1),
|
||||
"weird": math.NaN(),
|
||||
"audio_duration_s": 1.5, // finite siblings must survive
|
||||
},
|
||||
})
|
||||
|
||||
Eventually(trace.GetBackendTraces).Should(HaveLen(1))
|
||||
got := trace.GetBackendTraces()[0]
|
||||
|
||||
Expect(got.Data["audio_rms_dbfs"]).To(BeNil())
|
||||
Expect(got.Data["audio_peak_dbfs"]).To(BeNil())
|
||||
Expect(got.Data["weird"]).To(BeNil())
|
||||
Expect(got.Data["audio_duration_s"]).To(Equal(1.5), "finite floats must pass through untouched")
|
||||
|
||||
_, err := json.Marshal(trace.GetBackendTraces())
|
||||
Expect(err).ToNot(HaveOccurred(), "the whole trace buffer must marshal even with non-finite inputs")
|
||||
})
|
||||
|
||||
It("scrubs non-finite floats nested in maps and slices", func() {
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: time.Now(),
|
||||
Type: trace.BackendTraceLLM,
|
||||
ModelName: "m",
|
||||
Data: map[string]any{
|
||||
"nested": map[string]any{
|
||||
"logprob": math.Inf(-1),
|
||||
"ok": 0.25,
|
||||
},
|
||||
"scores": []any{1.0, math.Inf(1), math.NaN()},
|
||||
},
|
||||
})
|
||||
|
||||
Eventually(trace.GetBackendTraces).Should(HaveLen(1))
|
||||
got := trace.GetBackendTraces()[0]
|
||||
|
||||
nested := got.Data["nested"].(map[string]any)
|
||||
Expect(nested["logprob"]).To(BeNil())
|
||||
Expect(nested["ok"]).To(Equal(0.25))
|
||||
|
||||
scores := got.Data["scores"].([]any)
|
||||
Expect(scores[0]).To(Equal(1.0))
|
||||
Expect(scores[1]).To(BeNil())
|
||||
Expect(scores[2]).To(BeNil())
|
||||
|
||||
_, err := json.Marshal(trace.GetBackendTraces())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
})
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user