mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-28 02:17:00 -04:00
Compare commits
6 Commits
feat/ik-ll
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ade9cc9e37 | ||
|
|
471e38e4e7 | ||
|
|
f3d829e2ef | ||
|
|
91885c2c7e | ||
|
|
f1fcafb888 | ||
|
|
fdff114701 |
6
.github/workflows/test-extra.yml
vendored
6
.github/workflows/test-extra.yml
vendored
@@ -1008,7 +1008,11 @@ jobs:
|
||||
# image + working dir.
|
||||
tests-vibevoice-cpp-grpc-transcription:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.vibevoice-cpp == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
# Skip on release tag pushes: the ASR Q4_K model is ~10 GB and cannot be
|
||||
# pulled from HF within the inner `go test -timeout 30m` budget on a CI
|
||||
# runner, so every tag build hung and timed out. Still runs on PRs/branch
|
||||
# pushes that touch vibevoice-cpp so regressions are caught off the release path.
|
||||
if: (needs.detect-changes.outputs.vibevoice-cpp == 'true' || needs.detect-changes.outputs.run-all == 'true') && !startsWith(github.ref, 'refs/tags/')
|
||||
runs-on: bigger-runner
|
||||
timeout-minutes: 150
|
||||
steps:
|
||||
|
||||
@@ -1,6 +1,15 @@
|
||||
## Multimodal support is provided by the in-tree `mtmd` library target
|
||||
## (examples/mtmd/), which the grpc-server links and includes below. clip/llava
|
||||
## were pruned upstream; the high-level mtmd_* / mtmd_helper_* API is used instead.
|
||||
## Clip/LLaVA library for multimodal support — built locally from copied sources
|
||||
set(TARGET myclip)
|
||||
add_library(${TARGET} clip.cpp clip.h llava.cpp llava.h)
|
||||
install(TARGETS ${TARGET} LIBRARY)
|
||||
target_include_directories(myclip PUBLIC .)
|
||||
target_include_directories(myclip PUBLIC ../..)
|
||||
target_include_directories(myclip PUBLIC ../../common)
|
||||
target_link_libraries(${TARGET} PRIVATE common ggml llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
||||
if (NOT MSVC)
|
||||
target_compile_options(${TARGET} PRIVATE -Wno-cast-qual)
|
||||
endif()
|
||||
|
||||
set(TARGET grpc-server)
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
@@ -58,16 +67,12 @@ add_library(hw_grpc_proto
|
||||
${hw_proto_hdrs} )
|
||||
|
||||
add_executable(${TARGET} grpc-server.cpp json.hpp)
|
||||
# mtmd public headers (mtmd.h / mtmd-helper.h) live in examples/mtmd/.
|
||||
# Linking the mtmd target also propagates this include dir, but we add it
|
||||
# explicitly for clarity.
|
||||
target_include_directories(${TARGET} PRIVATE ../mtmd)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama mtmd ${CMAKE_THREAD_LIBS_INIT} absl::flags hw_grpc_proto
|
||||
target_link_libraries(${TARGET} PRIVATE common llama myclip ${CMAKE_THREAD_LIBS_INIT} absl::flags hw_grpc_proto
|
||||
absl::flags_parse
|
||||
gRPC::${_REFLECTION}
|
||||
gRPC::${_GRPC_GRPCPP}
|
||||
protobuf::${_PROTOBUF_LIBPROTOBUF})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
||||
if(TARGET BUILD_INFO)
|
||||
add_dependencies(${TARGET} BUILD_INFO)
|
||||
endif()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
IK_LLAMA_VERSION?=f96eaddba8bed6a9a5e628bbf6a566775c70b49c
|
||||
IK_LLAMA_VERSION?=b84902d2ad27c34f989f23947200c4b91b1568fd
|
||||
LLAMA_REPO?=https://github.com/ikawrakow/ik_llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -11,8 +11,8 @@
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <getopt.h>
|
||||
#include "mtmd.h"
|
||||
#include "mtmd-helper.h"
|
||||
#include "clip.h"
|
||||
#include "llava.h"
|
||||
#include "log.h"
|
||||
#include "common.h"
|
||||
#include "json.hpp"
|
||||
@@ -45,9 +45,7 @@ using backend::HealthMessage;
|
||||
|
||||
///// LLAMA.CPP server code below
|
||||
|
||||
// Match mtmd.h and ik_llama's server/common headers, which all use
|
||||
// nlohmann::ordered_json; a plain nlohmann::json alias collides at global scope.
|
||||
using json = nlohmann::ordered_json;
|
||||
using json = nlohmann::json;
|
||||
|
||||
struct server_params
|
||||
{
|
||||
@@ -221,11 +219,6 @@ struct llama_client_slot
|
||||
|
||||
// multimodal
|
||||
std::vector<slot_image> images;
|
||||
// Full prompt with mtmd media markers (mtmd_default_marker()) substituted in
|
||||
// place of the legacy [img-N] tags, covering the text up to and including the
|
||||
// last image. The text after the last image is kept in params.input_suffix and
|
||||
// decoded through the normal token path so the sampling loop is unchanged.
|
||||
std::string mtmd_prompt;
|
||||
|
||||
// stats
|
||||
size_t sent_count = 0;
|
||||
@@ -259,14 +252,14 @@ struct llama_client_slot
|
||||
|
||||
for (slot_image & img : images)
|
||||
{
|
||||
if (img.bitmap) {
|
||||
mtmd_bitmap_free(img.bitmap);
|
||||
img.bitmap = nullptr;
|
||||
free(img.image_embedding);
|
||||
if (img.img_data) {
|
||||
clip_image_u8_free(img.img_data);
|
||||
}
|
||||
img.prefix_prompt = "";
|
||||
}
|
||||
|
||||
images.clear();
|
||||
mtmd_prompt = "";
|
||||
}
|
||||
|
||||
bool has_budget(gpt_params &global_params) {
|
||||
@@ -403,13 +396,46 @@ struct llama_metrics {
|
||||
}
|
||||
};
|
||||
|
||||
struct llava_embd_batch {
|
||||
std::vector<llama_pos> pos;
|
||||
std::vector<int32_t> n_seq_id;
|
||||
std::vector<llama_seq_id> seq_id_0;
|
||||
std::vector<llama_seq_id *> seq_ids;
|
||||
std::vector<int8_t> logits;
|
||||
llama_batch batch;
|
||||
llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
|
||||
pos .resize(n_tokens);
|
||||
n_seq_id.resize(n_tokens);
|
||||
seq_ids .resize(n_tokens + 1);
|
||||
logits .resize(n_tokens);
|
||||
seq_id_0.resize(1);
|
||||
seq_id_0[0] = seq_id;
|
||||
seq_ids [n_tokens] = nullptr;
|
||||
batch = {
|
||||
/*n_tokens =*/ n_tokens,
|
||||
/*tokens =*/ nullptr,
|
||||
/*embd =*/ embd,
|
||||
/*pos =*/ pos.data(),
|
||||
/*n_seq_id =*/ n_seq_id.data(),
|
||||
/*seq_id =*/ seq_ids.data(),
|
||||
/*logits =*/ logits.data(),
|
||||
};
|
||||
for (int i = 0; i < n_tokens; i++) {
|
||||
batch.pos [i] = pos_0 + i;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id [i] = seq_id_0.data();
|
||||
batch.logits [i] = false;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct llama_server_context
|
||||
{
|
||||
llama_model *model = nullptr;
|
||||
llama_context *ctx = nullptr;
|
||||
const llama_vocab * vocab = nullptr;
|
||||
|
||||
mtmd_context *mctx = nullptr;
|
||||
clip_ctx *clp_ctx = nullptr;
|
||||
|
||||
gpt_params params;
|
||||
|
||||
@@ -465,6 +491,11 @@ struct llama_server_context
|
||||
if (!params.mmproj.path.empty()) {
|
||||
multimodal = true;
|
||||
LOG_INFO("Multi Modal Mode Enabled", {});
|
||||
clp_ctx = clip_model_load(params.mmproj.path.c_str(), /*verbosity=*/ 1);
|
||||
if(clp_ctx == nullptr) {
|
||||
LOG_ERR("unable to load clip model: %s", params.mmproj.path.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (params.n_ctx < 2048) { // request larger context for the image embedding
|
||||
params.n_ctx = 2048;
|
||||
@@ -481,24 +512,10 @@ struct llama_server_context
|
||||
}
|
||||
|
||||
if (multimodal) {
|
||||
// mtmd_init_from_file requires the already-loaded text model, so it must
|
||||
// run AFTER llama_init_from_gpt_params. It validates the projector
|
||||
// against the model internally and returns nullptr on dim mismatch, so
|
||||
// the explicit clip_n_mmproj_embd check is no longer needed.
|
||||
mtmd_context_params mparams = mtmd_context_params_default();
|
||||
mparams.use_gpu = params.mmproj_use_gpu;
|
||||
mparams.print_timings = false;
|
||||
mparams.n_threads = params.n_threads_mtmd != -1 ? params.n_threads_mtmd
|
||||
: params.n_threads_batch != -1 ? params.n_threads_batch
|
||||
: params.n_threads;
|
||||
mparams.verbosity = GGML_LOG_LEVEL_INFO;
|
||||
mparams.flash_attn_type = params.flash_attn ? LLAMA_FLASH_ATTN_TYPE_ENABLED
|
||||
: LLAMA_FLASH_ATTN_TYPE_DISABLED;
|
||||
mparams.image_min_tokens = params.image_min_tokens;
|
||||
mparams.image_max_tokens = params.image_max_tokens;
|
||||
mctx = mtmd_init_from_file(params.mmproj.path.c_str(), model, mparams);
|
||||
if (mctx == nullptr) {
|
||||
LOG_ERR("unable to load multimodal projector: %s", params.mmproj.path.c_str());
|
||||
const int n_embd_clip = clip_n_mmproj_embd(clp_ctx);
|
||||
const int n_embd_llm = llama_model_n_embd(model);
|
||||
if (n_embd_clip != n_embd_llm) {
|
||||
LOG("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_embd_clip, n_embd_llm);
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
return false;
|
||||
@@ -848,8 +865,8 @@ struct llama_server_context
|
||||
|
||||
slot_image img_sl;
|
||||
img_sl.id = img.count("id") != 0 ? img["id"].get<int>() : slot->images.size();
|
||||
img_sl.bitmap = mtmd_helper_bitmap_init_from_buf(mctx, image_buffer.data(), image_buffer.size());
|
||||
if (img_sl.bitmap == nullptr)
|
||||
img_sl.img_data = clip_image_u8_init();
|
||||
if (!clip_image_load_from_bytes(image_buffer.data(), image_buffer.size(), img_sl.img_data))
|
||||
{
|
||||
LOG_ERR("%s: failed to load image, slot_id: %d, img_sl_id: %d",
|
||||
__func__,
|
||||
@@ -862,74 +879,50 @@ struct llama_server_context
|
||||
{"slot_id", slot->id},
|
||||
{"img_sl_id", img_sl.id}
|
||||
});
|
||||
img_sl.request_encode_image = true;
|
||||
slot->images.push_back(img_sl);
|
||||
}
|
||||
// Translate the legacy [img-N] tags into mtmd media markers, in
|
||||
// order, and collect the matching bitmaps in marker order so they
|
||||
// line up with the markers passed to mtmd_tokenize(). The text after
|
||||
// the last image stays in input_suffix and is decoded through the
|
||||
// normal token path, so the sampling loop is unchanged.
|
||||
// example: system prompt [img-102] user [img-103] describe [img-134]
|
||||
// process prompt
|
||||
// example: system prompt [img-102] user [img-103] describe [img-134] -> [{id: 102, prefix: 'system prompt '}, {id: 103, prefix: ' user '}, {id: 134, prefix: ' describe '}]}
|
||||
if (slot->images.size() > 0 && !slot->prompt.is_array())
|
||||
{
|
||||
const std::string marker = mtmd_default_marker();
|
||||
std::string prompt = slot->prompt.get<std::string>();
|
||||
std::string built_prompt;
|
||||
std::vector<slot_image> ordered;
|
||||
size_t pos = 0, copy_from = 0;
|
||||
size_t pos = 0, begin_prefix = 0;
|
||||
std::string pattern = "[img-";
|
||||
|
||||
auto free_images = [&]() {
|
||||
for (slot_image &img : slot->images) {
|
||||
if (img.bitmap) {
|
||||
mtmd_bitmap_free(img.bitmap);
|
||||
img.bitmap = nullptr;
|
||||
}
|
||||
}
|
||||
slot->images.clear();
|
||||
};
|
||||
|
||||
while ((pos = prompt.find(pattern, pos)) != std::string::npos) {
|
||||
size_t tag_begin = pos;
|
||||
size_t end_prefix = pos;
|
||||
pos += pattern.length();
|
||||
size_t end_pos = prompt.find(']', pos);
|
||||
if (end_pos == std::string::npos) {
|
||||
break;
|
||||
}
|
||||
std::string image_id = prompt.substr(pos, end_pos - pos);
|
||||
try
|
||||
if (end_pos != std::string::npos)
|
||||
{
|
||||
int img_id = std::stoi(image_id);
|
||||
bool found = false;
|
||||
for (slot_image &img : slot->images)
|
||||
std::string image_id = prompt.substr(pos, end_pos - pos);
|
||||
try
|
||||
{
|
||||
if (img.id == img_id) {
|
||||
found = true;
|
||||
// text before this tag, then the media marker
|
||||
built_prompt += prompt.substr(copy_from, tag_begin - copy_from);
|
||||
built_prompt += marker;
|
||||
copy_from = end_pos + 1;
|
||||
ordered.push_back(img);
|
||||
break;
|
||||
int img_id = std::stoi(image_id);
|
||||
bool found = false;
|
||||
for (slot_image &img : slot->images)
|
||||
{
|
||||
if (img.id == img_id) {
|
||||
found = true;
|
||||
img.prefix_prompt = prompt.substr(begin_prefix, end_prefix - begin_prefix);
|
||||
begin_prefix = end_pos + 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!found) {
|
||||
LOG("ERROR: Image with id: %i, not found.\n", img_id);
|
||||
free_images();
|
||||
if (!found) {
|
||||
LOG("ERROR: Image with id: %i, not found.\n", img_id);
|
||||
slot->images.clear();
|
||||
return false;
|
||||
}
|
||||
} catch (const std::invalid_argument& e) {
|
||||
LOG("Invalid image number id in prompt\n");
|
||||
slot->images.clear();
|
||||
return false;
|
||||
}
|
||||
} catch (const std::invalid_argument& e) {
|
||||
LOG("Invalid image number id in prompt\n");
|
||||
free_images();
|
||||
return false;
|
||||
}
|
||||
pos = end_pos + 1;
|
||||
}
|
||||
// bitmaps are consumed in marker order by mtmd_tokenize()
|
||||
slot->images = ordered;
|
||||
slot->mtmd_prompt = built_prompt;
|
||||
slot->prompt = "";
|
||||
slot->params.input_suffix = prompt.substr(copy_from);
|
||||
slot->params.input_suffix = prompt.substr(begin_prefix);
|
||||
slot->params.cache_prompt = false; // multimodal doesn't support cache prompt
|
||||
}
|
||||
}
|
||||
@@ -1183,10 +1176,21 @@ struct llama_server_context
|
||||
|
||||
bool process_images(llama_client_slot &slot) const
|
||||
{
|
||||
// With the mtmd pipeline, image encoding is no longer eager: the bitmaps
|
||||
// are tokenized and encoded together with the surrounding text inside
|
||||
// ingest_images() via mtmd_tokenize() + mtmd_helper_eval_chunks(). This
|
||||
// just reports whether the slot carries any images to process.
|
||||
for (slot_image &img : slot.images)
|
||||
{
|
||||
if (!img.request_encode_image)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!llava_image_embed_make_with_clip_img(clp_ctx, params.n_threads, img.img_data, &img.image_embedding, &img.image_tokens)) {
|
||||
LOG("Error processing the given image");
|
||||
return false;
|
||||
}
|
||||
|
||||
img.request_encode_image = false;
|
||||
}
|
||||
|
||||
return slot.images.size() > 0;
|
||||
}
|
||||
|
||||
@@ -1431,70 +1435,69 @@ struct llama_server_context
|
||||
}
|
||||
}
|
||||
|
||||
// Tokenize the multimodal prompt (text interleaved with media markers) together
|
||||
// with the slot's bitmaps, then decode the resulting chunks into the llama
|
||||
// context via the high-level mtmd helper. The helper runs llama_decode() on the
|
||||
// text chunks and mtmd_encode() + llama_decode() on the image chunks, handling
|
||||
// batching and any pre/post decode setup (e.g. non-causal attention for gemma3).
|
||||
// Advances slot.n_past by the number of positions consumed, then leaves the
|
||||
// post-image suffix tokens in `batch` so the normal decode + sampling loop
|
||||
// produces the first generated token.
|
||||
// for multiple images processing
|
||||
bool ingest_images(llama_client_slot &slot, int n_batch)
|
||||
{
|
||||
if (mctx == nullptr)
|
||||
{
|
||||
LOG("%s : multimodal context is not initialized\n", __func__);
|
||||
return false;
|
||||
}
|
||||
int image_idx = 0;
|
||||
|
||||
// bitmaps stay owned by slot.images (freed on reset()); pass non-owning ptrs
|
||||
std::vector<const mtmd_bitmap *> bitmaps;
|
||||
bitmaps.reserve(slot.images.size());
|
||||
for (const slot_image &img : slot.images)
|
||||
while (image_idx < (int) slot.images.size())
|
||||
{
|
||||
bitmaps.push_back(img.bitmap);
|
||||
}
|
||||
slot_image &img = slot.images[image_idx];
|
||||
|
||||
mtmd_input_text inp_txt;
|
||||
inp_txt.text = slot.mtmd_prompt.c_str();
|
||||
inp_txt.add_special = add_bos_token;
|
||||
inp_txt.parse_special = true;
|
||||
// process prefix prompt
|
||||
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch)
|
||||
{
|
||||
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
||||
llama_batch batch_view = {
|
||||
n_tokens,
|
||||
batch.token + i,
|
||||
nullptr,
|
||||
batch.pos + i,
|
||||
batch.n_seq_id + i,
|
||||
batch.seq_id + i,
|
||||
batch.logits + i,
|
||||
};
|
||||
if (llama_decode(ctx, batch_view))
|
||||
{
|
||||
LOG("%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
mtmd::input_chunks chunks(mtmd_input_chunks_init());
|
||||
int32_t res = mtmd_tokenize(mctx,
|
||||
chunks.ptr.get(),
|
||||
&inp_txt,
|
||||
bitmaps.data(),
|
||||
bitmaps.size());
|
||||
if (res != 0)
|
||||
{
|
||||
LOG("%s : failed to tokenize multimodal prompt, res = %d\n", __func__, res);
|
||||
return false;
|
||||
}
|
||||
// process image with llm
|
||||
for (int i = 0; i < img.image_tokens; i += n_batch)
|
||||
{
|
||||
int n_eval = img.image_tokens - i;
|
||||
if (n_eval > n_batch)
|
||||
{
|
||||
n_eval = n_batch;
|
||||
}
|
||||
|
||||
const llama_pos start_pos = (llama_pos) system_tokens.size() + slot.n_past;
|
||||
llama_pos new_n_past = start_pos;
|
||||
if (mtmd_helper_eval_chunks(mctx,
|
||||
ctx,
|
||||
chunks.ptr.get(),
|
||||
start_pos,
|
||||
slot.id,
|
||||
n_batch,
|
||||
/*logits_last=*/ false,
|
||||
&new_n_past) != 0)
|
||||
{
|
||||
LOG("%s : failed to eval multimodal chunks\n", __func__);
|
||||
return false;
|
||||
}
|
||||
slot.n_past += (int32_t) (new_n_past - start_pos);
|
||||
const int n_embd = llama_model_n_embd(model);
|
||||
float * embd = img.image_embedding + i * n_embd;
|
||||
llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, slot.n_past, 0);
|
||||
if (llama_decode(ctx, llava_batch.batch))
|
||||
{
|
||||
LOG("%s : failed to eval image\n", __func__);
|
||||
return false;
|
||||
}
|
||||
slot.n_past += n_eval;
|
||||
}
|
||||
image_idx++;
|
||||
|
||||
// queue the post-image suffix text for the normal decode + sampling path
|
||||
common_batch_clear(batch);
|
||||
std::vector<llama_token> suffix_tokens = tokenize(slot.params.input_suffix, false);
|
||||
for (llama_token tok : suffix_tokens)
|
||||
{
|
||||
common_batch_add(batch, tok, system_tokens.size() + slot.n_past, { slot.id }, false);
|
||||
slot.n_past += 1;
|
||||
common_batch_clear(batch);
|
||||
|
||||
// append prefix of next image
|
||||
const auto json_prompt = (image_idx >= (int) slot.images.size()) ?
|
||||
slot.params.input_suffix : // no more images, then process suffix prompt
|
||||
(json)(slot.images[image_idx].prefix_prompt);
|
||||
|
||||
std::vector<llama_token> append_tokens = tokenize(json_prompt, false); // has next image
|
||||
for (int i = 0; i < (int) append_tokens.size(); ++i)
|
||||
{
|
||||
common_batch_add(batch, append_tokens[i], system_tokens.size() + slot.n_past, { slot.id }, true);
|
||||
slot.n_past += 1;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
@@ -1881,11 +1884,8 @@ struct llama_server_context
|
||||
|
||||
const bool has_images = process_images(slot);
|
||||
|
||||
// For the multimodal path the whole pre-image / inter-image text is
|
||||
// tokenized and decoded inside ingest_images() via mtmd, so no prefix
|
||||
// tokens are queued here; the post-image suffix is appended by
|
||||
// ingest_images() for the normal decode + sampling loop.
|
||||
std::vector<llama_token> prefix_tokens = has_images ? std::vector<llama_token>() : prompt_tokens;
|
||||
// process the prefix of first image
|
||||
std::vector<llama_token> prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens;
|
||||
|
||||
int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
|
||||
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
--- a/examples/llava/clip.cpp
|
||||
+++ b/examples/llava/clip.cpp
|
||||
@@ -2494,7 +2494,7 @@
|
||||
}
|
||||
new_data = work.data();
|
||||
|
||||
- new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, n_elms/cur->ne[0], cur->ne[0], nullptr);
|
||||
+ new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, n_elms/cur->ne[0], cur->ne[0], nullptr, nullptr);
|
||||
} else {
|
||||
new_type = cur->type;
|
||||
new_data = cur->data;
|
||||
@@ -17,9 +17,28 @@ cp -r grpc-server.cpp llama.cpp/examples/grpc-server/
|
||||
cp -r utils.hpp llama.cpp/examples/grpc-server/
|
||||
cp -rfv llama.cpp/vendor/nlohmann/json.hpp llama.cpp/examples/grpc-server/
|
||||
|
||||
## Multimodal support is provided by the `mtmd` library target (examples/mtmd/),
|
||||
## which the grpc-server links and includes directly. No source copy is needed:
|
||||
## clip/llava were pruned upstream and the high-level mtmd_* API is used instead.
|
||||
## Copy clip/llava files for multimodal support (built as myclip library)
|
||||
cp -rfv llama.cpp/examples/llava/clip.h llama.cpp/examples/grpc-server/clip.h
|
||||
cp -rfv llama.cpp/examples/llava/clip.cpp llama.cpp/examples/grpc-server/clip.cpp
|
||||
cp -rfv llama.cpp/examples/llava/llava.cpp llama.cpp/examples/grpc-server/llava.cpp
|
||||
# Prepend llama.h include to llava.h
|
||||
echo '#include "llama.h"' > llama.cpp/examples/grpc-server/llava.h
|
||||
cat llama.cpp/examples/llava/llava.h >> llama.cpp/examples/grpc-server/llava.h
|
||||
# Copy clip-impl.h if it exists
|
||||
if [ -f llama.cpp/examples/llava/clip-impl.h ]; then
|
||||
cp -rfv llama.cpp/examples/llava/clip-impl.h llama.cpp/examples/grpc-server/clip-impl.h
|
||||
fi
|
||||
# Copy stb_image.h
|
||||
if [ -f llama.cpp/vendor/stb/stb_image.h ]; then
|
||||
cp -rfv llama.cpp/vendor/stb/stb_image.h llama.cpp/examples/grpc-server/stb_image.h
|
||||
elif [ -f llama.cpp/common/stb_image.h ]; then
|
||||
cp -rfv llama.cpp/common/stb_image.h llama.cpp/examples/grpc-server/stb_image.h
|
||||
fi
|
||||
|
||||
## Fix API compatibility in llava.cpp (llama_n_embd -> llama_model_n_embd)
|
||||
if [ -f llama.cpp/examples/grpc-server/llava.cpp ]; then
|
||||
sed -i 's/llama_n_embd(/llama_model_n_embd(/g' llama.cpp/examples/grpc-server/llava.cpp
|
||||
fi
|
||||
|
||||
set +e
|
||||
if grep -q "grpc-server" llama.cpp/examples/CMakeLists.txt; then
|
||||
|
||||
@@ -11,12 +11,9 @@
|
||||
|
||||
#include "json.hpp"
|
||||
|
||||
#include "mtmd.h"
|
||||
#include "clip.h"
|
||||
|
||||
// mtmd.h and ik_llama's entire server/common stack (chat.h, server-common.h,
|
||||
// server-task.h, ...) declare `using json = nlohmann::ordered_json`, so match it
|
||||
// here: a plain `nlohmann::json` alias collides with mtmd.h's at global scope.
|
||||
using json = nlohmann::ordered_json;
|
||||
using json = nlohmann::json;
|
||||
|
||||
extern bool server_verbose;
|
||||
|
||||
@@ -114,12 +111,13 @@ struct slot_image
|
||||
{
|
||||
int32_t id;
|
||||
|
||||
// mtmd bitmap (image/audio) decoded from the request buffer. Owned by the
|
||||
// slot; freed via mtmd_bitmap_free() on reset. The high-level mtmd pipeline
|
||||
// (mtmd_tokenize + mtmd_helper_eval_chunks) consumes these directly, so the
|
||||
// legacy eager-encode fields (embedding/tokens) and per-image prefix prompt
|
||||
// are no longer needed.
|
||||
mtmd_bitmap * bitmap = nullptr;
|
||||
bool request_encode_image = false;
|
||||
float * image_embedding = nullptr;
|
||||
int32_t image_tokens = 0;
|
||||
|
||||
clip_image_u8 * img_data;
|
||||
|
||||
std::string prefix_prompt; // before of this image
|
||||
};
|
||||
|
||||
// completion token output with probabilities
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# stablediffusion.cpp (ggml)
|
||||
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
||||
STABLEDIFFUSION_GGML_VERSION?=8caa3f908ae6d4a4bef531e73b9a969f266a3d1f
|
||||
STABLEDIFFUSION_GGML_VERSION?=9956436c925a367daeab097598b1ea1f32d3503f
|
||||
|
||||
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
||||
|
||||
|
||||
@@ -355,6 +355,7 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
|
||||
PrefixProvider: prefixProvider,
|
||||
PrefixConfig: prefixCfg,
|
||||
Pressure: pressure,
|
||||
SharedModels: cfg.Distributed.SharedModels,
|
||||
})
|
||||
|
||||
// Wire staging-progress broadcasting so file-staging shows up on every
|
||||
|
||||
@@ -160,6 +160,7 @@ type RunCMD struct {
|
||||
RegistrationRequireAuth bool `env:"LOCALAI_REGISTRATION_REQUIRE_AUTH" default:"false" help:"Fail startup when distributed mode is enabled but LOCALAI_REGISTRATION_TOKEN is empty (node endpoints and worker file-transfer server would otherwise be unauthenticated)" group:"distributed"`
|
||||
DistributedRequireAuth bool `env:"LOCALAI_DISTRIBUTED_REQUIRE_AUTH" default:"false" help:"Umbrella switch: require BOTH NATS JWT credentials and a registration token when distributed mode is enabled (implies --nats-require-auth and --registration-require-auth)" group:"distributed"`
|
||||
AutoApproveNodes bool `env:"LOCALAI_AUTO_APPROVE_NODES" default:"false" help:"Auto-approve new worker nodes (skip admin approval)" group:"distributed"`
|
||||
DistributedSharedModels bool `env:"LOCALAI_DISTRIBUTED_SHARED_MODELS" default:"false" help:"Assert that every node mounts the SAME models directory at the SAME path (shared volume). When true, the router skips staging model files to workers and loads them directly from the shared path, avoiding re-downloads." group:"distributed"`
|
||||
DistributedPrefixCache bool `env:"LOCALAI_DISTRIBUTED_PREFIX_CACHE" default:"true" help:"Enable prefix-cache-aware routing in distributed mode (default true). When false, routing falls back to round-robin." group:"distributed"`
|
||||
DistributedPrefixCacheTTL string `env:"LOCALAI_DISTRIBUTED_PREFIX_CACHE_TTL" help:"Idle-timeout for prefix-cache index entries; also drives the background eviction cadence (every TTL/2). Default 5m." group:"distributed"`
|
||||
BackendInstallTimeout string `env:"LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT" help:"NATS round-trip timeout for backend.install requests sent to worker nodes (default 15m). Increase for slow links pulling multi-GB images." group:"distributed"`
|
||||
@@ -310,6 +311,9 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
if r.DistributedRequireAuth {
|
||||
opts = append(opts, config.EnableDistributedRequireAuth)
|
||||
}
|
||||
if r.DistributedSharedModels {
|
||||
opts = append(opts, config.EnableDistributedSharedModels)
|
||||
}
|
||||
if r.NatsAccountSeed != "" {
|
||||
opts = append(opts, config.WithNatsAccountSeed(r.NatsAccountSeed))
|
||||
}
|
||||
|
||||
@@ -31,6 +31,14 @@ type DistributedConfig struct {
|
||||
// available to enforce just one layer.
|
||||
RequireAuth bool // LOCALAI_DISTRIBUTED_REQUIRE_AUTH
|
||||
AutoApproveNodes bool // --auto-approve-nodes / LOCALAI_AUTO_APPROVE_NODES (skip admin approval for new workers)
|
||||
// SharedModels asserts that every node (frontend and workers) mounts the
|
||||
// SAME models directory at the SAME path (e.g. a shared volume, as in
|
||||
// docker-compose.distributed.yaml). When true, the router skips staging
|
||||
// model files to workers entirely: the frontend's absolute model paths are
|
||||
// already valid on the worker, so re-uploading them into a per-model
|
||||
// subdirectory only re-downloads what is already present (#10556). Default
|
||||
// false preserves the historical per-node staging behavior.
|
||||
SharedModels bool // --distributed-shared-models / LOCALAI_DISTRIBUTED_SHARED_MODELS
|
||||
|
||||
// NATS JWT auth (optional; see pkg/natsauth and docs/features/distributed-mode.md)
|
||||
NatsAccountSeed string // LOCALAI_NATS_ACCOUNT_SEED — account signing seed to mint per-node worker JWTs
|
||||
@@ -282,6 +290,13 @@ var EnableAutoApproveNodes = func(o *ApplicationConfig) {
|
||||
o.Distributed.AutoApproveNodes = true
|
||||
}
|
||||
|
||||
// EnableDistributedSharedModels marks the cluster as sharing one models
|
||||
// directory across all nodes, so the router skips staging model files to
|
||||
// workers (see DistributedConfig.SharedModels).
|
||||
var EnableDistributedSharedModels = func(o *ApplicationConfig) {
|
||||
o.Distributed.SharedModels = true
|
||||
}
|
||||
|
||||
// DisablePrefixCache turns off prefix-cache-aware routing (falls back to
|
||||
// round-robin). Prefix-cache routing is enabled by default in distributed mode.
|
||||
var DisablePrefixCache = func(o *ApplicationConfig) {
|
||||
|
||||
@@ -25,8 +25,8 @@ var (
|
||||
|
||||
type LlamaCPPImporter struct{}
|
||||
|
||||
func (i *LlamaCPPImporter) Name() string { return "llama-cpp" }
|
||||
func (i *LlamaCPPImporter) Modality() string { return "text" }
|
||||
func (i *LlamaCPPImporter) Name() string { return "llama-cpp" }
|
||||
func (i *LlamaCPPImporter) Modality() string { return "text" }
|
||||
func (i *LlamaCPPImporter) AutoDetects() bool { return true }
|
||||
|
||||
// AdditionalBackends advertises drop-in replacements that share the
|
||||
@@ -293,7 +293,7 @@ func pickPreferredGroup(groups []hfapi.ShardGroup, prefs []string) *hfapi.ShardG
|
||||
for _, pref := range prefs {
|
||||
lower := strings.ToLower(pref)
|
||||
for i := range groups {
|
||||
if strings.Contains(strings.ToLower(groups[i].Base), lower) {
|
||||
if quantTokenMatches(strings.ToLower(groups[i].Base), lower) {
|
||||
return &groups[i]
|
||||
}
|
||||
}
|
||||
@@ -301,6 +301,39 @@ func pickPreferredGroup(groups []hfapi.ShardGroup, prefs []string) *hfapi.ShardG
|
||||
return &groups[len(groups)-1]
|
||||
}
|
||||
|
||||
// quantTokenMatches reports whether pref appears in base as a whole token
|
||||
// rather than as a substring of a larger alphanumeric run. Both arguments
|
||||
// must already be lowercased.
|
||||
//
|
||||
// A plain strings.Contains is wrong here: `f16` is a substring of `bf16`, so
|
||||
// asking for the `F16` quant used to wrongly select a `BF16` file (#10559).
|
||||
// Only the OUTER edges of the matched preference must hit a boundary — a
|
||||
// non-alphanumeric char (or the start/end of base). Separators inside the
|
||||
// preference itself (e.g. `ud-q4_k_xl`) are intentionally left untouched.
|
||||
func quantTokenMatches(base, pref string) bool {
|
||||
if pref == "" {
|
||||
return false
|
||||
}
|
||||
for start := strings.Index(base, pref); start != -1; {
|
||||
end := start + len(pref)
|
||||
leftOK := start == 0 || !isAlphaNum(base[start-1])
|
||||
rightOK := end == len(base) || !isAlphaNum(base[end])
|
||||
if leftOK && rightOK {
|
||||
return true
|
||||
}
|
||||
next := strings.Index(base[start+1:], pref)
|
||||
if next == -1 {
|
||||
break
|
||||
}
|
||||
start += next + 1
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isAlphaNum(b byte) bool {
|
||||
return (b >= 'a' && b <= 'z') || (b >= '0' && b <= '9')
|
||||
}
|
||||
|
||||
// maybeApplyMTPDefaults parses the picked GGUF header (range-fetched over
|
||||
// HTTP for HF/URL imports) and, if the file declares a Multi-Token Prediction
|
||||
// head, appends the auto-MTP option keys to modelConfig.Options. Failures
|
||||
|
||||
@@ -374,6 +374,104 @@ var _ = Describe("LlamaCPPImporter", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Context("quant token boundary matching", func() {
|
||||
// Regression for #10559: the quant preference must match as a whole
|
||||
// token, not as a substring. Asking for `F16` used to select a
|
||||
// `BF16` mmproj because strings.Contains("...bf16.gguf", "f16") is
|
||||
// true — the leading `b` was ignored.
|
||||
|
||||
const repoBase = "https://huggingface.co/acme/example-GGUF/resolve/main/"
|
||||
|
||||
hfFile := func(path, sha string) hfapi.ModelFile {
|
||||
return hfapi.ModelFile{
|
||||
Path: path,
|
||||
SHA256: sha,
|
||||
URL: repoBase + path,
|
||||
}
|
||||
}
|
||||
|
||||
withHF := func(preferences string, files ...hfapi.ModelFile) Details {
|
||||
d := Details{
|
||||
URI: "https://huggingface.co/acme/example-GGUF",
|
||||
HuggingFace: &hfapi.ModelDetails{
|
||||
ModelID: "acme/example-GGUF",
|
||||
Files: files,
|
||||
},
|
||||
}
|
||||
if preferences != "" {
|
||||
d.Preferences = json.RawMessage(preferences)
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
It("selects the F16 mmproj over BF16 (BF16 listed first)", func() {
|
||||
details := withHF(`{"name":"VL","mmproj_quantizations":"F16"}`,
|
||||
hfFile("model-Q4_K_M.gguf", "model"),
|
||||
hfFile("mmproj-x-BF16.gguf", "bf16"),
|
||||
hfFile("mmproj-x-F16.gguf", "f16"),
|
||||
)
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("mmproj: llama-cpp/mmproj/VL/mmproj-x-F16.gguf"), fmt.Sprintf("%+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).ToNot(ContainSubstring("BF16"), fmt.Sprintf("%+v", modelConfig))
|
||||
})
|
||||
|
||||
It("selects the F16 mmproj over BF16 (F16 listed first)", func() {
|
||||
details := withHF(`{"name":"VL","mmproj_quantizations":"F16"}`,
|
||||
hfFile("model-Q4_K_M.gguf", "model"),
|
||||
hfFile("mmproj-x-F16.gguf", "f16"),
|
||||
hfFile("mmproj-x-BF16.gguf", "bf16"),
|
||||
)
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("mmproj: llama-cpp/mmproj/VL/mmproj-x-F16.gguf"), fmt.Sprintf("%+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).ToNot(ContainSubstring("BF16"), fmt.Sprintf("%+v", modelConfig))
|
||||
})
|
||||
|
||||
It("selects BF16 when BF16 is the requested mmproj quant", func() {
|
||||
details := withHF(`{"name":"VL","mmproj_quantizations":"BF16"}`,
|
||||
hfFile("model-Q4_K_M.gguf", "model"),
|
||||
hfFile("mmproj-x-F16.gguf", "f16"),
|
||||
hfFile("mmproj-x-BF16.gguf", "bf16"),
|
||||
)
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("mmproj: llama-cpp/mmproj/VL/mmproj-x-BF16.gguf"), fmt.Sprintf("%+v", modelConfig))
|
||||
})
|
||||
|
||||
It("still matches a normal model quant with internal separators", func() {
|
||||
// ud-q4_k_xl contains `-`/`_` internally; only the outer edges
|
||||
// must hit a token boundary.
|
||||
details := withHF(`{"name":"M","quantizations":"ud-q4_k_xl"}`,
|
||||
hfFile("model-UD-Q4_K_XL.gguf", "xl"),
|
||||
hfFile("model-Q3_K_M.gguf", "q3"),
|
||||
)
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: llama-cpp/models/M/model-UD-Q4_K_XL.gguf"), fmt.Sprintf("%+v", modelConfig))
|
||||
})
|
||||
|
||||
It("falls back to the last group when no preference matches", func() {
|
||||
details := withHF(`{"name":"M","quantizations":"Q2_K"}`,
|
||||
hfFile("model-Q8_0.gguf", "q8"),
|
||||
hfFile("model-Q3_K_M.gguf", "q3"),
|
||||
)
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: llama-cpp/models/M/model-Q3_K_M.gguf"), fmt.Sprintf("%+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("AdditionalBackends", func() {
|
||||
It("advertises ik-llama-cpp and turboquant as drop-in replacements", func() {
|
||||
entries := importer.AdditionalBackends()
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
@@ -550,12 +551,23 @@ func DeleteBackendOnNodeEndpoint(unloader nodes.NodeCommandSender) echo.HandlerF
|
||||
}
|
||||
|
||||
// ListBackendsOnNodeEndpoint lists installed backends on a worker node via NATS.
|
||||
func ListBackendsOnNodeEndpoint(unloader nodes.NodeCommandSender) echo.HandlerFunc {
|
||||
func ListBackendsOnNodeEndpoint(unloader nodes.NodeCommandSender, registry *nodes.NodeRegistry) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
nodeID := c.Param("id")
|
||||
// Agent-type workers don't run backends and never subscribe to the
|
||||
// nodes.<id>.backend.list NATS subject, so the request would hang
|
||||
// until timeout with "no responders". Their backend list is simply
|
||||
// empty. Mirror the aggregate-list guard in managers_distributed.go
|
||||
// (skip nodes whose NodeType is set and not "backend") so the
|
||||
// single-node and cluster-wide views stay consistent.
|
||||
if node, err := registry.Get(c.Request().Context(), nodeID); err == nil {
|
||||
if node.NodeType != "" && node.NodeType != nodes.NodeTypeBackend {
|
||||
return c.JSON(http.StatusOK, []messaging.NodeBackendInfo{})
|
||||
}
|
||||
}
|
||||
if unloader == nil {
|
||||
return c.JSON(http.StatusServiceUnavailable, nodeError(http.StatusServiceUnavailable, "NATS not configured"))
|
||||
}
|
||||
nodeID := c.Param("id")
|
||||
reply, err := unloader.ListBackends(nodeID)
|
||||
if err != nil {
|
||||
xlog.Error("Failed to list backends on node", "node", nodeID, "error", err)
|
||||
|
||||
103
core/http/endpoints/localai/nodes_backends_list_test.go
Normal file
103
core/http/endpoints/localai/nodes_backends_list_test.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/testutil"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// stubNodeCommandSender records whether ListBackends was invoked so the test can
|
||||
// assert the endpoint short-circuits (no NATS request) for agent-type nodes.
|
||||
type stubNodeCommandSender struct {
|
||||
listBackendsCalled bool
|
||||
}
|
||||
|
||||
func (s *stubNodeCommandSender) InstallBackend(_, _, _, _, _, _, _ string, _ int, _ string, _ func(messaging.BackendInstallProgressEvent)) (*messaging.BackendInstallReply, error) {
|
||||
return &messaging.BackendInstallReply{}, nil
|
||||
}
|
||||
|
||||
func (s *stubNodeCommandSender) UpgradeBackend(_, _, _, _, _, _ string, _ int, _ string, _ func(messaging.BackendInstallProgressEvent)) (*messaging.BackendUpgradeReply, error) {
|
||||
return &messaging.BackendUpgradeReply{}, nil
|
||||
}
|
||||
|
||||
func (s *stubNodeCommandSender) DeleteBackend(_, _ string) (*messaging.BackendDeleteReply, error) {
|
||||
return &messaging.BackendDeleteReply{Success: true}, nil
|
||||
}
|
||||
|
||||
func (s *stubNodeCommandSender) ListBackends(_ string) (*messaging.BackendListReply, error) {
|
||||
s.listBackendsCalled = true
|
||||
return &messaging.BackendListReply{Backends: []messaging.NodeBackendInfo{{Name: "llama-cpp"}}}, nil
|
||||
}
|
||||
|
||||
func (s *stubNodeCommandSender) StopBackend(_, _ string) error { return nil }
|
||||
|
||||
func (s *stubNodeCommandSender) UnloadModelOnNode(_, _ string) error { return nil }
|
||||
|
||||
var _ = Describe("ListBackendsOnNodeEndpoint", func() {
|
||||
var registry *nodes.NodeRegistry
|
||||
|
||||
BeforeEach(func() {
|
||||
db := testutil.SetupTestDB()
|
||||
var err error
|
||||
registry, err = nodes.NewNodeRegistry(db)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
callEndpoint := func(unloader nodes.NodeCommandSender, nodeID string) *httptest.ResponseRecorder {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
c.SetParamNames("id")
|
||||
c.SetParamValues(nodeID)
|
||||
handler := ListBackendsOnNodeEndpoint(unloader, registry)
|
||||
Expect(handler(c)).To(Succeed())
|
||||
return rec
|
||||
}
|
||||
|
||||
It("returns an empty list for an agent node without issuing a NATS request", func() {
|
||||
ctx := context.Background()
|
||||
node := &nodes.BackendNode{Name: "agent-1", NodeType: nodes.NodeTypeAgent}
|
||||
Expect(registry.Register(ctx, node, true)).To(Succeed())
|
||||
|
||||
stub := &stubNodeCommandSender{}
|
||||
rec := callEndpoint(stub, node.ID)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(stub.listBackendsCalled).To(BeFalse(),
|
||||
"agent workers don't subscribe to backend.list; the endpoint must not issue the doomed NATS request")
|
||||
|
||||
var list []messaging.NodeBackendInfo
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &list)).To(Succeed())
|
||||
Expect(list).To(BeEmpty())
|
||||
// Must be `[]`, not `null`, so the UI can render it.
|
||||
Expect(rec.Body.String()).To(ContainSubstring("[]"))
|
||||
})
|
||||
|
||||
It("consults the unloader (NATS) for a backend node", func() {
|
||||
ctx := context.Background()
|
||||
node := &nodes.BackendNode{Name: "backend-1", NodeType: nodes.NodeTypeBackend, Address: "10.0.0.1:50051"}
|
||||
Expect(registry.Register(ctx, node, true)).To(Succeed())
|
||||
|
||||
stub := &stubNodeCommandSender{}
|
||||
rec := callEndpoint(stub, node.ID)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(stub.listBackendsCalled).To(BeTrue(),
|
||||
"backend nodes must still be queried over NATS")
|
||||
|
||||
var list []messaging.NodeBackendInfo
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &list)).To(Succeed())
|
||||
Expect(list).To(HaveLen(1))
|
||||
Expect(list[0].Name).To(Equal("llama-cpp"))
|
||||
})
|
||||
})
|
||||
@@ -3,6 +3,7 @@ package openresponses
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
@@ -10,6 +11,7 @@ import (
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
|
||||
openaiEndpoint "github.com/mudler/LocalAI/core/http/endpoints/openai"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
@@ -246,8 +248,11 @@ func ResponsesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eval
|
||||
// Create cancellable context for background execution
|
||||
bgCtx, bgCancel := context.WithCancel(context.Background())
|
||||
|
||||
// Store the background response
|
||||
// Store the background response and stamp its owner before the ID
|
||||
// is returned to the client, so later GET/cancel/resume can verify
|
||||
// the caller owns it.
|
||||
store.StoreBackground(responseID, input, queuedResponse, bgCancel, input.Stream)
|
||||
store.SetOwner(responseID, ownerFromContext(c))
|
||||
|
||||
// Start background processing goroutine
|
||||
go func() {
|
||||
@@ -1587,6 +1592,7 @@ func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt i
|
||||
if shouldStore {
|
||||
store := GetGlobalStore()
|
||||
store.Store(responseID, input, response)
|
||||
store.SetOwner(responseID, ownerFromContext(c))
|
||||
}
|
||||
|
||||
return c.JSON(200, response)
|
||||
@@ -2322,6 +2328,7 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
if shouldStore {
|
||||
store := GetGlobalStore()
|
||||
store.Store(responseID, input, responseCompleted)
|
||||
store.SetOwner(responseID, ownerFromContext(c))
|
||||
}
|
||||
|
||||
// Send [DONE]
|
||||
@@ -2966,6 +2973,18 @@ func convertORToolsToOpenAIFormat(orTools []schema.ORFunctionTool) []functions.T
|
||||
return result
|
||||
}
|
||||
|
||||
// ownerFromContext returns the identity (user ID) of the authenticated
|
||||
// caller, or empty string when no authentication was performed (single-key /
|
||||
// no-auth deployments). It is the value stamped on a response at creation and
|
||||
// compared on read/cancel/resume to prevent one caller from accessing
|
||||
// another's response by guessing its ID.
|
||||
func ownerFromContext(c echo.Context) string {
|
||||
if u := auth.GetUser(c); u != nil {
|
||||
return u.ID
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetResponseEndpoint returns a handler for GET /responses/:id
|
||||
// This endpoint is used for polling background responses or resuming streaming
|
||||
// @Summary Get a response by ID
|
||||
@@ -2991,6 +3010,12 @@ func GetResponseEndpoint() func(c echo.Context) error {
|
||||
return sendOpenResponsesError(c, 404, "not_found", fmt.Sprintf("response not found: %s", responseID), "id")
|
||||
}
|
||||
|
||||
// Enforce response ownership. Return 404 (not 403) on mismatch so the
|
||||
// existence of another caller's response is not leaked.
|
||||
if !accessAllowed(stored, ownerFromContext(c)) {
|
||||
return sendOpenResponsesError(c, 404, "not_found", fmt.Sprintf("response not found: %s", responseID), "id")
|
||||
}
|
||||
|
||||
// Check if streaming resume is requested
|
||||
streamParam := c.QueryParam("stream")
|
||||
if streamParam == "true" {
|
||||
@@ -3022,16 +3047,21 @@ func GetResponseEndpoint() func(c echo.Context) error {
|
||||
|
||||
// handleStreamResume handles resuming a streaming response from a specific sequence number
|
||||
func handleStreamResume(c echo.Context, store *ResponseStore, responseID string, stored *StoredResponse, startingAfter int) error {
|
||||
// Fetch buffered events before committing to an SSE response so an
|
||||
// offset-lost gap can be reported as a clean HTTP status rather than a
|
||||
// silently truncated event stream.
|
||||
events, err := store.GetEventsAfter(responseID, startingAfter)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrOffsetLost) {
|
||||
return sendOpenResponsesError(c, 409, "invalid_request_error", fmt.Sprintf("starting_after=%d is older than the oldest retained event; the resume buffer evicted those events and the stream cannot be resumed from that point", startingAfter), "starting_after")
|
||||
}
|
||||
return sendOpenResponsesError(c, 500, "server_error", fmt.Sprintf("failed to get events: %v", err), "")
|
||||
}
|
||||
|
||||
c.Response().Header().Set("Content-Type", "text/event-stream")
|
||||
c.Response().Header().Set("Cache-Control", "no-cache")
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
|
||||
// Get buffered events after the starting point
|
||||
events, err := store.GetEventsAfter(responseID, startingAfter)
|
||||
if err != nil {
|
||||
return sendOpenResponsesError(c, 500, "server_error", fmt.Sprintf("failed to get events: %v", err), "")
|
||||
}
|
||||
|
||||
// Send all buffered events
|
||||
for _, event := range events {
|
||||
fmt.Fprintf(c.Response().Writer, "event: %s\ndata: %s\n\n", event.EventType, string(event.Data))
|
||||
@@ -3126,6 +3156,17 @@ func CancelResponseEndpoint() func(c echo.Context) error {
|
||||
}
|
||||
|
||||
store := GetGlobalStore()
|
||||
|
||||
// Look up first so ownership can be checked before any mutation.
|
||||
stored, err := store.Get(responseID)
|
||||
if err != nil {
|
||||
return sendOpenResponsesError(c, 404, "not_found", fmt.Sprintf("response not found: %s", responseID), "id")
|
||||
}
|
||||
// Return 404 (not 403) on owner mismatch so existence is not leaked.
|
||||
if !accessAllowed(stored, ownerFromContext(c)) {
|
||||
return sendOpenResponsesError(c, 404, "not_found", fmt.Sprintf("response not found: %s", responseID), "id")
|
||||
}
|
||||
|
||||
response, err := store.Cancel(responseID)
|
||||
if err != nil {
|
||||
return sendOpenResponsesError(c, 404, "not_found", fmt.Sprintf("response not found: %s", responseID), "id")
|
||||
|
||||
@@ -3,6 +3,7 @@ package openresponses
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -11,6 +12,30 @@ import (
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
const (
|
||||
// defaultMaxStreamEvents bounds how many resume-buffer events a single
|
||||
// background response retains. Without a cap, a long-running or abandoned
|
||||
// background generation grows StreamEvents without limit and can exhaust
|
||||
// process memory. When the cap is exceeded the oldest events are evicted
|
||||
// from the front (see AppendEvent). Mirrors llama.cpp's byte-capped slot
|
||||
// ring used for resumable /slots state.
|
||||
defaultMaxStreamEvents = 8192
|
||||
|
||||
// defaultMaxStreamBytes caps the total serialized size of retained
|
||||
// resume-buffer events, evicting oldest-first when exceeded. This guards
|
||||
// against a handful of very large events defeating the count cap. 0
|
||||
// disables the byte cap (count cap still applies).
|
||||
defaultMaxStreamBytes = 64 << 20 // 64 MiB
|
||||
)
|
||||
|
||||
// ErrOffsetLost is returned by GetEventsAfter when the requested
|
||||
// starting_after sequence number is older than the oldest event still
|
||||
// retained in the resume buffer (i.e. the events between the requested
|
||||
// offset and the current watermark were evicted by the cap). Callers should
|
||||
// surface this to clients as a distinct error instead of silently returning
|
||||
// a truncated stream that omits the dropped events.
|
||||
var ErrOffsetLost = errors.New("resume offset lost: requested events were evicted from the buffer")
|
||||
|
||||
// ResponseStore provides thread-safe storage for Open Responses API responses
|
||||
type ResponseStore struct {
|
||||
mu sync.RWMutex
|
||||
@@ -18,6 +43,12 @@ type ResponseStore struct {
|
||||
ttl time.Duration // Time-to-live for stored responses (0 = no expiration)
|
||||
cleanupCtx context.Context
|
||||
cleanupCancel context.CancelFunc
|
||||
|
||||
// maxStreamEvents / maxStreamBytes bound the per-response resume buffer.
|
||||
// Set once at construction from the default constants; tests may lower
|
||||
// them. A value <= 0 disables that particular cap.
|
||||
maxStreamEvents int
|
||||
maxStreamBytes int
|
||||
}
|
||||
|
||||
// StreamedEvent represents a buffered SSE event for streaming resume
|
||||
@@ -35,6 +66,12 @@ type StoredResponse struct {
|
||||
StoredAt time.Time
|
||||
ExpiresAt *time.Time // nil if no expiration
|
||||
|
||||
// Owner is the identity (user ID) that created this response. It is set
|
||||
// once at creation and never mutated, so it can be read without holding
|
||||
// mu. Empty means "no owner" (single-key / no-auth deployments), in which
|
||||
// case ownership checks are skipped for backward compatibility.
|
||||
Owner string
|
||||
|
||||
// Background execution support
|
||||
CancelFunc context.CancelFunc // For cancellation of background tasks
|
||||
StreamEvents []StreamedEvent // Buffered events for streaming resume
|
||||
@@ -42,6 +79,14 @@ type StoredResponse struct {
|
||||
IsBackground bool // Was created with background=true
|
||||
EventsChan chan struct{} // Signals new events for live subscribers
|
||||
mu sync.RWMutex // Protect concurrent access to this response
|
||||
|
||||
// streamBytes tracks the total serialized size of the events currently
|
||||
// retained in StreamEvents, used to enforce the byte cap. droppedThrough
|
||||
// is the highest sequence number evicted from the front of the buffer
|
||||
// (-1 = nothing evicted); it is the watermark GetEventsAfter compares
|
||||
// against to detect a lost resume offset. Both are guarded by mu.
|
||||
streamBytes int
|
||||
droppedThrough int
|
||||
}
|
||||
|
||||
var getGlobalStore = sync.OnceValue(func() *ResponseStore {
|
||||
@@ -81,8 +126,10 @@ func (s *ResponseStore) SetTTL(ttl time.Duration) {
|
||||
// If ttl is 0, responses are stored indefinitely
|
||||
func NewResponseStore(ttl time.Duration) *ResponseStore {
|
||||
store := &ResponseStore{
|
||||
responses: make(map[string]*StoredResponse),
|
||||
ttl: ttl,
|
||||
responses: make(map[string]*StoredResponse),
|
||||
ttl: ttl,
|
||||
maxStreamEvents: defaultMaxStreamEvents,
|
||||
maxStreamBytes: defaultMaxStreamBytes,
|
||||
}
|
||||
|
||||
// Start cleanup goroutine if TTL is set
|
||||
@@ -109,11 +156,12 @@ func (s *ResponseStore) Store(responseID string, request *schema.OpenResponsesRe
|
||||
}
|
||||
|
||||
stored := &StoredResponse{
|
||||
Request: request,
|
||||
Response: response,
|
||||
Items: items,
|
||||
StoredAt: time.Now(),
|
||||
ExpiresAt: nil,
|
||||
Request: request,
|
||||
Response: response,
|
||||
Items: items,
|
||||
StoredAt: time.Now(),
|
||||
ExpiresAt: nil,
|
||||
droppedThrough: -1,
|
||||
}
|
||||
|
||||
// Set expiration if TTL is configured
|
||||
@@ -256,16 +304,17 @@ func (s *ResponseStore) StoreBackground(responseID string, request *schema.OpenR
|
||||
}
|
||||
|
||||
stored := &StoredResponse{
|
||||
Request: request,
|
||||
Response: response,
|
||||
Items: items,
|
||||
StoredAt: time.Now(),
|
||||
ExpiresAt: nil,
|
||||
CancelFunc: cancelFunc,
|
||||
StreamEvents: []StreamedEvent{},
|
||||
StreamEnabled: streamEnabled,
|
||||
IsBackground: true,
|
||||
EventsChan: make(chan struct{}, 100), // Buffered channel for event notifications
|
||||
Request: request,
|
||||
Response: response,
|
||||
Items: items,
|
||||
StoredAt: time.Now(),
|
||||
ExpiresAt: nil,
|
||||
CancelFunc: cancelFunc,
|
||||
StreamEvents: []StreamedEvent{},
|
||||
StreamEnabled: streamEnabled,
|
||||
IsBackground: true,
|
||||
EventsChan: make(chan struct{}, 100), // Buffered channel for event notifications
|
||||
droppedThrough: -1,
|
||||
}
|
||||
|
||||
// Set expiration if TTL is configured
|
||||
@@ -349,6 +398,25 @@ func (s *ResponseStore) AppendEvent(responseID string, event *schema.ORStreamEve
|
||||
EventType: event.Type,
|
||||
Data: data,
|
||||
})
|
||||
stored.streamBytes += len(data)
|
||||
|
||||
// Evict oldest events from the front once either cap is exceeded. The
|
||||
// byte cap never evicts the only remaining event (a single oversized
|
||||
// event is still served once). Each eviction advances droppedThrough so
|
||||
// a later resume below the watermark is reported as ErrOffsetLost rather
|
||||
// than silently skipping the dropped events.
|
||||
for (s.maxStreamEvents > 0 && len(stored.StreamEvents) > s.maxStreamEvents) ||
|
||||
(s.maxStreamBytes > 0 && stored.streamBytes > s.maxStreamBytes && len(stored.StreamEvents) > 1) {
|
||||
evicted := stored.StreamEvents[0]
|
||||
stored.streamBytes -= len(evicted.Data)
|
||||
if evicted.SequenceNumber > stored.droppedThrough {
|
||||
stored.droppedThrough = evicted.SequenceNumber
|
||||
}
|
||||
// Release the evicted payload so it can be GC'd even though the
|
||||
// backing array element is still owned by the slice until reuse.
|
||||
stored.StreamEvents[0].Data = nil
|
||||
stored.StreamEvents = stored.StreamEvents[1:]
|
||||
}
|
||||
stored.mu.Unlock()
|
||||
|
||||
// Notify any subscribers of new event
|
||||
@@ -374,6 +442,14 @@ func (s *ResponseStore) GetEventsAfter(responseID string, startingAfter int) ([]
|
||||
stored.mu.RLock()
|
||||
defer stored.mu.RUnlock()
|
||||
|
||||
// If the requested offset is older than the watermark, the events the
|
||||
// client expects next (those in (startingAfter, droppedThrough]) were
|
||||
// evicted by the cap. Signal the gap rather than returning a stream that
|
||||
// silently skips them.
|
||||
if startingAfter < stored.droppedThrough {
|
||||
return nil, ErrOffsetLost
|
||||
}
|
||||
|
||||
var result []StreamedEvent
|
||||
for _, event := range stored.StreamEvents {
|
||||
if event.SequenceNumber > startingAfter {
|
||||
@@ -447,3 +523,30 @@ func (s *ResponseStore) IsStreamEnabled(responseID string) (bool, error) {
|
||||
|
||||
return stored.StreamEnabled, nil
|
||||
}
|
||||
|
||||
// SetOwner records the identity that owns a stored response. It is called
|
||||
// once, right after the response is stored and before its ID is handed back
|
||||
// to any client, so no lock on the stored response is required. A no-op for
|
||||
// an empty owner or unknown response ID.
|
||||
func (s *ResponseStore) SetOwner(responseID, owner string) {
|
||||
if owner == "" {
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
stored, exists := s.responses[responseID]
|
||||
s.mu.RUnlock()
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
stored.Owner = owner
|
||||
}
|
||||
|
||||
// accessAllowed reports whether a caller identified by callerID may read or
|
||||
// mutate the given stored response. An empty owner (single-key / no-auth
|
||||
// deployments) is accessible by anyone, preserving backward compatibility;
|
||||
// otherwise the caller identity must match the recorded owner.
|
||||
func accessAllowed(stored *StoredResponse, callerID string) bool {
|
||||
return stored.Owner == "" || stored.Owner == callerID
|
||||
}
|
||||
|
||||
@@ -585,6 +585,86 @@ var _ = Describe("ResponseStore", func() {
|
||||
Expect(enabled2).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should bound the resume buffer and evict oldest events past the cap", func() {
|
||||
// Lower the caps so the test stays fast; production defaults are
|
||||
// large. Same-package access to the unexported fields is fine.
|
||||
store.maxStreamEvents = 5
|
||||
store.maxStreamBytes = 0 // count cap only for this test
|
||||
|
||||
responseID := "resp_buffer_cap"
|
||||
request := &schema.OpenResponsesRequest{Model: "test"}
|
||||
response := &schema.ORResponseResource{
|
||||
ID: responseID,
|
||||
Object: "response",
|
||||
Status: schema.ORStatusInProgress,
|
||||
}
|
||||
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
store.StoreBackground(responseID, request, response, cancel, true)
|
||||
|
||||
// Append well past the cap.
|
||||
const total = 20
|
||||
for i := range total {
|
||||
err := store.AppendEvent(responseID, &schema.ORStreamEvent{
|
||||
Type: "response.output_text.delta",
|
||||
SequenceNumber: i,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
|
||||
stored, err := store.Get(responseID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// (a) Buffer length stays bounded by the cap.
|
||||
Expect(len(stored.StreamEvents)).To(Equal(5))
|
||||
|
||||
// (b) Oldest events were evicted: only the last 5 sequence numbers
|
||||
// remain (15..19).
|
||||
Expect(stored.StreamEvents[0].SequenceNumber).To(Equal(15))
|
||||
Expect(stored.StreamEvents[len(stored.StreamEvents)-1].SequenceNumber).To(Equal(19))
|
||||
|
||||
// Asking for events after the last retained seq still works.
|
||||
retained, err := store.GetEventsAfter(responseID, 14)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(retained).To(HaveLen(5))
|
||||
|
||||
// (c) Asking below the dropped watermark returns ErrOffsetLost.
|
||||
_, err = store.GetEventsAfter(responseID, 0)
|
||||
Expect(err).To(MatchError(ErrOffsetLost))
|
||||
|
||||
_, err = store.GetEventsAfter(responseID, -1)
|
||||
Expect(err).To(MatchError(ErrOffsetLost))
|
||||
})
|
||||
|
||||
It("should record and enforce response ownership", func() {
|
||||
responseID := "resp_owner_test"
|
||||
request := &schema.OpenResponsesRequest{Model: "test"}
|
||||
response := &schema.ORResponseResource{ID: responseID, Object: "response", Status: schema.ORStatusCompleted}
|
||||
|
||||
store.Store(responseID, request, response)
|
||||
store.SetOwner(responseID, "userA")
|
||||
|
||||
stored, err := store.Get(responseID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(stored.Owner).To(Equal("userA"))
|
||||
|
||||
// Owner matches -> allowed; different identity -> denied.
|
||||
Expect(accessAllowed(stored, "userA")).To(BeTrue())
|
||||
Expect(accessAllowed(stored, "userB")).To(BeFalse())
|
||||
|
||||
// Backward compatibility: a response with no owner is accessible
|
||||
// by any caller (single-key / no-auth deployments).
|
||||
noOwnerID := "resp_no_owner"
|
||||
store.Store(noOwnerID, request, &schema.ORResponseResource{ID: noOwnerID, Object: "response"})
|
||||
noOwner, err := store.Get(noOwnerID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(noOwner.Owner).To(BeEmpty())
|
||||
Expect(accessAllowed(noOwner, "anyone")).To(BeTrue())
|
||||
Expect(accessAllowed(noOwner, "")).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should notify subscribers of new events", func() {
|
||||
responseID := "resp_events_chan"
|
||||
request := &schema.OpenResponsesRequest{Model: "test"}
|
||||
|
||||
@@ -88,7 +88,7 @@ func RegisterNodeAdminRoutes(e *echo.Echo, registry *nodes.NodeRegistry, unloade
|
||||
admin.POST("/:id/approve", localai.ApproveNodeEndpoint(registry, authDB, hmacSecret, natsCfg))
|
||||
|
||||
// Backend management on workers
|
||||
admin.GET("/:id/backends", localai.ListBackendsOnNodeEndpoint(unloader))
|
||||
admin.GET("/:id/backends", localai.ListBackendsOnNodeEndpoint(unloader, registry))
|
||||
admin.POST("/:id/backends/install", localai.InstallBackendOnNodeEndpoint(unloader, galleryService, opcache, appConfig))
|
||||
admin.POST("/:id/backends/delete", localai.DeleteBackendOnNodeEndpoint(unloader))
|
||||
|
||||
|
||||
@@ -63,6 +63,11 @@ type SmartRouterOptions struct {
|
||||
// The reconciler reads the same instance to autoscale a saturated cache-warm
|
||||
// replica. nil disables recording (the disabled path stays a no-op).
|
||||
Pressure *prefixcache.Pressure
|
||||
// SharedModels asserts that every node mounts the same models directory at
|
||||
// the same path. When true, stageModelFiles skips all uploading and leaves
|
||||
// the absolute model paths untouched so the worker loads them directly from
|
||||
// the shared volume (#10556). See config.DistributedConfig.SharedModels.
|
||||
SharedModels bool
|
||||
}
|
||||
|
||||
// SmartRouter routes inference requests to the best available backend node.
|
||||
@@ -93,6 +98,9 @@ type SmartRouter struct {
|
||||
// per-request routing doesn't stall behind a busy backend's serialized
|
||||
// HealthCheck/Predict. See probe_cache.go for the rationale.
|
||||
probeCache *probeCache
|
||||
// sharedModels skips file staging when all nodes mount the same models
|
||||
// directory at the same path (see SmartRouterOptions.SharedModels).
|
||||
sharedModels bool
|
||||
}
|
||||
|
||||
// probeCacheTTL is how long a successful gRPC HealthCheck on a backend is
|
||||
@@ -122,6 +130,7 @@ func NewSmartRouter(registry ModelRouter, opts SmartRouterOptions) *SmartRouter
|
||||
prefixProvider: opts.PrefixProvider,
|
||||
prefixConfig: opts.PrefixConfig,
|
||||
pressure: opts.Pressure,
|
||||
sharedModels: opts.SharedModels,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -947,6 +956,19 @@ func (r *SmartRouter) buildClientForAddr(node *BackendNode, addr string, paralle
|
||||
// simply remove the {ModelsPath}/{trackingKey}/ directory.
|
||||
func (r *SmartRouter) stageModelFiles(ctx context.Context, node *BackendNode, opts *pb.ModelOptions, trackingKey string) (*pb.ModelOptions, error) {
|
||||
opts = proto.Clone(opts).(*pb.ModelOptions)
|
||||
|
||||
// Shared-models mode: every node mounts the same models directory at the
|
||||
// same path, so the frontend's absolute model paths are already valid on the
|
||||
// worker. Staging would only re-upload files that already exist on the shared
|
||||
// volume (under a tracking-key subdir the probe never reuses), re-downloading
|
||||
// the model on every load (#10556). Return the clone untouched: no upload, no
|
||||
// path rewrite, no staging tracker.
|
||||
if r.sharedModels {
|
||||
xlog.Info("Skipping model file staging: shared-models mode is on (LOCALAI_DISTRIBUTED_SHARED_MODELS); worker loads directly from the shared volume",
|
||||
"node", node.Name, "modelFile", opts.ModelFile, "trackingKey", trackingKey)
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
xlog.Info("Staging model files for remote node", "node", node.Name, "modelFile", opts.ModelFile, "trackingKey", trackingKey)
|
||||
|
||||
// Derive the frontend models directory from ModelFile and Model.
|
||||
|
||||
85
core/services/nodes/router_sharedmodels_test.go
Normal file
85
core/services/nodes/router_sharedmodels_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
// These tests cover shared-models mode (LOCALAI_DISTRIBUTED_SHARED_MODELS): when
|
||||
// every node mounts the same models directory at the same path, the router must
|
||||
// NOT stage model files to workers. The canonical absolute path is already valid
|
||||
// on the worker, so staging would only re-download what is already present
|
||||
// (#10556).
|
||||
var _ = Describe("stageModelFiles shared-models mode", func() {
|
||||
var (
|
||||
stager *fakeFileStager
|
||||
node *BackendNode
|
||||
tmp string
|
||||
gguf string
|
||||
modelID = "ornith-1.0-35b"
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
stager = &fakeFileStager{}
|
||||
node = &BackendNode{ID: "node-1", Name: "node-1", Address: "10.0.0.1:50051"}
|
||||
tmp = GinkgoT().TempDir()
|
||||
|
||||
modelDir := filepath.Join(tmp, "models", "llama-cpp", "models")
|
||||
Expect(os.MkdirAll(modelDir, 0o755)).To(Succeed())
|
||||
gguf = filepath.Join(modelDir, "ornith.gguf")
|
||||
Expect(os.WriteFile(gguf, []byte("weights"), 0o644)).To(Succeed())
|
||||
})
|
||||
|
||||
It("does not stage and keeps the canonical absolute ModelFile when shared-models is enabled", func() {
|
||||
router := &SmartRouter{
|
||||
fileStager: stager,
|
||||
stagingTracker: NewStagingTracker(),
|
||||
sharedModels: true,
|
||||
}
|
||||
|
||||
opts := &pb.ModelOptions{
|
||||
Model: "llama-cpp/models/ornith.gguf",
|
||||
ModelFile: gguf,
|
||||
}
|
||||
|
||||
staged, err := router.stageModelFiles(context.Background(), node, opts, modelID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// The file stager must never be touched: no upload, no re-download.
|
||||
Expect(stager.ensureCalls).To(BeEmpty())
|
||||
// The worker loads directly from the shared volume, so the path is unchanged.
|
||||
Expect(staged.ModelFile).To(Equal(gguf))
|
||||
})
|
||||
|
||||
It("stages files (existing behavior) when shared-models is disabled", func() {
|
||||
router := &SmartRouter{
|
||||
fileStager: stager,
|
||||
stagingTracker: NewStagingTracker(),
|
||||
sharedModels: false,
|
||||
}
|
||||
|
||||
opts := &pb.ModelOptions{
|
||||
Model: "llama-cpp/models/ornith.gguf",
|
||||
ModelFile: gguf,
|
||||
}
|
||||
|
||||
staged, err := router.stageModelFiles(context.Background(), node, opts, modelID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Default mode uploads the model file to the worker.
|
||||
Expect(stager.ensureCalls).ToNot(BeEmpty())
|
||||
stagedLocals := make([]string, 0, len(stager.ensureCalls))
|
||||
for _, c := range stager.ensureCalls {
|
||||
stagedLocals = append(stagedLocals, c.localPath)
|
||||
}
|
||||
Expect(stagedLocals).To(ContainElement(gguf))
|
||||
// ModelFile is rewritten to the remote (tracking-key namespaced) path.
|
||||
Expect(staged.ModelFile).ToNot(Equal(gguf))
|
||||
})
|
||||
})
|
||||
@@ -57,6 +57,11 @@ services:
|
||||
LOCALAI_AGENT_POOL_VECTOR_ENGINE: "postgres"
|
||||
LOCALAI_AGENT_POOL_DATABASE_URL: "postgresql://localai:localai@postgres:5432/localai?sslmode=disable"
|
||||
LOCALAI_REGISTRATION_TOKEN: "changeme" # Change this in production!
|
||||
# Shared-models mode (optional): set when every node mounts the SAME
|
||||
# models directory at the SAME path (see "Shared Volume Mode" below).
|
||||
# The router then skips gRPC file staging and workers load models
|
||||
# directly from the shared volume instead of re-downloading them.
|
||||
# LOCALAI_DISTRIBUTED_SHARED_MODELS: "true"
|
||||
# Auth (required for distributed mode — must use PostgreSQL)
|
||||
LOCALAI_AUTH: "true"
|
||||
LOCALAI_AUTH_DATABASE_URL: "postgresql://localai:localai@postgres:5432/localai?sslmode=disable"
|
||||
@@ -157,8 +162,11 @@ services:
|
||||
# Then add to the volumes section:
|
||||
# shared_models:
|
||||
#
|
||||
# With shared volumes, model files are already available on the backend —
|
||||
# gRPC file staging becomes a no-op (paths match).
|
||||
# With shared volumes the model files are already present on every worker at
|
||||
# the same path. Set LOCALAI_DISTRIBUTED_SHARED_MODELS=true on the frontend
|
||||
# (see its environment above) so the router skips gRPC file staging and the
|
||||
# worker loads the model directly from the shared path instead of
|
||||
# re-downloading it into a per-model subdirectory.
|
||||
|
||||
# --- Adding More Workers ---
|
||||
# Copy the worker-1 service above and change:
|
||||
|
||||
@@ -67,6 +67,7 @@ The frontend is a standard LocalAI instance with distributed mode enabled. These
|
||||
| `--registration-require-auth` | `LOCALAI_REGISTRATION_REQUIRE_AUTH` | `false` | Fail startup when distributed mode is enabled but the registration token is empty (node endpoints and worker file-transfer would otherwise be unauthenticated) |
|
||||
| `--distributed-require-auth` | `LOCALAI_DISTRIBUTED_REQUIRE_AUTH` | `false` | **Umbrella switch.** Implies both `--nats-require-auth` and `--registration-require-auth` — one knob to lock down the NATS bus *and* the registration/file-transfer layer. Set this in production instead of the two granular flags. |
|
||||
| `--auto-approve-nodes` | `LOCALAI_AUTO_APPROVE_NODES` | `false` | Auto-approve new worker nodes (skip admin approval) |
|
||||
| `--distributed-shared-models` | `LOCALAI_DISTRIBUTED_SHARED_MODELS` | `false` | Assert that every node mounts the **same** models directory at the **same** path (a shared volume). When `true`, the router skips file staging entirely and workers load models directly from the shared path instead of re-downloading them. See [Shared models directory](#shared-models-directory). |
|
||||
| `--auth` | `LOCALAI_AUTH` | `false` | **Must be `true`** for distributed mode |
|
||||
| `--auth-database-url` | `LOCALAI_AUTH_DATABASE_URL` | *(required)* | PostgreSQL connection URL |
|
||||
| `--backend-install-timeout` | `LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT` | `15m` | How long the frontend waits for a worker to acknowledge a backend install before considering the request stalled. Raise it when workers pull large backend images over slow links. If a worker takes longer than this, the operation shows as "still installing in background" in the admin UI and clears once the worker finishes. |
|
||||
@@ -133,6 +134,14 @@ When S3 is not configured, model files are transferred directly from the fronten
|
||||
|
||||
For high-throughput or very large model files, S3 can be more efficient since it avoids streaming through the frontend.
|
||||
|
||||
### Shared models directory
|
||||
|
||||
If every node (frontend and workers) mounts the **same** models directory at the **same** path - for example a shared volume or network filesystem, as shown in the "Shared Volume Mode" section of `docker-compose.distributed.yaml` - the model files are already present on each worker at their canonical path. In that case staging is wasted work: it copies files that already exist into a per-model subdirectory the worker then loads from, which shows up as a re-download of a model you already have.
|
||||
|
||||
Set `LOCALAI_DISTRIBUTED_SHARED_MODELS=true` (or `--distributed-shared-models`) on the frontend to skip staging entirely. The router then leaves the model's absolute paths untouched and the worker loads them directly from the shared volume.
|
||||
|
||||
This flag is a contract you assert: all nodes must mount identical paths. Leave it off (the default) when workers have independent models directories - the frontend stages files to them over HTTP (or S3) as described above.
|
||||
|
||||
{{% notice warning %}}
|
||||
The worker HTTP file transfer server is authenticated by `LOCALAI_REGISTRATION_TOKEN`. If the token is **empty**, the server **fails open** — anyone who can reach the port gets read/write access to the worker's models/staging/data directories (a remote model-poisoning / exfiltration vector). The worker logs a loud warning at startup in this case. Always set `LOCALAI_REGISTRATION_TOKEN` in distributed mode, and set `LOCALAI_DISTRIBUTED_REQUIRE_AUTH=true` (frontend **and** workers) to make a missing token *or* missing NATS credentials a hard startup error rather than a silent fail-open. Firewall the file-transfer port (gRPC base − 1) so only the frontend can reach it.
|
||||
{{% /notice %}}
|
||||
|
||||
Reference in New Issue
Block a user