mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-27 09:57:14 -04:00
Compare commits
5 Commits
feat/syncs
...
v4.5.5
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d11b202dd2 | ||
|
|
e95018ef70 | ||
|
|
0258f8af55 | ||
|
|
14b29ebf4e | ||
|
|
f0d0bff232 |
16
.github/workflows/test.yml
vendored
16
.github/workflows/test.yml
vendored
@@ -121,3 +121,19 @@ jobs:
|
|||||||
detached: true
|
detached: true
|
||||||
connect-timeout-seconds: 180
|
connect-timeout-seconds: 180
|
||||||
limit-access-to-actor: true
|
limit-access-to-actor: true
|
||||||
|
|
||||||
|
# Fast standalone unit tests for the backends' pure C++ helpers - currently the
|
||||||
|
# llama-cpp message reconstruction (backend/cpp/llama-cpp/message_content.h),
|
||||||
|
# which guards the OpenAI chat content normalization (mudler/LocalAI#10524,
|
||||||
|
# #7324, #7528). The runner discovers every *_test.cpp under backend/cpp/, so
|
||||||
|
# new pure-C++ unit tests are picked up with no CI changes. These need only the
|
||||||
|
# C++ stdlib + nlohmann/json, so they run on every PR without the full
|
||||||
|
# llama.cpp + gRPC backend build. (The same suite is also wired as an opt-in
|
||||||
|
# CMake/ctest target, -DLLAMA_GRPC_BUILD_TESTS=ON, for in-backend-build runs.)
|
||||||
|
tests-backend-cpp:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Clone
|
||||||
|
uses: actions/checkout@v7
|
||||||
|
- name: Run backend C++ unit tests
|
||||||
|
run: make test-backend-cpp
|
||||||
|
|||||||
9
Makefile
9
Makefile
@@ -103,7 +103,7 @@ COVERAGE_E2E_LABELS?=!real-models
|
|||||||
COVERAGE_EXCLUDE_RE?=grpc/proto/.*[.]pb[.]go
|
COVERAGE_EXCLUDE_RE?=grpc/proto/.*[.]pb[.]go
|
||||||
|
|
||||||
|
|
||||||
.PHONY: all test test-coverage test-coverage-baseline test-coverage-check test-ui test-ui-coverage-baseline test-ui-coverage-check install-hooks build vendor lint lint-all
|
.PHONY: all test test-coverage test-coverage-baseline test-coverage-check test-backend-cpp test-ui test-ui-coverage-baseline test-ui-coverage-check install-hooks build vendor lint lint-all
|
||||||
|
|
||||||
all: help
|
all: help
|
||||||
|
|
||||||
@@ -201,6 +201,13 @@ test: prepare-test
|
|||||||
OPUS_SHIM_LIBRARY=$(abspath ./pkg/opus/shim/libopusshim.so) \
|
OPUS_SHIM_LIBRARY=$(abspath ./pkg/opus/shim/libopusshim.so) \
|
||||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) --fail-fast -v -r $(TEST_PATHS)
|
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) --fail-fast -v -r $(TEST_PATHS)
|
||||||
|
|
||||||
|
## Compiles and runs the standalone C++ unit tests for the backends (pure
|
||||||
|
## helpers that depend only on the stdlib + nlohmann/json, no full backend
|
||||||
|
## build). Discovers every *_test.cpp under backend/cpp/ - see
|
||||||
|
## backend/cpp/run-unit-tests.sh. Set NLOHMANN_INCLUDE to skip the header fetch.
|
||||||
|
test-backend-cpp:
|
||||||
|
bash backend/cpp/run-unit-tests.sh
|
||||||
|
|
||||||
## Runs the core suite ($(TEST_PATHS)) with statement-coverage instrumentation
|
## Runs the core suite ($(TEST_PATHS)) with statement-coverage instrumentation
|
||||||
## and writes a merged profile to $(COVERAGE_PROFILE). Deliberately omits
|
## and writes a merged profile to $(COVERAGE_PROFILE). Deliberately omits
|
||||||
## --fail-fast so a single failure doesn't truncate the coverage number, and
|
## --fail-fast so a single failure doesn't truncate the coverage number, and
|
||||||
|
|||||||
@@ -87,3 +87,18 @@ target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
|||||||
if(TARGET BUILD_INFO)
|
if(TARGET BUILD_INFO)
|
||||||
add_dependencies(${TARGET} BUILD_INFO)
|
add_dependencies(${TARGET} BUILD_INFO)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
# Unit test for the message-content normalization helper (message_content.h).
|
||||||
|
# Off by default so the normal backend build is untouched; enable with
|
||||||
|
# -DLLAMA_GRPC_BUILD_TESTS=ON and run via ctest. It reuses llama.cpp's vendored
|
||||||
|
# <nlohmann/json.hpp> (propagated by the common helpers library) so it has no
|
||||||
|
# extra dependency beyond what the backend already builds against.
|
||||||
|
option(LLAMA_GRPC_BUILD_TESTS "Build grpc-server unit tests" OFF)
|
||||||
|
if(LLAMA_GRPC_BUILD_TESTS)
|
||||||
|
enable_testing()
|
||||||
|
add_executable(message_content_test message_content_test.cpp message_content.h)
|
||||||
|
target_include_directories(message_content_test PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||||
|
target_link_libraries(message_content_test PRIVATE ${_LLAMA_COMMON_TARGET})
|
||||||
|
target_compile_features(message_content_test PRIVATE cxx_std_17)
|
||||||
|
add_test(NAME message_content_test COMMAND message_content_test)
|
||||||
|
endif()
|
||||||
|
|||||||
@@ -39,6 +39,7 @@
|
|||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "arg.h"
|
#include "arg.h"
|
||||||
#include "chat-auto-parser.h"
|
#include "chat-auto-parser.h"
|
||||||
|
#include "message_content.h"
|
||||||
#include <getopt.h>
|
#include <getopt.h>
|
||||||
#include <grpcpp/ext/proto_server_reflection_plugin.h>
|
#include <grpcpp/ext/proto_server_reflection_plugin.h>
|
||||||
#include <grpcpp/grpcpp.h>
|
#include <grpcpp/grpcpp.h>
|
||||||
@@ -1616,242 +1617,20 @@ public:
|
|||||||
|
|
||||||
for (int i = 0; i < request->messages_size(); i++) {
|
for (int i = 0; i < request->messages_size(); i++) {
|
||||||
const auto& msg = request->messages(i);
|
const auto& msg = request->messages(i);
|
||||||
json msg_json;
|
llama_grpc::ReconstructedMessageInput rin;
|
||||||
msg_json["role"] = msg.role();
|
rin.role = msg.role();
|
||||||
|
rin.content = msg.content();
|
||||||
bool is_last_user_msg = (i == last_user_msg_idx);
|
rin.name = msg.name();
|
||||||
bool has_images_or_audio = (request->images_size() > 0 || request->audios_size() > 0 || request->videos_size() > 0);
|
rin.tool_call_id = msg.tool_call_id();
|
||||||
|
rin.reasoning_content = msg.reasoning_content();
|
||||||
// Handle content - can be string, null, or array
|
rin.tool_calls = msg.tool_calls();
|
||||||
// For multimodal content, we'll embed images/audio from separate fields
|
rin.is_last_user_msg = (i == last_user_msg_idx);
|
||||||
if (!msg.content().empty()) {
|
if (rin.is_last_user_msg) {
|
||||||
// Try to parse content as JSON to see if it's already an array
|
for (int j = 0; j < request->images_size(); j++) rin.images.push_back(request->images(j));
|
||||||
json content_val;
|
for (int j = 0; j < request->audios_size(); j++) rin.audios.push_back(request->audios(j));
|
||||||
try {
|
for (int j = 0; j < request->videos_size(); j++) rin.videos.push_back(request->videos(j));
|
||||||
content_val = json::parse(msg.content());
|
|
||||||
// Handle null values - convert to empty string to avoid template errors
|
|
||||||
if (content_val.is_null()) {
|
|
||||||
content_val = "";
|
|
||||||
}
|
|
||||||
} catch (const json::parse_error&) {
|
|
||||||
// Not JSON, treat as plain string
|
|
||||||
content_val = msg.content();
|
|
||||||
}
|
|
||||||
|
|
||||||
// If content is an object (e.g., from tool call failures), convert to string
|
|
||||||
if (content_val.is_object()) {
|
|
||||||
content_val = content_val.dump();
|
|
||||||
}
|
|
||||||
|
|
||||||
// If content is a string and this is the last user message with images/audio, combine them
|
|
||||||
if (content_val.is_string() && is_last_user_msg && has_images_or_audio) {
|
|
||||||
json content_array = json::array();
|
|
||||||
// Add text first
|
|
||||||
content_array.push_back({{"type", "text"}, {"text", content_val.get<std::string>()}});
|
|
||||||
// Add images
|
|
||||||
if (request->images_size() > 0) {
|
|
||||||
for (int j = 0; j < request->images_size(); j++) {
|
|
||||||
json image_chunk;
|
|
||||||
image_chunk["type"] = "image_url";
|
|
||||||
json image_url;
|
|
||||||
image_url["url"] = "data:image/jpeg;base64," + request->images(j);
|
|
||||||
image_chunk["image_url"] = image_url;
|
|
||||||
content_array.push_back(image_chunk);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Add audios
|
|
||||||
if (request->audios_size() > 0) {
|
|
||||||
for (int j = 0; j < request->audios_size(); j++) {
|
|
||||||
json audio_chunk;
|
|
||||||
audio_chunk["type"] = "input_audio";
|
|
||||||
json input_audio;
|
|
||||||
input_audio["data"] = request->audios(j);
|
|
||||||
input_audio["format"] = "wav"; // default, could be made configurable
|
|
||||||
audio_chunk["input_audio"] = input_audio;
|
|
||||||
content_array.push_back(audio_chunk);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (request->videos_size() > 0) {
|
|
||||||
for (int j = 0; j < request->videos_size(); j++) {
|
|
||||||
json video_chunk;
|
|
||||||
video_chunk["type"] = "input_video";
|
|
||||||
json input_video;
|
|
||||||
input_video["data"] = request->videos(j);
|
|
||||||
video_chunk["input_video"] = input_video;
|
|
||||||
content_array.push_back(video_chunk);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
msg_json["content"] = content_array;
|
|
||||||
} else {
|
|
||||||
// Use content as-is (already array or not last user message)
|
|
||||||
// Ensure null values are converted to empty string
|
|
||||||
if (content_val.is_null()) {
|
|
||||||
msg_json["content"] = "";
|
|
||||||
} else {
|
|
||||||
msg_json["content"] = content_val;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if (is_last_user_msg && has_images_or_audio) {
|
|
||||||
// If no content but this is the last user message with images/audio, create content array
|
|
||||||
json content_array = json::array();
|
|
||||||
if (request->images_size() > 0) {
|
|
||||||
for (int j = 0; j < request->images_size(); j++) {
|
|
||||||
json image_chunk;
|
|
||||||
image_chunk["type"] = "image_url";
|
|
||||||
json image_url;
|
|
||||||
image_url["url"] = "data:image/jpeg;base64," + request->images(j);
|
|
||||||
image_chunk["image_url"] = image_url;
|
|
||||||
content_array.push_back(image_chunk);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (request->audios_size() > 0) {
|
|
||||||
for (int j = 0; j < request->audios_size(); j++) {
|
|
||||||
json audio_chunk;
|
|
||||||
audio_chunk["type"] = "input_audio";
|
|
||||||
json input_audio;
|
|
||||||
input_audio["data"] = request->audios(j);
|
|
||||||
input_audio["format"] = "wav"; // default, could be made configurable
|
|
||||||
audio_chunk["input_audio"] = input_audio;
|
|
||||||
content_array.push_back(audio_chunk);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (request->videos_size() > 0) {
|
|
||||||
for (int j = 0; j < request->videos_size(); j++) {
|
|
||||||
json video_chunk;
|
|
||||||
video_chunk["type"] = "input_video";
|
|
||||||
json input_video;
|
|
||||||
input_video["data"] = request->videos(j);
|
|
||||||
video_chunk["input_video"] = input_video;
|
|
||||||
content_array.push_back(video_chunk);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
msg_json["content"] = content_array;
|
|
||||||
} else if (msg.role() == "tool") {
|
|
||||||
// Tool role messages must have content field set, even if empty
|
|
||||||
// Jinja templates expect content to be a string, not null or object
|
|
||||||
SRV_INF("[CONTENT DEBUG] PredictStream: Message %d is tool role, content_empty=%d\n", i, msg.content().empty() ? 1 : 0);
|
|
||||||
if (msg.content().empty()) {
|
|
||||||
msg_json["content"] = "";
|
|
||||||
SRV_INF("[CONTENT DEBUG] PredictStream: Message %d (tool): empty content, set to empty string\n", i);
|
|
||||||
} else {
|
|
||||||
SRV_INF("[CONTENT DEBUG] PredictStream: Message %d (tool): content exists: %s\n",
|
|
||||||
i, msg.content().substr(0, std::min<size_t>(200, msg.content().size())).c_str());
|
|
||||||
// Content exists, parse and ensure it's a string
|
|
||||||
json content_val;
|
|
||||||
try {
|
|
||||||
content_val = json::parse(msg.content());
|
|
||||||
SRV_INF("[CONTENT DEBUG] PredictStream: Message %d (tool): parsed JSON, type=%s\n",
|
|
||||||
i, content_val.is_null() ? "null" :
|
|
||||||
content_val.is_object() ? "object" :
|
|
||||||
content_val.is_string() ? "string" :
|
|
||||||
content_val.is_array() ? "array" : "other");
|
|
||||||
// Handle null values - Jinja templates expect content to be a string, not null
|
|
||||||
if (content_val.is_null()) {
|
|
||||||
msg_json["content"] = "";
|
|
||||||
SRV_INF("[CONTENT DEBUG] PredictStream: Message %d (tool): null content, converted to empty string\n", i);
|
|
||||||
} else if (content_val.is_object()) {
|
|
||||||
// If content is an object (e.g., from tool call failures/errors), convert to string
|
|
||||||
msg_json["content"] = content_val.dump();
|
|
||||||
SRV_INF("[CONTENT DEBUG] PredictStream: Message %d (tool): object content, converted to string: %s\n",
|
|
||||||
i, content_val.dump().substr(0, std::min<size_t>(200, content_val.dump().size())).c_str());
|
|
||||||
} else if (content_val.is_string()) {
|
|
||||||
msg_json["content"] = content_val.get<std::string>();
|
|
||||||
SRV_INF("[CONTENT DEBUG] PredictStream: Message %d (tool): string content, using as-is\n", i);
|
|
||||||
} else {
|
|
||||||
// For arrays or other types, convert to string
|
|
||||||
msg_json["content"] = content_val.dump();
|
|
||||||
SRV_INF("[CONTENT DEBUG] PredictStream: Message %d (tool): %s content, converted to string\n",
|
|
||||||
i, content_val.is_array() ? "array" : "other type");
|
|
||||||
}
|
|
||||||
} catch (const json::parse_error&) {
|
|
||||||
// Not JSON, treat as plain string
|
|
||||||
msg_json["content"] = msg.content();
|
|
||||||
SRV_INF("[CONTENT DEBUG] PredictStream: Message %d (tool): not JSON, using as string\n", i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Ensure all messages have content set (fallback for any unhandled cases)
|
|
||||||
// Jinja templates expect content to be present, default to empty string if not set
|
|
||||||
if (!msg_json.contains("content")) {
|
|
||||||
SRV_INF("[CONTENT DEBUG] PredictStream: Message %d (role=%s): no content field, adding empty string\n",
|
|
||||||
i, msg.role().c_str());
|
|
||||||
msg_json["content"] = "";
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
messages_json.push_back(llama_grpc::build_reconstructed_message(rin));
|
||||||
// Add optional fields for OpenAI-compatible message format
|
|
||||||
if (!msg.name().empty()) {
|
|
||||||
msg_json["name"] = msg.name();
|
|
||||||
}
|
|
||||||
if (!msg.tool_call_id().empty()) {
|
|
||||||
msg_json["tool_call_id"] = msg.tool_call_id();
|
|
||||||
}
|
|
||||||
if (!msg.reasoning_content().empty()) {
|
|
||||||
msg_json["reasoning_content"] = msg.reasoning_content();
|
|
||||||
}
|
|
||||||
if (!msg.tool_calls().empty()) {
|
|
||||||
// Parse tool_calls JSON string and add to message
|
|
||||||
try {
|
|
||||||
json tool_calls = json::parse(msg.tool_calls());
|
|
||||||
msg_json["tool_calls"] = tool_calls;
|
|
||||||
SRV_INF("[TOOL CALLS DEBUG] PredictStream: Message %d has tool_calls: %s\n", i, tool_calls.dump().c_str());
|
|
||||||
// IMPORTANT: If message has tool_calls but content is empty or not set,
|
|
||||||
// set content to space " " instead of empty string "", because llama.cpp's
|
|
||||||
// common_chat_msgs_to_json_oaicompat converts empty strings to null (line 312),
|
|
||||||
// which causes template errors when accessing message.content[:tool_start_length]
|
|
||||||
if (!msg_json.contains("content") || (msg_json.contains("content") && msg_json["content"].is_string() && msg_json["content"].get<std::string>().empty())) {
|
|
||||||
SRV_INF("[CONTENT DEBUG] PredictStream: Message %d has tool_calls but empty content, setting to space\n", i);
|
|
||||||
msg_json["content"] = " ";
|
|
||||||
}
|
|
||||||
// Log each tool call with name and arguments
|
|
||||||
if (tool_calls.is_array()) {
|
|
||||||
for (size_t tc_idx = 0; tc_idx < tool_calls.size(); tc_idx++) {
|
|
||||||
const auto& tc = tool_calls[tc_idx];
|
|
||||||
std::string tool_name = "unknown";
|
|
||||||
std::string tool_args = "{}";
|
|
||||||
if (tc.contains("function")) {
|
|
||||||
const auto& func = tc["function"];
|
|
||||||
if (func.contains("name")) {
|
|
||||||
tool_name = func["name"].get<std::string>();
|
|
||||||
}
|
|
||||||
if (func.contains("arguments")) {
|
|
||||||
tool_args = func["arguments"].is_string() ?
|
|
||||||
func["arguments"].get<std::string>() :
|
|
||||||
func["arguments"].dump();
|
|
||||||
}
|
|
||||||
} else if (tc.contains("name")) {
|
|
||||||
tool_name = tc["name"].get<std::string>();
|
|
||||||
if (tc.contains("arguments")) {
|
|
||||||
tool_args = tc["arguments"].is_string() ?
|
|
||||||
tc["arguments"].get<std::string>() :
|
|
||||||
tc["arguments"].dump();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
SRV_INF("[TOOL CALLS DEBUG] PredictStream: Message %d, tool_call %zu: name=%s, arguments=%s\n",
|
|
||||||
i, tc_idx, tool_name.c_str(), tool_args.c_str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch (const json::parse_error& e) {
|
|
||||||
SRV_WRN("Failed to parse tool_calls JSON: %s\n", e.what());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Debug: Log final content state before adding to array
|
|
||||||
if (msg_json.contains("content")) {
|
|
||||||
if (msg_json["content"].is_null()) {
|
|
||||||
SRV_INF("[CONTENT DEBUG] PredictStream: Message %d FINAL STATE: content is NULL - THIS WILL CAUSE ERROR!\n", i);
|
|
||||||
} else {
|
|
||||||
SRV_INF("[CONTENT DEBUG] PredictStream: Message %d FINAL STATE: content type=%s, has_value=%d\n",
|
|
||||||
i, msg_json["content"].is_string() ? "string" :
|
|
||||||
msg_json["content"].is_array() ? "array" :
|
|
||||||
msg_json["content"].is_object() ? "object" : "other",
|
|
||||||
msg_json["content"].is_null() ? 0 : 1);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
SRV_INF("[CONTENT DEBUG] PredictStream: Message %d FINAL STATE: NO CONTENT FIELD - THIS WILL CAUSE ERROR!\n", i);
|
|
||||||
}
|
|
||||||
|
|
||||||
messages_json.push_back(msg_json);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Final safety check: Ensure no message has null content (Jinja templates require strings)
|
// Final safety check: Ensure no message has null content (Jinja templates require strings)
|
||||||
@@ -2072,36 +1851,7 @@ public:
|
|||||||
if (body_json.contains("messages") && body_json["messages"].is_array()) {
|
if (body_json.contains("messages") && body_json["messages"].is_array()) {
|
||||||
SRV_INF("[CONTENT DEBUG] PredictStream: Before oaicompat_chat_params_parse - checking %zu messages\n", body_json["messages"].size());
|
SRV_INF("[CONTENT DEBUG] PredictStream: Before oaicompat_chat_params_parse - checking %zu messages\n", body_json["messages"].size());
|
||||||
for (size_t idx = 0; idx < body_json["messages"].size(); idx++) {
|
for (size_t idx = 0; idx < body_json["messages"].size(); idx++) {
|
||||||
auto& msg = body_json["messages"][idx];
|
llama_grpc::normalize_template_message(body_json["messages"][idx]);
|
||||||
std::string role_str = msg.contains("role") ? msg["role"].get<std::string>() : "unknown";
|
|
||||||
if (msg.contains("content")) {
|
|
||||||
if (msg["content"].is_null()) {
|
|
||||||
SRV_INF("[CONTENT DEBUG] PredictStream: BEFORE TEMPLATE - Message %zu (role=%s) has NULL content - FIXING!\n", idx, role_str.c_str());
|
|
||||||
msg["content"] = ""; // Fix null content
|
|
||||||
} else if (role_str == "tool" && msg["content"].is_array()) {
|
|
||||||
// Tool messages must have string content, not array
|
|
||||||
// oaicompat_chat_params_parse expects tool messages to have string content
|
|
||||||
SRV_INF("[CONTENT DEBUG] PredictStream: BEFORE TEMPLATE - Message %zu (role=tool) has array content, converting to string\n", idx);
|
|
||||||
msg["content"] = msg["content"].dump();
|
|
||||||
} else if (!msg["content"].is_string() && !msg["content"].is_array()) {
|
|
||||||
// If content is object or other non-string type, convert to string for templates
|
|
||||||
SRV_INF("[CONTENT DEBUG] PredictStream: BEFORE TEMPLATE - Message %zu (role=%s) content is not string/array, converting\n", idx, role_str.c_str());
|
|
||||||
if (msg["content"].is_object()) {
|
|
||||||
msg["content"] = msg["content"].dump();
|
|
||||||
} else {
|
|
||||||
msg["content"] = "";
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
SRV_INF("[CONTENT DEBUG] PredictStream: BEFORE TEMPLATE - Message %zu (role=%s): content type=%s\n",
|
|
||||||
idx, role_str.c_str(),
|
|
||||||
msg["content"].is_string() ? "string" :
|
|
||||||
msg["content"].is_array() ? "array" :
|
|
||||||
msg["content"].is_object() ? "object" : "other");
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
SRV_INF("[CONTENT DEBUG] PredictStream: BEFORE TEMPLATE - Message %zu (role=%s) MISSING content field - ADDING!\n", idx, role_str.c_str());
|
|
||||||
msg["content"] = ""; // Add missing content
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2433,264 +2183,20 @@ public:
|
|||||||
SRV_INF("[CONTENT DEBUG] Predict: Processing %d messages\n", request->messages_size());
|
SRV_INF("[CONTENT DEBUG] Predict: Processing %d messages\n", request->messages_size());
|
||||||
for (int i = 0; i < request->messages_size(); i++) {
|
for (int i = 0; i < request->messages_size(); i++) {
|
||||||
const auto& msg = request->messages(i);
|
const auto& msg = request->messages(i);
|
||||||
json msg_json;
|
llama_grpc::ReconstructedMessageInput rin;
|
||||||
msg_json["role"] = msg.role();
|
rin.role = msg.role();
|
||||||
|
rin.content = msg.content();
|
||||||
SRV_INF("[CONTENT DEBUG] Predict: Message %d: role=%s, content_empty=%d, content_length=%zu\n",
|
rin.name = msg.name();
|
||||||
i, msg.role().c_str(), msg.content().empty() ? 1 : 0, msg.content().size());
|
rin.tool_call_id = msg.tool_call_id();
|
||||||
if (!msg.content().empty()) {
|
rin.reasoning_content = msg.reasoning_content();
|
||||||
SRV_INF("[CONTENT DEBUG] Predict: Message %d content (first 200 chars): %s\n",
|
rin.tool_calls = msg.tool_calls();
|
||||||
i, msg.content().substr(0, std::min<size_t>(200, msg.content().size())).c_str());
|
rin.is_last_user_msg = (i == last_user_msg_idx);
|
||||||
|
if (rin.is_last_user_msg) {
|
||||||
|
for (int j = 0; j < request->images_size(); j++) rin.images.push_back(request->images(j));
|
||||||
|
for (int j = 0; j < request->audios_size(); j++) rin.audios.push_back(request->audios(j));
|
||||||
|
for (int j = 0; j < request->videos_size(); j++) rin.videos.push_back(request->videos(j));
|
||||||
}
|
}
|
||||||
|
messages_json.push_back(llama_grpc::build_reconstructed_message(rin));
|
||||||
bool is_last_user_msg = (i == last_user_msg_idx);
|
|
||||||
bool has_images_or_audio = (request->images_size() > 0 || request->audios_size() > 0 || request->videos_size() > 0);
|
|
||||||
|
|
||||||
// Handle content - can be string, null, or array
|
|
||||||
// For multimodal content, we'll embed images/audio from separate fields
|
|
||||||
if (!msg.content().empty()) {
|
|
||||||
// Try to parse content as JSON to see if it's already an array
|
|
||||||
json content_val;
|
|
||||||
try {
|
|
||||||
content_val = json::parse(msg.content());
|
|
||||||
// Handle null values - convert to empty string to avoid template errors
|
|
||||||
if (content_val.is_null()) {
|
|
||||||
SRV_INF("[CONTENT DEBUG] Predict: Message %d parsed JSON is null, converting to empty string\n", i);
|
|
||||||
content_val = "";
|
|
||||||
}
|
|
||||||
} catch (const json::parse_error&) {
|
|
||||||
// Not JSON, treat as plain string
|
|
||||||
content_val = msg.content();
|
|
||||||
}
|
|
||||||
|
|
||||||
// If content is an object (e.g., from tool call failures), convert to string
|
|
||||||
if (content_val.is_object()) {
|
|
||||||
SRV_INF("[CONTENT DEBUG] Predict: Message %d content is object, converting to string\n", i);
|
|
||||||
content_val = content_val.dump();
|
|
||||||
}
|
|
||||||
|
|
||||||
// If content is a string and this is the last user message with images/audio, combine them
|
|
||||||
if (content_val.is_string() && is_last_user_msg && has_images_or_audio) {
|
|
||||||
json content_array = json::array();
|
|
||||||
// Add text first
|
|
||||||
content_array.push_back({{"type", "text"}, {"text", content_val.get<std::string>()}});
|
|
||||||
// Add images
|
|
||||||
if (request->images_size() > 0) {
|
|
||||||
for (int j = 0; j < request->images_size(); j++) {
|
|
||||||
json image_chunk;
|
|
||||||
image_chunk["type"] = "image_url";
|
|
||||||
json image_url;
|
|
||||||
image_url["url"] = "data:image/jpeg;base64," + request->images(j);
|
|
||||||
image_chunk["image_url"] = image_url;
|
|
||||||
content_array.push_back(image_chunk);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Add audios
|
|
||||||
if (request->audios_size() > 0) {
|
|
||||||
for (int j = 0; j < request->audios_size(); j++) {
|
|
||||||
json audio_chunk;
|
|
||||||
audio_chunk["type"] = "input_audio";
|
|
||||||
json input_audio;
|
|
||||||
input_audio["data"] = request->audios(j);
|
|
||||||
input_audio["format"] = "wav"; // default, could be made configurable
|
|
||||||
audio_chunk["input_audio"] = input_audio;
|
|
||||||
content_array.push_back(audio_chunk);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (request->videos_size() > 0) {
|
|
||||||
for (int j = 0; j < request->videos_size(); j++) {
|
|
||||||
json video_chunk;
|
|
||||||
video_chunk["type"] = "input_video";
|
|
||||||
json input_video;
|
|
||||||
input_video["data"] = request->videos(j);
|
|
||||||
video_chunk["input_video"] = input_video;
|
|
||||||
content_array.push_back(video_chunk);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
msg_json["content"] = content_array;
|
|
||||||
} else {
|
|
||||||
// Use content as-is (already array or not last user message)
|
|
||||||
// Ensure null values are converted to empty string
|
|
||||||
if (content_val.is_null()) {
|
|
||||||
SRV_INF("[CONTENT DEBUG] Predict: Message %d content_val was null, setting to empty string\n", i);
|
|
||||||
msg_json["content"] = "";
|
|
||||||
} else {
|
|
||||||
msg_json["content"] = content_val;
|
|
||||||
SRV_INF("[CONTENT DEBUG] Predict: Message %d content set, type=%s\n",
|
|
||||||
i, content_val.is_string() ? "string" :
|
|
||||||
content_val.is_array() ? "array" :
|
|
||||||
content_val.is_object() ? "object" : "other");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if (is_last_user_msg && has_images_or_audio) {
|
|
||||||
// If no content but this is the last user message with images/audio, create content array
|
|
||||||
json content_array = json::array();
|
|
||||||
if (request->images_size() > 0) {
|
|
||||||
for (int j = 0; j < request->images_size(); j++) {
|
|
||||||
json image_chunk;
|
|
||||||
image_chunk["type"] = "image_url";
|
|
||||||
json image_url;
|
|
||||||
image_url["url"] = "data:image/jpeg;base64," + request->images(j);
|
|
||||||
image_chunk["image_url"] = image_url;
|
|
||||||
content_array.push_back(image_chunk);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (request->audios_size() > 0) {
|
|
||||||
for (int j = 0; j < request->audios_size(); j++) {
|
|
||||||
json audio_chunk;
|
|
||||||
audio_chunk["type"] = "input_audio";
|
|
||||||
json input_audio;
|
|
||||||
input_audio["data"] = request->audios(j);
|
|
||||||
input_audio["format"] = "wav"; // default, could be made configurable
|
|
||||||
audio_chunk["input_audio"] = input_audio;
|
|
||||||
content_array.push_back(audio_chunk);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (request->videos_size() > 0) {
|
|
||||||
for (int j = 0; j < request->videos_size(); j++) {
|
|
||||||
json video_chunk;
|
|
||||||
video_chunk["type"] = "input_video";
|
|
||||||
json input_video;
|
|
||||||
input_video["data"] = request->videos(j);
|
|
||||||
video_chunk["input_video"] = input_video;
|
|
||||||
content_array.push_back(video_chunk);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
msg_json["content"] = content_array;
|
|
||||||
SRV_INF("[CONTENT DEBUG] Predict: Message %d created content array with media\n", i);
|
|
||||||
} else if (!msg.tool_calls().empty()) {
|
|
||||||
// Tool call messages may have null content, but templates expect string
|
|
||||||
// IMPORTANT: Set to space " " instead of empty string "", because llama.cpp's
|
|
||||||
// common_chat_msgs_to_json_oaicompat converts empty strings to null (line 312),
|
|
||||||
// which causes template errors when accessing message.content[:tool_start_length]
|
|
||||||
SRV_INF("[CONTENT DEBUG] Predict: Message %d has tool_calls, setting content to space (not empty string)\n", i);
|
|
||||||
msg_json["content"] = " ";
|
|
||||||
} else if (msg.role() == "tool") {
|
|
||||||
// Tool role messages must have content field set, even if empty
|
|
||||||
// Jinja templates expect content to be a string, not null or object
|
|
||||||
SRV_INF("[CONTENT DEBUG] Predict: Message %d is tool role, content_empty=%d\n", i, msg.content().empty() ? 1 : 0);
|
|
||||||
if (msg.content().empty()) {
|
|
||||||
msg_json["content"] = "";
|
|
||||||
SRV_INF("[CONTENT DEBUG] Predict: Message %d (tool): empty content, set to empty string\n", i);
|
|
||||||
} else {
|
|
||||||
SRV_INF("[CONTENT DEBUG] Predict: Message %d (tool): content exists: %s\n",
|
|
||||||
i, msg.content().substr(0, std::min<size_t>(200, msg.content().size())).c_str());
|
|
||||||
// Content exists, parse and ensure it's a string
|
|
||||||
json content_val;
|
|
||||||
try {
|
|
||||||
content_val = json::parse(msg.content());
|
|
||||||
SRV_INF("[CONTENT DEBUG] Predict: Message %d (tool): parsed JSON, type=%s\n",
|
|
||||||
i, content_val.is_null() ? "null" :
|
|
||||||
content_val.is_object() ? "object" :
|
|
||||||
content_val.is_string() ? "string" :
|
|
||||||
content_val.is_array() ? "array" : "other");
|
|
||||||
// Handle null values - Jinja templates expect content to be a string, not null
|
|
||||||
if (content_val.is_null()) {
|
|
||||||
msg_json["content"] = "";
|
|
||||||
SRV_INF("[CONTENT DEBUG] Predict: Message %d (tool): null content, converted to empty string\n", i);
|
|
||||||
} else if (content_val.is_object()) {
|
|
||||||
// If content is an object (e.g., from tool call failures/errors), convert to string
|
|
||||||
msg_json["content"] = content_val.dump();
|
|
||||||
SRV_INF("[CONTENT DEBUG] Predict: Message %d (tool): object content, converted to string: %s\n",
|
|
||||||
i, content_val.dump().substr(0, std::min<size_t>(200, content_val.dump().size())).c_str());
|
|
||||||
} else if (content_val.is_string()) {
|
|
||||||
msg_json["content"] = content_val.get<std::string>();
|
|
||||||
SRV_INF("[CONTENT DEBUG] Predict: Message %d (tool): string content, using as-is\n", i);
|
|
||||||
} else {
|
|
||||||
// For arrays or other types, convert to string
|
|
||||||
msg_json["content"] = content_val.dump();
|
|
||||||
SRV_INF("[CONTENT DEBUG] Predict: Message %d (tool): %s content, converted to string\n",
|
|
||||||
i, content_val.is_array() ? "array" : "other type");
|
|
||||||
}
|
|
||||||
} catch (const json::parse_error&) {
|
|
||||||
// Not JSON, treat as plain string
|
|
||||||
msg_json["content"] = msg.content();
|
|
||||||
SRV_INF("[CONTENT DEBUG] Predict: Message %d (tool): not JSON, using as string\n", i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Ensure all messages have content set (fallback for any unhandled cases)
|
|
||||||
// Jinja templates expect content to be present, default to empty string if not set
|
|
||||||
if (!msg_json.contains("content")) {
|
|
||||||
SRV_INF("[CONTENT DEBUG] Predict: Message %d (role=%s): no content field, adding empty string\n",
|
|
||||||
i, msg.role().c_str());
|
|
||||||
msg_json["content"] = "";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add optional fields for OpenAI-compatible message format
|
|
||||||
if (!msg.name().empty()) {
|
|
||||||
msg_json["name"] = msg.name();
|
|
||||||
}
|
|
||||||
if (!msg.tool_call_id().empty()) {
|
|
||||||
msg_json["tool_call_id"] = msg.tool_call_id();
|
|
||||||
}
|
|
||||||
if (!msg.reasoning_content().empty()) {
|
|
||||||
msg_json["reasoning_content"] = msg.reasoning_content();
|
|
||||||
}
|
|
||||||
if (!msg.tool_calls().empty()) {
|
|
||||||
// Parse tool_calls JSON string and add to message
|
|
||||||
try {
|
|
||||||
json tool_calls = json::parse(msg.tool_calls());
|
|
||||||
msg_json["tool_calls"] = tool_calls;
|
|
||||||
SRV_INF("[TOOL CALLS DEBUG] Predict: Message %d has tool_calls: %s\n", i, tool_calls.dump().c_str());
|
|
||||||
// IMPORTANT: If message has tool_calls but content is empty or not set,
|
|
||||||
// set content to space " " instead of empty string "", because llama.cpp's
|
|
||||||
// common_chat_msgs_to_json_oaicompat converts empty strings to null (line 312),
|
|
||||||
// which causes template errors when accessing message.content[:tool_start_length]
|
|
||||||
if (!msg_json.contains("content") || (msg_json.contains("content") && msg_json["content"].is_string() && msg_json["content"].get<std::string>().empty())) {
|
|
||||||
SRV_INF("[CONTENT DEBUG] Predict: Message %d has tool_calls but empty content, setting to space\n", i);
|
|
||||||
msg_json["content"] = " ";
|
|
||||||
}
|
|
||||||
// Log each tool call with name and arguments
|
|
||||||
if (tool_calls.is_array()) {
|
|
||||||
for (size_t tc_idx = 0; tc_idx < tool_calls.size(); tc_idx++) {
|
|
||||||
const auto& tc = tool_calls[tc_idx];
|
|
||||||
std::string tool_name = "unknown";
|
|
||||||
std::string tool_args = "{}";
|
|
||||||
if (tc.contains("function")) {
|
|
||||||
const auto& func = tc["function"];
|
|
||||||
if (func.contains("name")) {
|
|
||||||
tool_name = func["name"].get<std::string>();
|
|
||||||
}
|
|
||||||
if (func.contains("arguments")) {
|
|
||||||
tool_args = func["arguments"].is_string() ?
|
|
||||||
func["arguments"].get<std::string>() :
|
|
||||||
func["arguments"].dump();
|
|
||||||
}
|
|
||||||
} else if (tc.contains("name")) {
|
|
||||||
tool_name = tc["name"].get<std::string>();
|
|
||||||
if (tc.contains("arguments")) {
|
|
||||||
tool_args = tc["arguments"].is_string() ?
|
|
||||||
tc["arguments"].get<std::string>() :
|
|
||||||
tc["arguments"].dump();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
SRV_INF("[TOOL CALLS DEBUG] Predict: Message %d, tool_call %zu: name=%s, arguments=%s\n",
|
|
||||||
i, tc_idx, tool_name.c_str(), tool_args.c_str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch (const json::parse_error& e) {
|
|
||||||
SRV_WRN("Failed to parse tool_calls JSON: %s\n", e.what());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Debug: Log final content state before adding to array
|
|
||||||
if (msg_json.contains("content")) {
|
|
||||||
if (msg_json["content"].is_null()) {
|
|
||||||
SRV_INF("[CONTENT DEBUG] Predict: Message %d FINAL STATE: content is NULL - THIS WILL CAUSE ERROR!\n", i);
|
|
||||||
} else {
|
|
||||||
SRV_INF("[CONTENT DEBUG] Predict: Message %d FINAL STATE: content type=%s, has_value=%d\n",
|
|
||||||
i, msg_json["content"].is_string() ? "string" :
|
|
||||||
msg_json["content"].is_array() ? "array" :
|
|
||||||
msg_json["content"].is_object() ? "object" : "other",
|
|
||||||
msg_json["content"].is_null() ? 0 : 1);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
SRV_INF("[CONTENT DEBUG] Predict: Message %d FINAL STATE: NO CONTENT FIELD - THIS WILL CAUSE ERROR!\n", i);
|
|
||||||
}
|
|
||||||
|
|
||||||
messages_json.push_back(msg_json);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Final safety check: Ensure no message has null content (Jinja templates require strings)
|
// Final safety check: Ensure no message has null content (Jinja templates require strings)
|
||||||
@@ -2911,36 +2417,7 @@ public:
|
|||||||
if (body_json.contains("messages") && body_json["messages"].is_array()) {
|
if (body_json.contains("messages") && body_json["messages"].is_array()) {
|
||||||
SRV_INF("[CONTENT DEBUG] Predict: Before oaicompat_chat_params_parse - checking %zu messages\n", body_json["messages"].size());
|
SRV_INF("[CONTENT DEBUG] Predict: Before oaicompat_chat_params_parse - checking %zu messages\n", body_json["messages"].size());
|
||||||
for (size_t idx = 0; idx < body_json["messages"].size(); idx++) {
|
for (size_t idx = 0; idx < body_json["messages"].size(); idx++) {
|
||||||
auto& msg = body_json["messages"][idx];
|
llama_grpc::normalize_template_message(body_json["messages"][idx]);
|
||||||
std::string role_str = msg.contains("role") ? msg["role"].get<std::string>() : "unknown";
|
|
||||||
if (msg.contains("content")) {
|
|
||||||
if (msg["content"].is_null()) {
|
|
||||||
SRV_INF("[CONTENT DEBUG] Predict: BEFORE TEMPLATE - Message %zu (role=%s) has NULL content - FIXING!\n", idx, role_str.c_str());
|
|
||||||
msg["content"] = ""; // Fix null content
|
|
||||||
} else if (role_str == "tool" && msg["content"].is_array()) {
|
|
||||||
// Tool messages must have string content, not array
|
|
||||||
// oaicompat_chat_params_parse expects tool messages to have string content
|
|
||||||
SRV_INF("[CONTENT DEBUG] Predict: BEFORE TEMPLATE - Message %zu (role=tool) has array content, converting to string\n", idx);
|
|
||||||
msg["content"] = msg["content"].dump();
|
|
||||||
} else if (!msg["content"].is_string() && !msg["content"].is_array()) {
|
|
||||||
// If content is object or other non-string type, convert to string for templates
|
|
||||||
SRV_INF("[CONTENT DEBUG] Predict: BEFORE TEMPLATE - Message %zu (role=%s) content is not string/array, converting\n", idx, role_str.c_str());
|
|
||||||
if (msg["content"].is_object()) {
|
|
||||||
msg["content"] = msg["content"].dump();
|
|
||||||
} else {
|
|
||||||
msg["content"] = "";
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
SRV_INF("[CONTENT DEBUG] Predict: BEFORE TEMPLATE - Message %zu (role=%s): content type=%s\n",
|
|
||||||
idx, role_str.c_str(),
|
|
||||||
msg["content"].is_string() ? "string" :
|
|
||||||
msg["content"].is_array() ? "array" :
|
|
||||||
msg["content"].is_object() ? "object" : "other");
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
SRV_INF("[CONTENT DEBUG] Predict: BEFORE TEMPLATE - Message %zu (role=%s) MISSING content field - ADDING!\n", idx, role_str.c_str());
|
|
||||||
msg["content"] = ""; // Add missing content
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
192
backend/cpp/llama-cpp/message_content.h
Normal file
192
backend/cpp/llama-cpp/message_content.h
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
|
||||||
|
namespace llama_grpc {
|
||||||
|
|
||||||
|
// Normalizes a proto message's content string into the JSON value used when
|
||||||
|
// reconstructing OpenAI-format messages for the tokenizer (jinja) template.
|
||||||
|
//
|
||||||
|
// Shared by the streaming (PredictStream) and non-streaming (Predict) message
|
||||||
|
// reconstruction paths so the two cannot drift.
|
||||||
|
//
|
||||||
|
// LocalAI's Go layer (schema.Messages.ToProto) always sends content as a plain
|
||||||
|
// text string; multimodal media travels in separate proto fields, never inside
|
||||||
|
// content. So user/system/developer content is *only ever* opaque text and must
|
||||||
|
// NOT be JSON-sniffed: a prompt that merely looks like JSON (e.g. an ingredient
|
||||||
|
// list ["1/4 cup sugar", ...]) would otherwise be reinterpreted as structured
|
||||||
|
// content parts and rejected by oaicompat_chat_params_parse with
|
||||||
|
// "unsupported content[].type" (https://github.com/mudler/LocalAI/issues/10524).
|
||||||
|
// (developer is OpenAI's modern system alias - same "human-authored text" nature.)
|
||||||
|
//
|
||||||
|
// For assistant/tool messages we still collapse a literal JSON null/object
|
||||||
|
// (tool-call bookkeeping) to a string, but we never turn a plain string into an
|
||||||
|
// array/scalar. The array defense is therefore role-independent (arrays/scalars
|
||||||
|
// fall through for every role); the role gate only governs the null/object case.
|
||||||
|
inline nlohmann::ordered_json normalize_message_content(const std::string& role,
|
||||||
|
const std::string& content) {
|
||||||
|
nlohmann::ordered_json content_val = content;
|
||||||
|
if (role != "user" && role != "system" && role != "developer") {
|
||||||
|
try {
|
||||||
|
nlohmann::ordered_json parsed = nlohmann::ordered_json::parse(content);
|
||||||
|
if (parsed.is_null()) {
|
||||||
|
content_val = "";
|
||||||
|
} else if (parsed.is_object()) {
|
||||||
|
content_val = parsed.dump();
|
||||||
|
}
|
||||||
|
// arrays / scalars: keep the original plain-text string as-is
|
||||||
|
} catch (const nlohmann::ordered_json::parse_error&) {
|
||||||
|
// Not JSON, already the plain string
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return content_val;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Final safety pass applied to each reconstructed OpenAI message right before it
|
||||||
|
// is handed to oaicompat_chat_params_parse (jinja templating). Jinja templates
|
||||||
|
// assume content is a string: a literal null breaks slicing such as
|
||||||
|
// message.content[:N] (#7324), and a tool message with array content is rejected
|
||||||
|
// (#7528). A multimodal user message legitimately carries a typed-part array
|
||||||
|
// ({type:text}, {type:image_url}, ...), which must be left intact. Shared by the
|
||||||
|
// streaming and non-streaming paths so this invariant cannot drift between them.
|
||||||
|
inline void normalize_template_message(nlohmann::ordered_json& msg) {
|
||||||
|
if (!msg.contains("content")) {
|
||||||
|
msg["content"] = ""; // templates expect the field to exist
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
nlohmann::ordered_json& content = msg["content"];
|
||||||
|
const std::string role = (msg.contains("role") && msg["role"].is_string())
|
||||||
|
? msg["role"].get<std::string>()
|
||||||
|
: std::string();
|
||||||
|
if (content.is_null()) {
|
||||||
|
content = ""; // #7324: null would crash content[:N] slicing
|
||||||
|
} else if (role == "tool" && content.is_array()) {
|
||||||
|
content = content.dump(); // #7528: tool messages must have string content
|
||||||
|
} else if (!content.is_string() && !content.is_array()) {
|
||||||
|
if (content.is_object()) {
|
||||||
|
content = content.dump(); // tool-call bookkeeping object -> string
|
||||||
|
} else {
|
||||||
|
content = ""; // other scalar (number/bool) -> empty
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// string, or a non-tool (multimodal) typed-part array: leave untouched
|
||||||
|
}
|
||||||
|
|
||||||
|
// One proto message's data, flattened to plain types so the reconstruction logic
|
||||||
|
// can be shared and unit-tested without protobuf. The streaming and non-streaming
|
||||||
|
// predict paths both populate this from proto::Message + the request's media.
|
||||||
|
struct ReconstructedMessageInput {
|
||||||
|
std::string role;
|
||||||
|
std::string content; // proto.Message.content (always a plain string)
|
||||||
|
std::string name;
|
||||||
|
std::string tool_call_id;
|
||||||
|
std::string reasoning_content;
|
||||||
|
std::string tool_calls; // tool_calls as a JSON string, or empty
|
||||||
|
bool is_last_user_msg = false; // attach request media to this message
|
||||||
|
std::vector<std::string> images; // base64 (jpeg)
|
||||||
|
std::vector<std::string> audios; // base64 (wav)
|
||||||
|
std::vector<std::string> videos; // base64
|
||||||
|
};
|
||||||
|
|
||||||
|
// Appends the request's media as OpenAI typed content parts. Imperative (not
|
||||||
|
// brace-init) to avoid nlohmann's object-vs-array initializer-list ambiguity.
|
||||||
|
inline void append_media_parts(nlohmann::ordered_json& content_array,
|
||||||
|
const std::vector<std::string>& images,
|
||||||
|
const std::vector<std::string>& audios,
|
||||||
|
const std::vector<std::string>& videos) {
|
||||||
|
for (const auto& img : images) {
|
||||||
|
nlohmann::ordered_json image_chunk;
|
||||||
|
image_chunk["type"] = "image_url";
|
||||||
|
nlohmann::ordered_json image_url;
|
||||||
|
image_url["url"] = "data:image/jpeg;base64," + img;
|
||||||
|
image_chunk["image_url"] = image_url;
|
||||||
|
content_array.push_back(image_chunk);
|
||||||
|
}
|
||||||
|
for (const auto& aud : audios) {
|
||||||
|
nlohmann::ordered_json audio_chunk;
|
||||||
|
audio_chunk["type"] = "input_audio";
|
||||||
|
nlohmann::ordered_json input_audio;
|
||||||
|
input_audio["data"] = aud;
|
||||||
|
input_audio["format"] = "wav"; // default; could be made configurable
|
||||||
|
audio_chunk["input_audio"] = input_audio;
|
||||||
|
content_array.push_back(audio_chunk);
|
||||||
|
}
|
||||||
|
for (const auto& vid : videos) {
|
||||||
|
nlohmann::ordered_json video_chunk;
|
||||||
|
video_chunk["type"] = "input_video";
|
||||||
|
nlohmann::ordered_json input_video;
|
||||||
|
input_video["data"] = vid;
|
||||||
|
video_chunk["input_video"] = input_video;
|
||||||
|
content_array.push_back(video_chunk);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reconstructs a single OpenAI-format message (the object fed to
|
||||||
|
// oaicompat_chat_params_parse) from a proto message. Shared by PredictStream and
|
||||||
|
// Predict so the content/multimodal/tool_calls handling cannot drift between the
|
||||||
|
// two stream modes (it previously lived as two ~150-line copies with a redundant
|
||||||
|
// Predict-only tool_calls->" " branch). Guarantees content is always a string or
|
||||||
|
// a typed-part array, never null/missing.
|
||||||
|
inline nlohmann::ordered_json build_reconstructed_message(const ReconstructedMessageInput& in) {
|
||||||
|
nlohmann::ordered_json msg_json;
|
||||||
|
msg_json["role"] = in.role;
|
||||||
|
const bool has_media = !in.images.empty() || !in.audios.empty() || !in.videos.empty();
|
||||||
|
|
||||||
|
if (!in.content.empty()) {
|
||||||
|
nlohmann::ordered_json content_val = normalize_message_content(in.role, in.content);
|
||||||
|
if (content_val.is_string() && in.is_last_user_msg && has_media) {
|
||||||
|
// Last user message + media: build a typed-part array (text first).
|
||||||
|
nlohmann::ordered_json content_array = nlohmann::ordered_json::array();
|
||||||
|
nlohmann::ordered_json text_part;
|
||||||
|
text_part["type"] = "text";
|
||||||
|
text_part["text"] = content_val.get<std::string>();
|
||||||
|
content_array.push_back(text_part);
|
||||||
|
append_media_parts(content_array, in.images, in.audios, in.videos);
|
||||||
|
msg_json["content"] = content_array;
|
||||||
|
} else if (content_val.is_null()) {
|
||||||
|
msg_json["content"] = "";
|
||||||
|
} else {
|
||||||
|
msg_json["content"] = content_val;
|
||||||
|
}
|
||||||
|
} else if (in.is_last_user_msg && has_media) {
|
||||||
|
// No text but media on the last user message: media-only typed array.
|
||||||
|
nlohmann::ordered_json content_array = nlohmann::ordered_json::array();
|
||||||
|
append_media_parts(content_array, in.images, in.audios, in.videos);
|
||||||
|
msg_json["content"] = content_array;
|
||||||
|
} else {
|
||||||
|
// Empty content (any role, incl. tool/assistant): templates need a string.
|
||||||
|
msg_json["content"] = "";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!in.name.empty()) {
|
||||||
|
msg_json["name"] = in.name;
|
||||||
|
}
|
||||||
|
if (!in.tool_call_id.empty()) {
|
||||||
|
msg_json["tool_call_id"] = in.tool_call_id;
|
||||||
|
}
|
||||||
|
if (!in.reasoning_content.empty()) {
|
||||||
|
msg_json["reasoning_content"] = in.reasoning_content;
|
||||||
|
}
|
||||||
|
if (!in.tool_calls.empty()) {
|
||||||
|
try {
|
||||||
|
nlohmann::ordered_json tool_calls = nlohmann::ordered_json::parse(in.tool_calls);
|
||||||
|
msg_json["tool_calls"] = tool_calls;
|
||||||
|
// tool_calls + empty/blank content: use " " not "", because llama.cpp's
|
||||||
|
// common_chat_msgs_to_json_oaicompat turns "" into null, which breaks
|
||||||
|
// templates that slice message.content[:tool_start_length] (#7324).
|
||||||
|
if (!msg_json.contains("content") ||
|
||||||
|
(msg_json["content"].is_string() && msg_json["content"].get<std::string>().empty())) {
|
||||||
|
msg_json["content"] = " ";
|
||||||
|
}
|
||||||
|
} catch (const nlohmann::ordered_json::parse_error&) {
|
||||||
|
// Malformed tool_calls JSON: leave content as-is (prior behavior).
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return msg_json;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace llama_grpc
|
||||||
234
backend/cpp/llama-cpp/message_content_test.cpp
Normal file
234
backend/cpp/llama-cpp/message_content_test.cpp
Normal file
@@ -0,0 +1,234 @@
|
|||||||
|
// Unit tests for the shared message-reconstruction helpers (message_content.h).
|
||||||
|
//
|
||||||
|
// Build & run standalone (nlohmann/json single header on the include path):
|
||||||
|
// g++ -std=c++17 -I<dir-with-nlohmann> message_content_test.cpp -o t && ./t
|
||||||
|
// or via CMake: -DLLAMA_GRPC_BUILD_TESTS=ON then ctest.
|
||||||
|
//
|
||||||
|
// Regression coverage for:
|
||||||
|
// #10524 - a user/system prompt that is itself a JSON-array string must stay
|
||||||
|
// plain text, never be reinterpreted as OpenAI structured parts.
|
||||||
|
// #7324 - assistant/tool null content -> "" (templates slice content[:N]);
|
||||||
|
// assistant+tool_calls+empty content -> " " (not "", which becomes null).
|
||||||
|
// #7528 - tool message array content must reach the template as a string.
|
||||||
|
// multimodal - last user message text + media -> typed-part array, media kept.
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
#include <iostream>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "message_content.h"
|
||||||
|
|
||||||
|
using nlohmann::ordered_json;
|
||||||
|
using llama_grpc::normalize_message_content;
|
||||||
|
using llama_grpc::normalize_template_message;
|
||||||
|
using llama_grpc::build_reconstructed_message;
|
||||||
|
using llama_grpc::ReconstructedMessageInput;
|
||||||
|
|
||||||
|
static int failures = 0;
|
||||||
|
|
||||||
|
static void check(bool ok, const std::string& name, const std::string& detail = "") {
|
||||||
|
if (!ok) {
|
||||||
|
std::cerr << "FAIL " << name << (detail.empty() ? "" : ": " + detail) << "\n";
|
||||||
|
failures++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- normalize_message_content -------------------------------------------
|
||||||
|
|
||||||
|
static void expect_norm_string(const char* name, const std::string& role,
|
||||||
|
const std::string& content, const std::string& want) {
|
||||||
|
auto got = normalize_message_content(role, content);
|
||||||
|
if (!got.is_string()) {
|
||||||
|
check(false, name, "expected a JSON string, got " +
|
||||||
|
std::string(got.is_array() ? "array" : got.is_object() ? "object" : "other") +
|
||||||
|
" (" + got.dump() + ")");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
check(got.get<std::string>() == want, name, "expected \"" + want + "\", got \"" + got.get<std::string>() + "\"");
|
||||||
|
}
|
||||||
|
|
||||||
|
static void test_normalize() {
|
||||||
|
const std::string ingredients = R"(["1/4 cup brown sugar, packed","1 pound ground beef"])";
|
||||||
|
|
||||||
|
// #10524 - JSON-array text must stay a string. Role-INDEPENDENT array defense.
|
||||||
|
for (const char* role : {"user", "system", "developer", "function", "assistant", "tool"}) {
|
||||||
|
expect_norm_string((std::string("json_array_stays_text:") + role).c_str(), role, ingredients, ingredients);
|
||||||
|
}
|
||||||
|
|
||||||
|
// #10524 - user/system/developer JSON-object text stays verbatim (NOT re-dumped).
|
||||||
|
expect_norm_string("user_json_object_verbatim", "user", R"({"a":1})", R"({"a":1})");
|
||||||
|
expect_norm_string("system_json_object_verbatim", "system", R"({"a":1})", R"({"a":1})");
|
||||||
|
expect_norm_string("developer_json_object_verbatim", "developer", R"({"a":1})", R"({"a":1})");
|
||||||
|
|
||||||
|
// Plain text unchanged for all roles.
|
||||||
|
expect_norm_string("user_plain_text", "user", "hello world", "hello world");
|
||||||
|
expect_norm_string("assistant_non_json_text_kept", "assistant", "hi [unclosed", "hi [unclosed");
|
||||||
|
|
||||||
|
// #7324 boundary - user/system/developer literal "null" preserved (never parsed).
|
||||||
|
expect_norm_string("user_literal_null_stays", "user", "null", "null");
|
||||||
|
expect_norm_string("system_literal_null_stays", "system", "null", "null");
|
||||||
|
expect_norm_string("developer_literal_null_stays", "developer", "null", "null");
|
||||||
|
|
||||||
|
// #7324 - assistant/tool literal null collapses to empty string.
|
||||||
|
expect_norm_string("assistant_null_to_empty", "assistant", "null", "");
|
||||||
|
expect_norm_string("tool_null_to_empty", "tool", "null", "");
|
||||||
|
|
||||||
|
// #7324/#7528 - assistant/tool object bookkeeping stringified (stays a string).
|
||||||
|
check(normalize_message_content("assistant", R"({"tool":"x"})").is_string(), "assistant_object_stringified");
|
||||||
|
check(normalize_message_content("tool", R"({"error":"boom"})").is_string(), "tool_object_stringified");
|
||||||
|
|
||||||
|
// #10524-family - a bare scalar that parses as a JSON number stays the string.
|
||||||
|
expect_norm_string("assistant_scalar_number_stays_string", "assistant", "42", "42");
|
||||||
|
|
||||||
|
// baseline - empty content stays empty.
|
||||||
|
expect_norm_string("user_empty_stays_empty", "user", "", "");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- normalize_template_message (BEFORE TEMPLATE sanitizer) ---------------
|
||||||
|
|
||||||
|
static void test_template_sanitizer() {
|
||||||
|
// #7528 - a tool message with an ACTUAL array becomes a string.
|
||||||
|
{
|
||||||
|
ordered_json msg = {{"role", "tool"}, {"content", ordered_json::array({{{"type", "text"}, {"text", "r"}}})}};
|
||||||
|
normalize_template_message(msg);
|
||||||
|
check(msg["content"].is_string(), "before_template_tool_array_to_string", "got " + msg["content"].dump());
|
||||||
|
}
|
||||||
|
// #7324 - null content -> "" for any role.
|
||||||
|
{
|
||||||
|
ordered_json msg = {{"role", "assistant"}, {"content", nullptr}};
|
||||||
|
normalize_template_message(msg);
|
||||||
|
check(msg["content"].is_string() && msg["content"] == "", "before_template_null_to_empty");
|
||||||
|
}
|
||||||
|
// object content -> dumped string (would otherwise throw at the template).
|
||||||
|
{
|
||||||
|
ordered_json msg = {{"role", "assistant"}, {"content", {{"x", 1}}}};
|
||||||
|
normalize_template_message(msg);
|
||||||
|
check(msg["content"].is_string(), "before_template_object_to_string", "got " + msg["content"].dump());
|
||||||
|
}
|
||||||
|
// missing content field -> "".
|
||||||
|
{
|
||||||
|
ordered_json msg = {{"role", "user"}};
|
||||||
|
normalize_template_message(msg);
|
||||||
|
check(msg.contains("content") && msg["content"] == "", "before_template_missing_to_empty");
|
||||||
|
}
|
||||||
|
// multimodal: a well-typed user array must be left UNTOUCHED (role!=tool).
|
||||||
|
{
|
||||||
|
ordered_json parts = ordered_json::array();
|
||||||
|
parts.push_back({{"type", "text"}, {"text", "x"}});
|
||||||
|
ordered_json img; img["type"] = "image_url"; img["image_url"] = {{"url", "data:..."}};
|
||||||
|
parts.push_back(img);
|
||||||
|
ordered_json msg = {{"role", "user"}, {"content", parts}};
|
||||||
|
normalize_template_message(msg);
|
||||||
|
check(msg["content"].is_array() && msg["content"].size() == 2, "before_template_user_typed_array_preserved",
|
||||||
|
"got " + msg["content"].dump());
|
||||||
|
}
|
||||||
|
// a plain string is left untouched.
|
||||||
|
{
|
||||||
|
ordered_json msg = {{"role", "user"}, {"content", "hello"}};
|
||||||
|
normalize_template_message(msg);
|
||||||
|
check(msg["content"] == "hello", "before_template_string_untouched");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- build_reconstructed_message ----------------------------------------
|
||||||
|
|
||||||
|
static void test_reconstruction() {
|
||||||
|
const std::string ingredients = R"(["1/4 cup brown sugar","1 pound ground beef"])";
|
||||||
|
|
||||||
|
// #10524 end-state - user JSON-array text, no media -> string content.
|
||||||
|
{
|
||||||
|
ReconstructedMessageInput in;
|
||||||
|
in.role = "user"; in.content = ingredients;
|
||||||
|
auto m = build_reconstructed_message(in);
|
||||||
|
check(m["content"].is_string() && m["content"] == ingredients, "recon_user_json_array_string",
|
||||||
|
"got " + m["content"].dump());
|
||||||
|
}
|
||||||
|
// multimodal - user text + one image on last user msg -> typed array, image kept.
|
||||||
|
{
|
||||||
|
ReconstructedMessageInput in;
|
||||||
|
in.role = "user"; in.content = ingredients; in.is_last_user_msg = true;
|
||||||
|
in.images.push_back("BASE64IMG");
|
||||||
|
auto m = build_reconstructed_message(in);
|
||||||
|
check(m["content"].is_array() && m["content"].size() == 2, "recon_multimodal_text_plus_image",
|
||||||
|
"got " + m["content"].dump());
|
||||||
|
check(m["content"][0]["type"] == "text" && m["content"][0]["text"] == ingredients, "recon_multimodal_text_first");
|
||||||
|
check(m["content"][1]["type"] == "image_url", "recon_multimodal_image_kept");
|
||||||
|
}
|
||||||
|
// multimodal media-only - empty text + image on last user msg.
|
||||||
|
{
|
||||||
|
ReconstructedMessageInput in;
|
||||||
|
in.role = "user"; in.content = ""; in.is_last_user_msg = true;
|
||||||
|
in.images.push_back("BASE64IMG");
|
||||||
|
auto m = build_reconstructed_message(in);
|
||||||
|
check(m["content"].is_array() && m["content"].size() == 1 && m["content"][0]["type"] == "image_url",
|
||||||
|
"recon_media_only", "got " + m["content"].dump());
|
||||||
|
}
|
||||||
|
// #7528 - tool array-string content stays a string.
|
||||||
|
{
|
||||||
|
ReconstructedMessageInput in;
|
||||||
|
in.role = "tool"; in.content = R"(["a","b"])"; in.tool_call_id = "call_1";
|
||||||
|
auto m = build_reconstructed_message(in);
|
||||||
|
check(m["content"].is_string() && m["content"] == R"(["a","b"])", "recon_tool_array_string",
|
||||||
|
"got " + m["content"].dump());
|
||||||
|
check(m["tool_call_id"] == "call_1", "recon_tool_call_id_set");
|
||||||
|
}
|
||||||
|
// tool empty content -> "".
|
||||||
|
{
|
||||||
|
ReconstructedMessageInput in;
|
||||||
|
in.role = "tool"; in.content = "";
|
||||||
|
auto m = build_reconstructed_message(in);
|
||||||
|
check(m["content"].is_string() && m["content"] == "", "recon_tool_empty_to_string");
|
||||||
|
}
|
||||||
|
// #7324 - assistant + tool_calls + empty content -> " " (single space, not "").
|
||||||
|
{
|
||||||
|
ReconstructedMessageInput in;
|
||||||
|
in.role = "assistant"; in.content = "";
|
||||||
|
in.tool_calls = R"([{"id":"c1","type":"function","function":{"name":"f","arguments":"{}"}}])";
|
||||||
|
auto m = build_reconstructed_message(in);
|
||||||
|
check(m["content"].is_string() && m["content"] == " ", "recon_toolcalls_empty_content_space",
|
||||||
|
"got " + m["content"].dump());
|
||||||
|
check(m["tool_calls"].is_array() && m["tool_calls"].size() == 1, "recon_toolcalls_parsed");
|
||||||
|
}
|
||||||
|
// assistant + tool_calls + real content keeps the content.
|
||||||
|
{
|
||||||
|
ReconstructedMessageInput in;
|
||||||
|
in.role = "assistant"; in.content = "I'll call f";
|
||||||
|
in.tool_calls = R"([{"id":"c1","type":"function","function":{"name":"f","arguments":"{}"}}])";
|
||||||
|
auto m = build_reconstructed_message(in);
|
||||||
|
check(m["content"] == "I'll call f", "recon_toolcalls_with_content_kept");
|
||||||
|
}
|
||||||
|
// assistant null content -> "".
|
||||||
|
{
|
||||||
|
ReconstructedMessageInput in;
|
||||||
|
in.role = "assistant"; in.content = "null";
|
||||||
|
auto m = build_reconstructed_message(in);
|
||||||
|
check(m["content"] == "", "recon_assistant_null_to_empty");
|
||||||
|
}
|
||||||
|
// malformed tool_calls JSON must not throw; content preserved.
|
||||||
|
{
|
||||||
|
ReconstructedMessageInput in;
|
||||||
|
in.role = "assistant"; in.content = "hi"; in.tool_calls = "{not json";
|
||||||
|
auto m = build_reconstructed_message(in);
|
||||||
|
check(m["content"] == "hi" && !m.contains("tool_calls"), "recon_malformed_toolcalls_safe");
|
||||||
|
}
|
||||||
|
// optional fields: name + reasoning carried through.
|
||||||
|
{
|
||||||
|
ReconstructedMessageInput in;
|
||||||
|
in.role = "tool"; in.content = "result"; in.name = "get_weather"; in.reasoning_content = "thinking";
|
||||||
|
auto m = build_reconstructed_message(in);
|
||||||
|
check(m["name"] == "get_weather" && m["reasoning_content"] == "thinking", "recon_optional_fields");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
test_normalize();
|
||||||
|
test_template_sanitizer();
|
||||||
|
test_reconstruction();
|
||||||
|
|
||||||
|
if (failures == 0) {
|
||||||
|
std::cout << "OK: all message_content tests passed\n";
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
std::cerr << failures << " test(s) failed\n";
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
@@ -18,6 +18,10 @@ done
|
|||||||
|
|
||||||
cp -r CMakeLists.txt llama.cpp/tools/grpc-server/
|
cp -r CMakeLists.txt llama.cpp/tools/grpc-server/
|
||||||
cp -r grpc-server.cpp llama.cpp/tools/grpc-server/
|
cp -r grpc-server.cpp llama.cpp/tools/grpc-server/
|
||||||
|
# Shared message-reconstruction helpers (included by grpc-server.cpp) and their
|
||||||
|
# unit test (compiled only when -DLLAMA_GRPC_BUILD_TESTS=ON).
|
||||||
|
cp -r message_content.h llama.cpp/tools/grpc-server/
|
||||||
|
cp -r message_content_test.cpp llama.cpp/tools/grpc-server/
|
||||||
cp -rfv llama.cpp/vendor/nlohmann/json.hpp llama.cpp/tools/grpc-server/
|
cp -rfv llama.cpp/vendor/nlohmann/json.hpp llama.cpp/tools/grpc-server/
|
||||||
cp -rfv llama.cpp/vendor/cpp-httplib/httplib.h llama.cpp/tools/grpc-server/
|
cp -rfv llama.cpp/vendor/cpp-httplib/httplib.h llama.cpp/tools/grpc-server/
|
||||||
|
|
||||||
|
|||||||
71
backend/cpp/run-unit-tests.sh
Executable file
71
backend/cpp/run-unit-tests.sh
Executable file
@@ -0,0 +1,71 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#
|
||||||
|
# Discovers and runs every standalone C++ unit test under backend/cpp/.
|
||||||
|
#
|
||||||
|
# A "standalone" unit test is a *_test.cpp that depends only on the C++ standard
|
||||||
|
# library and nlohmann/json (single header) - i.e. it exercises pure helpers and
|
||||||
|
# does not need the full llama.cpp + gRPC backend build. Tests that DO need the
|
||||||
|
# backend build use the CMake/ctest path (e.g. -DLLAMA_GRPC_BUILD_TESTS=ON)
|
||||||
|
# instead and are skipped here.
|
||||||
|
#
|
||||||
|
# This keeps CI generic: adding a new pure-C++ unit test file named *_test.cpp in
|
||||||
|
# an active backend source dir is picked up automatically, with no CI edits.
|
||||||
|
#
|
||||||
|
# Env:
|
||||||
|
# NLOHMANN_INCLUDE include dir that contains nlohmann/json.hpp. If unset, the
|
||||||
|
# nlohmann/json single header is fetched to a temp dir.
|
||||||
|
# CXX compiler (default: g++).
|
||||||
|
# JSON_VERSION nlohmann/json tag to fetch when NLOHMANN_INCLUDE is unset
|
||||||
|
# (default: v3.11.3).
|
||||||
|
set -uo pipefail
|
||||||
|
|
||||||
|
ROOT="$(cd "$(dirname "$0")" && pwd)"
|
||||||
|
CXX="${CXX:-g++}"
|
||||||
|
JSON_VERSION="${JSON_VERSION:-v3.11.3}"
|
||||||
|
|
||||||
|
JSON_INC="${NLOHMANN_INCLUDE:-}"
|
||||||
|
if [ -z "$JSON_INC" ]; then
|
||||||
|
JSON_INC="$(mktemp -d)"
|
||||||
|
mkdir -p "$JSON_INC/nlohmann"
|
||||||
|
echo "Fetching nlohmann/json ${JSON_VERSION} single header..."
|
||||||
|
if ! curl -L -sf \
|
||||||
|
"https://raw.githubusercontent.com/nlohmann/json/${JSON_VERSION}/single_include/nlohmann/json.hpp" \
|
||||||
|
-o "$JSON_INC/nlohmann/json.hpp"; then
|
||||||
|
echo "ERROR: failed to fetch nlohmann/json header" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Active source dirs only - exclude per-variant build copies, dev snapshots and
|
||||||
|
# the vendored upstream llama.cpp tree.
|
||||||
|
mapfile -t tests < <(find "$ROOT" -name '*_test.cpp' \
|
||||||
|
-not -path '*/llama.cpp/*' \
|
||||||
|
-not -path '*-build/*' \
|
||||||
|
-not -path '*-dev/*' \
|
||||||
|
-not -path '*fallback*' | sort)
|
||||||
|
|
||||||
|
if [ "${#tests[@]}" -eq 0 ]; then
|
||||||
|
echo "No standalone C++ unit tests found under $ROOT"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
fail=0
|
||||||
|
for test_src in "${tests[@]}"; do
|
||||||
|
name="$(basename "$test_src" .cpp)"
|
||||||
|
bin="$(mktemp -d)/$name"
|
||||||
|
echo "==> $test_src"
|
||||||
|
if ! "$CXX" -std=c++17 -Wall -Wextra \
|
||||||
|
-I"$JSON_INC" -I"$(dirname "$test_src")" \
|
||||||
|
"$test_src" -o "$bin"; then
|
||||||
|
echo "COMPILE FAILED: $test_src" >&2
|
||||||
|
fail=1
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
if ! "$bin"; then
|
||||||
|
echo "TEST FAILED: $test_src" >&2
|
||||||
|
fail=1
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "Ran ${#tests[@]} standalone C++ unit test file(s)"
|
||||||
|
exit "$fail"
|
||||||
@@ -13,8 +13,14 @@ if [ "$(uname)" != "Darwin" ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
if [ "$(uname)" = "Darwin" ]; then
|
if [ "$(uname)" = "Darwin" ]; then
|
||||||
# macOS: single dylib variant (Metal or Accelerate)
|
# macOS: single fallback variant (Metal/Accelerate). The cmake build emits a
|
||||||
LIBRARY="$CURDIR/libgowhisper-fallback.dylib"
|
# Mach-O named .so, but tolerate .dylib too — pick whichever exists so the Go
|
||||||
|
# loader doesn't panic on a hardcoded name that isn't on disk.
|
||||||
|
if [ -e "$CURDIR/libgowhisper-fallback.dylib" ]; then
|
||||||
|
LIBRARY="$CURDIR/libgowhisper-fallback.dylib"
|
||||||
|
else
|
||||||
|
LIBRARY="$CURDIR/libgowhisper-fallback.so"
|
||||||
|
fi
|
||||||
export DYLD_LIBRARY_PATH="$CURDIR"/lib:$DYLD_LIBRARY_PATH
|
export DYLD_LIBRARY_PATH="$CURDIR"/lib:$DYLD_LIBRARY_PATH
|
||||||
else
|
else
|
||||||
LIBRARY="$CURDIR/libgowhisper-fallback.so"
|
LIBRARY="$CURDIR/libgowhisper-fallback.so"
|
||||||
|
|||||||
@@ -7,3 +7,7 @@ setuptools
|
|||||||
six
|
six
|
||||||
scipy
|
scipy
|
||||||
numpy
|
numpy
|
||||||
|
# fish-speech is installed editable with --no-build-isolation, so the build
|
||||||
|
# backends of its transitive deps must already be in the venv. One of them
|
||||||
|
# builds a Rust extension and needs setuptools-rust present at metadata time.
|
||||||
|
setuptools-rust
|
||||||
|
|||||||
@@ -11,14 +11,31 @@ fi
|
|||||||
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade "
|
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade "
|
||||||
installRequirements
|
installRequirements
|
||||||
|
|
||||||
# Fetch convert_hf_to_gguf.py from llama.cpp
|
# Fetch convert_hf_to_gguf.py from llama.cpp.
|
||||||
|
# Upstream split the model-specific logic out of the single file into a
|
||||||
|
# sibling `conversion/` package (convert_hf_to_gguf.py now does
|
||||||
|
# `from conversion import ...`), so a single-file download no longer runs —
|
||||||
|
# it fails with `ModuleNotFoundError: No module named 'conversion'`. We clone
|
||||||
|
# the repo and copy both the script and the package; Python puts the script's
|
||||||
|
# own directory on sys.path[0], so the package resolves when placed beside it.
|
||||||
LLAMA_CPP_CONVERT_VERSION="${LLAMA_CPP_CONVERT_VERSION:-master}"
|
LLAMA_CPP_CONVERT_VERSION="${LLAMA_CPP_CONVERT_VERSION:-master}"
|
||||||
|
LLAMA_CPP_SRC="${EDIR}/llama.cpp"
|
||||||
CONVERT_SCRIPT="${EDIR}/convert_hf_to_gguf.py"
|
CONVERT_SCRIPT="${EDIR}/convert_hf_to_gguf.py"
|
||||||
if [ ! -f "${CONVERT_SCRIPT}" ]; then
|
|
||||||
echo "Downloading convert_hf_to_gguf.py from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
|
cloneLlamaCpp() {
|
||||||
curl -L --fail --retry 3 \
|
if [ ! -d "${LLAMA_CPP_SRC}/.git" ]; then
|
||||||
"https://raw.githubusercontent.com/ggml-org/llama.cpp/${LLAMA_CPP_CONVERT_VERSION}/convert_hf_to_gguf.py" \
|
git clone --depth 1 --branch "${LLAMA_CPP_CONVERT_VERSION}" \
|
||||||
-o "${CONVERT_SCRIPT}" || echo "Warning: Failed to download convert_hf_to_gguf.py."
|
https://github.com/ggml-org/llama.cpp.git "${LLAMA_CPP_SRC}" 2>/dev/null || \
|
||||||
|
git clone --depth 1 https://github.com/ggml-org/llama.cpp.git "${LLAMA_CPP_SRC}"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
if [ ! -f "${CONVERT_SCRIPT}" ] || [ ! -d "${EDIR}/conversion" ]; then
|
||||||
|
echo "Fetching convert_hf_to_gguf.py + conversion/ from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
|
||||||
|
cloneLlamaCpp
|
||||||
|
cp "${LLAMA_CPP_SRC}/convert_hf_to_gguf.py" "${CONVERT_SCRIPT}"
|
||||||
|
rm -rf "${EDIR}/conversion"
|
||||||
|
cp -r "${LLAMA_CPP_SRC}/conversion" "${EDIR}/conversion"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Install gguf package from the same llama.cpp commit to keep them in sync
|
# Install gguf package from the same llama.cpp commit to keep them in sync
|
||||||
@@ -41,12 +58,7 @@ QUANTIZE_BIN="${EDIR}/llama-quantize"
|
|||||||
if [ ! -x "${QUANTIZE_BIN}" ] && ! command -v llama-quantize &>/dev/null; then
|
if [ ! -x "${QUANTIZE_BIN}" ] && ! command -v llama-quantize &>/dev/null; then
|
||||||
if command -v cmake &>/dev/null; then
|
if command -v cmake &>/dev/null; then
|
||||||
echo "Building llama-quantize from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
|
echo "Building llama-quantize from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
|
||||||
LLAMA_CPP_SRC="${EDIR}/llama.cpp"
|
cloneLlamaCpp # reuses the clone fetched for convert_hf_to_gguf.py
|
||||||
if [ ! -d "${LLAMA_CPP_SRC}" ]; then
|
|
||||||
git clone --depth 1 --branch "${LLAMA_CPP_CONVERT_VERSION}" \
|
|
||||||
https://github.com/ggml-org/llama.cpp.git "${LLAMA_CPP_SRC}" 2>/dev/null || \
|
|
||||||
git clone --depth 1 https://github.com/ggml-org/llama.cpp.git "${LLAMA_CPP_SRC}"
|
|
||||||
fi
|
|
||||||
cmake -B "${LLAMA_CPP_SRC}/build" -S "${LLAMA_CPP_SRC}" -DGGML_NATIVE=OFF -DBUILD_SHARED_LIBS=OFF
|
cmake -B "${LLAMA_CPP_SRC}/build" -S "${LLAMA_CPP_SRC}" -DGGML_NATIVE=OFF -DBUILD_SHARED_LIBS=OFF
|
||||||
cmake --build "${LLAMA_CPP_SRC}/build" --target llama-quantize -j"$(nproc 2>/dev/null || echo 2)"
|
cmake --build "${LLAMA_CPP_SRC}/build" --target llama-quantize -j"$(nproc 2>/dev/null || echo 2)"
|
||||||
cp "${LLAMA_CPP_SRC}/build/bin/llama-quantize" "${QUANTIZE_BIN}"
|
cp "${LLAMA_CPP_SRC}/build/bin/llama-quantize" "${QUANTIZE_BIN}"
|
||||||
|
|||||||
@@ -85,9 +85,15 @@ if [ "x${BUILD_TYPE}" == "x" ] || [ "x${FROM_SOURCE:-}" == "xtrue" ]; then
|
|||||||
# The resulting binary still requires an AVX-512 capable CPU at runtime,
|
# The resulting binary still requires an AVX-512 capable CPU at runtime,
|
||||||
# same constraint sglang upstream documents in docker/xeon.Dockerfile.
|
# same constraint sglang upstream documents in docker/xeon.Dockerfile.
|
||||||
|
|
||||||
|
# Pin the source build to the same release the GPU path floors on
|
||||||
|
# (0.5.11, see requirements-cublas12-after.txt). An unpinned master clone
|
||||||
|
# pulls in newer CPU kernels (e.g. mamba/fla.cpp) that fail to compile
|
||||||
|
# (constexpr non-constant + kineto_LIBRARY-NOTFOUND). Bump deliberately.
|
||||||
|
SGLANG_VERSION="${SGLANG_VERSION:-v0.5.11}"
|
||||||
_sgl_src=$(mktemp -d)
|
_sgl_src=$(mktemp -d)
|
||||||
trap 'rm -rf "${_sgl_src}"' EXIT
|
trap 'rm -rf "${_sgl_src}"' EXIT
|
||||||
git clone --depth 1 https://github.com/sgl-project/sglang "${_sgl_src}/sglang"
|
git clone --depth 1 --branch "${SGLANG_VERSION}" \
|
||||||
|
https://github.com/sgl-project/sglang "${_sgl_src}/sglang"
|
||||||
|
|
||||||
# Patch -march=native → -march=sapphirerapids in the CPU kernel CMakeLists
|
# Patch -march=native → -march=sapphirerapids in the CPU kernel CMakeLists
|
||||||
sed -i 's/-march=native/-march=sapphirerapids/g' \
|
sed -i 's/-march=native/-march=sapphirerapids/g' \
|
||||||
|
|||||||
@@ -570,6 +570,43 @@ impl Backend for KokorosService {
|
|||||||
) -> Result<Response<backend::Result>, Status> {
|
) -> Result<Response<backend::Result>, Status> {
|
||||||
Err(Status::unimplemented("Not supported"))
|
Err(Status::unimplemented("Not supported"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn sound_detection(
|
||||||
|
&self,
|
||||||
|
_: Request<backend::SoundDetectionRequest>,
|
||||||
|
) -> Result<Response<backend::SoundDetectionResponse>, Status> {
|
||||||
|
Err(Status::unimplemented("Not supported"))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn depth(
|
||||||
|
&self,
|
||||||
|
_: Request<backend::DepthRequest>,
|
||||||
|
) -> Result<Response<backend::DepthResponse>, Status> {
|
||||||
|
Err(Status::unimplemented("Not supported"))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn token_classify(
|
||||||
|
&self,
|
||||||
|
_: Request<backend::TokenClassifyRequest>,
|
||||||
|
) -> Result<Response<backend::TokenClassifyResponse>, Status> {
|
||||||
|
Err(Status::unimplemented("Not supported"))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn score(
|
||||||
|
&self,
|
||||||
|
_: Request<backend::ScoreRequest>,
|
||||||
|
) -> Result<Response<backend::ScoreResponse>, Status> {
|
||||||
|
Err(Status::unimplemented("Not supported"))
|
||||||
|
}
|
||||||
|
|
||||||
|
type ForwardStream = ReceiverStream<Result<backend::ForwardReply, Status>>;
|
||||||
|
|
||||||
|
async fn forward(
|
||||||
|
&self,
|
||||||
|
_: Request<tonic::Streaming<backend::ForwardRequest>>,
|
||||||
|
) -> Result<Response<Self::ForwardStream>, Status> {
|
||||||
|
Err(Status::unimplemented("Not supported"))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|||||||
@@ -37,8 +37,6 @@ func (a *Application) RestartAgentJobService() error {
|
|||||||
if d.JobStore != nil {
|
if d.JobStore != nil {
|
||||||
agentJobService.SetDistributedJobStore(d.JobStore)
|
agentJobService.SetDistributedJobStore(d.JobStore)
|
||||||
}
|
}
|
||||||
// Keep agent tasks consistent across replicas (same client the dispatcher uses).
|
|
||||||
agentJobService.SetTaskSyncNATS(d.Nats)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start the service
|
// Start the service
|
||||||
|
|||||||
@@ -604,10 +604,6 @@ func (a *Application) StartAgentPool() {
|
|||||||
usm.SetJobDBStore(s)
|
usm.SetJobDBStore(s)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Keep per-user agent tasks consistent across replicas (nil in standalone).
|
|
||||||
if d := a.Distributed(); d != nil {
|
|
||||||
usm.SetJobSyncNATS(d.Nats)
|
|
||||||
}
|
|
||||||
aps.SetUserServicesManager(usm)
|
aps.SetUserServicesManager(usm)
|
||||||
|
|
||||||
a.agentPoolService.Store(aps)
|
a.agentPoolService.Store(aps)
|
||||||
|
|||||||
@@ -280,9 +280,6 @@ func New(opts ...config.AppOption) (*Application, error) {
|
|||||||
if application.agentJobService != nil {
|
if application.agentJobService != nil {
|
||||||
application.agentJobService.SetDistributedBackends(distSvc.Dispatcher)
|
application.agentJobService.SetDistributedBackends(distSvc.Dispatcher)
|
||||||
application.agentJobService.SetDistributedJobStore(distSvc.JobStore)
|
application.agentJobService.SetDistributedJobStore(distSvc.JobStore)
|
||||||
// Keep agent tasks consistent across replicas (jobs already sync via the
|
|
||||||
// dispatcher + DB read-through). Same NATS client the dispatcher uses.
|
|
||||||
application.agentJobService.SetTaskSyncNATS(distSvc.Nats)
|
|
||||||
}
|
}
|
||||||
// Wire skill store into AgentPoolService (wired at pool start time via closure)
|
// Wire skill store into AgentPoolService (wired at pool start time via closure)
|
||||||
// The actual wiring happens in StartAgentPool since the pool doesn't exist yet.
|
// The actual wiring happens in StartAgentPool since the pool doesn't exist yet.
|
||||||
|
|||||||
@@ -23,10 +23,8 @@ import (
|
|||||||
|
|
||||||
"github.com/mudler/LocalAI/core/application"
|
"github.com/mudler/LocalAI/core/application"
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
"github.com/mudler/LocalAI/core/services/distributed"
|
|
||||||
"github.com/mudler/LocalAI/core/services/finetune"
|
"github.com/mudler/LocalAI/core/services/finetune"
|
||||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||||
"github.com/mudler/LocalAI/core/services/messaging"
|
|
||||||
"github.com/mudler/LocalAI/core/services/nodes"
|
"github.com/mudler/LocalAI/core/services/nodes"
|
||||||
"github.com/mudler/LocalAI/core/services/quantization"
|
"github.com/mudler/LocalAI/core/services/quantization"
|
||||||
|
|
||||||
@@ -402,45 +400,25 @@ func API(application *application.Application) (*echo.Echo, error) {
|
|||||||
routes.RegisterAgentPoolRoutes(e, application, agentsMw, skillsMw, collectionsMw)
|
routes.RegisterAgentPoolRoutes(e, application, agentsMw, skillsMw, collectionsMw)
|
||||||
// Fine-tuning routes
|
// Fine-tuning routes
|
||||||
fineTuningMw := auth.RequireFeature(application.AuthDB(), auth.FeatureFineTuning)
|
fineTuningMw := auth.RequireFeature(application.AuthDB(), auth.FeatureFineTuning)
|
||||||
// In distributed mode pass the shared NATS client + PostgreSQL store so
|
|
||||||
// fine-tune jobs stay consistent across replicas (the SyncedMap broadcasts
|
|
||||||
// mutations and hydrates from the DB); standalone passes nil for both.
|
|
||||||
var ftNats messaging.MessagingClient
|
|
||||||
var ftStore *distributed.FineTuneStore
|
|
||||||
if d := application.Distributed(); d != nil {
|
|
||||||
ftNats = d.Nats
|
|
||||||
if d.DistStores != nil && d.DistStores.FineTune != nil {
|
|
||||||
ftStore = d.DistStores.FineTune
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ftService := finetune.NewFineTuneService(
|
ftService := finetune.NewFineTuneService(
|
||||||
application.ApplicationConfig(),
|
application.ApplicationConfig(),
|
||||||
application.ModelLoader(),
|
application.ModelLoader(),
|
||||||
application.ModelConfigLoader(),
|
application.ModelConfigLoader(),
|
||||||
ftNats,
|
|
||||||
ftStore,
|
|
||||||
)
|
)
|
||||||
|
if d := application.Distributed(); d != nil {
|
||||||
|
ftService.SetNATSClient(d.Nats)
|
||||||
|
if d.DistStores != nil && d.DistStores.FineTune != nil {
|
||||||
|
ftService.SetFineTuneStore(d.DistStores.FineTune)
|
||||||
|
}
|
||||||
|
}
|
||||||
routes.RegisterFineTuningRoutes(e, ftService, application.ApplicationConfig(), fineTuningMw)
|
routes.RegisterFineTuningRoutes(e, ftService, application.ApplicationConfig(), fineTuningMw)
|
||||||
|
|
||||||
// Quantization routes
|
// Quantization routes
|
||||||
quantizationMw := auth.RequireFeature(application.AuthDB(), auth.FeatureQuantization)
|
quantizationMw := auth.RequireFeature(application.AuthDB(), auth.FeatureQuantization)
|
||||||
// In distributed mode pass the shared NATS client + PostgreSQL store so
|
|
||||||
// quantization jobs stay consistent across replicas (the SyncedMap broadcasts
|
|
||||||
// mutations and hydrates from the DB); standalone passes nil for both.
|
|
||||||
var quantNats messaging.MessagingClient
|
|
||||||
var quantStore *distributed.QuantStore
|
|
||||||
if d := application.Distributed(); d != nil {
|
|
||||||
quantNats = d.Nats
|
|
||||||
if d.DistStores != nil && d.DistStores.Quant != nil {
|
|
||||||
quantStore = d.DistStores.Quant
|
|
||||||
}
|
|
||||||
}
|
|
||||||
qService := quantization.NewQuantizationService(
|
qService := quantization.NewQuantizationService(
|
||||||
application.ApplicationConfig(),
|
application.ApplicationConfig(),
|
||||||
application.ModelLoader(),
|
application.ModelLoader(),
|
||||||
application.ModelConfigLoader(),
|
application.ModelConfigLoader(),
|
||||||
quantNats,
|
|
||||||
quantStore,
|
|
||||||
)
|
)
|
||||||
routes.RegisterQuantizationRoutes(e, qService, application.ApplicationConfig(), quantizationMw)
|
routes.RegisterQuantizationRoutes(e, qService, application.ApplicationConfig(), quantizationMw)
|
||||||
|
|
||||||
|
|||||||
@@ -68,6 +68,32 @@ var _ = Describe("LLM tests", func() {
|
|||||||
Expect(protoMessages[0].Content).To(Equal("Hello World"))
|
Expect(protoMessages[0].Content).To(Equal("Hello World"))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Regression for mudler/LocalAI#10524: a text part whose inner text is
|
||||||
|
// itself a JSON-array string (mealie sends an ingredient list) must
|
||||||
|
// flatten to that exact string verbatim. ToProto must NOT escape or
|
||||||
|
// restructure it - the C++ backend then treats it as opaque text. This
|
||||||
|
// pins the precise Go-side input that produced the "unsupported
|
||||||
|
// content[].type" gRPC error before the backend stopped re-parsing it.
|
||||||
|
It("flattens a JSON-array-looking text part to the verbatim string (#10524)", func() {
|
||||||
|
ingredients := `["1/4 cup brown sugar, packed","1 pound ground beef"]`
|
||||||
|
messages := Messages{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "text",
|
||||||
|
"text": ingredients,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
protoMessages := messages.ToProto()
|
||||||
|
|
||||||
|
Expect(protoMessages).To(HaveLen(1))
|
||||||
|
Expect(protoMessages[0].Content).To(Equal(ingredients))
|
||||||
|
})
|
||||||
|
|
||||||
It("should convert message with tool_calls", func() {
|
It("should convert message with tool_calls", func() {
|
||||||
messages := Messages{
|
messages := Messages{
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -30,8 +30,6 @@ import (
|
|||||||
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
|
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
"github.com/mudler/LocalAI/core/services/jobs"
|
"github.com/mudler/LocalAI/core/services/jobs"
|
||||||
"github.com/mudler/LocalAI/core/services/messaging"
|
|
||||||
"github.com/mudler/LocalAI/core/services/syncstate"
|
|
||||||
"github.com/mudler/LocalAI/core/templates"
|
"github.com/mudler/LocalAI/core/templates"
|
||||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||||
"github.com/mudler/LocalAI/pkg/model"
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
@@ -45,18 +43,8 @@ type AgentJobService struct {
|
|||||||
configLoader *config.ModelConfigLoader
|
configLoader *config.ModelConfigLoader
|
||||||
evaluator *templates.Evaluator
|
evaluator *templates.Evaluator
|
||||||
|
|
||||||
// tasks is the cross-replica task store: an in-memory map kept consistent
|
|
||||||
// across replicas via NATS, with read-through to the configured persister
|
|
||||||
// (file in standalone, PostgreSQL in distributed). Unlike jobs - which already
|
|
||||||
// converge via the dispatcher + DB read-through - tasks previously read
|
|
||||||
// in-memory only, so ListTasks went stale on non-originating replicas.
|
|
||||||
tasks *syncstate.SyncedMap[string, schema.Task]
|
|
||||||
// taskNats is the distributed NATS client backing the tasks SyncedMap. It is
|
|
||||||
// not available at construction time, so it is injected via SetTaskSyncNATS
|
|
||||||
// during distributed wiring; nil keeps tasks in-memory-only (standalone).
|
|
||||||
taskNats messaging.MessagingClient
|
|
||||||
|
|
||||||
// Storage (in-memory primary, persister for secondary persistence)
|
// Storage (in-memory primary, persister for secondary persistence)
|
||||||
|
tasks *xsync.SyncedMap[string, schema.Task]
|
||||||
jobs *xsync.SyncedMap[string, schema.Job]
|
jobs *xsync.SyncedMap[string, schema.Job]
|
||||||
persister JobPersister
|
persister JobPersister
|
||||||
userID string // Scoping: empty for global (main service), set for per-user instances
|
userID string // Scoping: empty for global (main service), set for per-user instances
|
||||||
@@ -108,31 +96,6 @@ func (s *AgentJobService) SetDistributedJobStore(store *jobs.JobStore) {
|
|||||||
s.persister = &dbJobPersister{store: store}
|
s.persister = &dbJobPersister{store: store}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetTaskSyncNATS wires the distributed NATS client used to keep agent *tasks*
|
|
||||||
// consistent across replicas (jobs already converge via the dispatcher + DB
|
|
||||||
// read-through, so they are left untouched). The client is not available when the
|
|
||||||
// service is constructed, so it is injected here during distributed wiring and the
|
|
||||||
// tasks SyncedMap is rebuilt to pick it up. It is always called before Start /
|
|
||||||
// hydrate, while the map is still empty, so rebuilding loses no state. Passing nil
|
|
||||||
// (standalone) keeps the map in-memory-only with no broadcast.
|
|
||||||
func (s *AgentJobService) SetTaskSyncNATS(nats messaging.MessagingClient) {
|
|
||||||
s.taskNats = nats
|
|
||||||
s.buildTasksMap()
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildTasksMap (re)constructs the cross-replica tasks SyncedMap from the current
|
|
||||||
// taskNats. The Store adapter reads s.persister/s.userID live, so a persister swap
|
|
||||||
// (SetDistributedJobStore) needs no rebuild; only the NATS client, fixed at
|
|
||||||
// New-time, forces one - hence SetTaskSyncNATS calls this.
|
|
||||||
func (s *AgentJobService) buildTasksMap() {
|
|
||||||
s.tasks = syncstate.New(syncstate.Config[string, schema.Task]{
|
|
||||||
Name: "agent.tasks",
|
|
||||||
Key: func(t schema.Task) string { return t.ID },
|
|
||||||
Nats: s.taskNats,
|
|
||||||
Store: &taskStoreAdapter{svc: s},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dispatcher returns the distributed dispatcher (nil if not in distributed mode).
|
// Dispatcher returns the distributed dispatcher (nil if not in distributed mode).
|
||||||
func (s *AgentJobService) Dispatcher() DistributedDispatcher {
|
func (s *AgentJobService) Dispatcher() DistributedDispatcher {
|
||||||
return s.dispatcher
|
return s.dispatcher
|
||||||
@@ -143,6 +106,13 @@ func (s *AgentJobService) DBStore() *jobs.JobStore {
|
|||||||
return s.rawDBStore
|
return s.rawDBStore
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// saveTasks persists tasks via the configured persister (file or DB).
|
||||||
|
func (s *AgentJobService) saveTasks(task schema.Task) {
|
||||||
|
if err := s.persister.SaveTask(s.userID, task); err != nil {
|
||||||
|
xlog.Warn("Failed to persist task", "error", err, "task_id", task.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// saveJobs persists jobs via the configured persister (file or DB).
|
// saveJobs persists jobs via the configured persister (file or DB).
|
||||||
func (s *AgentJobService) saveJobs(job schema.Job) {
|
func (s *AgentJobService) saveJobs(job schema.Job) {
|
||||||
if err := s.persister.SaveJob(s.userID, job); err != nil {
|
if err := s.persister.SaveJob(s.userID, job); err != nil {
|
||||||
@@ -159,8 +129,18 @@ func (s *AgentJobService) LoadFromDB() {
|
|||||||
|
|
||||||
// loadFromPersister loads tasks and jobs from the configured persister into memory.
|
// loadFromPersister loads tasks and jobs from the configured persister into memory.
|
||||||
func (s *AgentJobService) loadFromPersister() {
|
func (s *AgentJobService) loadFromPersister() {
|
||||||
if err := s.hydrateTasks(s.appConfig.Context); err != nil {
|
if tasks, err := s.persister.LoadTasks(s.userID); err != nil {
|
||||||
xlog.Warn("Failed to load tasks from persister", "error", err)
|
xlog.Warn("Failed to load tasks from persister", "error", err)
|
||||||
|
} else {
|
||||||
|
for _, task := range tasks {
|
||||||
|
s.tasks.Set(task.ID, task)
|
||||||
|
if task.Enabled && task.Cron != "" {
|
||||||
|
if err := s.ScheduleCronTask(task); err != nil {
|
||||||
|
xlog.Warn("Failed to schedule cron task on load", "error", err, "task_id", task.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
xlog.Info("Loaded tasks from persister", "count", len(tasks))
|
||||||
}
|
}
|
||||||
|
|
||||||
if loadedJobs, err := s.persister.LoadJobs(s.userID); err != nil {
|
if loadedJobs, err := s.persister.LoadJobs(s.userID); err != nil {
|
||||||
@@ -173,27 +153,6 @@ func (s *AgentJobService) loadFromPersister() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// hydrateTasks loads tasks into the cross-replica SyncedMap and (re)schedules
|
|
||||||
// cron entries for enabled tasks. Hydration goes through the SyncedMap's Store
|
|
||||||
// read-through (Start), not Set, so it neither re-persists nor re-broadcasts the
|
|
||||||
// loaded tasks. Each service instance hydrates exactly once: the main service via
|
|
||||||
// Start -> loadFromPersister, per-user services via LoadFromDB or LoadTasksFromFile.
|
|
||||||
func (s *AgentJobService) hydrateTasks(ctx context.Context) error {
|
|
||||||
if err := s.tasks.Start(ctx); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
tasks := s.tasks.List()
|
|
||||||
for _, task := range tasks {
|
|
||||||
if task.Enabled && task.Cron != "" {
|
|
||||||
if err := s.ScheduleCronTask(task); err != nil {
|
|
||||||
xlog.Warn("Failed to schedule cron task on load", "error", err, "task_id", task.ID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
xlog.Info("Loaded tasks from persister", "count", len(tasks))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// JobExecution represents a job to be executed
|
// JobExecution represents a job to be executed
|
||||||
type JobExecution struct {
|
type JobExecution struct {
|
||||||
Job schema.Job
|
Job schema.Job
|
||||||
@@ -241,19 +200,21 @@ func NewAgentJobServiceWithPaths(
|
|||||||
) *AgentJobService {
|
) *AgentJobService {
|
||||||
retentionDays := cmp.Or(appConfig.AgentJobRetentionDays, 30)
|
retentionDays := cmp.Or(appConfig.AgentJobRetentionDays, 30)
|
||||||
|
|
||||||
|
tasks := xsync.NewSyncedMap[string, schema.Task]()
|
||||||
jobsMap := xsync.NewSyncedMap[string, schema.Job]()
|
jobsMap := xsync.NewSyncedMap[string, schema.Job]()
|
||||||
|
|
||||||
s := &AgentJobService{
|
return &AgentJobService{
|
||||||
appConfig: appConfig,
|
appConfig: appConfig,
|
||||||
modelLoader: modelLoader,
|
modelLoader: modelLoader,
|
||||||
configLoader: configLoader,
|
configLoader: configLoader,
|
||||||
evaluator: evaluator,
|
evaluator: evaluator,
|
||||||
|
tasks: tasks,
|
||||||
jobs: jobsMap,
|
jobs: jobsMap,
|
||||||
persister: &fileJobPersister{
|
persister: &fileJobPersister{
|
||||||
|
tasks: tasks,
|
||||||
jobs: jobsMap,
|
jobs: jobsMap,
|
||||||
tasksFile: tasksFile,
|
tasksFile: tasksFile,
|
||||||
jobsFile: jobsFile,
|
jobsFile: jobsFile,
|
||||||
taskSet: make(map[string]schema.Task),
|
|
||||||
},
|
},
|
||||||
jobQueue: make(chan JobExecution, 100), // Buffer for 100 jobs
|
jobQueue: make(chan JobExecution, 100), // Buffer for 100 jobs
|
||||||
cancellations: xsync.NewSyncedMap[string, context.CancelFunc](),
|
cancellations: xsync.NewSyncedMap[string, context.CancelFunc](),
|
||||||
@@ -261,17 +222,25 @@ func NewAgentJobServiceWithPaths(
|
|||||||
cronEntries: xsync.NewSyncedMap[string, cron.EntryID](),
|
cronEntries: xsync.NewSyncedMap[string, cron.EntryID](),
|
||||||
retentionDays: retentionDays,
|
retentionDays: retentionDays,
|
||||||
}
|
}
|
||||||
// Build the cross-replica tasks map standalone (nil NATS); SetTaskSyncNATS
|
|
||||||
// rebuilds it with the distributed client once that is available, before Start.
|
|
||||||
s.buildTasksMap()
|
|
||||||
return s
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadTasksFromFile loads tasks from the persister into the in-memory map
|
// LoadTasksFromFile loads tasks from the persister into the in-memory map
|
||||||
// and schedules cron entries. Named "FromFile" for backward compat; in DB
|
// and schedules cron entries. Named "FromFile" for backward compat; in DB
|
||||||
// mode it loads from the database.
|
// mode it loads from the database.
|
||||||
func (s *AgentJobService) LoadTasksFromFile() error {
|
func (s *AgentJobService) LoadTasksFromFile() error {
|
||||||
return s.hydrateTasks(s.appConfig.Context)
|
tasks, err := s.persister.LoadTasks(s.userID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, task := range tasks {
|
||||||
|
s.tasks.Set(task.ID, task)
|
||||||
|
if task.Enabled && task.Cron != "" {
|
||||||
|
if err := s.ScheduleCronTask(task); err != nil {
|
||||||
|
xlog.Warn("Failed to schedule cron task on load", "error", err, "task_id", task.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveTasksToFile flushes the current tasks map via the persister. File
|
// SaveTasksToFile flushes the current tasks map via the persister. File
|
||||||
@@ -324,12 +293,8 @@ func (s *AgentJobService) CreateTask(task schema.Task) (string, error) {
|
|||||||
task.Enabled = true // Default to enabled
|
task.Enabled = true // Default to enabled
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store task: Set updates the in-memory map, write-throughs to the persister
|
// Store task
|
||||||
// (file or DB), and broadcasts the create to peer replicas. Background ctx
|
s.tasks.Set(id, task)
|
||||||
// because CreateTask carries no request ctx (mirrors the finetune service).
|
|
||||||
if err := s.tasks.Set(context.Background(), task); err != nil {
|
|
||||||
return "", fmt.Errorf("failed to persist task: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Schedule cron if enabled and has cron expression
|
// Schedule cron if enabled and has cron expression
|
||||||
if task.Enabled && task.Cron != "" {
|
if task.Enabled && task.Cron != "" {
|
||||||
@@ -338,15 +303,16 @@ func (s *AgentJobService) CreateTask(task schema.Task) (string, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.saveTasks(task)
|
||||||
return id, nil
|
return id, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateTask updates an existing task
|
// UpdateTask updates an existing task
|
||||||
func (s *AgentJobService) UpdateTask(id string, task schema.Task) error {
|
func (s *AgentJobService) UpdateTask(id string, task schema.Task) error {
|
||||||
existing, ok := s.tasks.Get(id)
|
if !s.tasks.Exists(id) {
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("%w: %s", ErrTaskNotFound, id)
|
return fmt.Errorf("%w: %s", ErrTaskNotFound, id)
|
||||||
}
|
}
|
||||||
|
existing := s.tasks.Get(id)
|
||||||
|
|
||||||
// Preserve ID and CreatedAt
|
// Preserve ID and CreatedAt
|
||||||
task.ID = id
|
task.ID = id
|
||||||
@@ -358,10 +324,8 @@ func (s *AgentJobService) UpdateTask(id string, task schema.Task) error {
|
|||||||
s.UnscheduleCronTask(id)
|
s.UnscheduleCronTask(id)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store updated task: write-through + broadcast (see CreateTask).
|
// Store updated task
|
||||||
if err := s.tasks.Set(context.Background(), task); err != nil {
|
s.tasks.Set(id, task)
|
||||||
return fmt.Errorf("failed to persist task: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Schedule new cron if enabled and has cron expression
|
// Schedule new cron if enabled and has cron expression
|
||||||
if task.Enabled && task.Cron != "" {
|
if task.Enabled && task.Cron != "" {
|
||||||
@@ -370,22 +334,24 @@ func (s *AgentJobService) UpdateTask(id string, task schema.Task) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.saveTasks(task)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteTask deletes a task
|
// DeleteTask deletes a task
|
||||||
func (s *AgentJobService) DeleteTask(id string) error {
|
func (s *AgentJobService) DeleteTask(id string) error {
|
||||||
if _, ok := s.tasks.Get(id); !ok {
|
if !s.tasks.Exists(id) {
|
||||||
return fmt.Errorf("%w: %s", ErrTaskNotFound, id)
|
return fmt.Errorf("%w: %s", ErrTaskNotFound, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Unschedule cron
|
// Unschedule cron
|
||||||
s.UnscheduleCronTask(id)
|
s.UnscheduleCronTask(id)
|
||||||
|
|
||||||
// Delete removes from the in-memory map, deletes from the persister, and
|
// Remove from memory
|
||||||
// broadcasts the removal to peer replicas.
|
s.tasks.Delete(id)
|
||||||
if err := s.tasks.Delete(context.Background(), id); err != nil {
|
|
||||||
xlog.Warn("Failed to delete task from store", "error", err, "task_id", id)
|
if err := s.persister.DeleteTask(id); err != nil {
|
||||||
|
xlog.Warn("Failed to delete task from persister", "error", err, "task_id", id)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -393,8 +359,8 @@ func (s *AgentJobService) DeleteTask(id string) error {
|
|||||||
|
|
||||||
// GetTask retrieves a task by ID
|
// GetTask retrieves a task by ID
|
||||||
func (s *AgentJobService) GetTask(id string) (*schema.Task, error) {
|
func (s *AgentJobService) GetTask(id string) (*schema.Task, error) {
|
||||||
task, ok := s.tasks.Get(id)
|
task := s.tasks.Get(id)
|
||||||
if !ok {
|
if task.ID == "" {
|
||||||
return nil, fmt.Errorf("%w: %s", ErrTaskNotFound, id)
|
return nil, fmt.Errorf("%w: %s", ErrTaskNotFound, id)
|
||||||
}
|
}
|
||||||
return &task, nil
|
return &task, nil
|
||||||
@@ -402,7 +368,7 @@ func (s *AgentJobService) GetTask(id string) (*schema.Task, error) {
|
|||||||
|
|
||||||
// ListTasks returns all tasks, sorted by creation date (newest first)
|
// ListTasks returns all tasks, sorted by creation date (newest first)
|
||||||
func (s *AgentJobService) ListTasks() []schema.Task {
|
func (s *AgentJobService) ListTasks() []schema.Task {
|
||||||
tasks := s.tasks.List()
|
tasks := s.tasks.Values()
|
||||||
// Sort by CreatedAt descending (newest first), then by Name for stability
|
// Sort by CreatedAt descending (newest first), then by Name for stability
|
||||||
slices.SortFunc(tasks, func(a, b schema.Task) int {
|
slices.SortFunc(tasks, func(a, b schema.Task) int {
|
||||||
if a.CreatedAt.Equal(b.CreatedAt) {
|
if a.CreatedAt.Equal(b.CreatedAt) {
|
||||||
@@ -431,8 +397,8 @@ func (s *AgentJobService) buildPrompt(templateStr string, params map[string]stri
|
|||||||
// ExecuteJob creates and queues a job for execution
|
// ExecuteJob creates and queues a job for execution
|
||||||
// multimedia can be nil for backward compatibility
|
// multimedia can be nil for backward compatibility
|
||||||
func (s *AgentJobService) ExecuteJob(taskID string, params map[string]string, triggeredBy string, multimedia *schema.MultimediaAttachment) (string, error) {
|
func (s *AgentJobService) ExecuteJob(taskID string, params map[string]string, triggeredBy string, multimedia *schema.MultimediaAttachment) (string, error) {
|
||||||
task, ok := s.tasks.Get(taskID)
|
task := s.tasks.Get(taskID)
|
||||||
if !ok {
|
if task.ID == "" {
|
||||||
return "", fmt.Errorf("%w: %s", ErrTaskNotFound, taskID)
|
return "", fmt.Errorf("%w: %s", ErrTaskNotFound, taskID)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1485,12 +1451,6 @@ func (s *AgentJobService) Stop() error {
|
|||||||
if s.cronScheduler != nil {
|
if s.cronScheduler != nil {
|
||||||
s.cronScheduler.Stop()
|
s.cronScheduler.Stop()
|
||||||
}
|
}
|
||||||
// Release the tasks SyncedMap subscription / background workers.
|
|
||||||
if s.tasks != nil {
|
|
||||||
if err := s.tasks.Close(); err != nil {
|
|
||||||
xlog.Warn("Error closing tasks sync map", "error", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
xlog.Info("AgentJobService stopped")
|
xlog.Info("AgentJobService stopped")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,38 +14,24 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// fileJobPersister persists tasks and jobs to JSON files.
|
// fileJobPersister persists tasks and jobs to JSON files.
|
||||||
//
|
// It holds references to the service's syncmaps and serializes the entire
|
||||||
// Jobs serialize the service's in-memory jobs syncmap on each save (bulk write).
|
// map contents on each save (bulk write). Reads at runtime return nil
|
||||||
// Tasks are kept in this persister's own taskSet map instead: the tasks SyncedMap
|
// (the in-memory map is the authoritative source); LoadTasks/LoadJobs
|
||||||
// calls SaveTask/DeleteTask while holding its internal lock (write-through), so
|
// are used only at startup to bootstrap the syncmaps.
|
||||||
// reading back the SyncedMap here would re-enter that lock and deadlock. The
|
|
||||||
// self-contained taskSet, seeded by LoadTasks, lets a per-task write rewrite the
|
|
||||||
// whole bulk file without touching the SyncedMap.
|
|
||||||
//
|
|
||||||
// Runtime reads (GetJob/ListJobs) return nil (the in-memory state is the
|
|
||||||
// authoritative source); LoadTasks/LoadJobs bootstrap state at startup.
|
|
||||||
type fileJobPersister struct {
|
type fileJobPersister struct {
|
||||||
|
tasks *xsync.SyncedMap[string, schema.Task]
|
||||||
jobs *xsync.SyncedMap[string, schema.Job]
|
jobs *xsync.SyncedMap[string, schema.Job]
|
||||||
tasksFile string
|
tasksFile string
|
||||||
jobsFile string
|
jobsFile string
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
// taskSet is the persister's own view of all tasks, seeded by LoadTasks and
|
|
||||||
// updated by SaveTask/DeleteTask. The bulk JSON file is rewritten from it.
|
|
||||||
taskSet map[string]schema.Task
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *fileJobPersister) SaveTask(_ string, task schema.Task) error {
|
func (p *fileJobPersister) SaveTask(_ string, _ schema.Task) error {
|
||||||
p.mu.Lock()
|
return p.saveTasksToFile()
|
||||||
defer p.mu.Unlock()
|
|
||||||
p.taskSet[task.ID] = task
|
|
||||||
return p.writeTasksLocked()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *fileJobPersister) DeleteTask(taskID string) error {
|
func (p *fileJobPersister) DeleteTask(_ string) error {
|
||||||
p.mu.Lock()
|
return p.saveTasksToFile()
|
||||||
defer p.mu.Unlock()
|
|
||||||
delete(p.taskSet, taskID)
|
|
||||||
return p.writeTasksLocked()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *fileJobPersister) SaveJob(_ string, _ schema.Job) error {
|
func (p *fileJobPersister) SaveJob(_ string, _ schema.Job) error {
|
||||||
@@ -57,9 +43,7 @@ func (p *fileJobPersister) DeleteJob(_ string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *fileJobPersister) FlushTasks() error {
|
func (p *fileJobPersister) FlushTasks() error {
|
||||||
p.mu.Lock()
|
return p.saveTasksToFile()
|
||||||
defer p.mu.Unlock()
|
|
||||||
return p.writeTasksLocked()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *fileJobPersister) FlushJobs() error {
|
func (p *fileJobPersister) FlushJobs() error {
|
||||||
@@ -99,12 +83,6 @@ func (p *fileJobPersister) LoadTasks(_ string) ([]schema.Task, error) {
|
|||||||
return nil, fmt.Errorf("failed to parse tasks file: %w", err)
|
return nil, fmt.Errorf("failed to parse tasks file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Seed the in-memory set so subsequent per-task SaveTask/DeleteTask merge into
|
|
||||||
// (rather than overwrite) the persisted tasks when the bulk file is rewritten.
|
|
||||||
for _, t := range tf.Tasks {
|
|
||||||
p.taskSet[t.ID] = t
|
|
||||||
}
|
|
||||||
|
|
||||||
xlog.Info("Loaded tasks from file", "count", len(tf.Tasks))
|
xlog.Info("Loaded tasks from file", "count", len(tf.Tasks))
|
||||||
return tf.Tasks, nil
|
return tf.Tasks, nil
|
||||||
}
|
}
|
||||||
@@ -140,19 +118,18 @@ func (p *fileJobPersister) CleanupOldJobs(_ time.Duration) (int64, error) {
|
|||||||
return 0, nil // cleanup handled via in-memory filtering
|
return 0, nil // cleanup handled via in-memory filtering
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeTasksLocked serializes the persister's task set to the JSON file. Callers
|
// saveTasksToFile serializes the entire tasks map to the JSON file.
|
||||||
// must hold p.mu.
|
func (p *fileJobPersister) saveTasksToFile() error {
|
||||||
func (p *fileJobPersister) writeTasksLocked() error {
|
|
||||||
if p.tasksFile == "" {
|
if p.tasksFile == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
tasks := make([]schema.Task, 0, len(p.taskSet))
|
p.mu.Lock()
|
||||||
for _, t := range p.taskSet {
|
defer p.mu.Unlock()
|
||||||
tasks = append(tasks, t)
|
|
||||||
}
|
|
||||||
|
|
||||||
tf := schema.TasksFile{Tasks: tasks}
|
tf := schema.TasksFile{
|
||||||
|
Tasks: p.tasks.Values(),
|
||||||
|
}
|
||||||
|
|
||||||
data, err := json.MarshalIndent(tf, "", " ")
|
data, err := json.MarshalIndent(tf, "", " ")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -20,26 +20,28 @@ var _ = Describe("JobPersister", func() {
|
|||||||
Context("fileJobPersister", func() {
|
Context("fileJobPersister", func() {
|
||||||
var (
|
var (
|
||||||
p *fileJobPersister
|
p *fileJobPersister
|
||||||
|
tasks *xsync.SyncedMap[string, schema.Task]
|
||||||
jobsMap *xsync.SyncedMap[string, schema.Job]
|
jobsMap *xsync.SyncedMap[string, schema.Job]
|
||||||
tmpDir string
|
tmpDir string
|
||||||
)
|
)
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
tmpDir = GinkgoT().TempDir()
|
tmpDir = GinkgoT().TempDir()
|
||||||
|
tasks = xsync.NewSyncedMap[string, schema.Task]()
|
||||||
jobsMap = xsync.NewSyncedMap[string, schema.Job]()
|
jobsMap = xsync.NewSyncedMap[string, schema.Job]()
|
||||||
p = &fileJobPersister{
|
p = &fileJobPersister{
|
||||||
|
tasks: tasks,
|
||||||
jobs: jobsMap,
|
jobs: jobsMap,
|
||||||
tasksFile: filepath.Join(tmpDir, "tasks.json"),
|
tasksFile: filepath.Join(tmpDir, "tasks.json"),
|
||||||
jobsFile: filepath.Join(tmpDir, "jobs.json"),
|
jobsFile: filepath.Join(tmpDir, "jobs.json"),
|
||||||
// taskSet is the persister's own task view (decoupled from the tasks
|
|
||||||
// SyncedMap to avoid re-entering its lock during write-through).
|
|
||||||
taskSet: make(map[string]schema.Task),
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
It("SaveTask writes all tasks to file", func() {
|
It("SaveTask writes all tasks to file", func() {
|
||||||
Expect(p.SaveTask("", schema.Task{ID: "t1", Name: "Task One", Model: "m", Prompt: "p"})).To(Succeed())
|
tasks.Set("t1", schema.Task{ID: "t1", Name: "Task One", Model: "m", Prompt: "p"})
|
||||||
Expect(p.SaveTask("", schema.Task{ID: "t2", Name: "Task Two", Model: "m", Prompt: "p"})).To(Succeed())
|
tasks.Set("t2", schema.Task{ID: "t2", Name: "Task Two", Model: "m", Prompt: "p"})
|
||||||
|
|
||||||
|
Expect(p.SaveTask("", schema.Task{})).To(Succeed())
|
||||||
|
|
||||||
// Verify file contents
|
// Verify file contents
|
||||||
data, err := os.ReadFile(p.tasksFile)
|
data, err := os.ReadFile(p.tasksFile)
|
||||||
@@ -50,9 +52,11 @@ var _ = Describe("JobPersister", func() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
It("DeleteTask writes updated tasks to file", func() {
|
It("DeleteTask writes updated tasks to file", func() {
|
||||||
Expect(p.SaveTask("", schema.Task{ID: "t1", Name: "Keep"})).To(Succeed())
|
tasks.Set("t1", schema.Task{ID: "t1", Name: "Keep"})
|
||||||
Expect(p.SaveTask("", schema.Task{ID: "t2", Name: "Delete"})).To(Succeed())
|
tasks.Set("t2", schema.Task{ID: "t2", Name: "Delete"})
|
||||||
|
|
||||||
|
// Simulate deletion from memory (caller does this before calling persister)
|
||||||
|
tasks.Delete("t2")
|
||||||
Expect(p.DeleteTask("t2")).To(Succeed())
|
Expect(p.DeleteTask("t2")).To(Succeed())
|
||||||
|
|
||||||
data, err := os.ReadFile(p.tasksFile)
|
data, err := os.ReadFile(p.tasksFile)
|
||||||
|
|||||||
@@ -1,152 +0,0 @@
|
|||||||
package agentpool
|
|
||||||
|
|
||||||
// White-box tests (package agentpool) so a spec can build two AgentJobService
|
|
||||||
// instances sharing one in-memory bus and assert that agent *tasks* converge
|
|
||||||
// across replicas - the bug this migration fixes (ListTasks used to read
|
|
||||||
// in-memory only, so a task created on replica A was invisible on replica B).
|
|
||||||
// Jobs are deliberately untouched here: they already converge via the dispatcher
|
|
||||||
// + DB read-through.
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo/v2"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core/config"
|
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
|
||||||
"github.com/mudler/LocalAI/core/services/messaging"
|
|
||||||
"github.com/mudler/LocalAI/core/services/syncstate"
|
|
||||||
"github.com/mudler/LocalAI/core/services/testutil"
|
|
||||||
"github.com/mudler/LocalAI/pkg/system"
|
|
||||||
)
|
|
||||||
|
|
||||||
// newTaskSyncService builds an AgentJobService wired to the given bus and a
|
|
||||||
// throwaway data dir (so the file persister has somewhere to write). Model/config
|
|
||||||
// loaders are nil because the task sync paths under test never touch them.
|
|
||||||
func newTaskSyncService(bus messaging.MessagingClient) *AgentJobService {
|
|
||||||
tmpDir := GinkgoT().TempDir()
|
|
||||||
sysState := &system.SystemState{}
|
|
||||||
sysState.Model.ModelsPath = tmpDir
|
|
||||||
appConfig := config.NewApplicationConfig(
|
|
||||||
config.WithDynamicConfigDir(tmpDir),
|
|
||||||
config.WithContext(context.Background()),
|
|
||||||
)
|
|
||||||
appConfig.SystemState = sysState
|
|
||||||
|
|
||||||
svc := NewAgentJobServiceWithPaths(appConfig, nil, nil, nil,
|
|
||||||
// Distinct per-replica files so the file persister write-through never
|
|
||||||
// crosses replicas: convergence here must be proven via the bus alone.
|
|
||||||
tmpDir+"/tasks.json", tmpDir+"/jobs.json")
|
|
||||||
svc.SetTaskSyncNATS(bus)
|
|
||||||
return svc
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ = Describe("AgentJobService task cross-replica sync", func() {
|
|
||||||
Describe("two replicas sharing one bus", func() {
|
|
||||||
var (
|
|
||||||
bus *testutil.FakeBus
|
|
||||||
a, b *AgentJobService
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
// One shared bus, two replicas: exactly the distributed topology where a
|
|
||||||
// round-robin request may land on a replica that did not originate the
|
|
||||||
// change.
|
|
||||||
bus = testutil.NewFakeBus()
|
|
||||||
a = newTaskSyncService(bus)
|
|
||||||
b = newTaskSyncService(bus)
|
|
||||||
// Start hydrates (empty here) and subscribes both replicas to deltas.
|
|
||||||
Expect(a.Start(context.Background())).To(Succeed())
|
|
||||||
Expect(b.Start(context.Background())).To(Succeed())
|
|
||||||
})
|
|
||||||
|
|
||||||
AfterEach(func() {
|
|
||||||
Expect(a.Stop()).To(Succeed())
|
|
||||||
Expect(b.Stop()).To(Succeed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("makes a task created on A visible via B's GetTask and ListTasks", func() {
|
|
||||||
id, err := a.CreateTask(schema.Task{Name: "Shared", Model: "m", Prompt: "p"})
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
|
|
||||||
got, err := b.GetTask(id)
|
|
||||||
Expect(err).NotTo(HaveOccurred(), "B must see a task A just created")
|
|
||||||
Expect(got.Name).To(Equal("Shared"))
|
|
||||||
|
|
||||||
listed := b.ListTasks()
|
|
||||||
Expect(listed).To(HaveLen(1))
|
|
||||||
Expect(listed[0].ID).To(Equal(id))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("propagates a task update from A to B", func() {
|
|
||||||
id, err := a.CreateTask(schema.Task{Name: "Before", Model: "m", Prompt: "p"})
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
|
|
||||||
Expect(a.UpdateTask(id, schema.Task{Name: "After", Model: "m", Prompt: "p"})).To(Succeed())
|
|
||||||
|
|
||||||
got, err := b.GetTask(id)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Expect(got.Name).To(Equal("After"), "an update on A must be visible on B")
|
|
||||||
})
|
|
||||||
|
|
||||||
It("removes a task from B when it is deleted on A", func() {
|
|
||||||
id, err := a.CreateTask(schema.Task{Name: "Doomed", Model: "m", Prompt: "p"})
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
_, err = b.GetTask(id)
|
|
||||||
Expect(err).NotTo(HaveOccurred(), "precondition: B must have the task before the delete")
|
|
||||||
|
|
||||||
Expect(a.DeleteTask(id)).To(Succeed())
|
|
||||||
|
|
||||||
_, err = b.GetTask(id)
|
|
||||||
Expect(err).To(HaveOccurred(), "a delete on A must remove the task from B")
|
|
||||||
Expect(b.ListTasks()).To(BeEmpty())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("does not re-broadcast a delta it received (echo-loop guard)", func() {
|
|
||||||
subject := messaging.SubjectSyncStateDelta("agent.tasks")
|
|
||||||
|
|
||||||
_, err := a.CreateTask(schema.Task{Name: "Once", Model: "m", Prompt: "p"})
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
|
|
||||||
// Exactly one publish: A's create. B applies it without re-publishing,
|
|
||||||
// otherwise this would be 2+ and a real bus would storm.
|
|
||||||
Expect(bus.PublishCount(subject)).To(Equal(1))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Describe("ListTasks ordering and scoping", func() {
|
|
||||||
var svc *AgentJobService
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
svc = newTaskSyncService(testutil.NewFakeBus())
|
|
||||||
Expect(svc.Start(context.Background())).To(Succeed())
|
|
||||||
})
|
|
||||||
AfterEach(func() { Expect(svc.Stop()).To(Succeed()) })
|
|
||||||
|
|
||||||
It("sorts newest-first, breaking ties by name", func() {
|
|
||||||
// CreateTask stamps CreatedAt with time.Now(); space them out so ordering
|
|
||||||
// is deterministic rather than relying on the sub-millisecond gap.
|
|
||||||
oldID, err := svc.CreateTask(schema.Task{Name: "Old", Model: "m", Prompt: "p"})
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
time.Sleep(5 * time.Millisecond)
|
|
||||||
newID, err := svc.CreateTask(schema.Task{Name: "New", Model: "m", Prompt: "p"})
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
|
|
||||||
listed := svc.ListTasks()
|
|
||||||
Expect(listed).To(HaveLen(2))
|
|
||||||
Expect(listed[0].ID).To(Equal(newID), "newest first")
|
|
||||||
Expect(listed[1].ID).To(Equal(oldID))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Describe("compile-time adapter contract", func() {
|
|
||||||
It("satisfies syncstate.Store for tasks", func() {
|
|
||||||
// Mirrors the var assertion in task_syncstore.go; keeps the type
|
|
||||||
// referenced from a spec so drift surfaces here too.
|
|
||||||
var _ syncstate.Store[string, schema.Task] = (*taskStoreAdapter)(nil)
|
|
||||||
Expect(&taskStoreAdapter{}).ToNot(BeNil())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
@@ -1,47 +0,0 @@
|
|||||||
package agentpool
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
|
||||||
"github.com/mudler/LocalAI/core/services/syncstate"
|
|
||||||
)
|
|
||||||
|
|
||||||
// taskStoreAdapter bridges the existing JobPersister (file- or DB-backed) to the
|
|
||||||
// generic syncstate.Store the tasks SyncedMap consumes. Only tasks are migrated:
|
|
||||||
// jobs already converge across replicas via the dispatcher (NATS) plus the DB
|
|
||||||
// read-through in ListJobs/GetJob, whereas ListTasks read in-memory only and so
|
|
||||||
// went stale on replicas that did not originate the change.
|
|
||||||
//
|
|
||||||
// The adapter reads svc.persister and svc.userID live (rather than capturing
|
|
||||||
// them) because both are configured by setters - SetDistributedJobStore swaps the
|
|
||||||
// file persister for the DB one, SetUserID scopes per-user queries - AFTER the
|
|
||||||
// service, and thus this adapter, is constructed. Reading them at call time means
|
|
||||||
// the SyncedMap never has to be rebuilt when the persister is swapped.
|
|
||||||
//
|
|
||||||
// The SyncedMap value type is schema.Task: the exact shape ListTasks returns, so
|
|
||||||
// reads need no conversion and REST responses are provably unchanged.
|
|
||||||
type taskStoreAdapter struct {
|
|
||||||
svc *AgentJobService
|
|
||||||
}
|
|
||||||
|
|
||||||
// compile-time assertion that the adapter satisfies the component's Store.
|
|
||||||
var _ syncstate.Store[string, schema.Task] = (*taskStoreAdapter)(nil)
|
|
||||||
|
|
||||||
// List hydrates the map from durable storage on Start/reconnect: the file's task
|
|
||||||
// list (standalone) or every task row (DB / distributed).
|
|
||||||
func (a *taskStoreAdapter) List(_ context.Context) ([]schema.Task, error) {
|
|
||||||
return a.svc.persister.LoadTasks(a.svc.userID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Upsert write-through persists a single task created/updated locally; the
|
|
||||||
// SyncedMap then broadcasts the delta to peers.
|
|
||||||
func (a *taskStoreAdapter) Upsert(_ context.Context, task schema.Task) error {
|
|
||||||
return a.svc.persister.SaveTask(a.svc.userID, task)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete write-through removes a task locally; the SyncedMap then broadcasts the
|
|
||||||
// removal to peers.
|
|
||||||
func (a *taskStoreAdapter) Delete(_ context.Context, id string) error {
|
|
||||||
return a.svc.persister.DeleteTask(id)
|
|
||||||
}
|
|
||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"github.com/mudler/LocalAGI/webui/collections"
|
"github.com/mudler/LocalAGI/webui/collections"
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
"github.com/mudler/LocalAI/core/services/jobs"
|
"github.com/mudler/LocalAI/core/services/jobs"
|
||||||
"github.com/mudler/LocalAI/core/services/messaging"
|
|
||||||
"github.com/mudler/LocalAI/core/templates"
|
"github.com/mudler/LocalAI/core/templates"
|
||||||
"github.com/mudler/LocalAI/pkg/model"
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
"github.com/mudler/xlog"
|
"github.com/mudler/xlog"
|
||||||
@@ -29,9 +28,6 @@ type UserServicesManager struct {
|
|||||||
// Shared distributed backends (set once, inherited by per-user job services)
|
// Shared distributed backends (set once, inherited by per-user job services)
|
||||||
jobDispatcher DistributedDispatcher
|
jobDispatcher DistributedDispatcher
|
||||||
jobDBStore *jobs.JobStore
|
jobDBStore *jobs.JobStore
|
||||||
// jobNats keeps per-user agent tasks consistent across replicas (nil in
|
|
||||||
// standalone). Inherited by each per-user AgentJobService.
|
|
||||||
jobNats messaging.MessagingClient
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUserServicesManager creates a new UserServicesManager.
|
// NewUserServicesManager creates a new UserServicesManager.
|
||||||
@@ -166,10 +162,6 @@ func (m *UserServicesManager) GetJobs(userID string) (*AgentJobService, error) {
|
|||||||
if m.jobDispatcher != nil {
|
if m.jobDispatcher != nil {
|
||||||
svc.SetDistributedBackends(m.jobDispatcher)
|
svc.SetDistributedBackends(m.jobDispatcher)
|
||||||
}
|
}
|
||||||
// Inherit the NATS client so per-user tasks broadcast across replicas. Must be
|
|
||||||
// set before the hydrate below (LoadFromDB / LoadTasksFromFile) so the tasks
|
|
||||||
// SyncedMap is rebuilt with the client while it is still empty.
|
|
||||||
svc.SetTaskSyncNATS(m.jobNats)
|
|
||||||
if m.jobDBStore != nil {
|
if m.jobDBStore != nil {
|
||||||
svc.SetDistributedJobStore(m.jobDBStore)
|
svc.SetDistributedJobStore(m.jobDBStore)
|
||||||
// Load tasks/jobs from DB immediately (per-user services skip Start())
|
// Load tasks/jobs from DB immediately (per-user services skip Start())
|
||||||
@@ -197,12 +189,6 @@ func (m *UserServicesManager) SetJobDBStore(s *jobs.JobStore) {
|
|||||||
m.jobDBStore = s
|
m.jobDBStore = s
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetJobSyncNATS sets the NATS client used to keep per-user agent tasks consistent
|
|
||||||
// across replicas.
|
|
||||||
func (m *UserServicesManager) SetJobSyncNATS(nats messaging.MessagingClient) {
|
|
||||||
m.jobNats = nats
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListAllUserIDs returns all user IDs that have scoped data directories.
|
// ListAllUserIDs returns all user IDs that have scoped data directories.
|
||||||
func (m *UserServicesManager) ListAllUserIDs() ([]string, error) {
|
func (m *UserServicesManager) ListAllUserIDs() ([]string, error) {
|
||||||
return m.storage.ListUserDirs()
|
return m.storage.ListUserDirs()
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/mudler/LocalAI/core/services/advisorylock"
|
"github.com/mudler/LocalAI/core/services/advisorylock"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"gorm.io/gorm/clause"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// FineTuneJobRecord tracks fine-tune jobs in PostgreSQL.
|
// FineTuneJobRecord tracks fine-tune jobs in PostgreSQL.
|
||||||
@@ -81,34 +80,6 @@ func (s *FineTuneStore) List(userID string) ([]FineTuneJobRecord, error) {
|
|||||||
return jobs, q.Find(&jobs).Error
|
return jobs, q.Find(&jobs).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListAll returns every fine-tune job across all users. The SyncedMap that backs
|
|
||||||
// FineTuneService is a single global map (the REST API filters by user at read
|
|
||||||
// time), so hydrate needs the full set rather than the per-user List above.
|
|
||||||
func (s *FineTuneStore) ListAll() ([]FineTuneJobRecord, error) {
|
|
||||||
var jobs []FineTuneJobRecord
|
|
||||||
return jobs, s.db.Order("created_at DESC").Find(&jobs).Error
|
|
||||||
}
|
|
||||||
|
|
||||||
// Upsert idempotently inserts or fully replaces a job row by primary key. The
|
|
||||||
// SyncedMap write-through path issues a single Set per mutation regardless of
|
|
||||||
// whether the job already exists, so it needs one create-or-update primitive
|
|
||||||
// (Create alone fails on a duplicate key, UpdateStatus alone misses new rows and
|
|
||||||
// only touches a few columns).
|
|
||||||
func (s *FineTuneStore) Upsert(job *FineTuneJobRecord) error {
|
|
||||||
if job.ID == "" {
|
|
||||||
job.ID = uuid.New().String()
|
|
||||||
}
|
|
||||||
now := time.Now()
|
|
||||||
if job.CreatedAt.IsZero() {
|
|
||||||
job.CreatedAt = now
|
|
||||||
}
|
|
||||||
job.UpdatedAt = now
|
|
||||||
return s.db.Clauses(clause.OnConflict{
|
|
||||||
Columns: []clause.Column{{Name: "id"}},
|
|
||||||
UpdateAll: true,
|
|
||||||
}).Create(job).Error
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateStatus updates the status and message of a fine-tune job.
|
// UpdateStatus updates the status and message of a fine-tune job.
|
||||||
func (s *FineTuneStore) UpdateStatus(id, status, message string) error {
|
func (s *FineTuneStore) UpdateStatus(id, status, message string) error {
|
||||||
return s.db.Model(&FineTuneJobRecord{}).Where("id = ?", id).Updates(map[string]any{
|
return s.db.Model(&FineTuneJobRecord{}).Where("id = ?", id).Updates(map[string]any{
|
||||||
|
|||||||
@@ -1,13 +0,0 @@
|
|||||||
package distributed_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo/v2"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestDistributed(t *testing.T) {
|
|
||||||
RegisterFailHandler(Fail)
|
|
||||||
RunSpecs(t, "Distributed Suite")
|
|
||||||
}
|
|
||||||
@@ -1,61 +0,0 @@
|
|||||||
package distributed_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
. "github.com/onsi/ginkgo/v2"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core/services/distributed"
|
|
||||||
"github.com/mudler/LocalAI/core/services/testutil"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("FineTuneStore", func() {
|
|
||||||
var store *distributed.FineTuneStore
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
db := testutil.SetupTestDB()
|
|
||||||
var err error
|
|
||||||
store, err = distributed.NewFineTuneStore(db)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
Describe("ListAll", func() {
|
|
||||||
It("returns jobs across all users (unlike per-user List)", func() {
|
|
||||||
Expect(store.Create(&distributed.FineTuneJobRecord{ID: "j1", UserID: "u1", Status: "queued"})).To(Succeed())
|
|
||||||
Expect(store.Create(&distributed.FineTuneJobRecord{ID: "j2", UserID: "u2", Status: "queued"})).To(Succeed())
|
|
||||||
|
|
||||||
all, err := store.ListAll()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(all).To(HaveLen(2))
|
|
||||||
|
|
||||||
perUser, err := store.List("u1")
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(perUser).To(HaveLen(1), "List stays per-user")
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Describe("Upsert", func() {
|
|
||||||
It("inserts a new row", func() {
|
|
||||||
Expect(store.Upsert(&distributed.FineTuneJobRecord{ID: "up-1", UserID: "u1", Status: "queued"})).To(Succeed())
|
|
||||||
|
|
||||||
got, err := store.Get("up-1")
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(got.Status).To(Equal("queued"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("idempotently updates an existing row on a repeated key", func() {
|
|
||||||
Expect(store.Upsert(&distributed.FineTuneJobRecord{ID: "up-2", UserID: "u1", Status: "queued"})).To(Succeed())
|
|
||||||
// Second Upsert with the same primary key must update, not error on a
|
|
||||||
// duplicate-key violation (this is the SyncedMap write-through contract).
|
|
||||||
Expect(store.Upsert(&distributed.FineTuneJobRecord{ID: "up-2", UserID: "u1", Status: "completed", Message: "done"})).To(Succeed())
|
|
||||||
|
|
||||||
got, err := store.Get("up-2")
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(got.Status).To(Equal("completed"))
|
|
||||||
Expect(got.Message).To(Equal("done"))
|
|
||||||
|
|
||||||
all, err := store.ListAll()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(all).To(HaveLen(1), "upsert must not create a duplicate")
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
@@ -11,7 +11,6 @@ import (
|
|||||||
type Stores struct {
|
type Stores struct {
|
||||||
Gallery *GalleryStore
|
Gallery *GalleryStore
|
||||||
FineTune *FineTuneStore
|
FineTune *FineTuneStore
|
||||||
Quant *QuantStore
|
|
||||||
Skills *SkillStore
|
Skills *SkillStore
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -27,21 +26,15 @@ func InitStores(db *gorm.DB) (*Stores, error) {
|
|||||||
return nil, fmt.Errorf("fine-tune store: %w", err)
|
return nil, fmt.Errorf("fine-tune store: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
quant, err := NewQuantStore(db)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("quantization store: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
skills, err := NewSkillStore(db)
|
skills, err := NewSkillStore(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("skills store: %w", err)
|
return nil, fmt.Errorf("skills store: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
xlog.Info("Distributed stores initialized (Gallery, FineTune, Quant, Skills)")
|
xlog.Info("Distributed stores initialized (Gallery, FineTune, Skills)")
|
||||||
return &Stores{
|
return &Stores{
|
||||||
Gallery: gallery,
|
Gallery: gallery,
|
||||||
FineTune: ft,
|
FineTune: ft,
|
||||||
Quant: quant,
|
|
||||||
Skills: skills,
|
Skills: skills,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,105 +0,0 @@
|
|||||||
package distributed
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/mudler/LocalAI/core/services/advisorylock"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
"gorm.io/gorm/clause"
|
|
||||||
)
|
|
||||||
|
|
||||||
// QuantJobRecord tracks quantization jobs in PostgreSQL. The columns mirror the
|
|
||||||
// API shape (schema.QuantizationJob); the structured Config and ExtraOptions are
|
|
||||||
// serialized into JSON text columns so a record fully reconstructs the job.
|
|
||||||
type QuantJobRecord struct {
|
|
||||||
ID string `gorm:"primaryKey;size:36" json:"id"`
|
|
||||||
UserID string `gorm:"index;size:36" json:"user_id,omitempty"`
|
|
||||||
Model string `gorm:"size:255" json:"model"`
|
|
||||||
Backend string `gorm:"size:64" json:"backend"`
|
|
||||||
ModelID string `gorm:"size:255" json:"model_id,omitempty"`
|
|
||||||
QuantizationType string `gorm:"size:32" json:"quantization_type"`
|
|
||||||
Status string `gorm:"index;size:32;default:queued" json:"status"` // queued, downloading, converting, quantizing, completed, failed, stopped
|
|
||||||
Message string `gorm:"type:text" json:"message,omitempty"`
|
|
||||||
OutputDir string `gorm:"size:512" json:"output_dir,omitempty"`
|
|
||||||
OutputFile string `gorm:"size:512" json:"output_file,omitempty"`
|
|
||||||
ConfigJSON string `gorm:"column:config;type:text" json:"-"`
|
|
||||||
ExtraOptsJSON string `gorm:"column:extra_options;type:text" json:"-"`
|
|
||||||
ImportStatus string `gorm:"size:32" json:"import_status,omitempty"`
|
|
||||||
ImportMessage string `gorm:"type:text" json:"import_message,omitempty"`
|
|
||||||
ImportModelName string `gorm:"size:255" json:"import_model_name,omitempty"`
|
|
||||||
CreatedAt time.Time `json:"created_at"`
|
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (QuantJobRecord) TableName() string { return "quantization_jobs" }
|
|
||||||
|
|
||||||
// QuantStore manages quantization job state in PostgreSQL.
|
|
||||||
type QuantStore struct {
|
|
||||||
db *gorm.DB
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewQuantStore creates a new QuantStore and auto-migrates.
|
|
||||||
// Uses a PostgreSQL advisory lock to prevent concurrent migration races
|
|
||||||
// when multiple instances (frontend + workers) start at the same time.
|
|
||||||
func NewQuantStore(db *gorm.DB) (*QuantStore, error) {
|
|
||||||
if err := advisorylock.WithLockCtx(context.Background(), db, advisorylock.KeySchemaMigrate, func() error {
|
|
||||||
return db.AutoMigrate(&QuantJobRecord{})
|
|
||||||
}); err != nil {
|
|
||||||
return nil, fmt.Errorf("migrating quantization_jobs: %w", err)
|
|
||||||
}
|
|
||||||
return &QuantStore{db: db}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create stores a new quantization job.
|
|
||||||
func (s *QuantStore) Create(job *QuantJobRecord) error {
|
|
||||||
if job.ID == "" {
|
|
||||||
job.ID = uuid.New().String()
|
|
||||||
}
|
|
||||||
job.CreatedAt = time.Now()
|
|
||||||
job.UpdatedAt = job.CreatedAt
|
|
||||||
return s.db.Create(job).Error
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get retrieves a quantization job by ID.
|
|
||||||
func (s *QuantStore) Get(id string) (*QuantJobRecord, error) {
|
|
||||||
var job QuantJobRecord
|
|
||||||
if err := s.db.First(&job, "id = ?", id).Error; err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &job, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListAll returns every quantization job across all users. The SyncedMap that
|
|
||||||
// backs QuantizationService is a single global map (the REST API filters by user
|
|
||||||
// at read time), so hydrate needs the full set.
|
|
||||||
func (s *QuantStore) ListAll() ([]QuantJobRecord, error) {
|
|
||||||
var jobs []QuantJobRecord
|
|
||||||
return jobs, s.db.Order("created_at DESC").Find(&jobs).Error
|
|
||||||
}
|
|
||||||
|
|
||||||
// Upsert idempotently inserts or fully replaces a job row by primary key. The
|
|
||||||
// SyncedMap write-through path issues a single Set per mutation regardless of
|
|
||||||
// whether the job already exists, so it needs one create-or-update primitive
|
|
||||||
// (Create alone fails on a duplicate key).
|
|
||||||
func (s *QuantStore) Upsert(job *QuantJobRecord) error {
|
|
||||||
if job.ID == "" {
|
|
||||||
job.ID = uuid.New().String()
|
|
||||||
}
|
|
||||||
now := time.Now()
|
|
||||||
if job.CreatedAt.IsZero() {
|
|
||||||
job.CreatedAt = now
|
|
||||||
}
|
|
||||||
job.UpdatedAt = now
|
|
||||||
return s.db.Clauses(clause.OnConflict{
|
|
||||||
Columns: []clause.Column{{Name: "id"}},
|
|
||||||
UpdateAll: true,
|
|
||||||
}).Create(job).Error
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete removes a quantization job.
|
|
||||||
func (s *QuantStore) Delete(id string) error {
|
|
||||||
return s.db.Where("id = ?", id).Delete(&QuantJobRecord{}).Error
|
|
||||||
}
|
|
||||||
@@ -1,57 +0,0 @@
|
|||||||
package distributed_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
. "github.com/onsi/ginkgo/v2"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core/services/distributed"
|
|
||||||
"github.com/mudler/LocalAI/core/services/testutil"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("QuantStore", func() {
|
|
||||||
var store *distributed.QuantStore
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
db := testutil.SetupTestDB()
|
|
||||||
var err error
|
|
||||||
store, err = distributed.NewQuantStore(db)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
Describe("ListAll", func() {
|
|
||||||
It("returns jobs across all users", func() {
|
|
||||||
Expect(store.Create(&distributed.QuantJobRecord{ID: "j1", UserID: "u1", Status: "queued"})).To(Succeed())
|
|
||||||
Expect(store.Create(&distributed.QuantJobRecord{ID: "j2", UserID: "u2", Status: "queued"})).To(Succeed())
|
|
||||||
|
|
||||||
all, err := store.ListAll()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(all).To(HaveLen(2))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Describe("Upsert", func() {
|
|
||||||
It("inserts a new row", func() {
|
|
||||||
Expect(store.Upsert(&distributed.QuantJobRecord{ID: "up-1", UserID: "u1", Status: "queued"})).To(Succeed())
|
|
||||||
|
|
||||||
got, err := store.Get("up-1")
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(got.Status).To(Equal("queued"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("idempotently updates an existing row on a repeated key", func() {
|
|
||||||
Expect(store.Upsert(&distributed.QuantJobRecord{ID: "up-2", UserID: "u1", Status: "queued"})).To(Succeed())
|
|
||||||
// Second Upsert with the same primary key must update, not error on a
|
|
||||||
// duplicate-key violation (this is the SyncedMap write-through contract).
|
|
||||||
Expect(store.Upsert(&distributed.QuantJobRecord{ID: "up-2", UserID: "u1", Status: "completed", Message: "done"})).To(Succeed())
|
|
||||||
|
|
||||||
got, err := store.Get("up-2")
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(got.Status).To(Equal("completed"))
|
|
||||||
Expect(got.Message).To(Equal("done"))
|
|
||||||
|
|
||||||
all, err := store.ListAll()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(all).To(HaveLen(1), "upsert must not create a duplicate")
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
package finetune
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo/v2"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestFinetune(t *testing.T) {
|
|
||||||
RegisterFailHandler(Fail)
|
|
||||||
RunSpecs(t, "Finetune Suite")
|
|
||||||
}
|
|
||||||
@@ -19,7 +19,6 @@ import (
|
|||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
"github.com/mudler/LocalAI/core/services/distributed"
|
"github.com/mudler/LocalAI/core/services/distributed"
|
||||||
"github.com/mudler/LocalAI/core/services/messaging"
|
"github.com/mudler/LocalAI/core/services/messaging"
|
||||||
"github.com/mudler/LocalAI/core/services/syncstate"
|
|
||||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||||
"github.com/mudler/LocalAI/pkg/model"
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
"github.com/mudler/LocalAI/pkg/utils"
|
"github.com/mudler/LocalAI/pkg/utils"
|
||||||
@@ -33,63 +32,44 @@ type FineTuneService struct {
|
|||||||
modelLoader *model.ModelLoader
|
modelLoader *model.ModelLoader
|
||||||
configLoader *config.ModelConfigLoader
|
configLoader *config.ModelConfigLoader
|
||||||
|
|
||||||
// mu serializes the read-modify-write of job values. The SyncedMap guards its
|
mu sync.Mutex
|
||||||
// own map structure, but a job is a pointer mutated in place (e.g. the export
|
jobs map[string]*schema.FineTuneJob
|
||||||
// goroutine), so the service still needs a lock to keep those field updates
|
|
||||||
// and the subsequent Set atomic with respect to readers.
|
|
||||||
mu sync.Mutex
|
|
||||||
|
|
||||||
// jobs is the cross-replica job store: an in-memory map kept consistent across
|
// Distributed mode (nil when not in distributed mode)
|
||||||
// replicas via NATS, optionally read-through to PostgreSQL in distributed mode.
|
natsClient messaging.Publisher
|
||||||
jobs *syncstate.SyncedMap[string, *schema.FineTuneJob]
|
fineTuneStore *distributed.FineTuneStore
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewFineTuneService creates a new FineTuneService. In distributed mode pass the
|
// SetNATSClient sets the NATS client for distributed progress publishing.
|
||||||
// shared NATS client and PostgreSQL store so jobs stay consistent across
|
func (s *FineTuneService) SetNATSClient(nc messaging.Publisher) {
|
||||||
// replicas; pass nil for both in standalone mode, where the disk Loader hydrates
|
s.mu.Lock()
|
||||||
// the map and there is nothing to broadcast.
|
defer s.mu.Unlock()
|
||||||
|
s.natsClient = nc
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetFineTuneStore sets the PostgreSQL fine-tune store for distributed persistence.
|
||||||
|
func (s *FineTuneService) SetFineTuneStore(store *distributed.FineTuneStore) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.fineTuneStore = store
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFineTuneService creates a new FineTuneService.
|
||||||
func NewFineTuneService(
|
func NewFineTuneService(
|
||||||
appConfig *config.ApplicationConfig,
|
appConfig *config.ApplicationConfig,
|
||||||
modelLoader *model.ModelLoader,
|
modelLoader *model.ModelLoader,
|
||||||
configLoader *config.ModelConfigLoader,
|
configLoader *config.ModelConfigLoader,
|
||||||
nats messaging.MessagingClient,
|
|
||||||
store *distributed.FineTuneStore,
|
|
||||||
) *FineTuneService {
|
) *FineTuneService {
|
||||||
s := &FineTuneService{
|
s := &FineTuneService{
|
||||||
appConfig: appConfig,
|
appConfig: appConfig,
|
||||||
modelLoader: modelLoader,
|
modelLoader: modelLoader,
|
||||||
configLoader: configLoader,
|
configLoader: configLoader,
|
||||||
|
jobs: make(map[string]*schema.FineTuneJob),
|
||||||
}
|
}
|
||||||
|
s.loadAllJobs()
|
||||||
// Only attach a Store interface when a concrete store exists, otherwise the
|
|
||||||
// SyncedMap would see a non-nil interface wrapping a nil pointer and try to
|
|
||||||
// hydrate/write through a nil DB.
|
|
||||||
var syncStore syncstate.Store[string, *schema.FineTuneJob]
|
|
||||||
if store != nil {
|
|
||||||
syncStore = &fineTuneStoreAdapter{store: store}
|
|
||||||
}
|
|
||||||
|
|
||||||
s.jobs = syncstate.New(syncstate.Config[string, *schema.FineTuneJob]{
|
|
||||||
Name: "finetune.jobs",
|
|
||||||
Key: func(j *schema.FineTuneJob) string { return j.ID },
|
|
||||||
Nats: nats,
|
|
||||||
Store: syncStore,
|
|
||||||
Loader: s.loadJobsFromDisk, // ignored when Store is set (distributed mode)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Hydrate + subscribe. A hydrate failure must not take the server down: log
|
|
||||||
// and continue degraded (standalone), mirroring the OpCache wiring.
|
|
||||||
if err := s.jobs.Start(appConfig.Context); err != nil {
|
|
||||||
xlog.Warn("FineTune SyncedMap start failed; running degraded", "error", err)
|
|
||||||
}
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close releases the SyncedMap subscription and background workers.
|
|
||||||
func (s *FineTuneService) Close() error {
|
|
||||||
return s.jobs.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// fineTuneBaseDir returns the base directory for fine-tune job data.
|
// fineTuneBaseDir returns the base directory for fine-tune job data.
|
||||||
func (s *FineTuneService) fineTuneBaseDir() string {
|
func (s *FineTuneService) fineTuneBaseDir() string {
|
||||||
return filepath.Join(s.appConfig.DataPath, "fine-tune")
|
return filepath.Join(s.appConfig.DataPath, "fine-tune")
|
||||||
@@ -120,18 +100,15 @@ func (s *FineTuneService) saveJobState(job *schema.FineTuneJob) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// loadJobsFromDisk scans the fine-tune directory for persisted jobs and returns
|
// loadAllJobs scans the fine-tune directory for persisted jobs and loads them.
|
||||||
// them. It is the SyncedMap Loader used in standalone mode (no DB); the returned
|
func (s *FineTuneService) loadAllJobs() {
|
||||||
// slice hydrates the map on Start.
|
|
||||||
func (s *FineTuneService) loadJobsFromDisk(_ context.Context) ([]*schema.FineTuneJob, error) {
|
|
||||||
baseDir := s.fineTuneBaseDir()
|
baseDir := s.fineTuneBaseDir()
|
||||||
entries, err := os.ReadDir(baseDir)
|
entries, err := os.ReadDir(baseDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Directory doesn't exist yet — that's fine, start empty.
|
// Directory doesn't exist yet — that's fine
|
||||||
return nil, nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var jobs []*schema.FineTuneJob
|
|
||||||
for _, entry := range entries {
|
for _, entry := range entries {
|
||||||
if !entry.IsDir() {
|
if !entry.IsDir() {
|
||||||
continue
|
continue
|
||||||
@@ -160,13 +137,12 @@ func (s *FineTuneService) loadJobsFromDisk(_ context.Context) ([]*schema.FineTun
|
|||||||
job.ExportMessage = "Server restarted while export was running"
|
job.ExportMessage = "Server restarted while export was running"
|
||||||
}
|
}
|
||||||
|
|
||||||
jobs = append(jobs, &job)
|
s.jobs[job.ID] = &job
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(jobs) > 0 {
|
if len(s.jobs) > 0 {
|
||||||
xlog.Info("Loaded persisted fine-tune jobs", "count", len(jobs))
|
xlog.Info("Loaded persisted fine-tune jobs", "count", len(s.jobs))
|
||||||
}
|
}
|
||||||
return jobs, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// StartJob starts a new fine-tuning job.
|
// StartJob starts a new fine-tuning job.
|
||||||
@@ -260,14 +236,28 @@ func (s *FineTuneService) StartJob(ctx context.Context, userID string, req schem
|
|||||||
CreatedAt: time.Now().UTC().Format(time.RFC3339),
|
CreatedAt: time.Now().UTC().Format(time.RFC3339),
|
||||||
Config: &req,
|
Config: &req,
|
||||||
}
|
}
|
||||||
// Set write-through persists to PostgreSQL (distributed) and broadcasts to
|
s.jobs[jobID] = job
|
||||||
// peer replicas; the disk state.json is written separately for restart
|
|
||||||
// recovery / standalone hydrate.
|
|
||||||
if err := s.jobs.Set(ctx, job); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to persist job: %w", err)
|
|
||||||
}
|
|
||||||
s.saveJobState(job)
|
s.saveJobState(job)
|
||||||
|
|
||||||
|
// Persist to PostgreSQL in distributed mode
|
||||||
|
if s.fineTuneStore != nil {
|
||||||
|
configJSON, _ := json.Marshal(req)
|
||||||
|
extraJSON, _ := json.Marshal(req.ExtraOptions)
|
||||||
|
s.fineTuneStore.Create(&distributed.FineTuneJobRecord{
|
||||||
|
ID: jobID,
|
||||||
|
UserID: userID,
|
||||||
|
Model: req.Model,
|
||||||
|
Backend: backendName,
|
||||||
|
ModelID: modelID,
|
||||||
|
TrainingType: req.TrainingType,
|
||||||
|
TrainingMethod: req.TrainingMethod,
|
||||||
|
Status: "queued",
|
||||||
|
OutputDir: outputDir,
|
||||||
|
ConfigJSON: string(configJSON),
|
||||||
|
ExtraOptsJSON: string(extraJSON),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
return &schema.FineTuneJobResponse{
|
return &schema.FineTuneJobResponse{
|
||||||
ID: jobID,
|
ID: jobID,
|
||||||
Status: "queued",
|
Status: "queued",
|
||||||
@@ -280,7 +270,7 @@ func (s *FineTuneService) GetJob(userID, jobID string) (*schema.FineTuneJob, err
|
|||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
job, ok := s.jobs.Get(jobID)
|
job, ok := s.jobs[jobID]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("job not found: %s", jobID)
|
return nil, fmt.Errorf("job not found: %s", jobID)
|
||||||
}
|
}
|
||||||
@@ -296,7 +286,7 @@ func (s *FineTuneService) ListJobs(userID string) []*schema.FineTuneJob {
|
|||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
var result []*schema.FineTuneJob
|
var result []*schema.FineTuneJob
|
||||||
for _, job := range s.jobs.List() {
|
for _, job := range s.jobs {
|
||||||
if userID == "" || job.UserID == userID {
|
if userID == "" || job.UserID == userID {
|
||||||
result = append(result, job)
|
result = append(result, job)
|
||||||
}
|
}
|
||||||
@@ -312,7 +302,7 @@ func (s *FineTuneService) ListJobs(userID string) []*schema.FineTuneJob {
|
|||||||
// StopJob stops a running fine-tuning job.
|
// StopJob stops a running fine-tuning job.
|
||||||
func (s *FineTuneService) StopJob(ctx context.Context, userID, jobID string, saveCheckpoint bool) error {
|
func (s *FineTuneService) StopJob(ctx context.Context, userID, jobID string, saveCheckpoint bool) error {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
job, ok := s.jobs.Get(jobID)
|
job, ok := s.jobs[jobID]
|
||||||
if !ok {
|
if !ok {
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
return fmt.Errorf("job not found: %s", jobID)
|
return fmt.Errorf("job not found: %s", jobID)
|
||||||
@@ -333,10 +323,10 @@ func (s *FineTuneService) StopJob(ctx context.Context, userID, jobID string, sav
|
|||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
job.Status = "stopped"
|
job.Status = "stopped"
|
||||||
job.Message = "Training stopped by user"
|
job.Message = "Training stopped by user"
|
||||||
if err := s.jobs.Set(ctx, job); err != nil {
|
|
||||||
xlog.Warn("Failed to persist stopped job", "job_id", jobID, "error", err)
|
|
||||||
}
|
|
||||||
s.saveJobState(job)
|
s.saveJobState(job)
|
||||||
|
if s.fineTuneStore != nil {
|
||||||
|
s.fineTuneStore.UpdateStatus(jobID, "stopped", "Training stopped by user")
|
||||||
|
}
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -345,7 +335,7 @@ func (s *FineTuneService) StopJob(ctx context.Context, userID, jobID string, sav
|
|||||||
// DeleteJob removes a fine-tuning job and its associated data from disk.
|
// DeleteJob removes a fine-tuning job and its associated data from disk.
|
||||||
func (s *FineTuneService) DeleteJob(userID, jobID string) error {
|
func (s *FineTuneService) DeleteJob(userID, jobID string) error {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
job, ok := s.jobs.Get(jobID)
|
job, ok := s.jobs[jobID]
|
||||||
if !ok {
|
if !ok {
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
return fmt.Errorf("job not found: %s", jobID)
|
return fmt.Errorf("job not found: %s", jobID)
|
||||||
@@ -370,10 +360,9 @@ func (s *FineTuneService) DeleteJob(userID, jobID string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
exportModelName := job.ExportModelName
|
exportModelName := job.ExportModelName
|
||||||
// Delete write-through removes the DB row (distributed) and broadcasts the
|
delete(s.jobs, jobID)
|
||||||
// removal to peer replicas. DeleteJob has no ctx, so use Background.
|
if s.fineTuneStore != nil {
|
||||||
if err := s.jobs.Delete(context.Background(), jobID); err != nil {
|
s.fineTuneStore.Delete(jobID)
|
||||||
xlog.Warn("Failed to delete job from store", "job_id", jobID, "error", err)
|
|
||||||
}
|
}
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
@@ -409,7 +398,7 @@ func (s *FineTuneService) DeleteJob(userID, jobID string) error {
|
|||||||
// StreamProgress opens a gRPC progress stream and calls the callback for each update.
|
// StreamProgress opens a gRPC progress stream and calls the callback for each update.
|
||||||
func (s *FineTuneService) StreamProgress(ctx context.Context, userID, jobID string, callback func(event *schema.FineTuneProgressEvent)) error {
|
func (s *FineTuneService) StreamProgress(ctx context.Context, userID, jobID string, callback func(event *schema.FineTuneProgressEvent)) error {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
job, ok := s.jobs.Get(jobID)
|
job, ok := s.jobs[jobID]
|
||||||
if !ok {
|
if !ok {
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
return fmt.Errorf("job not found: %s", jobID)
|
return fmt.Errorf("job not found: %s", jobID)
|
||||||
@@ -438,7 +427,7 @@ func (s *FineTuneService) StreamProgress(ctx context.Context, userID, jobID stri
|
|||||||
}, func(update *pb.FineTuneProgressUpdate) {
|
}, func(update *pb.FineTuneProgressUpdate) {
|
||||||
// Update job status and persist
|
// Update job status and persist
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
if j, ok := s.jobs.Get(jobID); ok {
|
if j, ok := s.jobs[jobID]; ok {
|
||||||
// Don't let progress updates overwrite terminal states
|
// Don't let progress updates overwrite terminal states
|
||||||
isTerminal := j.Status == "stopped" || j.Status == "completed" || j.Status == "failed"
|
isTerminal := j.Status == "stopped" || j.Status == "completed" || j.Status == "failed"
|
||||||
if !isTerminal {
|
if !isTerminal {
|
||||||
@@ -447,10 +436,10 @@ func (s *FineTuneService) StreamProgress(ctx context.Context, userID, jobID stri
|
|||||||
if update.Message != "" {
|
if update.Message != "" {
|
||||||
j.Message = update.Message
|
j.Message = update.Message
|
||||||
}
|
}
|
||||||
if err := s.jobs.Set(ctx, j); err != nil {
|
|
||||||
xlog.Warn("Failed to persist progress update", "job_id", jobID, "error", err)
|
|
||||||
}
|
|
||||||
s.saveJobState(j)
|
s.saveJobState(j)
|
||||||
|
if s.fineTuneStore != nil {
|
||||||
|
s.fineTuneStore.UpdateStatus(jobID, j.Status, j.Message)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
@@ -485,7 +474,7 @@ func (s *FineTuneService) StreamProgress(ctx context.Context, userID, jobID stri
|
|||||||
// ListCheckpoints lists checkpoints for a job.
|
// ListCheckpoints lists checkpoints for a job.
|
||||||
func (s *FineTuneService) ListCheckpoints(ctx context.Context, userID, jobID string) ([]*pb.CheckpointInfo, error) {
|
func (s *FineTuneService) ListCheckpoints(ctx context.Context, userID, jobID string) ([]*pb.CheckpointInfo, error) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
job, ok := s.jobs.Get(jobID)
|
job, ok := s.jobs[jobID]
|
||||||
if !ok {
|
if !ok {
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
return nil, fmt.Errorf("job not found: %s", jobID)
|
return nil, fmt.Errorf("job not found: %s", jobID)
|
||||||
@@ -531,7 +520,7 @@ func sanitizeModelName(s string) string {
|
|||||||
// ExportModel starts an async model export from a checkpoint and returns the intended model name immediately.
|
// ExportModel starts an async model export from a checkpoint and returns the intended model name immediately.
|
||||||
func (s *FineTuneService) ExportModel(ctx context.Context, userID, jobID string, req schema.ExportRequest) (string, error) {
|
func (s *FineTuneService) ExportModel(ctx context.Context, userID, jobID string, req schema.ExportRequest) (string, error) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
job, ok := s.jobs.Get(jobID)
|
job, ok := s.jobs[jobID]
|
||||||
if !ok {
|
if !ok {
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
return "", fmt.Errorf("job not found: %s", jobID)
|
return "", fmt.Errorf("job not found: %s", jobID)
|
||||||
@@ -583,9 +572,6 @@ func (s *FineTuneService) ExportModel(ctx context.Context, userID, jobID string,
|
|||||||
job.ExportStatus = "exporting"
|
job.ExportStatus = "exporting"
|
||||||
job.ExportMessage = ""
|
job.ExportMessage = ""
|
||||||
job.ExportModelName = ""
|
job.ExportModelName = ""
|
||||||
if err := s.jobs.Set(ctx, job); err != nil {
|
|
||||||
xlog.Warn("Failed to persist export start", "job_id", jobID, "error", err)
|
|
||||||
}
|
|
||||||
s.saveJobState(job)
|
s.saveJobState(job)
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
@@ -676,30 +662,24 @@ func (s *FineTuneService) ExportModel(ctx context.Context, userID, jobID string,
|
|||||||
|
|
||||||
xlog.Info("Model exported and registered", "job_id", jobID, "model_name", modelName, "format", req.ExportFormat)
|
xlog.Info("Model exported and registered", "job_id", jobID, "model_name", modelName, "format", req.ExportFormat)
|
||||||
|
|
||||||
// Runs after the HTTP request returns, so use Background rather than the
|
|
||||||
// (now likely cancelled) request ctx for the write-through.
|
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
job.ExportStatus = "completed"
|
job.ExportStatus = "completed"
|
||||||
job.ExportModelName = modelName
|
job.ExportModelName = modelName
|
||||||
job.ExportMessage = ""
|
job.ExportMessage = ""
|
||||||
if err := s.jobs.Set(context.Background(), job); err != nil {
|
|
||||||
xlog.Warn("Failed to persist export completion", "job_id", jobID, "error", err)
|
|
||||||
}
|
|
||||||
s.saveJobState(job)
|
s.saveJobState(job)
|
||||||
|
if s.fineTuneStore != nil {
|
||||||
|
s.fineTuneStore.UpdateExportStatus(jobID, "completed", "", modelName)
|
||||||
|
}
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return modelName, nil
|
return modelName, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// setExportMessage updates the export message and persists the job state. Called
|
// setExportMessage updates the export message and persists the job state.
|
||||||
// from the background export goroutine, so it uses Background for write-through.
|
|
||||||
func (s *FineTuneService) setExportMessage(job *schema.FineTuneJob, msg string) {
|
func (s *FineTuneService) setExportMessage(job *schema.FineTuneJob, msg string) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
job.ExportMessage = msg
|
job.ExportMessage = msg
|
||||||
if err := s.jobs.Set(context.Background(), job); err != nil {
|
|
||||||
xlog.Warn("Failed to persist export message", "job_id", job.ID, "error", err)
|
|
||||||
}
|
|
||||||
s.saveJobState(job)
|
s.saveJobState(job)
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
}
|
}
|
||||||
@@ -707,7 +687,7 @@ func (s *FineTuneService) setExportMessage(job *schema.FineTuneJob, msg string)
|
|||||||
// GetExportedModelPath returns the path to the exported model directory and its name.
|
// GetExportedModelPath returns the path to the exported model directory and its name.
|
||||||
func (s *FineTuneService) GetExportedModelPath(userID, jobID string) (string, string, error) {
|
func (s *FineTuneService) GetExportedModelPath(userID, jobID string) (string, string, error) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
job, ok := s.jobs.Get(jobID)
|
job, ok := s.jobs[jobID]
|
||||||
if !ok {
|
if !ok {
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
return "", "", fmt.Errorf("job not found: %s", jobID)
|
return "", "", fmt.Errorf("job not found: %s", jobID)
|
||||||
@@ -743,10 +723,10 @@ func (s *FineTuneService) setExportFailed(job *schema.FineTuneJob, message strin
|
|||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
job.ExportStatus = "failed"
|
job.ExportStatus = "failed"
|
||||||
job.ExportMessage = message
|
job.ExportMessage = message
|
||||||
if err := s.jobs.Set(context.Background(), job); err != nil {
|
|
||||||
xlog.Warn("Failed to persist export failure", "job_id", job.ID, "error", err)
|
|
||||||
}
|
|
||||||
s.saveJobState(job)
|
s.saveJobState(job)
|
||||||
|
if s.fineTuneStore != nil {
|
||||||
|
s.fineTuneStore.UpdateExportStatus(job.ID, "failed", message, "")
|
||||||
|
}
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,185 +0,0 @@
|
|||||||
package finetune
|
|
||||||
|
|
||||||
// White-box tests (package finetune) so a spec can drive the service's internal
|
|
||||||
// SyncedMap the same way StartJob does (via jobs.Set) without standing up a
|
|
||||||
// training backend, then assert the cross-replica reads (GetJob/ListJobs) and
|
|
||||||
// the adapter conversions that keep REST responses byte-for-byte unchanged.
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo/v2"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core/config"
|
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
|
||||||
"github.com/mudler/LocalAI/core/services/distributed"
|
|
||||||
"github.com/mudler/LocalAI/core/services/testutil"
|
|
||||||
)
|
|
||||||
|
|
||||||
// newTestService builds a standalone FineTuneService wired to the given bus. The
|
|
||||||
// model/config loaders are nil because the read/sync paths under test never touch
|
|
||||||
// them; the data dir is a throwaway temp dir so the disk Loader finds nothing.
|
|
||||||
func newTestService(bus *testutil.FakeBus) *FineTuneService {
|
|
||||||
appConfig := &config.ApplicationConfig{
|
|
||||||
Context: context.Background(),
|
|
||||||
DataPath: GinkgoT().TempDir(),
|
|
||||||
}
|
|
||||||
return NewFineTuneService(appConfig, nil, nil, bus, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ = Describe("FineTuneService", func() {
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
Describe("cross-replica job visibility", func() {
|
|
||||||
var (
|
|
||||||
bus *testutil.FakeBus
|
|
||||||
a, b *FineTuneService
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
// One shared bus, two replicas: exactly the distributed topology where
|
|
||||||
// a round-robin request may land on a replica that did not originate
|
|
||||||
// the change.
|
|
||||||
bus = testutil.NewFakeBus()
|
|
||||||
a = newTestService(bus)
|
|
||||||
b = newTestService(bus)
|
|
||||||
})
|
|
||||||
|
|
||||||
AfterEach(func() {
|
|
||||||
Expect(a.Close()).To(Succeed())
|
|
||||||
Expect(b.Close()).To(Succeed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("makes a job created on A visible via B's GetJob and ListJobs", func() {
|
|
||||||
job := &schema.FineTuneJob{ID: "job-1", UserID: "user-1", Status: "queued", CreatedAt: "2026-06-27T10:00:00Z"}
|
|
||||||
// StartJob persists via jobs.Set; drive that directly to avoid a backend.
|
|
||||||
Expect(a.jobs.Set(ctx, job)).To(Succeed())
|
|
||||||
|
|
||||||
got, err := b.GetJob("user-1", "job-1")
|
|
||||||
Expect(err).ToNot(HaveOccurred(), "B must see a job A just created")
|
|
||||||
Expect(got.Status).To(Equal("queued"))
|
|
||||||
|
|
||||||
listed := b.ListJobs("user-1")
|
|
||||||
Expect(listed).To(HaveLen(1))
|
|
||||||
Expect(listed[0].ID).To(Equal("job-1"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("removes a job from B when it is deleted on A", func() {
|
|
||||||
job := &schema.FineTuneJob{ID: "job-2", UserID: "user-1", Status: "completed", CreatedAt: "2026-06-27T10:00:00Z"}
|
|
||||||
Expect(a.jobs.Set(ctx, job)).To(Succeed())
|
|
||||||
_, err := b.GetJob("user-1", "job-2")
|
|
||||||
Expect(err).ToNot(HaveOccurred(), "precondition: B must have the job before the delete")
|
|
||||||
|
|
||||||
Expect(a.jobs.Delete(ctx, "job-2")).To(Succeed())
|
|
||||||
|
|
||||||
_, err = b.GetJob("user-1", "job-2")
|
|
||||||
Expect(err).To(HaveOccurred(), "a delete on A must remove the job from B")
|
|
||||||
})
|
|
||||||
|
|
||||||
It("propagates a status update from A to B", func() {
|
|
||||||
job := &schema.FineTuneJob{ID: "job-3", UserID: "user-1", Status: "training", CreatedAt: "2026-06-27T10:00:00Z"}
|
|
||||||
Expect(a.jobs.Set(ctx, job)).To(Succeed())
|
|
||||||
|
|
||||||
updated := &schema.FineTuneJob{ID: "job-3", UserID: "user-1", Status: "completed", CreatedAt: "2026-06-27T10:00:00Z"}
|
|
||||||
Expect(a.jobs.Set(ctx, updated)).To(Succeed())
|
|
||||||
|
|
||||||
got, err := b.GetJob("user-1", "job-3")
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(got.Status).To(Equal("completed"))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Describe("ListJobs", func() {
|
|
||||||
var svc *FineTuneService
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
svc = newTestService(testutil.NewFakeBus())
|
|
||||||
})
|
|
||||||
AfterEach(func() { Expect(svc.Close()).To(Succeed()) })
|
|
||||||
|
|
||||||
It("filters by user and sorts newest-first", func() {
|
|
||||||
Expect(svc.jobs.Set(ctx, &schema.FineTuneJob{ID: "old", UserID: "u1", CreatedAt: "2026-06-25T10:00:00Z"})).To(Succeed())
|
|
||||||
Expect(svc.jobs.Set(ctx, &schema.FineTuneJob{ID: "new", UserID: "u1", CreatedAt: "2026-06-27T10:00:00Z"})).To(Succeed())
|
|
||||||
Expect(svc.jobs.Set(ctx, &schema.FineTuneJob{ID: "other", UserID: "u2", CreatedAt: "2026-06-26T10:00:00Z"})).To(Succeed())
|
|
||||||
|
|
||||||
jobs := svc.ListJobs("u1")
|
|
||||||
Expect(jobs).To(HaveLen(2), "only u1's jobs")
|
|
||||||
Expect(jobs[0].ID).To(Equal("new"), "newest first")
|
|
||||||
Expect(jobs[1].ID).To(Equal("old"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("returns every user's jobs when the userID filter is empty", func() {
|
|
||||||
Expect(svc.jobs.Set(ctx, &schema.FineTuneJob{ID: "a", UserID: "u1", CreatedAt: "2026-06-25T10:00:00Z"})).To(Succeed())
|
|
||||||
Expect(svc.jobs.Set(ctx, &schema.FineTuneJob{ID: "b", UserID: "u2", CreatedAt: "2026-06-26T10:00:00Z"})).To(Succeed())
|
|
||||||
|
|
||||||
Expect(svc.ListJobs("")).To(HaveLen(2))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects GetJob for a job owned by another user", func() {
|
|
||||||
Expect(svc.jobs.Set(ctx, &schema.FineTuneJob{ID: "x", UserID: "owner", CreatedAt: "2026-06-25T10:00:00Z"})).To(Succeed())
|
|
||||||
|
|
||||||
_, err := svc.GetJob("intruder", "x")
|
|
||||||
Expect(err).To(HaveOccurred(), "a different user must not read someone else's job")
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Describe("store adapter conversion", func() {
|
|
||||||
// The SyncedMap value type is *schema.FineTuneJob (the exact REST shape).
|
|
||||||
// These specs prove the DB adapter round-trips it losslessly, so hydrate
|
|
||||||
// and write-through in distributed mode keep responses unchanged.
|
|
||||||
It("round-trips a job through jobToRecord/recordToJob preserving the API shape", func() {
|
|
||||||
original := &schema.FineTuneJob{
|
|
||||||
ID: "rt-1",
|
|
||||||
UserID: "user-1",
|
|
||||||
Model: "base-model",
|
|
||||||
Backend: "trl",
|
|
||||||
ModelID: "trl-finetune-rt-1",
|
|
||||||
TrainingType: "lora",
|
|
||||||
TrainingMethod: "sft",
|
|
||||||
Status: "completed",
|
|
||||||
Message: "done",
|
|
||||||
OutputDir: "/data/fine-tune/rt-1",
|
|
||||||
ExtraOptions: map[string]string{"hf_token": "secret"},
|
|
||||||
CreatedAt: "2026-06-27T10:00:00Z",
|
|
||||||
ExportStatus: "completed",
|
|
||||||
ExportMessage: "",
|
|
||||||
ExportModelName: "base-model-ft-rt-1",
|
|
||||||
Config: &schema.FineTuneJobRequest{Model: "base-model", Backend: "trl", DatasetSource: "data.jsonl"},
|
|
||||||
}
|
|
||||||
|
|
||||||
rec := jobToRecord(original)
|
|
||||||
Expect(rec.ID).To(Equal("rt-1"))
|
|
||||||
Expect(rec.ConfigJSON).ToNot(BeEmpty(), "structured config must serialize into the JSON column")
|
|
||||||
Expect(rec.ExtraOptsJSON).ToNot(BeEmpty())
|
|
||||||
|
|
||||||
back := recordToJob(rec)
|
|
||||||
Expect(back.ID).To(Equal(original.ID))
|
|
||||||
Expect(back.UserID).To(Equal(original.UserID))
|
|
||||||
Expect(back.Model).To(Equal(original.Model))
|
|
||||||
Expect(back.Backend).To(Equal(original.Backend))
|
|
||||||
Expect(back.ModelID).To(Equal(original.ModelID))
|
|
||||||
Expect(back.TrainingType).To(Equal(original.TrainingType))
|
|
||||||
Expect(back.TrainingMethod).To(Equal(original.TrainingMethod))
|
|
||||||
Expect(back.Status).To(Equal(original.Status))
|
|
||||||
Expect(back.Message).To(Equal(original.Message))
|
|
||||||
Expect(back.OutputDir).To(Equal(original.OutputDir))
|
|
||||||
Expect(back.ExportStatus).To(Equal(original.ExportStatus))
|
|
||||||
Expect(back.ExportModelName).To(Equal(original.ExportModelName))
|
|
||||||
Expect(back.CreatedAt).To(Equal(original.CreatedAt))
|
|
||||||
Expect(back.ExtraOptions).To(Equal(original.ExtraOptions))
|
|
||||||
Expect(back.Config).ToNot(BeNil())
|
|
||||||
Expect(back.Config.DatasetSource).To(Equal("data.jsonl"))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Describe("compile-time adapter contract", func() {
|
|
||||||
It("satisfies syncstate.Store for *distributed.FineTuneStore", func() {
|
|
||||||
// Guards against drift between the adapter and the component interface;
|
|
||||||
// the var assertion in syncstore.go covers it at build time, this keeps
|
|
||||||
// the type referenced from a spec too.
|
|
||||||
var _ *distributed.FineTuneStore
|
|
||||||
Expect(&fineTuneStoreAdapter{}).ToNot(BeNil())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
@@ -1,114 +0,0 @@
|
|||||||
package finetune
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
|
||||||
"github.com/mudler/LocalAI/core/services/distributed"
|
|
||||||
"github.com/mudler/LocalAI/core/services/syncstate"
|
|
||||||
)
|
|
||||||
|
|
||||||
// fineTuneStoreAdapter bridges the distributed PostgreSQL FineTuneStore to the
|
|
||||||
// generic syncstate.Store the SyncedMap consumes. It is only wired in distributed
|
|
||||||
// mode; standalone leaves Store nil and hydrates from disk via a Loader instead.
|
|
||||||
//
|
|
||||||
// The SyncedMap value type is *schema.FineTuneJob (the exact shape the REST API
|
|
||||||
// returns) so reads need no conversion and the response JSON is provably
|
|
||||||
// unchanged. The adapter is the single place that translates between that API
|
|
||||||
// shape and the DB FineTuneJobRecord.
|
|
||||||
type fineTuneStoreAdapter struct {
|
|
||||||
store *distributed.FineTuneStore
|
|
||||||
}
|
|
||||||
|
|
||||||
// compile-time assertion that the adapter satisfies the component's Store.
|
|
||||||
var _ syncstate.Store[string, *schema.FineTuneJob] = (*fineTuneStoreAdapter)(nil)
|
|
||||||
|
|
||||||
func (a *fineTuneStoreAdapter) List(_ context.Context) ([]*schema.FineTuneJob, error) {
|
|
||||||
records, err := a.store.ListAll()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
jobs := make([]*schema.FineTuneJob, 0, len(records))
|
|
||||||
for i := range records {
|
|
||||||
jobs = append(jobs, recordToJob(&records[i]))
|
|
||||||
}
|
|
||||||
return jobs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *fineTuneStoreAdapter) Upsert(_ context.Context, job *schema.FineTuneJob) error {
|
|
||||||
return a.store.Upsert(jobToRecord(job))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *fineTuneStoreAdapter) Delete(_ context.Context, id string) error {
|
|
||||||
return a.store.Delete(id)
|
|
||||||
}
|
|
||||||
|
|
||||||
// recordToJob maps a persisted DB record back to the API shape, reconstructing
|
|
||||||
// the structured Config / ExtraOptions from their JSON columns.
|
|
||||||
func recordToJob(r *distributed.FineTuneJobRecord) *schema.FineTuneJob {
|
|
||||||
job := &schema.FineTuneJob{
|
|
||||||
ID: r.ID,
|
|
||||||
UserID: r.UserID,
|
|
||||||
Model: r.Model,
|
|
||||||
Backend: r.Backend,
|
|
||||||
ModelID: r.ModelID,
|
|
||||||
TrainingType: r.TrainingType,
|
|
||||||
TrainingMethod: r.TrainingMethod,
|
|
||||||
Status: r.Status,
|
|
||||||
Message: r.Message,
|
|
||||||
OutputDir: r.OutputDir,
|
|
||||||
ExportStatus: r.ExportStatus,
|
|
||||||
ExportMessage: r.ExportMessage,
|
|
||||||
ExportModelName: r.ExportModelName,
|
|
||||||
CreatedAt: r.CreatedAt.UTC().Format(time.RFC3339),
|
|
||||||
}
|
|
||||||
if r.ExtraOptsJSON != "" {
|
|
||||||
// Best-effort: a malformed column must not drop the whole job from the API.
|
|
||||||
_ = json.Unmarshal([]byte(r.ExtraOptsJSON), &job.ExtraOptions)
|
|
||||||
}
|
|
||||||
if r.ConfigJSON != "" {
|
|
||||||
var cfg schema.FineTuneJobRequest
|
|
||||||
if err := json.Unmarshal([]byte(r.ConfigJSON), &cfg); err == nil {
|
|
||||||
job.Config = &cfg
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return job
|
|
||||||
}
|
|
||||||
|
|
||||||
// jobToRecord maps the API shape to a DB record for write-through, serializing
|
|
||||||
// the structured Config / ExtraOptions into their JSON columns. CreatedAt is
|
|
||||||
// parsed back from the RFC3339 string the service stamps; an unparseable value
|
|
||||||
// is left zero so FineTuneStore.Upsert stamps "now".
|
|
||||||
func jobToRecord(job *schema.FineTuneJob) *distributed.FineTuneJobRecord {
|
|
||||||
rec := &distributed.FineTuneJobRecord{
|
|
||||||
ID: job.ID,
|
|
||||||
UserID: job.UserID,
|
|
||||||
Model: job.Model,
|
|
||||||
Backend: job.Backend,
|
|
||||||
ModelID: job.ModelID,
|
|
||||||
TrainingType: job.TrainingType,
|
|
||||||
TrainingMethod: job.TrainingMethod,
|
|
||||||
Status: job.Status,
|
|
||||||
Message: job.Message,
|
|
||||||
OutputDir: job.OutputDir,
|
|
||||||
ExportStatus: job.ExportStatus,
|
|
||||||
ExportMessage: job.ExportMessage,
|
|
||||||
ExportModelName: job.ExportModelName,
|
|
||||||
}
|
|
||||||
if job.Config != nil {
|
|
||||||
if data, err := json.Marshal(job.Config); err == nil {
|
|
||||||
rec.ConfigJSON = string(data)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if job.ExtraOptions != nil {
|
|
||||||
if data, err := json.Marshal(job.ExtraOptions); err == nil {
|
|
||||||
rec.ExtraOptsJSON = string(data)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if t, err := time.Parse(time.RFC3339, job.CreatedAt); err == nil {
|
|
||||||
rec.CreatedAt = t
|
|
||||||
}
|
|
||||||
return rec
|
|
||||||
}
|
|
||||||
@@ -22,14 +22,6 @@ const subscribeConfirmTimeout = 5 * time.Second
|
|||||||
type Client struct {
|
type Client struct {
|
||||||
conn *nats.Conn
|
conn *nats.Conn
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
|
|
||||||
// reconnectCbs are invoked after the underlying connection is
|
|
||||||
// re-established. nats.go transparently resubscribes existing
|
|
||||||
// subscriptions on reconnect, but it cannot know that a consumer kept
|
|
||||||
// derived in-memory state (e.g. syncstate.SyncedMap) that may have drifted
|
|
||||||
// while the link was down — these callbacks let such consumers re-hydrate.
|
|
||||||
cbMu sync.Mutex
|
|
||||||
reconnectCbs []func()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new NATS client with auto-reconnect.
|
// New creates a new NATS client with auto-reconnect.
|
||||||
@@ -39,10 +31,6 @@ func New(url string, opts ...Option) (*Client, error) {
|
|||||||
o(&cfg)
|
o(&cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allocate the client up front so the reconnect handler closure can reach
|
|
||||||
// it; conn is populated after nats.Connect succeeds below.
|
|
||||||
c := &Client{}
|
|
||||||
|
|
||||||
natsOpts := []nats.Option{
|
natsOpts := []nats.Option{
|
||||||
nats.RetryOnFailedConnect(true),
|
nats.RetryOnFailedConnect(true),
|
||||||
nats.MaxReconnects(-1),
|
nats.MaxReconnects(-1),
|
||||||
@@ -53,7 +41,6 @@ func New(url string, opts ...Option) (*Client, error) {
|
|||||||
}),
|
}),
|
||||||
nats.ReconnectHandler(func(_ *nats.Conn) {
|
nats.ReconnectHandler(func(_ *nats.Conn) {
|
||||||
xlog.Info("NATS reconnected")
|
xlog.Info("NATS reconnected")
|
||||||
c.runReconnectCallbacks()
|
|
||||||
}),
|
}),
|
||||||
nats.ClosedHandler(func(_ *nats.Conn) {
|
nats.ClosedHandler(func(_ *nats.Conn) {
|
||||||
xlog.Info("NATS connection closed")
|
xlog.Info("NATS connection closed")
|
||||||
@@ -116,33 +103,7 @@ func New(url string, opts ...Option) (*Client, error) {
|
|||||||
return nil, fmt.Errorf("connecting to NATS at %s: %w", sanitize.URL(url), err)
|
return nil, fmt.Errorf("connecting to NATS at %s: %w", sanitize.URL(url), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.conn = nc
|
return &Client{conn: nc}, nil
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnReconnect registers a callback invoked after the NATS connection is
|
|
||||||
// re-established. It is consumed via an optional interface type-assertion
|
|
||||||
// (interface{ OnReconnect(func()) }) rather than being added to MessagingClient,
|
|
||||||
// so the messaging abstraction stays minimal and standalone/test clients are not
|
|
||||||
// forced to implement reconnect semantics. A nil callback is ignored.
|
|
||||||
func (c *Client) OnReconnect(cb func()) {
|
|
||||||
if cb == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.cbMu.Lock()
|
|
||||||
c.reconnectCbs = append(c.reconnectCbs, cb)
|
|
||||||
c.cbMu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// runReconnectCallbacks invokes registered reconnect callbacks. It copies the
|
|
||||||
// slice under the lock so a callback that (re)registers cannot deadlock.
|
|
||||||
func (c *Client) runReconnectCallbacks() {
|
|
||||||
c.cbMu.Lock()
|
|
||||||
cbs := append([]func(){}, c.reconnectCbs...)
|
|
||||||
c.cbMu.Unlock()
|
|
||||||
for _, cb := range cbs {
|
|
||||||
cb()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Publish marshals data as JSON and publishes it to the given subject.
|
// Publish marshals data as JSON and publishes it to the given subject.
|
||||||
|
|||||||
@@ -380,20 +380,6 @@ func SubjectCacheInvalidateCollection(name string) string {
|
|||||||
return "cache.invalidate.collections." + sanitizeSubjectToken(name)
|
return "cache.invalidate.collections." + sanitizeSubjectToken(name)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SyncedMap State Sync (Pub/Sub — broadcast to all frontends)
|
|
||||||
//
|
|
||||||
// The reusable syncstate.SyncedMap component publishes a {op,key,value} delta on
|
|
||||||
// this subject whenever a replica mutates a piece of cross-replica in-memory
|
|
||||||
// state. Peers subscribe and apply the delta to their own map, so a round-robin
|
|
||||||
// API request that lands on a replica which did not originate the change still
|
|
||||||
// sees it. Convergence on (re)connect is done by re-hydrating from the durable
|
|
||||||
// source, so no request/reply snapshot subject is needed here.
|
|
||||||
func SubjectSyncStateDelta(name string) string {
|
|
||||||
return subjectSyncStatePrefix + sanitizeSubjectToken(name) + ".delta"
|
|
||||||
}
|
|
||||||
|
|
||||||
const subjectSyncStatePrefix = "state."
|
|
||||||
|
|
||||||
// Prefix-Cache Routing Sync (Pub/Sub - broadcast to all frontends)
|
// Prefix-Cache Routing Sync (Pub/Sub - broadcast to all frontends)
|
||||||
//
|
//
|
||||||
// Frontends share prefix-cache observations so a request routed to any replica
|
// Frontends share prefix-cache observations so a request routed to any replica
|
||||||
|
|||||||
@@ -1,13 +0,0 @@
|
|||||||
package quantization
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo/v2"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestQuantization(t *testing.T) {
|
|
||||||
RegisterFailHandler(Fail)
|
|
||||||
RunSpecs(t, "Quantization Suite")
|
|
||||||
}
|
|
||||||
@@ -17,9 +17,6 @@ import (
|
|||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
"github.com/mudler/LocalAI/core/services/distributed"
|
|
||||||
"github.com/mudler/LocalAI/core/services/messaging"
|
|
||||||
"github.com/mudler/LocalAI/core/services/syncstate"
|
|
||||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||||
"github.com/mudler/LocalAI/pkg/model"
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
"github.com/mudler/LocalAI/pkg/utils"
|
"github.com/mudler/LocalAI/pkg/utils"
|
||||||
@@ -33,63 +30,26 @@ type QuantizationService struct {
|
|||||||
modelLoader *model.ModelLoader
|
modelLoader *model.ModelLoader
|
||||||
configLoader *config.ModelConfigLoader
|
configLoader *config.ModelConfigLoader
|
||||||
|
|
||||||
// mu serializes the read-modify-write of job values. The SyncedMap guards its
|
mu sync.Mutex
|
||||||
// own map structure, but a job is a pointer mutated in place (e.g. the import
|
jobs map[string]*schema.QuantizationJob
|
||||||
// goroutine), so the service still needs a lock to keep those field updates and
|
|
||||||
// the subsequent Set atomic with respect to readers.
|
|
||||||
mu sync.Mutex
|
|
||||||
|
|
||||||
// jobs is the cross-replica job store: an in-memory map kept consistent across
|
|
||||||
// replicas via NATS, optionally read-through to PostgreSQL in distributed mode.
|
|
||||||
jobs *syncstate.SyncedMap[string, *schema.QuantizationJob]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewQuantizationService creates a new QuantizationService. In distributed mode
|
// NewQuantizationService creates a new QuantizationService.
|
||||||
// pass the shared NATS client and PostgreSQL store so jobs stay consistent across
|
|
||||||
// replicas; pass nil for both in standalone mode, where the disk Loader hydrates
|
|
||||||
// the map and there is nothing to broadcast.
|
|
||||||
func NewQuantizationService(
|
func NewQuantizationService(
|
||||||
appConfig *config.ApplicationConfig,
|
appConfig *config.ApplicationConfig,
|
||||||
modelLoader *model.ModelLoader,
|
modelLoader *model.ModelLoader,
|
||||||
configLoader *config.ModelConfigLoader,
|
configLoader *config.ModelConfigLoader,
|
||||||
nats messaging.MessagingClient,
|
|
||||||
store *distributed.QuantStore,
|
|
||||||
) *QuantizationService {
|
) *QuantizationService {
|
||||||
s := &QuantizationService{
|
s := &QuantizationService{
|
||||||
appConfig: appConfig,
|
appConfig: appConfig,
|
||||||
modelLoader: modelLoader,
|
modelLoader: modelLoader,
|
||||||
configLoader: configLoader,
|
configLoader: configLoader,
|
||||||
|
jobs: make(map[string]*schema.QuantizationJob),
|
||||||
}
|
}
|
||||||
|
s.loadAllJobs()
|
||||||
// Only attach a Store interface when a concrete store exists, otherwise the
|
|
||||||
// SyncedMap would see a non-nil interface wrapping a nil pointer and try to
|
|
||||||
// hydrate/write through a nil DB.
|
|
||||||
var syncStore syncstate.Store[string, *schema.QuantizationJob]
|
|
||||||
if store != nil {
|
|
||||||
syncStore = &quantStoreAdapter{store: store}
|
|
||||||
}
|
|
||||||
|
|
||||||
s.jobs = syncstate.New(syncstate.Config[string, *schema.QuantizationJob]{
|
|
||||||
Name: "quant.jobs",
|
|
||||||
Key: func(j *schema.QuantizationJob) string { return j.ID },
|
|
||||||
Nats: nats,
|
|
||||||
Store: syncStore,
|
|
||||||
Loader: s.loadJobsFromDisk, // ignored when Store is set (distributed mode)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Hydrate + subscribe. A hydrate failure must not take the server down: log and
|
|
||||||
// continue degraded (standalone), mirroring the FineTune/OpCache wiring.
|
|
||||||
if err := s.jobs.Start(appConfig.Context); err != nil {
|
|
||||||
xlog.Warn("Quantization SyncedMap start failed; running degraded", "error", err)
|
|
||||||
}
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close releases the SyncedMap subscription and background workers.
|
|
||||||
func (s *QuantizationService) Close() error {
|
|
||||||
return s.jobs.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// quantizationBaseDir returns the base directory for quantization job data.
|
// quantizationBaseDir returns the base directory for quantization job data.
|
||||||
func (s *QuantizationService) quantizationBaseDir() string {
|
func (s *QuantizationService) quantizationBaseDir() string {
|
||||||
return filepath.Join(s.appConfig.DataPath, "quantization")
|
return filepath.Join(s.appConfig.DataPath, "quantization")
|
||||||
@@ -120,18 +80,15 @@ func (s *QuantizationService) saveJobState(job *schema.QuantizationJob) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// loadJobsFromDisk scans the quantization directory for persisted jobs and
|
// loadAllJobs scans the quantization directory for persisted jobs and loads them.
|
||||||
// returns them. It is the SyncedMap Loader used in standalone mode (no DB); the
|
func (s *QuantizationService) loadAllJobs() {
|
||||||
// returned slice hydrates the map on Start.
|
|
||||||
func (s *QuantizationService) loadJobsFromDisk(_ context.Context) ([]*schema.QuantizationJob, error) {
|
|
||||||
baseDir := s.quantizationBaseDir()
|
baseDir := s.quantizationBaseDir()
|
||||||
entries, err := os.ReadDir(baseDir)
|
entries, err := os.ReadDir(baseDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Directory doesn't exist yet — that's fine, start empty.
|
// Directory doesn't exist yet — that's fine
|
||||||
return nil, nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var jobs []*schema.QuantizationJob
|
|
||||||
for _, entry := range entries {
|
for _, entry := range entries {
|
||||||
if !entry.IsDir() {
|
if !entry.IsDir() {
|
||||||
continue
|
continue
|
||||||
@@ -160,13 +117,12 @@ func (s *QuantizationService) loadJobsFromDisk(_ context.Context) ([]*schema.Qua
|
|||||||
job.ImportMessage = "Server restarted while import was running"
|
job.ImportMessage = "Server restarted while import was running"
|
||||||
}
|
}
|
||||||
|
|
||||||
jobs = append(jobs, &job)
|
s.jobs[job.ID] = &job
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(jobs) > 0 {
|
if len(s.jobs) > 0 {
|
||||||
xlog.Info("Loaded persisted quantization jobs", "count", len(jobs))
|
xlog.Info("Loaded persisted quantization jobs", "count", len(s.jobs))
|
||||||
}
|
}
|
||||||
return jobs, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// StartJob starts a new quantization job.
|
// StartJob starts a new quantization job.
|
||||||
@@ -232,12 +188,7 @@ func (s *QuantizationService) StartJob(ctx context.Context, userID string, req s
|
|||||||
CreatedAt: time.Now().UTC().Format(time.RFC3339),
|
CreatedAt: time.Now().UTC().Format(time.RFC3339),
|
||||||
Config: &req,
|
Config: &req,
|
||||||
}
|
}
|
||||||
// Set write-through persists to PostgreSQL (distributed) and broadcasts to
|
s.jobs[jobID] = job
|
||||||
// peer replicas; the disk state.json is written separately for restart
|
|
||||||
// recovery / standalone hydrate.
|
|
||||||
if err := s.jobs.Set(ctx, job); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to persist job: %w", err)
|
|
||||||
}
|
|
||||||
s.saveJobState(job)
|
s.saveJobState(job)
|
||||||
|
|
||||||
return &schema.QuantizationJobResponse{
|
return &schema.QuantizationJobResponse{
|
||||||
@@ -252,7 +203,7 @@ func (s *QuantizationService) GetJob(userID, jobID string) (*schema.Quantization
|
|||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
job, ok := s.jobs.Get(jobID)
|
job, ok := s.jobs[jobID]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("job not found: %s", jobID)
|
return nil, fmt.Errorf("job not found: %s", jobID)
|
||||||
}
|
}
|
||||||
@@ -268,7 +219,7 @@ func (s *QuantizationService) ListJobs(userID string) []*schema.QuantizationJob
|
|||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
var result []*schema.QuantizationJob
|
var result []*schema.QuantizationJob
|
||||||
for _, job := range s.jobs.List() {
|
for _, job := range s.jobs {
|
||||||
if userID == "" || job.UserID == userID {
|
if userID == "" || job.UserID == userID {
|
||||||
result = append(result, job)
|
result = append(result, job)
|
||||||
}
|
}
|
||||||
@@ -284,7 +235,7 @@ func (s *QuantizationService) ListJobs(userID string) []*schema.QuantizationJob
|
|||||||
// StopJob stops a running quantization job.
|
// StopJob stops a running quantization job.
|
||||||
func (s *QuantizationService) StopJob(ctx context.Context, userID, jobID string) error {
|
func (s *QuantizationService) StopJob(ctx context.Context, userID, jobID string) error {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
job, ok := s.jobs.Get(jobID)
|
job, ok := s.jobs[jobID]
|
||||||
if !ok {
|
if !ok {
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
return fmt.Errorf("job not found: %s", jobID)
|
return fmt.Errorf("job not found: %s", jobID)
|
||||||
@@ -305,9 +256,6 @@ func (s *QuantizationService) StopJob(ctx context.Context, userID, jobID string)
|
|||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
job.Status = "stopped"
|
job.Status = "stopped"
|
||||||
job.Message = "Quantization stopped by user"
|
job.Message = "Quantization stopped by user"
|
||||||
if err := s.jobs.Set(ctx, job); err != nil {
|
|
||||||
xlog.Warn("Failed to persist stopped job", "job_id", jobID, "error", err)
|
|
||||||
}
|
|
||||||
s.saveJobState(job)
|
s.saveJobState(job)
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
@@ -317,7 +265,7 @@ func (s *QuantizationService) StopJob(ctx context.Context, userID, jobID string)
|
|||||||
// DeleteJob removes a quantization job and its associated data from disk.
|
// DeleteJob removes a quantization job and its associated data from disk.
|
||||||
func (s *QuantizationService) DeleteJob(userID, jobID string) error {
|
func (s *QuantizationService) DeleteJob(userID, jobID string) error {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
job, ok := s.jobs.Get(jobID)
|
job, ok := s.jobs[jobID]
|
||||||
if !ok {
|
if !ok {
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
return fmt.Errorf("job not found: %s", jobID)
|
return fmt.Errorf("job not found: %s", jobID)
|
||||||
@@ -341,11 +289,7 @@ func (s *QuantizationService) DeleteJob(userID, jobID string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
importModelName := job.ImportModelName
|
importModelName := job.ImportModelName
|
||||||
// Delete write-through removes the DB row (distributed) and broadcasts the
|
delete(s.jobs, jobID)
|
||||||
// removal to peer replicas. DeleteJob has no ctx, so use Background.
|
|
||||||
if err := s.jobs.Delete(context.Background(), jobID); err != nil {
|
|
||||||
xlog.Warn("Failed to delete job from store", "job_id", jobID, "error", err)
|
|
||||||
}
|
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
// Remove job directory (state.json, output files)
|
// Remove job directory (state.json, output files)
|
||||||
@@ -380,7 +324,7 @@ func (s *QuantizationService) DeleteJob(userID, jobID string) error {
|
|||||||
// StreamProgress opens a gRPC progress stream and calls the callback for each update.
|
// StreamProgress opens a gRPC progress stream and calls the callback for each update.
|
||||||
func (s *QuantizationService) StreamProgress(ctx context.Context, userID, jobID string, callback func(event *schema.QuantizationProgressEvent)) error {
|
func (s *QuantizationService) StreamProgress(ctx context.Context, userID, jobID string, callback func(event *schema.QuantizationProgressEvent)) error {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
job, ok := s.jobs.Get(jobID)
|
job, ok := s.jobs[jobID]
|
||||||
if !ok {
|
if !ok {
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
return fmt.Errorf("job not found: %s", jobID)
|
return fmt.Errorf("job not found: %s", jobID)
|
||||||
@@ -409,7 +353,7 @@ func (s *QuantizationService) StreamProgress(ctx context.Context, userID, jobID
|
|||||||
}, func(update *pb.QuantizationProgressUpdate) {
|
}, func(update *pb.QuantizationProgressUpdate) {
|
||||||
// Update job status and persist
|
// Update job status and persist
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
if j, ok := s.jobs.Get(jobID); ok {
|
if j, ok := s.jobs[jobID]; ok {
|
||||||
// Don't let progress updates overwrite terminal states
|
// Don't let progress updates overwrite terminal states
|
||||||
isTerminal := j.Status == "stopped" || j.Status == "completed" || j.Status == "failed"
|
isTerminal := j.Status == "stopped" || j.Status == "completed" || j.Status == "failed"
|
||||||
if !isTerminal {
|
if !isTerminal {
|
||||||
@@ -421,9 +365,6 @@ func (s *QuantizationService) StreamProgress(ctx context.Context, userID, jobID
|
|||||||
if update.OutputFile != "" {
|
if update.OutputFile != "" {
|
||||||
j.OutputFile = update.OutputFile
|
j.OutputFile = update.OutputFile
|
||||||
}
|
}
|
||||||
if err := s.jobs.Set(ctx, j); err != nil {
|
|
||||||
xlog.Warn("Failed to persist progress update", "job_id", jobID, "error", err)
|
|
||||||
}
|
|
||||||
s.saveJobState(j)
|
s.saveJobState(j)
|
||||||
}
|
}
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
@@ -458,7 +399,7 @@ func sanitizeQuantModelName(s string) string {
|
|||||||
// ImportModel imports a quantized model into LocalAI asynchronously.
|
// ImportModel imports a quantized model into LocalAI asynchronously.
|
||||||
func (s *QuantizationService) ImportModel(ctx context.Context, userID, jobID string, req schema.QuantizationImportRequest) (string, error) {
|
func (s *QuantizationService) ImportModel(ctx context.Context, userID, jobID string, req schema.QuantizationImportRequest) (string, error) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
job, ok := s.jobs.Get(jobID)
|
job, ok := s.jobs[jobID]
|
||||||
if !ok {
|
if !ok {
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
return "", fmt.Errorf("job not found: %s", jobID)
|
return "", fmt.Errorf("job not found: %s", jobID)
|
||||||
@@ -518,9 +459,6 @@ func (s *QuantizationService) ImportModel(ctx context.Context, userID, jobID str
|
|||||||
job.ImportStatus = "importing"
|
job.ImportStatus = "importing"
|
||||||
job.ImportMessage = ""
|
job.ImportMessage = ""
|
||||||
job.ImportModelName = ""
|
job.ImportModelName = ""
|
||||||
if err := s.jobs.Set(ctx, job); err != nil {
|
|
||||||
xlog.Warn("Failed to persist import start", "job_id", jobID, "error", err)
|
|
||||||
}
|
|
||||||
s.saveJobState(job)
|
s.saveJobState(job)
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
@@ -576,15 +514,10 @@ func (s *QuantizationService) ImportModel(ctx context.Context, userID, jobID str
|
|||||||
|
|
||||||
xlog.Info("Quantized model imported and registered", "job_id", jobID, "model_name", modelName)
|
xlog.Info("Quantized model imported and registered", "job_id", jobID, "model_name", modelName)
|
||||||
|
|
||||||
// Runs after the HTTP request returns, so use Background rather than the
|
|
||||||
// (now likely cancelled) request ctx for the write-through.
|
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
job.ImportStatus = "completed"
|
job.ImportStatus = "completed"
|
||||||
job.ImportModelName = modelName
|
job.ImportModelName = modelName
|
||||||
job.ImportMessage = ""
|
job.ImportMessage = ""
|
||||||
if err := s.jobs.Set(context.Background(), job); err != nil {
|
|
||||||
xlog.Warn("Failed to persist import completion", "job_id", jobID, "error", err)
|
|
||||||
}
|
|
||||||
s.saveJobState(job)
|
s.saveJobState(job)
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
}()
|
}()
|
||||||
@@ -592,14 +525,10 @@ func (s *QuantizationService) ImportModel(ctx context.Context, userID, jobID str
|
|||||||
return modelName, nil
|
return modelName, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// setImportMessage updates the import message and persists the job state. Called
|
// setImportMessage updates the import message and persists the job state.
|
||||||
// from the background import goroutine, so it uses Background for write-through.
|
|
||||||
func (s *QuantizationService) setImportMessage(job *schema.QuantizationJob, msg string) {
|
func (s *QuantizationService) setImportMessage(job *schema.QuantizationJob, msg string) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
job.ImportMessage = msg
|
job.ImportMessage = msg
|
||||||
if err := s.jobs.Set(context.Background(), job); err != nil {
|
|
||||||
xlog.Warn("Failed to persist import message", "job_id", job.ID, "error", err)
|
|
||||||
}
|
|
||||||
s.saveJobState(job)
|
s.saveJobState(job)
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
}
|
}
|
||||||
@@ -610,9 +539,6 @@ func (s *QuantizationService) setImportFailed(job *schema.QuantizationJob, messa
|
|||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
job.ImportStatus = "failed"
|
job.ImportStatus = "failed"
|
||||||
job.ImportMessage = message
|
job.ImportMessage = message
|
||||||
if err := s.jobs.Set(context.Background(), job); err != nil {
|
|
||||||
xlog.Warn("Failed to persist import failure", "job_id", job.ID, "error", err)
|
|
||||||
}
|
|
||||||
s.saveJobState(job)
|
s.saveJobState(job)
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
}
|
}
|
||||||
@@ -620,7 +546,7 @@ func (s *QuantizationService) setImportFailed(job *schema.QuantizationJob, messa
|
|||||||
// GetOutputPath returns the path to the quantized model file and a download name.
|
// GetOutputPath returns the path to the quantized model file and a download name.
|
||||||
func (s *QuantizationService) GetOutputPath(userID, jobID string) (string, string, error) {
|
func (s *QuantizationService) GetOutputPath(userID, jobID string) (string, string, error) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
job, ok := s.jobs.Get(jobID)
|
job, ok := s.jobs[jobID]
|
||||||
if !ok {
|
if !ok {
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
return "", "", fmt.Errorf("job not found: %s", jobID)
|
return "", "", fmt.Errorf("job not found: %s", jobID)
|
||||||
|
|||||||
@@ -1,187 +0,0 @@
|
|||||||
package quantization
|
|
||||||
|
|
||||||
// White-box tests (package quantization) so a spec can drive the service's
|
|
||||||
// internal SyncedMap the same way StartJob does (via jobs.Set) without standing
|
|
||||||
// up a quantization backend, then assert the cross-replica reads
|
|
||||||
// (GetJob/ListJobs) and the adapter conversions that keep REST responses
|
|
||||||
// byte-for-byte unchanged.
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo/v2"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core/config"
|
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
|
||||||
"github.com/mudler/LocalAI/core/services/distributed"
|
|
||||||
"github.com/mudler/LocalAI/core/services/testutil"
|
|
||||||
)
|
|
||||||
|
|
||||||
// newTestService builds a standalone QuantizationService wired to the given bus.
|
|
||||||
// The model/config loaders are nil because the read/sync paths under test never
|
|
||||||
// touch them; the data dir is a throwaway temp dir so the disk Loader finds
|
|
||||||
// nothing.
|
|
||||||
func newTestService(bus *testutil.FakeBus) *QuantizationService {
|
|
||||||
appConfig := &config.ApplicationConfig{
|
|
||||||
Context: context.Background(),
|
|
||||||
DataPath: GinkgoT().TempDir(),
|
|
||||||
}
|
|
||||||
return NewQuantizationService(appConfig, nil, nil, bus, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ = Describe("QuantizationService", func() {
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
Describe("cross-replica job visibility", func() {
|
|
||||||
var (
|
|
||||||
bus *testutil.FakeBus
|
|
||||||
a, b *QuantizationService
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
// One shared bus, two replicas: exactly the distributed topology where a
|
|
||||||
// round-robin request may land on a replica that did not originate the
|
|
||||||
// change.
|
|
||||||
bus = testutil.NewFakeBus()
|
|
||||||
a = newTestService(bus)
|
|
||||||
b = newTestService(bus)
|
|
||||||
})
|
|
||||||
|
|
||||||
AfterEach(func() {
|
|
||||||
Expect(a.Close()).To(Succeed())
|
|
||||||
Expect(b.Close()).To(Succeed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("makes a job created on A visible via B's GetJob and ListJobs", func() {
|
|
||||||
job := &schema.QuantizationJob{ID: "job-1", UserID: "user-1", Status: "queued", CreatedAt: "2026-06-27T10:00:00Z"}
|
|
||||||
// StartJob persists via jobs.Set; drive that directly to avoid a backend.
|
|
||||||
Expect(a.jobs.Set(ctx, job)).To(Succeed())
|
|
||||||
|
|
||||||
got, err := b.GetJob("user-1", "job-1")
|
|
||||||
Expect(err).ToNot(HaveOccurred(), "B must see a job A just created")
|
|
||||||
Expect(got.Status).To(Equal("queued"))
|
|
||||||
|
|
||||||
listed := b.ListJobs("user-1")
|
|
||||||
Expect(listed).To(HaveLen(1))
|
|
||||||
Expect(listed[0].ID).To(Equal("job-1"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("removes a job from B when it is deleted on A", func() {
|
|
||||||
job := &schema.QuantizationJob{ID: "job-2", UserID: "user-1", Status: "completed", CreatedAt: "2026-06-27T10:00:00Z"}
|
|
||||||
Expect(a.jobs.Set(ctx, job)).To(Succeed())
|
|
||||||
_, err := b.GetJob("user-1", "job-2")
|
|
||||||
Expect(err).ToNot(HaveOccurred(), "precondition: B must have the job before the delete")
|
|
||||||
|
|
||||||
Expect(a.jobs.Delete(ctx, "job-2")).To(Succeed())
|
|
||||||
|
|
||||||
_, err = b.GetJob("user-1", "job-2")
|
|
||||||
Expect(err).To(HaveOccurred(), "a delete on A must remove the job from B")
|
|
||||||
})
|
|
||||||
|
|
||||||
It("propagates a status update from A to B", func() {
|
|
||||||
job := &schema.QuantizationJob{ID: "job-3", UserID: "user-1", Status: "quantizing", CreatedAt: "2026-06-27T10:00:00Z"}
|
|
||||||
Expect(a.jobs.Set(ctx, job)).To(Succeed())
|
|
||||||
|
|
||||||
updated := &schema.QuantizationJob{ID: "job-3", UserID: "user-1", Status: "completed", CreatedAt: "2026-06-27T10:00:00Z"}
|
|
||||||
Expect(a.jobs.Set(ctx, updated)).To(Succeed())
|
|
||||||
|
|
||||||
got, err := b.GetJob("user-1", "job-3")
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(got.Status).To(Equal("completed"))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Describe("ListJobs", func() {
|
|
||||||
var svc *QuantizationService
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
svc = newTestService(testutil.NewFakeBus())
|
|
||||||
})
|
|
||||||
AfterEach(func() { Expect(svc.Close()).To(Succeed()) })
|
|
||||||
|
|
||||||
It("filters by user and sorts newest-first", func() {
|
|
||||||
Expect(svc.jobs.Set(ctx, &schema.QuantizationJob{ID: "old", UserID: "u1", CreatedAt: "2026-06-25T10:00:00Z"})).To(Succeed())
|
|
||||||
Expect(svc.jobs.Set(ctx, &schema.QuantizationJob{ID: "new", UserID: "u1", CreatedAt: "2026-06-27T10:00:00Z"})).To(Succeed())
|
|
||||||
Expect(svc.jobs.Set(ctx, &schema.QuantizationJob{ID: "other", UserID: "u2", CreatedAt: "2026-06-26T10:00:00Z"})).To(Succeed())
|
|
||||||
|
|
||||||
jobs := svc.ListJobs("u1")
|
|
||||||
Expect(jobs).To(HaveLen(2), "only u1's jobs")
|
|
||||||
Expect(jobs[0].ID).To(Equal("new"), "newest first")
|
|
||||||
Expect(jobs[1].ID).To(Equal("old"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("returns every user's jobs when the userID filter is empty", func() {
|
|
||||||
Expect(svc.jobs.Set(ctx, &schema.QuantizationJob{ID: "a", UserID: "u1", CreatedAt: "2026-06-25T10:00:00Z"})).To(Succeed())
|
|
||||||
Expect(svc.jobs.Set(ctx, &schema.QuantizationJob{ID: "b", UserID: "u2", CreatedAt: "2026-06-26T10:00:00Z"})).To(Succeed())
|
|
||||||
|
|
||||||
Expect(svc.ListJobs("")).To(HaveLen(2))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects GetJob for a job owned by another user", func() {
|
|
||||||
Expect(svc.jobs.Set(ctx, &schema.QuantizationJob{ID: "x", UserID: "owner", CreatedAt: "2026-06-25T10:00:00Z"})).To(Succeed())
|
|
||||||
|
|
||||||
_, err := svc.GetJob("intruder", "x")
|
|
||||||
Expect(err).To(HaveOccurred(), "a different user must not read someone else's job")
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Describe("store adapter conversion", func() {
|
|
||||||
// The SyncedMap value type is *schema.QuantizationJob (the exact REST shape).
|
|
||||||
// These specs prove the DB adapter round-trips it losslessly, so hydrate and
|
|
||||||
// write-through in distributed mode keep responses unchanged.
|
|
||||||
It("round-trips a job through jobToRecord/recordToJob preserving the API shape", func() {
|
|
||||||
original := &schema.QuantizationJob{
|
|
||||||
ID: "rt-1",
|
|
||||||
UserID: "user-1",
|
|
||||||
Model: "base-model",
|
|
||||||
Backend: "llama-cpp-quantization",
|
|
||||||
ModelID: "llama-cpp-quantization-quantize-rt-1",
|
|
||||||
QuantizationType: "q4_k_m",
|
|
||||||
Status: "completed",
|
|
||||||
Message: "done",
|
|
||||||
OutputDir: "/data/quantization/rt-1",
|
|
||||||
OutputFile: "/data/quantization/rt-1/model.gguf",
|
|
||||||
ExtraOptions: map[string]string{"hf_token": "secret"},
|
|
||||||
CreatedAt: "2026-06-27T10:00:00Z",
|
|
||||||
ImportStatus: "completed",
|
|
||||||
ImportMessage: "",
|
|
||||||
ImportModelName: "base-model-q4_k_m-rt-1",
|
|
||||||
Config: &schema.QuantizationJobRequest{Model: "base-model", Backend: "llama-cpp-quantization", QuantizationType: "q4_k_m"},
|
|
||||||
}
|
|
||||||
|
|
||||||
rec := jobToRecord(original)
|
|
||||||
Expect(rec.ID).To(Equal("rt-1"))
|
|
||||||
Expect(rec.ConfigJSON).ToNot(BeEmpty(), "structured config must serialize into the JSON column")
|
|
||||||
Expect(rec.ExtraOptsJSON).ToNot(BeEmpty())
|
|
||||||
|
|
||||||
back := recordToJob(rec)
|
|
||||||
Expect(back.ID).To(Equal(original.ID))
|
|
||||||
Expect(back.UserID).To(Equal(original.UserID))
|
|
||||||
Expect(back.Model).To(Equal(original.Model))
|
|
||||||
Expect(back.Backend).To(Equal(original.Backend))
|
|
||||||
Expect(back.ModelID).To(Equal(original.ModelID))
|
|
||||||
Expect(back.QuantizationType).To(Equal(original.QuantizationType))
|
|
||||||
Expect(back.Status).To(Equal(original.Status))
|
|
||||||
Expect(back.Message).To(Equal(original.Message))
|
|
||||||
Expect(back.OutputDir).To(Equal(original.OutputDir))
|
|
||||||
Expect(back.OutputFile).To(Equal(original.OutputFile))
|
|
||||||
Expect(back.ImportStatus).To(Equal(original.ImportStatus))
|
|
||||||
Expect(back.ImportModelName).To(Equal(original.ImportModelName))
|
|
||||||
Expect(back.CreatedAt).To(Equal(original.CreatedAt))
|
|
||||||
Expect(back.ExtraOptions).To(Equal(original.ExtraOptions))
|
|
||||||
Expect(back.Config).ToNot(BeNil())
|
|
||||||
Expect(back.Config.QuantizationType).To(Equal("q4_k_m"))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Describe("compile-time adapter contract", func() {
|
|
||||||
It("satisfies syncstate.Store for *distributed.QuantStore", func() {
|
|
||||||
// Guards against drift between the adapter and the component interface;
|
|
||||||
// the var assertion in syncstore.go covers it at build time, this keeps
|
|
||||||
// the type referenced from a spec too.
|
|
||||||
var _ *distributed.QuantStore
|
|
||||||
Expect(&quantStoreAdapter{}).ToNot(BeNil())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
@@ -1,114 +0,0 @@
|
|||||||
package quantization
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
|
||||||
"github.com/mudler/LocalAI/core/services/distributed"
|
|
||||||
"github.com/mudler/LocalAI/core/services/syncstate"
|
|
||||||
)
|
|
||||||
|
|
||||||
// quantStoreAdapter bridges the distributed PostgreSQL QuantStore to the generic
|
|
||||||
// syncstate.Store the SyncedMap consumes. It is only wired in distributed mode;
|
|
||||||
// standalone leaves Store nil and hydrates from disk via a Loader instead.
|
|
||||||
//
|
|
||||||
// The SyncedMap value type is *schema.QuantizationJob (the exact shape the REST
|
|
||||||
// API returns) so reads need no conversion and the response JSON is provably
|
|
||||||
// unchanged. The adapter is the single place that translates between that API
|
|
||||||
// shape and the DB QuantJobRecord.
|
|
||||||
type quantStoreAdapter struct {
|
|
||||||
store *distributed.QuantStore
|
|
||||||
}
|
|
||||||
|
|
||||||
// compile-time assertion that the adapter satisfies the component's Store.
|
|
||||||
var _ syncstate.Store[string, *schema.QuantizationJob] = (*quantStoreAdapter)(nil)
|
|
||||||
|
|
||||||
func (a *quantStoreAdapter) List(_ context.Context) ([]*schema.QuantizationJob, error) {
|
|
||||||
records, err := a.store.ListAll()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
jobs := make([]*schema.QuantizationJob, 0, len(records))
|
|
||||||
for i := range records {
|
|
||||||
jobs = append(jobs, recordToJob(&records[i]))
|
|
||||||
}
|
|
||||||
return jobs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *quantStoreAdapter) Upsert(_ context.Context, job *schema.QuantizationJob) error {
|
|
||||||
return a.store.Upsert(jobToRecord(job))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *quantStoreAdapter) Delete(_ context.Context, id string) error {
|
|
||||||
return a.store.Delete(id)
|
|
||||||
}
|
|
||||||
|
|
||||||
// recordToJob maps a persisted DB record back to the API shape, reconstructing
|
|
||||||
// the structured Config / ExtraOptions from their JSON columns.
|
|
||||||
func recordToJob(r *distributed.QuantJobRecord) *schema.QuantizationJob {
|
|
||||||
job := &schema.QuantizationJob{
|
|
||||||
ID: r.ID,
|
|
||||||
UserID: r.UserID,
|
|
||||||
Model: r.Model,
|
|
||||||
Backend: r.Backend,
|
|
||||||
ModelID: r.ModelID,
|
|
||||||
QuantizationType: r.QuantizationType,
|
|
||||||
Status: r.Status,
|
|
||||||
Message: r.Message,
|
|
||||||
OutputDir: r.OutputDir,
|
|
||||||
OutputFile: r.OutputFile,
|
|
||||||
ImportStatus: r.ImportStatus,
|
|
||||||
ImportMessage: r.ImportMessage,
|
|
||||||
ImportModelName: r.ImportModelName,
|
|
||||||
CreatedAt: r.CreatedAt.UTC().Format(time.RFC3339),
|
|
||||||
}
|
|
||||||
if r.ExtraOptsJSON != "" {
|
|
||||||
// Best-effort: a malformed column must not drop the whole job from the API.
|
|
||||||
_ = json.Unmarshal([]byte(r.ExtraOptsJSON), &job.ExtraOptions)
|
|
||||||
}
|
|
||||||
if r.ConfigJSON != "" {
|
|
||||||
var cfg schema.QuantizationJobRequest
|
|
||||||
if err := json.Unmarshal([]byte(r.ConfigJSON), &cfg); err == nil {
|
|
||||||
job.Config = &cfg
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return job
|
|
||||||
}
|
|
||||||
|
|
||||||
// jobToRecord maps the API shape to a DB record for write-through, serializing
|
|
||||||
// the structured Config / ExtraOptions into their JSON columns. CreatedAt is
|
|
||||||
// parsed back from the RFC3339 string the service stamps; an unparseable value is
|
|
||||||
// left zero so QuantStore.Upsert stamps "now".
|
|
||||||
func jobToRecord(job *schema.QuantizationJob) *distributed.QuantJobRecord {
|
|
||||||
rec := &distributed.QuantJobRecord{
|
|
||||||
ID: job.ID,
|
|
||||||
UserID: job.UserID,
|
|
||||||
Model: job.Model,
|
|
||||||
Backend: job.Backend,
|
|
||||||
ModelID: job.ModelID,
|
|
||||||
QuantizationType: job.QuantizationType,
|
|
||||||
Status: job.Status,
|
|
||||||
Message: job.Message,
|
|
||||||
OutputDir: job.OutputDir,
|
|
||||||
OutputFile: job.OutputFile,
|
|
||||||
ImportStatus: job.ImportStatus,
|
|
||||||
ImportMessage: job.ImportMessage,
|
|
||||||
ImportModelName: job.ImportModelName,
|
|
||||||
}
|
|
||||||
if job.Config != nil {
|
|
||||||
if data, err := json.Marshal(job.Config); err == nil {
|
|
||||||
rec.ConfigJSON = string(data)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if job.ExtraOptions != nil {
|
|
||||||
if data, err := json.Marshal(job.ExtraOptions); err == nil {
|
|
||||||
rec.ExtraOptsJSON = string(data)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if t, err := time.Parse(time.RFC3339, job.CreatedAt); err == nil {
|
|
||||||
rec.CreatedAt = t
|
|
||||||
}
|
|
||||||
return rec
|
|
||||||
}
|
|
||||||
@@ -1,289 +0,0 @@
|
|||||||
// Package syncstate provides SyncedMap, a reusable cross-replica in-memory map.
|
|
||||||
//
|
|
||||||
// LocalAI in distributed mode runs multiple frontend replicas behind a
|
|
||||||
// round-robin load balancer. Several features keep process-local in-memory state
|
|
||||||
// that is surfaced to the HTTP/UI API; without cross-replica sync a poll that
|
|
||||||
// lands on a replica which did not originate a change sees stale or missing data.
|
|
||||||
// SyncedMap collapses the three legs each feature otherwise hand-wires - an
|
|
||||||
// in-memory map, a NATS broadcast/apply path, and optional durable read-through -
|
|
||||||
// into one well-tested component so cross-replica consistency is a configuration
|
|
||||||
// choice rather than a bespoke re-implementation.
|
|
||||||
package syncstate
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core/services/messaging"
|
|
||||||
"github.com/mudler/xlog"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Op values carried on the wire and passed to OnApply.
|
|
||||||
const (
|
|
||||||
opSet = "set"
|
|
||||||
opDelete = "delete"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Store is optional durable backing for a SyncedMap. In distributed mode it is a
|
|
||||||
// single shared DB, so the apply path (a delta received from a peer) updates
|
|
||||||
// memory only and never re-writes the Store.
|
|
||||||
type Store[K comparable, V any] interface {
|
|
||||||
List(ctx context.Context) ([]V, error)
|
|
||||||
Upsert(ctx context.Context, v V) error
|
|
||||||
Delete(ctx context.Context, k K) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// Config configures a SyncedMap.
|
|
||||||
type Config[K comparable, V any] struct {
|
|
||||||
Name string // subject namespace, e.g. "finetune.jobs"
|
|
||||||
Key func(V) K // extract the key from a value
|
|
||||||
Nats messaging.MessagingClient // nil => standalone: in-memory only, no broadcast/subscribe
|
|
||||||
Store Store[K, V] // optional read-through persistence
|
|
||||||
Loader func(ctx context.Context) ([]V, error) // source when there is no Store (e.g. disk reload)
|
|
||||||
OnApply func(op string, k K, v V) // optional hook after an applied change (e.g. ShutdownModel)
|
|
||||||
Reconcile time.Duration // optional periodic re-hydrate; 0 = off
|
|
||||||
}
|
|
||||||
|
|
||||||
// delta is the JSON wire envelope broadcast on every local mutation. Value is
|
|
||||||
// omitempty so a delete carries only op+key.
|
|
||||||
type delta[K comparable, V any] struct {
|
|
||||||
Op string `json:"op"`
|
|
||||||
Key K `json:"key"`
|
|
||||||
Value V `json:"value,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// SyncedMap is a cross-replica in-memory map. A local write (Set/Delete) updates
|
|
||||||
// memory, the optional durable Store, then broadcasts a delta to peers. A peer's
|
|
||||||
// delta updates memory only and fires OnApply - it never re-broadcasts and never
|
|
||||||
// writes the Store. That structural split is the echo-loop guard (same pattern as
|
|
||||||
// galleryop.mergeStatus / OpCache.applyStart): receiving your own broadcast just
|
|
||||||
// re-applies an idempotent value to memory, so there is no storm and no
|
|
||||||
// double-write.
|
|
||||||
type SyncedMap[K comparable, V any] struct {
|
|
||||||
cfg Config[K, V]
|
|
||||||
|
|
||||||
mu sync.RWMutex
|
|
||||||
data map[K]V
|
|
||||||
|
|
||||||
sub Subscription
|
|
||||||
|
|
||||||
// lifeCtx outlives Start's argument: a reconnect callback or reconcile tick
|
|
||||||
// can fire long after Start returns, so they must not be tied to a ctx the
|
|
||||||
// caller may cancel. Close cancels it.
|
|
||||||
lifeCtx context.Context
|
|
||||||
cancel context.CancelFunc
|
|
||||||
wg sync.WaitGroup
|
|
||||||
}
|
|
||||||
|
|
||||||
// Subscription is the subset of messaging.Subscription the component holds onto.
|
|
||||||
type Subscription = messaging.Subscription
|
|
||||||
|
|
||||||
// New constructs a SyncedMap. Call Start to hydrate and begin syncing.
|
|
||||||
func New[K comparable, V any](cfg Config[K, V]) *SyncedMap[K, V] {
|
|
||||||
return &SyncedMap[K, V]{cfg: cfg, data: make(map[K]V)}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *SyncedMap[K, V]) subject() string {
|
|
||||||
return messaging.SubjectSyncStateDelta(m.cfg.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start hydrates from the source, subscribes for peer deltas, registers a
|
|
||||||
// reconnect re-hydrate (when the client supports it), and starts the optional
|
|
||||||
// reconcile ticker.
|
|
||||||
func (m *SyncedMap[K, V]) Start(ctx context.Context) error {
|
|
||||||
if err := m.hydrate(ctx); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// The cancel func is stored on the struct and invoked in Close (covered by
|
|
||||||
// tests); lifeCtx must outlive Start to drive the reconnect/reconcile
|
|
||||||
// goroutines, so it cannot be cancelled or deferred within this scope.
|
|
||||||
m.lifeCtx, m.cancel = context.WithCancel(context.Background()) // #nosec G118 -- cancel is invoked in Close()
|
|
||||||
|
|
||||||
if m.cfg.Nats != nil {
|
|
||||||
sub, err := messaging.SubscribeJSON(m.cfg.Nats, m.subject(), m.apply)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
m.sub = sub
|
|
||||||
|
|
||||||
// nats.go transparently resubscribes on reconnect, but it cannot know we
|
|
||||||
// kept derived in-memory state that may have drifted while the link was
|
|
||||||
// down, so re-hydrate from the durable source. Detected via an optional
|
|
||||||
// interface so MessagingClient itself stays minimal; standalone/test
|
|
||||||
// clients without the method simply fall back to the reconcile ticker.
|
|
||||||
if r, ok := m.cfg.Nats.(interface{ OnReconnect(func()) }); ok {
|
|
||||||
r.OnReconnect(func() {
|
|
||||||
if err := m.hydrate(m.lifeCtx); err != nil {
|
|
||||||
xlog.Warn("syncstate: reconnect re-hydrate failed", "name", m.cfg.Name, "error", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.cfg.Reconcile > 0 {
|
|
||||||
m.wg.Add(1)
|
|
||||||
go m.reconcileLoop()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close unsubscribes and stops the reconcile ticker.
|
|
||||||
func (m *SyncedMap[K, V]) Close() error {
|
|
||||||
if m.cancel != nil {
|
|
||||||
m.cancel()
|
|
||||||
}
|
|
||||||
m.wg.Wait()
|
|
||||||
if m.sub != nil {
|
|
||||||
return m.sub.Unsubscribe()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set updates the value locally, writes through the Store, then broadcasts.
|
|
||||||
// Per the data-flow contract the Store write happens under the lock so memory and
|
|
||||||
// durable state move together; the broadcast is best-effort after unlocking.
|
|
||||||
func (m *SyncedMap[K, V]) Set(ctx context.Context, v V) error {
|
|
||||||
k := m.cfg.Key(v)
|
|
||||||
m.mu.Lock()
|
|
||||||
m.data[k] = v
|
|
||||||
if m.cfg.Store != nil {
|
|
||||||
if err := m.cfg.Store.Upsert(ctx, v); err != nil {
|
|
||||||
m.mu.Unlock()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
m.mu.Unlock()
|
|
||||||
m.publish(opSet, k, v)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete removes the key locally, deletes it from the Store, then broadcasts.
|
|
||||||
func (m *SyncedMap[K, V]) Delete(ctx context.Context, k K) error {
|
|
||||||
m.mu.Lock()
|
|
||||||
delete(m.data, k)
|
|
||||||
if m.cfg.Store != nil {
|
|
||||||
if err := m.cfg.Store.Delete(ctx, k); err != nil {
|
|
||||||
m.mu.Unlock()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
m.mu.Unlock()
|
|
||||||
var zero V
|
|
||||||
m.publish(opDelete, k, zero)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get returns the value for k and whether it was present.
|
|
||||||
func (m *SyncedMap[K, V]) Get(k K) (V, bool) {
|
|
||||||
m.mu.RLock()
|
|
||||||
defer m.mu.RUnlock()
|
|
||||||
v, ok := m.data[k]
|
|
||||||
return v, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
// List returns a snapshot slice of all values.
|
|
||||||
func (m *SyncedMap[K, V]) List() []V {
|
|
||||||
m.mu.RLock()
|
|
||||||
defer m.mu.RUnlock()
|
|
||||||
out := make([]V, 0, len(m.data))
|
|
||||||
for _, v := range m.data {
|
|
||||||
out = append(out, v)
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
// Snapshot returns a copy of the underlying map.
|
|
||||||
func (m *SyncedMap[K, V]) Snapshot() map[K]V {
|
|
||||||
m.mu.RLock()
|
|
||||||
defer m.mu.RUnlock()
|
|
||||||
out := make(map[K]V, len(m.data))
|
|
||||||
for k, v := range m.data {
|
|
||||||
out[k] = v
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
// publish broadcasts a delta. Standalone (nil Nats) is a strict no-op.
|
|
||||||
func (m *SyncedMap[K, V]) publish(op string, k K, v V) {
|
|
||||||
if m.cfg.Nats == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := m.cfg.Nats.Publish(m.subject(), delta[K, V]{Op: op, Key: k, Value: v}); err != nil {
|
|
||||||
xlog.Warn("syncstate: failed to broadcast delta", "name", m.cfg.Name, "op", op, "error", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// apply handles a peer's delta: memory-only update plus OnApply. It deliberately
|
|
||||||
// never writes the Store nor re-publishes - that is the echo-loop guard.
|
|
||||||
func (m *SyncedMap[K, V]) apply(d delta[K, V]) {
|
|
||||||
switch d.Op {
|
|
||||||
case opSet:
|
|
||||||
m.mu.Lock()
|
|
||||||
m.data[d.Key] = d.Value
|
|
||||||
m.mu.Unlock()
|
|
||||||
case opDelete:
|
|
||||||
m.mu.Lock()
|
|
||||||
delete(m.data, d.Key)
|
|
||||||
m.mu.Unlock()
|
|
||||||
default:
|
|
||||||
xlog.Warn("syncstate: ignoring delta with unknown op", "name", m.cfg.Name, "op", d.Op)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if m.cfg.OnApply != nil {
|
|
||||||
m.cfg.OnApply(d.Op, d.Key, d.Value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// hydrate replaces the whole map from the durable source: Store if present, else
|
|
||||||
// Loader. With neither, a late joiner starts empty and catches up via deltas
|
|
||||||
// (acceptable only for ephemeral state).
|
|
||||||
func (m *SyncedMap[K, V]) hydrate(ctx context.Context) error {
|
|
||||||
var (
|
|
||||||
vals []V
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
switch {
|
|
||||||
case m.cfg.Store != nil:
|
|
||||||
vals, err = m.cfg.Store.List(ctx)
|
|
||||||
case m.cfg.Loader != nil:
|
|
||||||
vals, err = m.cfg.Loader(ctx)
|
|
||||||
default:
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
m.replaceAll(vals)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// replaceAll atomically swaps the map contents for the given values, keyed via
|
|
||||||
// cfg.Key.
|
|
||||||
func (m *SyncedMap[K, V]) replaceAll(vals []V) {
|
|
||||||
next := make(map[K]V, len(vals))
|
|
||||||
for _, v := range vals {
|
|
||||||
next[m.cfg.Key(v)] = v
|
|
||||||
}
|
|
||||||
m.mu.Lock()
|
|
||||||
m.data = next
|
|
||||||
m.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// reconcileLoop periodically re-hydrates to repair silent drift (missed deltas).
|
|
||||||
func (m *SyncedMap[K, V]) reconcileLoop() {
|
|
||||||
defer m.wg.Done()
|
|
||||||
t := time.NewTicker(m.cfg.Reconcile)
|
|
||||||
defer t.Stop()
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-m.lifeCtx.Done():
|
|
||||||
return
|
|
||||||
case <-t.C:
|
|
||||||
if err := m.hydrate(m.lifeCtx); err != nil {
|
|
||||||
xlog.Warn("syncstate: reconcile re-hydrate failed", "name", m.cfg.Name, "error", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
package syncstate_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo/v2"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestSyncstate(t *testing.T) {
|
|
||||||
RegisterFailHandler(Fail)
|
|
||||||
RunSpecs(t, "Syncstate Suite")
|
|
||||||
}
|
|
||||||
@@ -1,291 +0,0 @@
|
|||||||
package syncstate_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo/v2"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core/services/messaging"
|
|
||||||
"github.com/mudler/LocalAI/core/services/syncstate"
|
|
||||||
"github.com/mudler/LocalAI/core/services/testutil"
|
|
||||||
)
|
|
||||||
|
|
||||||
// job is a minimal JSON-serializable value stand-in for the real cross-replica
|
|
||||||
// records (finetune/quant/agent jobs) the component is built for.
|
|
||||||
type job struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func jobKey(j *job) string { return j.ID }
|
|
||||||
|
|
||||||
const stateName = "test.jobs"
|
|
||||||
|
|
||||||
func deltaSubject() string { return messaging.SubjectSyncStateDelta(stateName) }
|
|
||||||
|
|
||||||
// fakeStore is an in-memory Store that records call counts so specs can assert
|
|
||||||
// the write-through-vs-apply split (local writes hit the Store; applied deltas
|
|
||||||
// must not).
|
|
||||||
type fakeStore struct {
|
|
||||||
mu sync.Mutex
|
|
||||||
data map[string]*job
|
|
||||||
upsertCalls int
|
|
||||||
deleteCalls int
|
|
||||||
listCalls int
|
|
||||||
}
|
|
||||||
|
|
||||||
func newFakeStore(seed ...*job) *fakeStore {
|
|
||||||
s := &fakeStore{data: map[string]*job{}}
|
|
||||||
for _, j := range seed {
|
|
||||||
s.data[j.ID] = j
|
|
||||||
}
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *fakeStore) List(_ context.Context) ([]*job, error) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
s.listCalls++
|
|
||||||
out := make([]*job, 0, len(s.data))
|
|
||||||
for _, j := range s.data {
|
|
||||||
out = append(out, j)
|
|
||||||
}
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *fakeStore) Upsert(_ context.Context, j *job) error {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
s.upsertCalls++
|
|
||||||
s.data[j.ID] = j
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *fakeStore) Delete(_ context.Context, k string) error {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
s.deleteCalls++
|
|
||||||
delete(s.data, k)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// add simulates a peer replica writing to the shared DB out-of-band (e.g. while
|
|
||||||
// this replica was partitioned), so a re-hydrate can be observed to pick it up.
|
|
||||||
func (s *fakeStore) add(j *job) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
s.data[j.ID] = j
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *fakeStore) counts() (upsert, del, list int) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
return s.upsertCalls, s.deleteCalls, s.listCalls
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ = Describe("SyncedMap", func() {
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
Describe("cross-replica delta propagation", func() {
|
|
||||||
var (
|
|
||||||
bus *testutil.FakeBus
|
|
||||||
a, b *syncstate.SyncedMap[string, *job]
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
bus = testutil.NewFakeBus()
|
|
||||||
a = syncstate.New(syncstate.Config[string, *job]{Name: stateName, Key: jobKey, Nats: bus})
|
|
||||||
b = syncstate.New(syncstate.Config[string, *job]{Name: stateName, Key: jobKey, Nats: bus})
|
|
||||||
Expect(a.Start(ctx)).To(Succeed())
|
|
||||||
Expect(b.Start(ctx)).To(Succeed())
|
|
||||||
})
|
|
||||||
|
|
||||||
AfterEach(func() {
|
|
||||||
Expect(a.Close()).To(Succeed())
|
|
||||||
Expect(b.Close()).To(Succeed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("propagates a Set on A to B", func() {
|
|
||||||
Expect(a.Set(ctx, &job{ID: "1", Status: "running"})).To(Succeed())
|
|
||||||
|
|
||||||
got, ok := b.Get("1")
|
|
||||||
Expect(ok).To(BeTrue(), "replica B should see the value A just set")
|
|
||||||
Expect(got.Status).To(Equal("running"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("prunes a Delete on A from B", func() {
|
|
||||||
Expect(a.Set(ctx, &job{ID: "1", Status: "running"})).To(Succeed())
|
|
||||||
_, present := b.Get("1")
|
|
||||||
Expect(present).To(BeTrue(), "precondition: B must have the value before the delete")
|
|
||||||
|
|
||||||
Expect(a.Delete(ctx, "1")).To(Succeed())
|
|
||||||
|
|
||||||
_, ok := b.Get("1")
|
|
||||||
Expect(ok).To(BeFalse(), "a delete on A must remove the key from B")
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Describe("hydration", func() {
|
|
||||||
It("hydrates on Start from a preloaded Store", func() {
|
|
||||||
store := newFakeStore(&job{ID: "x", Status: "done"})
|
|
||||||
m := syncstate.New(syncstate.Config[string, *job]{Name: stateName, Key: jobKey, Store: store})
|
|
||||||
Expect(m.Start(ctx)).To(Succeed())
|
|
||||||
|
|
||||||
got, ok := m.Get("x")
|
|
||||||
Expect(ok).To(BeTrue(), "Start must populate the map from the Store")
|
|
||||||
Expect(got.Status).To(Equal("done"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("uses the Loader when Store is nil", func() {
|
|
||||||
m := syncstate.New(syncstate.Config[string, *job]{
|
|
||||||
Name: stateName,
|
|
||||||
Key: jobKey,
|
|
||||||
Loader: func(_ context.Context) ([]*job, error) {
|
|
||||||
return []*job{{ID: "l", Status: "loaded"}}, nil
|
|
||||||
},
|
|
||||||
})
|
|
||||||
Expect(m.Start(ctx)).To(Succeed())
|
|
||||||
|
|
||||||
got, ok := m.Get("l")
|
|
||||||
Expect(ok).To(BeTrue(), "Loader output must hydrate the map when there is no Store")
|
|
||||||
Expect(got.Status).To(Equal("loaded"))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Describe("echo-loop guard", func() {
|
|
||||||
It("applies its own broadcast once and does not re-publish", func() {
|
|
||||||
bus := testutil.NewFakeBus()
|
|
||||||
a := syncstate.New(syncstate.Config[string, *job]{Name: stateName, Key: jobKey, Nats: bus})
|
|
||||||
b := syncstate.New(syncstate.Config[string, *job]{Name: stateName, Key: jobKey, Nats: bus})
|
|
||||||
Expect(a.Start(ctx)).To(Succeed())
|
|
||||||
Expect(b.Start(ctx)).To(Succeed())
|
|
||||||
defer func() {
|
|
||||||
Expect(a.Close()).To(Succeed())
|
|
||||||
Expect(b.Close()).To(Succeed())
|
|
||||||
}()
|
|
||||||
|
|
||||||
Expect(a.Set(ctx, &job{ID: "e", Status: "running"})).To(Succeed())
|
|
||||||
|
|
||||||
// One local write must produce exactly one broadcast: A and B both
|
|
||||||
// receive it and apply to memory, but the apply path never re-publishes.
|
|
||||||
Expect(bus.PublishCount(deltaSubject())).To(Equal(1),
|
|
||||||
"the apply path must not re-broadcast, otherwise replicas storm")
|
|
||||||
Expect(a.List()).To(HaveLen(1), "A must not double-store its own echo")
|
|
||||||
_, ok := b.Get("e")
|
|
||||||
Expect(ok).To(BeTrue())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Describe("Store write-through vs apply", func() {
|
|
||||||
It("writes the Store on local Set/Delete but not on an applied delta", func() {
|
|
||||||
bus := testutil.NewFakeBus()
|
|
||||||
storeA := newFakeStore()
|
|
||||||
storeB := newFakeStore()
|
|
||||||
a := syncstate.New(syncstate.Config[string, *job]{Name: stateName, Key: jobKey, Nats: bus, Store: storeA})
|
|
||||||
b := syncstate.New(syncstate.Config[string, *job]{Name: stateName, Key: jobKey, Nats: bus, Store: storeB})
|
|
||||||
Expect(a.Start(ctx)).To(Succeed())
|
|
||||||
Expect(b.Start(ctx)).To(Succeed())
|
|
||||||
defer func() {
|
|
||||||
Expect(a.Close()).To(Succeed())
|
|
||||||
Expect(b.Close()).To(Succeed())
|
|
||||||
}()
|
|
||||||
|
|
||||||
Expect(a.Set(ctx, &job{ID: "w", Status: "running"})).To(Succeed())
|
|
||||||
|
|
||||||
upA, _, _ := storeA.counts()
|
|
||||||
upB, _, _ := storeB.counts()
|
|
||||||
Expect(upA).To(Equal(1), "local Set must write through to its own Store")
|
|
||||||
Expect(upB).To(Equal(0), "the apply path must never write the peer's Store")
|
|
||||||
|
|
||||||
Expect(a.Delete(ctx, "w")).To(Succeed())
|
|
||||||
_, delA, _ := storeA.counts()
|
|
||||||
_, delB, _ := storeB.counts()
|
|
||||||
Expect(delA).To(Equal(1), "local Delete must delete from its own Store")
|
|
||||||
Expect(delB).To(Equal(0), "the apply path must never delete from the peer's Store")
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Describe("OnApply hook", func() {
|
|
||||||
It("fires with the correct op and key on an applied delta", func() {
|
|
||||||
bus := testutil.NewFakeBus()
|
|
||||||
var (
|
|
||||||
mu sync.Mutex
|
|
||||||
ops []string
|
|
||||||
keys []string
|
|
||||||
)
|
|
||||||
a := syncstate.New(syncstate.Config[string, *job]{Name: stateName, Key: jobKey, Nats: bus})
|
|
||||||
b := syncstate.New(syncstate.Config[string, *job]{
|
|
||||||
Name: stateName, Key: jobKey, Nats: bus,
|
|
||||||
OnApply: func(op string, k string, _ *job) {
|
|
||||||
mu.Lock()
|
|
||||||
ops = append(ops, op)
|
|
||||||
keys = append(keys, k)
|
|
||||||
mu.Unlock()
|
|
||||||
},
|
|
||||||
})
|
|
||||||
Expect(a.Start(ctx)).To(Succeed())
|
|
||||||
Expect(b.Start(ctx)).To(Succeed())
|
|
||||||
defer func() {
|
|
||||||
Expect(a.Close()).To(Succeed())
|
|
||||||
Expect(b.Close()).To(Succeed())
|
|
||||||
}()
|
|
||||||
|
|
||||||
Expect(a.Set(ctx, &job{ID: "o", Status: "running"})).To(Succeed())
|
|
||||||
Expect(a.Delete(ctx, "o")).To(Succeed())
|
|
||||||
|
|
||||||
mu.Lock()
|
|
||||||
defer mu.Unlock()
|
|
||||||
Expect(ops).To(Equal([]string{"set", "delete"}))
|
|
||||||
Expect(keys).To(Equal([]string{"o", "o"}))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Describe("standalone (nil Nats)", func() {
|
|
||||||
It("works in-memory with no panic and nothing to broadcast", func() {
|
|
||||||
m := syncstate.New(syncstate.Config[string, *job]{Name: stateName, Key: jobKey})
|
|
||||||
Expect(m.Start(ctx)).To(Succeed())
|
|
||||||
defer func() { Expect(m.Close()).To(Succeed()) }()
|
|
||||||
|
|
||||||
Expect(func() {
|
|
||||||
Expect(m.Set(ctx, &job{ID: "s", Status: "running"})).To(Succeed())
|
|
||||||
}).ToNot(Panic())
|
|
||||||
|
|
||||||
got, ok := m.Get("s")
|
|
||||||
Expect(ok).To(BeTrue())
|
|
||||||
Expect(got.Status).To(Equal("running"))
|
|
||||||
Expect(m.List()).To(HaveLen(1))
|
|
||||||
Expect(m.Snapshot()).To(HaveKey("s"))
|
|
||||||
|
|
||||||
Expect(m.Delete(ctx, "s")).To(Succeed())
|
|
||||||
_, ok = m.Get("s")
|
|
||||||
Expect(ok).To(BeFalse())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Describe("reconnect re-hydrate", func() {
|
|
||||||
It("re-reads the source when the messaging client reconnects", func() {
|
|
||||||
bus := testutil.NewFakeBus()
|
|
||||||
store := newFakeStore(&job{ID: "init", Status: "running"})
|
|
||||||
m := syncstate.New(syncstate.Config[string, *job]{Name: stateName, Key: jobKey, Nats: bus, Store: store})
|
|
||||||
Expect(m.Start(ctx)).To(Succeed())
|
|
||||||
defer func() { Expect(m.Close()).To(Succeed()) }()
|
|
||||||
|
|
||||||
_, ok := m.Get("init")
|
|
||||||
Expect(ok).To(BeTrue())
|
|
||||||
|
|
||||||
// A peer writes to the shared DB while we are unaware (no delta seen).
|
|
||||||
store.add(&job{ID: "late", Status: "running"})
|
|
||||||
_, ok = m.Get("late")
|
|
||||||
Expect(ok).To(BeFalse(), "the new row should not appear before a re-hydrate")
|
|
||||||
|
|
||||||
bus.TriggerReconnect()
|
|
||||||
|
|
||||||
_, ok = m.Get("late")
|
|
||||||
Expect(ok).To(BeTrue(), "reconnect must re-hydrate from the source and pick up drift")
|
|
||||||
_, _, list := store.counts()
|
|
||||||
Expect(list).To(Equal(2), "exactly one Start hydrate plus one reconnect re-hydrate")
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
@@ -1,160 +0,0 @@
|
|||||||
package testutil
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core/services/messaging"
|
|
||||||
)
|
|
||||||
|
|
||||||
// FakeBus is an in-memory messaging.MessagingClient that delivers each published
|
|
||||||
// message synchronously to every registered subscriber whose subject filter
|
|
||||||
// matches, including NATS-style wildcard subjects (`*` matches exactly one
|
|
||||||
// token).
|
|
||||||
//
|
|
||||||
// Synchronous delivery keeps specs deterministic: the moment Publish returns,
|
|
||||||
// every matching subscriber's handler has already run, so the spec body can read
|
|
||||||
// the resulting state without polling. It is the shared test double for every
|
|
||||||
// cross-replica-sync adopter (gallery, syncstate, ...) so they exercise the same
|
|
||||||
// delivery semantics. It deliberately depends only on the standard library and
|
|
||||||
// the messaging package — no test framework — so it is importable anywhere.
|
|
||||||
type FakeBus struct {
|
|
||||||
mu sync.Mutex
|
|
||||||
subs []fakeBusSub
|
|
||||||
// publishCounts records how many messages were published per subject, so a
|
|
||||||
// spec can assert the echo-loop guard (an applied delta must not re-publish).
|
|
||||||
publishCounts map[string]int
|
|
||||||
|
|
||||||
// reconnectCbs back the optional OnReconnect/TriggerReconnect pair, letting a
|
|
||||||
// spec exercise the component's reconnect re-hydrate path without a real
|
|
||||||
// NATS server.
|
|
||||||
reconnectCbs []func()
|
|
||||||
}
|
|
||||||
|
|
||||||
type fakeBusSub struct {
|
|
||||||
subject string
|
|
||||||
handler func([]byte)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewFakeBus returns a ready-to-use in-memory bus.
|
|
||||||
func NewFakeBus() *FakeBus {
|
|
||||||
return &FakeBus{publishCounts: map[string]int{}}
|
|
||||||
}
|
|
||||||
|
|
||||||
// subjectMatches reports whether a subscription filter matches a concrete
|
|
||||||
// subject, honoring the single-token `*` wildcard used by NATS.
|
|
||||||
func subjectMatches(filter, subject string) bool {
|
|
||||||
if filter == subject {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
fp := strings.Split(filter, ".")
|
|
||||||
sp := strings.Split(subject, ".")
|
|
||||||
if len(fp) != len(sp) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
for i := range fp {
|
|
||||||
if fp[i] == "*" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if fp[i] != sp[i] {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Publish marshals data as JSON and delivers it synchronously to every matching
|
|
||||||
// subscriber.
|
|
||||||
func (b *FakeBus) Publish(subject string, data any) error {
|
|
||||||
payload, err := json.Marshal(data)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
b.mu.Lock()
|
|
||||||
b.publishCounts[subject]++
|
|
||||||
subs := append([]fakeBusSub(nil), b.subs...)
|
|
||||||
b.mu.Unlock()
|
|
||||||
for _, s := range subs {
|
|
||||||
if subjectMatches(s.subject, subject) {
|
|
||||||
s.handler(payload)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// PublishCount returns how many messages were published on the exact subject.
|
|
||||||
func (b *FakeBus) PublishCount(subject string) int {
|
|
||||||
b.mu.Lock()
|
|
||||||
defer b.mu.Unlock()
|
|
||||||
return b.publishCounts[subject]
|
|
||||||
}
|
|
||||||
|
|
||||||
type fakeBusSubscription struct {
|
|
||||||
bus *FakeBus
|
|
||||||
subRef fakeBusSub
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *fakeBusSubscription) Unsubscribe() error {
|
|
||||||
s.bus.mu.Lock()
|
|
||||||
defer s.bus.mu.Unlock()
|
|
||||||
for i, candidate := range s.bus.subs {
|
|
||||||
if candidate.subject == s.subRef.subject {
|
|
||||||
s.bus.subs = append(s.bus.subs[:i], s.bus.subs[i+1:]...)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *FakeBus) Subscribe(subject string, handler func([]byte)) (messaging.Subscription, error) {
|
|
||||||
sub := fakeBusSub{subject: subject, handler: handler}
|
|
||||||
b.mu.Lock()
|
|
||||||
b.subs = append(b.subs, sub)
|
|
||||||
b.mu.Unlock()
|
|
||||||
return &fakeBusSubscription{bus: b, subRef: sub}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *FakeBus) QueueSubscribe(subject, _ string, handler func([]byte)) (messaging.Subscription, error) {
|
|
||||||
return b.Subscribe(subject, handler)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *FakeBus) QueueSubscribeReply(string, string, func([]byte, func([]byte))) (messaging.Subscription, error) {
|
|
||||||
return &fakeBusSubscription{bus: b}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *FakeBus) SubscribeReply(string, func([]byte, func([]byte))) (messaging.Subscription, error) {
|
|
||||||
return &fakeBusSubscription{bus: b}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *FakeBus) Request(string, []byte, time.Duration) ([]byte, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *FakeBus) IsConnected() bool { return true }
|
|
||||||
func (b *FakeBus) Close() {}
|
|
||||||
|
|
||||||
// OnReconnect mirrors *messaging.Client.OnReconnect so a spec can drive the
|
|
||||||
// component's reconnect re-hydrate path. The component detects this method via an
|
|
||||||
// optional interface assertion; implementing it here keeps the fake a faithful
|
|
||||||
// stand-in for the concrete client.
|
|
||||||
func (b *FakeBus) OnReconnect(cb func()) {
|
|
||||||
if cb == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
b.mu.Lock()
|
|
||||||
b.reconnectCbs = append(b.reconnectCbs, cb)
|
|
||||||
b.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// TriggerReconnect runs every registered reconnect callback, simulating a NATS
|
|
||||||
// reconnect event.
|
|
||||||
func (b *FakeBus) TriggerReconnect() {
|
|
||||||
b.mu.Lock()
|
|
||||||
cbs := append([]func(){}, b.reconnectCbs...)
|
|
||||||
b.mu.Unlock()
|
|
||||||
for _, cb := range cbs {
|
|
||||||
cb()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,4 +1,52 @@
|
|||||||
---
|
---
|
||||||
|
- name: "qwen-agentworld-35b-a3b"
|
||||||
|
url: "github:mudler/LocalAI/gallery/virtual.yaml@master"
|
||||||
|
urls:
|
||||||
|
- https://huggingface.co/unsloth/Qwen-AgentWorld-35B-A3B-GGUF
|
||||||
|
description: |
|
||||||
|
# Qwen-AgentWorld-35B-A3B
|
||||||
|
|
||||||
|
📑 Technical Report |
|
||||||
|
📖 Blog |
|
||||||
|
🤗 Hugging Face |
|
||||||
|
🤖 ModelScope |
|
||||||
|
💻 GitHub |
|
||||||
|
🖥️ Demo
|
||||||
|
|
||||||
|
> [!Note]
|
||||||
|
> This repository contains the model weights and configuration files for **Qwen-AgentWorld-35B-A3B**, a native language world model trained for agentic environment simulation.
|
||||||
|
>
|
||||||
|
> These artifacts are compatible with Hugging Face Transformers, vLLM, SGLang, etc.
|
||||||
|
|
||||||
|
**Qwen-AgentWorld** is the first language world model to cover seven agent interaction domains within a single model. It simulates agentic environments via long chain-of-thought reasoning, predicting the next environment state given an agent's action and interaction history. Trained through a three-stage pipeline — CPT injects environment knowledge, SFT activates next-state-prediction reasoning, RL sharpens simulation fidelity — Qwen-AgentWorld is a **native world model**: environment modeling is the training objective from the CPT stage onward, not a post-hoc add-on.
|
||||||
|
|
||||||
|
## Highlights
|
||||||
|
|
||||||
|
...
|
||||||
|
license: "apache-2.0"
|
||||||
|
tags:
|
||||||
|
- llm
|
||||||
|
- gguf
|
||||||
|
- qwen
|
||||||
|
icon: https://qianwen-res.oss-accelerate-overseas.aliyuncs.com/Qwen-AgentWorld/logo.png
|
||||||
|
overrides:
|
||||||
|
backend: llama-cpp
|
||||||
|
function:
|
||||||
|
automatic_tool_parsing_fallback: true
|
||||||
|
grammar:
|
||||||
|
disable: true
|
||||||
|
known_usecases:
|
||||||
|
- chat
|
||||||
|
options:
|
||||||
|
- use_jinja:true
|
||||||
|
parameters:
|
||||||
|
model: llama-cpp/models/Qwen-AgentWorld-35B-A3B-GGUF/Qwen-AgentWorld-35B-A3B-UD-Q4_K_M.gguf
|
||||||
|
template:
|
||||||
|
use_tokenizer_template: true
|
||||||
|
files:
|
||||||
|
- filename: llama-cpp/models/Qwen-AgentWorld-35B-A3B-GGUF/Qwen-AgentWorld-35B-A3B-UD-Q4_K_M.gguf
|
||||||
|
sha256: e7a8eafdd8013443b6bcc4b6fb47b2d2025f772d359650b9ceb7d75971e22cad
|
||||||
|
uri: https://huggingface.co/unsloth/Qwen-AgentWorld-35B-A3B-GGUF/resolve/main/Qwen-AgentWorld-35B-A3B-UD-Q4_K_M.gguf
|
||||||
- name: "ornith-1.0-9b"
|
- name: "ornith-1.0-9b"
|
||||||
url: "github:mudler/LocalAI/gallery/virtual.yaml@master"
|
url: "github:mudler/LocalAI/gallery/virtual.yaml@master"
|
||||||
urls:
|
urls:
|
||||||
|
|||||||
@@ -17,9 +17,15 @@ rm -rf "${BACKEND_DIR}"/build-*
|
|||||||
# run.sh's final `exec $CURDIR/<binary>` is the contract for what gets launched;
|
# run.sh's final `exec $CURDIR/<binary>` is the contract for what gets launched;
|
||||||
# the binary is not always named after the backend (e.g. parakeet-cpp launches
|
# the binary is not always named after the backend (e.g. parakeet-cpp launches
|
||||||
# parakeet-cpp-grpc), so derive it from run.sh and fall back to ${BACKEND}.
|
# parakeet-cpp-grpc), so derive it from run.sh and fall back to ${BACKEND}.
|
||||||
|
#
|
||||||
|
# Only scan the `exec` line(s): many run.sh select a runtime CPU variant via
|
||||||
|
# unquoted `LIBRARY=$CURDIR/libgo<x>-avx512.so` lines, and a whole-file grep
|
||||||
|
# would pick the last of those (avx512, which Darwin never builds) instead of
|
||||||
|
# the binary — failing the check below for whisper/sam3-cpp/vibevoice-cpp/...
|
||||||
|
# Also tolerate the exec being quoted (`exec "$CURDIR"/<binary>`).
|
||||||
RUN_BINARY=""
|
RUN_BINARY=""
|
||||||
if [ -f "${BACKEND_DIR}/run.sh" ]; then
|
if [ -f "${BACKEND_DIR}/run.sh" ]; then
|
||||||
RUN_BINARY=$(grep -oE '\$CURDIR/[A-Za-z0-9._-]+' "${BACKEND_DIR}/run.sh" | grep -v 'ld\.so' | tail -1 | sed 's|\$CURDIR/||')
|
RUN_BINARY=$(grep -E '^[[:space:]]*exec[[:space:]]' "${BACKEND_DIR}/run.sh" | grep -oE '"?\$CURDIR"?/[A-Za-z0-9._-]+' | grep -v 'ld\.so' | tail -1 | sed -E 's|"?\$CURDIR"?/||')
|
||||||
fi
|
fi
|
||||||
RUN_BINARY="${RUN_BINARY:-${BACKEND}}"
|
RUN_BINARY="${RUN_BINARY:-${BACKEND}}"
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user