Compare commits

...

25 Commits

Author SHA1 Message Date
Ettore Di Giacinto
3e8a54f4b6 chore(docs): improve
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2025-11-17 19:34:25 +01:00
Ettore Di Giacinto
18d11396cd chore(docs): improve documentation and split into sections bigger topics (#7292)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2025-11-17 18:39:21 +01:00
Ettore Di Giacinto
93cd688f40 chore: small ux enhancements (#7290)
* chore: improve chat attachments

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* chore: display installed backends/models

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2025-11-17 17:09:42 +01:00
Ettore Di Giacinto
721c3f962b chore: scroll in thinking mode, better buttons placement (#7289)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2025-11-17 16:21:14 +01:00
LocalAI [bot]
fb834805db chore: ⬆️ Update ggml-org/llama.cpp to 80deff3648b93727422461c41c7279ef1dac7452 (#7287)
⬆️ Update ggml-org/llama.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2025-11-17 07:51:08 +01:00
LocalAI [bot]
839aa7b42b feat(swagger): update swagger (#7286)
Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2025-11-17 07:49:06 +01:00
Ettore Di Giacinto
e963a45d66 feat(index): minor enhancements (#7288)
* feat(ui): add placeholder effect and select first model by default

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix(ui): correctly bind focus to parent

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2025-11-16 21:23:18 +01:00
Mikhail Khludnev
c313b2c671 fix(reranker): tests and top_n check fix #7212 (#7284)
reranker tests and top_n check fix #7212

Signed-off-by: Mikhail Khludnev <mkhl@apache.org>
2025-11-16 17:53:23 +01:00
Ettore Di Giacinto
137f16336e feat(ui): small refinements (#7285)
* feat(ui): show loaded models in the index

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* chore(ui): re-organize navbar

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2025-11-16 17:50:13 +01:00
Ettore Di Giacinto
d7f9f3ac93 feat: add support to logitbias and logprobs (#7283)
* feat: add support to logprobs in results

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat: add support to logitbias

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2025-11-16 13:27:36 +01:00
Ettore Di Giacinto
cd7d384500 feat: restyle index (#7282)
* Move management to separate section

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Make index to redirect to chat

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Use logo in index

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* work out the wizard in the front-page

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2025-11-16 11:01:05 +01:00
LocalAI [bot]
d1a0dd10e6 chore: ⬆️ Update ggml-org/llama.cpp to 662192e1dcd224bc25759aadd0190577524c6a66 (#7277)
⬆️ Update ggml-org/llama.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2025-11-16 08:41:12 +01:00
Ettore Di Giacinto
be8cf838c2 feat(importers): add transformers and vLLM (#7278)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2025-11-15 22:47:09 +01:00
LocalAI [bot]
3276d1cdaf feat(swagger): update swagger (#7276)
Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2025-11-15 21:50:30 +01:00
Ettore Di Giacinto
5e5f01badd chore(ui): import vendored libs (#7281)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2025-11-15 21:49:52 +01:00
Ettore Di Giacinto
6d0f646c37 chore: guide the user to import models (#7280)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2025-11-15 21:37:50 +01:00
Ettore Di Giacinto
99d31667f8 chore: do not use placeholder image (#7279)
Use font-awesome icons instead

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2025-11-15 21:31:58 +01:00
Ettore Di Giacinto
47b546afdc feat(mcp): add LocalAI endpoint to stream live results of the agent (#7274)
* feat(mcp): add LocalAI endpoint to stream live results of the agent

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* wip

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Refactoring

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* MCP UX integration

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Enhance UX

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Support also non-SSE

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2025-11-15 17:38:00 +01:00
LocalAI [bot]
a09d49da43 chore: ⬆️ Update ggml-org/llama.cpp to 9b17d74ab7d31cb7d15ee7eec1616c3d825a84c0 (#7273)
⬆️ Update ggml-org/llama.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2025-11-15 00:05:39 +01:00
Ettore Di Giacinto
1cdcaf0152 feat: migrate to echo and enable cancellation of non-streaming requests (#7270)
* WIP: migrate to echo

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* tests

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2025-11-14 22:57:53 +01:00
Ettore Di Giacinto
03e9f4b140 fix: handle tool errors (#7271)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2025-11-14 17:23:56 +01:00
Ettore Di Giacinto
7129409bf6 chore(deps): bump llama.cpp to c4abcb2457217198efdd67d02675f5fddb7071c2 (#7266)
* chore(deps): bump llama.cpp to '92bb442ad999a0d52df0af2730cd861012e8ac5c'

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* DEBUG

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Bump

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* test/debug

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Revert "DEBUG"

This reverts commit 2501ca3ff242076d623c13c86b3d6afcec426281.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2025-11-14 12:16:52 +01:00
LocalAI [bot]
d9e9ec6825 chore: ⬆️ Update ggml-org/whisper.cpp to d9b7613b34a343848af572cc14467fc5e82fc788 (#7268)
⬆️ Update ggml-org/whisper.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2025-11-13 23:05:06 +01:00
LocalAI [bot]
b82645d28d feat(swagger): update swagger (#7267)
Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2025-11-13 21:28:10 +00:00
Ettore Di Giacinto
735ca757fa feat(ui): allow to cancel ops (#7264)
* feat(ui): allow to cancel ops

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Improve progress text

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Cancel queued ops, don't show up message cancellation always

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix: fixup displaying of total progress over multiple files

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2025-11-13 18:41:47 +01:00
153 changed files with 11679 additions and 3778 deletions

View File

@@ -156,6 +156,8 @@ message PredictOptions {
string CorrelationId = 47;
string Tools = 48; // JSON array of available tools/functions for tool calling
string ToolChoice = 49; // JSON string or object specifying tool choice behavior
int32 Logprobs = 50; // Number of top logprobs to return (maps to OpenAI logprobs parameter)
int32 TopLogprobs = 51; // Number of top logprobs to return per token (maps to OpenAI top_logprobs parameter)
}
// The response message containing the result
@@ -166,6 +168,7 @@ message Reply {
double timing_prompt_processing = 4;
double timing_token_generation = 5;
bytes audio = 6;
bytes logprobs = 7; // JSON-encoded logprobs data matching OpenAI format
}
message GrammarTrigger {

View File

@@ -1,5 +1,5 @@
LLAMA_VERSION?=7d019cff744b73084b15ca81ba9916f3efab1223
LLAMA_VERSION?=80deff3648b93727422461c41c7279ef1dac7452
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
CMAKE_ARGS?=

View File

@@ -166,6 +166,34 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, const
SRV_INF("Extracted tool_choice as string: %s\n", predict->toolchoice().c_str());
}
}
// Extract logprobs and top_logprobs from proto and add to JSON data
// Following server.cpp pattern: logprobs maps to n_probs when provided
if (predict->logprobs() > 0) {
data["logprobs"] = predict->logprobs();
// Map logprobs to n_probs (following server.cpp line 369 pattern)
// n_probs will be set by params_from_json_cmpl if logprobs is provided
data["n_probs"] = predict->logprobs();
SRV_INF("Using logprobs: %d\n", predict->logprobs());
}
if (predict->toplogprobs() > 0) {
data["top_logprobs"] = predict->toplogprobs();
SRV_INF("Using top_logprobs: %d\n", predict->toplogprobs());
}
// Extract logit_bias from proto and add to JSON data
if (!predict->logitbias().empty()) {
try {
// Parse logit_bias JSON string from proto
json logit_bias_json = json::parse(predict->logitbias());
// Add to data - llama.cpp server expects it as an object (map)
data["logit_bias"] = logit_bias_json;
SRV_INF("Using logit_bias: %s\n", predict->logitbias().c_str());
} catch (const json::parse_error& e) {
SRV_ERR("Failed to parse logit_bias JSON from proto: %s\n", e.what());
}
}
data["ignore_eos"] = predict->ignoreeos();
data["embeddings"] = predict->embeddings();
@@ -568,6 +596,28 @@ public:
return Status::OK;
}
// Helper function to extract logprobs from JSON response
static json extract_logprobs_from_json(const json& res_json) {
json logprobs_json = json::object();
// Check for OAI-compatible format: choices[0].logprobs
if (res_json.contains("choices") && res_json["choices"].is_array() &&
res_json["choices"].size() > 0 && res_json["choices"][0].contains("logprobs")) {
logprobs_json = res_json["choices"][0]["logprobs"];
}
// Check for non-OAI format: completion_probabilities
else if (res_json.contains("completion_probabilities")) {
// Convert completion_probabilities to OAI format
logprobs_json["content"] = res_json["completion_probabilities"];
}
// Check for direct logprobs field
else if (res_json.contains("logprobs")) {
logprobs_json = res_json["logprobs"];
}
return logprobs_json;
}
grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter<backend::Reply>* writer) override {
json data = parse_options(true, request, ctx_server);
@@ -579,7 +629,8 @@ public:
auto completion_id = gen_chatcmplid();
std::unordered_set<int> task_ids;
// need to store the reader as a pointer, so that it won't be destroyed when the handle returns
const auto rd = std::make_shared<server_response_reader>(ctx_server);
try {
std::vector<server_task> tasks;
@@ -620,6 +671,11 @@ public:
content_val = msg.content();
}
// If content is an object (e.g., from tool call failures), convert to string
if (content_val.is_object()) {
content_val = content_val.dump();
}
// If content is a string and this is the last user message with images/audio, combine them
if (content_val.is_string() && is_last_user_msg && has_images_or_audio) {
json content_array = json::array();
@@ -871,18 +927,91 @@ public:
tasks.push_back(std::move(task));
}
task_ids = server_task::get_list_id(tasks);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(std::move(tasks));
rd->post_tasks(std::move(tasks));
} catch (const std::exception & e) {
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, e.what());
}
ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool {
// Get first result for error checking (following server.cpp pattern)
server_task_result_ptr first_result = rd->next([&context]() { return context->IsCancelled(); });
if (first_result == nullptr) {
// connection is closed
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
} else if (first_result->is_error()) {
json error_json = first_result->to_json();
backend::Reply reply;
reply.set_message(error_json.value("message", ""));
writer->Write(reply);
return grpc::Status(grpc::StatusCode::INTERNAL, error_json.value("message", "Error occurred"));
}
// Process first result
json first_res_json = first_result->to_json();
if (first_res_json.is_array()) {
for (const auto & res : first_res_json) {
std::string completion_text = res.value("content", "");
backend::Reply reply;
reply.set_message(completion_text);
int32_t tokens_predicted = res.value("tokens_predicted", 0);
reply.set_tokens(tokens_predicted);
int32_t tokens_evaluated = res.value("tokens_evaluated", 0);
reply.set_prompt_tokens(tokens_evaluated);
if (res.contains("timings")) {
double timing_prompt_processing = res.at("timings").value("prompt_ms", 0.0);
reply.set_timing_prompt_processing(timing_prompt_processing);
double timing_token_generation = res.at("timings").value("predicted_ms", 0.0);
reply.set_timing_token_generation(timing_token_generation);
}
// Extract and set logprobs if present
json logprobs_json = extract_logprobs_from_json(res);
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
std::string logprobs_str = logprobs_json.dump();
reply.set_logprobs(logprobs_str);
}
writer->Write(reply);
}
} else {
std::string completion_text = first_res_json.value("content", "");
backend::Reply reply;
reply.set_message(completion_text);
int32_t tokens_predicted = first_res_json.value("tokens_predicted", 0);
reply.set_tokens(tokens_predicted);
int32_t tokens_evaluated = first_res_json.value("tokens_evaluated", 0);
reply.set_prompt_tokens(tokens_evaluated);
if (first_res_json.contains("timings")) {
double timing_prompt_processing = first_res_json.at("timings").value("prompt_ms", 0.0);
reply.set_timing_prompt_processing(timing_prompt_processing);
double timing_token_generation = first_res_json.at("timings").value("predicted_ms", 0.0);
reply.set_timing_token_generation(timing_token_generation);
}
// Extract and set logprobs if present
json logprobs_json = extract_logprobs_from_json(first_res_json);
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
std::string logprobs_str = logprobs_json.dump();
reply.set_logprobs(logprobs_str);
}
writer->Write(reply);
}
// Process subsequent results
while (rd->has_next()) {
// Check if context is cancelled before processing result
if (context->IsCancelled()) {
ctx_server.cancel_tasks(task_ids);
return false;
break;
}
auto result = rd->next([&context]() { return context->IsCancelled(); });
if (result == nullptr) {
// connection is closed
break;
}
json res_json = result->to_json();
@@ -904,9 +1033,13 @@ public:
reply.set_timing_token_generation(timing_token_generation);
}
// Log Request Correlation Id
// Send the reply
// Extract and set logprobs if present
json logprobs_json = extract_logprobs_from_json(res);
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
std::string logprobs_str = logprobs_json.dump();
reply.set_logprobs(logprobs_str);
}
writer->Write(reply);
}
} else {
@@ -926,24 +1059,16 @@ public:
reply.set_timing_token_generation(timing_token_generation);
}
// Extract and set logprobs if present
json logprobs_json = extract_logprobs_from_json(res_json);
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
std::string logprobs_str = logprobs_json.dump();
reply.set_logprobs(logprobs_str);
}
// Send the reply
writer->Write(reply);
writer->Write(reply);
}
return true;
}, [&](const json & error_data) {
backend::Reply reply;
reply.set_message(error_data.value("content", ""));
writer->Write(reply);
return true;
}, [&context]() {
// Check if the gRPC context is cancelled
return context->IsCancelled();
});
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
}
// Check if context was cancelled during processing
if (context->IsCancelled()) {
@@ -963,7 +1088,7 @@ public:
}
std::cout << "[PREDICT] Received result: " << data.dump(2) << std::endl;
auto completion_id = gen_chatcmplid();
std::unordered_set<int> task_ids;
const auto rd = std::make_shared<server_response_reader>(ctx_server);
try {
std::vector<server_task> tasks;
@@ -1004,6 +1129,11 @@ public:
content_val = msg.content();
}
// If content is an object (e.g., from tool call failures), convert to string
if (content_val.is_object()) {
content_val = content_val.dump();
}
// If content is a string and this is the last user message with images/audio, combine them
if (content_val.is_string() && is_last_user_msg && has_images_or_audio) {
json content_array = json::array();
@@ -1261,9 +1391,7 @@ public:
tasks.push_back(std::move(task));
}
task_ids = server_task::get_list_id(tasks);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(std::move(tasks));
rd->post_tasks(std::move(tasks));
} catch (const std::exception & e) {
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, e.what());
}
@@ -1271,51 +1399,71 @@ public:
std::cout << "[DEBUG] Waiting for results..." << std::endl;
// Check cancellation before waiting for results
if (context->IsCancelled()) {
ctx_server.cancel_tasks(task_ids);
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
// Wait for all results
auto all_results = rd->wait_for_all([&context]() { return context->IsCancelled(); });
if (all_results.is_terminated) {
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
}
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
std::cout << "[DEBUG] Received " << results.size() << " results" << std::endl;
if (results.size() == 1) {
} else if (all_results.error) {
std::cout << "[DEBUG] Error in results: " << all_results.error->to_json().value("message", "") << std::endl;
reply->set_message(all_results.error->to_json().value("message", ""));
return grpc::Status(grpc::StatusCode::INTERNAL, all_results.error->to_json().value("message", "Error occurred"));
} else {
std::cout << "[DEBUG] Received " << all_results.results.size() << " results" << std::endl;
if (all_results.results.size() == 1) {
// single result
reply->set_message(results[0]->to_json().value("content", ""));
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(all_results.results[0].get()) != nullptr);
json result_json = all_results.results[0]->to_json();
reply->set_message(result_json.value("content", ""));
int32_t tokens_predicted = results[0]->to_json().value("tokens_predicted", 0);
int32_t tokens_predicted = result_json.value("tokens_predicted", 0);
reply->set_tokens(tokens_predicted);
int32_t tokens_evaluated = results[0]->to_json().value("tokens_evaluated", 0);
int32_t tokens_evaluated = result_json.value("tokens_evaluated", 0);
reply->set_prompt_tokens(tokens_evaluated);
if (results[0]->to_json().contains("timings")) {
double timing_prompt_processing = results[0]->to_json().at("timings").value("prompt_ms", 0.0);
if (result_json.contains("timings")) {
double timing_prompt_processing = result_json.at("timings").value("prompt_ms", 0.0);
reply->set_timing_prompt_processing(timing_prompt_processing);
double timing_token_generation = results[0]->to_json().at("timings").value("predicted_ms", 0.0);
double timing_token_generation = result_json.at("timings").value("predicted_ms", 0.0);
reply->set_timing_token_generation(timing_token_generation);
}
// Extract and set logprobs if present
json logprobs_json = extract_logprobs_from_json(result_json);
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
std::string logprobs_str = logprobs_json.dump();
reply->set_logprobs(logprobs_str);
}
} else {
// multiple results (multitask)
json arr = json::array();
for (auto & res : results) {
arr.push_back(res->to_json().value("content", ""));
json logprobs_arr = json::array();
bool has_logprobs = false;
for (auto & res : all_results.results) {
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
json res_json = res->to_json();
arr.push_back(res_json.value("content", ""));
// Extract logprobs for each result
json logprobs_json = extract_logprobs_from_json(res_json);
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
has_logprobs = true;
logprobs_arr.push_back(logprobs_json);
} else {
logprobs_arr.push_back(json::object());
}
}
reply->set_message(arr);
// Set logprobs if any result has them
if (has_logprobs) {
std::string logprobs_str = logprobs_arr.dump();
reply->set_logprobs(logprobs_str);
}
}
}, [&](const json & error_data) {
std::cout << "[DEBUG] Error in results: " << error_data.value("content", "") << std::endl;
reply->set_message(error_data.value("content", ""));
}, [&context]() {
// Check if the gRPC context is cancelled
// This is checked every HTTP_POLLING_SECONDS (1 second) during receive_multi_results
return context->IsCancelled();
});
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
}
std::cout << "[DEBUG] Predict request completed successfully" << std::endl;
// Check if context was cancelled during processing
@@ -1352,9 +1500,7 @@ public:
int embd_normalize = 2; // default to Euclidean/L2 norm
// create and queue the task
json responses = json::array();
bool error = false;
std::unordered_set<int> task_ids;
const auto rd = std::make_shared<server_response_reader>(ctx_server);
{
std::vector<server_task> tasks;
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
@@ -1369,40 +1515,23 @@ public:
tasks.push_back(std::move(task));
}
task_ids = server_task::get_list_id(tasks);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(std::move(tasks));
rd->post_tasks(std::move(tasks));
}
// Check cancellation before waiting for results
if (context->IsCancelled()) {
ctx_server.cancel_tasks(task_ids);
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
// Wait for all results
auto all_results = rd->wait_for_all([&context]() { return context->IsCancelled(); });
if (all_results.is_terminated) {
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
} else if (all_results.error) {
return grpc::Status(grpc::StatusCode::INTERNAL, all_results.error->to_json().value("message", "Error in receiving results"));
}
// get the result
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
for (auto & res : results) {
GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr);
responses.push_back(res->to_json());
}
}, [&](const json & error_data) {
error = true;
}, [&context]() {
// Check if the gRPC context is cancelled
return context->IsCancelled();
});
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
// Check if context was cancelled during processing
if (context->IsCancelled()) {
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
}
if (error) {
return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results");
// Collect responses
json responses = json::array();
for (auto & res : all_results.results) {
GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr);
responses.push_back(res->to_json());
}
std::cout << "[DEBUG] Responses size: " << responses.size() << std::endl;
@@ -1453,9 +1582,7 @@ public:
}
// Create and queue the task
json responses = json::array();
bool error = false;
std::unordered_set<int> task_ids;
const auto rd = std::make_shared<server_response_reader>(ctx_server);
{
std::vector<server_task> tasks;
std::vector<std::string> documents;
@@ -1473,40 +1600,23 @@ public:
tasks.push_back(std::move(task));
}
task_ids = server_task::get_list_id(tasks);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(std::move(tasks));
rd->post_tasks(std::move(tasks));
}
// Check cancellation before waiting for results
if (context->IsCancelled()) {
ctx_server.cancel_tasks(task_ids);
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
// Wait for all results
auto all_results = rd->wait_for_all([&context]() { return context->IsCancelled(); });
if (all_results.is_terminated) {
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
} else if (all_results.error) {
return grpc::Status(grpc::StatusCode::INTERNAL, all_results.error->to_json().value("message", "Error in receiving results"));
}
// Get the results
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
for (auto & res : results) {
GGML_ASSERT(dynamic_cast<server_task_result_rerank*>(res.get()) != nullptr);
responses.push_back(res->to_json());
}
}, [&](const json & error_data) {
error = true;
}, [&context]() {
// Check if the gRPC context is cancelled
return context->IsCancelled();
});
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
// Check if context was cancelled during processing
if (context->IsCancelled()) {
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
}
if (error) {
return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results");
// Collect responses
json responses = json::array();
for (auto & res : all_results.results) {
GGML_ASSERT(dynamic_cast<server_task_result_rerank*>(res.get()) != nullptr);
responses.push_back(res->to_json());
}
// Sort responses by score in descending order
std::sort(responses.begin(), responses.end(), [](const json& a, const json& b) {

View File

@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
# whisper.cpp version
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
WHISPER_CPP_VERSION?=a1867e0dad0b21b35afa43fc815dae60c9a139d6
WHISPER_CPP_VERSION?=d9b7613b34a343848af572cc14467fc5e82fc788
SO_TARGET?=libgowhisper.so
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF

View File

@@ -62,12 +62,12 @@ func New(opts ...config.AppOption) (*Application, error) {
}
}
if err := coreStartup.InstallModels(application.GalleryService(), options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
if err := coreStartup.InstallModels(options.Context, application.GalleryService(), options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
log.Error().Err(err).Msg("error installing models")
}
for _, backend := range options.ExternalBackends {
if err := coreStartup.InstallExternalBackends(options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil {
if err := coreStartup.InstallExternalBackends(options.Context, options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil {
log.Error().Err(err).Msg("error installing external backend")
}
}

View File

@@ -2,6 +2,7 @@ package backend
import (
"context"
"encoding/json"
"regexp"
"slices"
"strings"
@@ -24,6 +25,7 @@ type LLMResponse struct {
Response string // should this be []byte?
Usage TokenUsage
AudioOutput string
Logprobs *schema.Logprobs // Logprobs from the backend response
}
type TokenUsage struct {
@@ -33,7 +35,7 @@ type TokenUsage struct {
TimingTokenGeneration float64
}
func ModelInference(ctx context.Context, s string, messages schema.Messages, images, videos, audios []string, loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool, tools string, toolChoice string) (func() (LLMResponse, error), error) {
func ModelInference(ctx context.Context, s string, messages schema.Messages, images, videos, audios []string, loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool, tools string, toolChoice string, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (LLMResponse, error), error) {
modelFile := c.Model
// Check if the modelFile exists, if it doesn't try to load it from the gallery
@@ -45,7 +47,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
if !slices.Contains(modelNames, c.Name) {
utils.ResetDownloadTimers()
// if we failed to load the model, we try to download it
err := gallery.InstallModelFromGallery(o.Galleries, o.BackendGalleries, o.SystemState, loader, c.Name, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
err := gallery.InstallModelFromGallery(ctx, o.Galleries, o.BackendGalleries, o.SystemState, loader, c.Name, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
if err != nil {
log.Error().Err(err).Msgf("failed to install model %q from gallery", modelFile)
//return nil, err
@@ -78,6 +80,19 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
opts.Audios = audios
opts.Tools = tools
opts.ToolChoice = toolChoice
if logprobs != nil {
opts.Logprobs = int32(*logprobs)
}
if topLogprobs != nil {
opts.TopLogprobs = int32(*topLogprobs)
}
if len(logitBias) > 0 {
// Serialize logit_bias map to JSON string for proto
logitBiasJSON, err := json.Marshal(logitBias)
if err == nil {
opts.LogitBias = string(logitBiasJSON)
}
}
tokenUsage := TokenUsage{}
@@ -109,6 +124,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
}
ss := ""
var logprobs *schema.Logprobs
var partialRune []byte
err := inferenceModel.PredictStream(ctx, opts, func(reply *proto.Reply) {
@@ -120,6 +136,14 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
tokenUsage.TimingTokenGeneration = reply.TimingTokenGeneration
tokenUsage.TimingPromptProcessing = reply.TimingPromptProcessing
// Parse logprobs from reply if present (collect from last chunk that has them)
if len(reply.Logprobs) > 0 {
var parsedLogprobs schema.Logprobs
if err := json.Unmarshal(reply.Logprobs, &parsedLogprobs); err == nil {
logprobs = &parsedLogprobs
}
}
// Process complete runes and accumulate them
var completeRunes []byte
for len(partialRune) > 0 {
@@ -145,6 +169,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
return LLMResponse{
Response: ss,
Usage: tokenUsage,
Logprobs: logprobs,
}, err
} else {
// TODO: Is the chicken bit the only way to get here? is that acceptable?
@@ -167,9 +192,19 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
response = c.TemplateConfig.ReplyPrefix + response
}
// Parse logprobs from reply if present
var logprobs *schema.Logprobs
if len(reply.Logprobs) > 0 {
var parsedLogprobs schema.Logprobs
if err := json.Unmarshal(reply.Logprobs, &parsedLogprobs); err == nil {
logprobs = &parsedLogprobs
}
}
return LLMResponse{
Response: response,
Usage: tokenUsage,
Logprobs: logprobs,
}, err
}
}

View File

@@ -212,7 +212,7 @@ func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions
}
}
return &pb.PredictOptions{
pbOpts := &pb.PredictOptions{
Temperature: float32(*c.Temperature),
TopP: float32(*c.TopP),
NDraft: c.NDraft,
@@ -249,4 +249,6 @@ func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions
TailFreeSamplingZ: float32(*c.TFZ),
TypicalP: float32(*c.TypicalP),
}
// Logprobs and TopLogprobs are set by the caller if provided
return pbOpts
}

View File

@@ -1,6 +1,7 @@
package cli
import (
"context"
"encoding/json"
"fmt"
@@ -102,7 +103,7 @@ func (bi *BackendsInstall) Run(ctx *cliContext.Context) error {
}
modelLoader := model.NewModelLoader(systemState, true)
err = startup.InstallExternalBackends(galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
err = startup.InstallExternalBackends(context.Background(), galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
if err != nil {
return err
}

View File

@@ -48,10 +48,12 @@ func (e *ExplorerCMD) Run(ctx *cliContext.Context) error {
appHTTP := http.Explorer(db)
signals.RegisterGracefulTerminationHandler(func() {
if err := appHTTP.Shutdown(); err != nil {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := appHTTP.Shutdown(ctx); err != nil {
log.Error().Err(err).Msg("error during shutdown")
}
})
return appHTTP.Listen(e.Address)
return appHTTP.Start(e.Address)
}

View File

@@ -135,7 +135,7 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
}
modelLoader := model.NewModelLoader(systemState, true)
err = startup.InstallModels(galleryService, galleries, backendGalleries, systemState, modelLoader, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, progressCallback, modelName)
err = startup.InstallModels(context.Background(), galleryService, galleries, backendGalleries, systemState, modelLoader, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, progressCallback, modelName)
if err != nil {
return err
}

View File

@@ -232,5 +232,5 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
}
})
return appHTTP.Listen(r.Address)
return appHTTP.Start(r.Address)
}

View File

@@ -1,6 +1,7 @@
package worker
import (
"context"
"encoding/json"
"errors"
"fmt"
@@ -42,7 +43,7 @@ func findLLamaCPPBackend(galleries string, systemState *system.SystemState) (str
log.Error().Err(err).Msg("failed loading galleries")
return "", err
}
err := gallery.InstallBackendFromGallery(gals, systemState, ml, llamaCPPGalleryName, nil, true)
err := gallery.InstallBackendFromGallery(context.Background(), gals, systemState, ml, llamaCPPGalleryName, nil, true)
if err != nil {
log.Error().Err(err).Msg("llama-cpp backend not found, failed to install it")
return "", err

View File

@@ -9,6 +9,7 @@ import (
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/downloader"
"github.com/mudler/LocalAI/pkg/functions"
"github.com/mudler/cogito"
"gopkg.in/yaml.v3"
)
@@ -668,3 +669,40 @@ func (c *ModelConfig) GuessUsecases(u ModelConfigUsecases) bool {
return true
}
// BuildCogitoOptions generates cogito options from the model configuration
// It accepts a context, MCP sessions, and optional callback functions for status, reasoning, tool calls, and tool results
func (c *ModelConfig) BuildCogitoOptions() []cogito.Option {
cogitoOpts := []cogito.Option{
cogito.WithIterations(3), // default to 3 iterations
cogito.WithMaxAttempts(3), // default to 3 attempts
cogito.WithForceReasoning(),
}
// Apply agent configuration options
if c.Agent.EnableReasoning {
cogitoOpts = append(cogitoOpts, cogito.EnableToolReasoner)
}
if c.Agent.EnablePlanning {
cogitoOpts = append(cogitoOpts, cogito.EnableAutoPlan)
}
if c.Agent.EnableMCPPrompts {
cogitoOpts = append(cogitoOpts, cogito.EnableMCPPrompts)
}
if c.Agent.EnablePlanReEvaluator {
cogitoOpts = append(cogitoOpts, cogito.EnableAutoPlanReEvaluator)
}
if c.Agent.MaxIterations != 0 {
cogitoOpts = append(cogitoOpts, cogito.WithIterations(c.Agent.MaxIterations))
}
if c.Agent.MaxAttempts != 0 {
cogitoOpts = append(cogitoOpts, cogito.WithMaxAttempts(c.Agent.MaxAttempts))
}
return cogitoOpts
}

View File

@@ -3,6 +3,7 @@
package gallery
import (
"context"
"encoding/json"
"errors"
"fmt"
@@ -69,7 +70,7 @@ func writeBackendMetadata(backendPath string, metadata *BackendMetadata) error {
}
// InstallBackendFromGallery installs a backend from the gallery.
func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, name string, downloadStatus func(string, string, string, float64), force bool) error {
func InstallBackendFromGallery(ctx context.Context, galleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, name string, downloadStatus func(string, string, string, float64), force bool) error {
if !force {
// check if we already have the backend installed
backends, err := ListSystemBackends(systemState)
@@ -109,7 +110,7 @@ func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.S
log.Debug().Str("name", name).Str("bestBackend", bestBackend.Name).Msg("Installing backend from meta backend")
// Then, let's install the best backend
if err := InstallBackend(systemState, modelLoader, bestBackend, downloadStatus); err != nil {
if err := InstallBackend(ctx, systemState, modelLoader, bestBackend, downloadStatus); err != nil {
return err
}
@@ -134,10 +135,10 @@ func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.S
return nil
}
return InstallBackend(systemState, modelLoader, backend, downloadStatus)
return InstallBackend(ctx, systemState, modelLoader, backend, downloadStatus)
}
func InstallBackend(systemState *system.SystemState, modelLoader *model.ModelLoader, config *GalleryBackend, downloadStatus func(string, string, string, float64)) error {
func InstallBackend(ctx context.Context, systemState *system.SystemState, modelLoader *model.ModelLoader, config *GalleryBackend, downloadStatus func(string, string, string, float64)) error {
// Create base path if it doesn't exist
err := os.MkdirAll(systemState.Backend.BackendsPath, 0750)
if err != nil {
@@ -164,11 +165,17 @@ func InstallBackend(systemState *system.SystemState, modelLoader *model.ModelLoa
}
} else {
uri := downloader.URI(config.URI)
if err := uri.DownloadFile(backendPath, "", 1, 1, downloadStatus); err != nil {
if err := uri.DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err != nil {
success := false
// Try to download from mirrors
for _, mirror := range config.Mirrors {
if err := downloader.URI(mirror).DownloadFile(backendPath, "", 1, 1, downloadStatus); err == nil {
// Check for cancellation before trying next mirror
select {
case <-ctx.Done():
return ctx.Err()
default:
}
if err := downloader.URI(mirror).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil {
success = true
break
}

View File

@@ -1,6 +1,7 @@
package gallery
import (
"context"
"encoding/json"
"os"
"path/filepath"
@@ -55,7 +56,7 @@ var _ = Describe("Runtime capability-based backend selection", func() {
)
must(err)
sysDefault.GPUVendor = "" // force default selection
backs, err := ListSystemBackends(sysDefault)
backs, err := ListSystemBackends(sysDefault)
must(err)
aliasBack, ok := backs.Get("llama-cpp")
Expect(ok).To(BeTrue())
@@ -77,7 +78,7 @@ var _ = Describe("Runtime capability-based backend selection", func() {
must(err)
sysNvidia.GPUVendor = "nvidia"
sysNvidia.VRAM = 8 * 1024 * 1024 * 1024
backs, err = ListSystemBackends(sysNvidia)
backs, err = ListSystemBackends(sysNvidia)
must(err)
aliasBack, ok = backs.Get("llama-cpp")
Expect(ok).To(BeTrue())
@@ -116,13 +117,13 @@ var _ = Describe("Gallery Backends", func() {
Describe("InstallBackendFromGallery", func() {
It("should return error when backend is not found", func() {
err := InstallBackendFromGallery(galleries, systemState, ml, "non-existent", nil, true)
err := InstallBackendFromGallery(context.TODO(), galleries, systemState, ml, "non-existent", nil, true)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("no backend found with name \"non-existent\""))
})
It("should install backend from gallery", func() {
err := InstallBackendFromGallery(galleries, systemState, ml, "test-backend", nil, true)
err := InstallBackendFromGallery(context.TODO(), galleries, systemState, ml, "test-backend", nil, true)
Expect(err).ToNot(HaveOccurred())
Expect(filepath.Join(tempDir, "test-backend", "run.sh")).To(BeARegularFile())
})
@@ -298,7 +299,7 @@ var _ = Describe("Gallery Backends", func() {
VRAM: 1000000000000,
Backend: system.Backend{BackendsPath: tempDir},
}
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
Expect(err).NotTo(HaveOccurred())
metaBackendPath := filepath.Join(tempDir, "meta-backend")
@@ -378,7 +379,7 @@ var _ = Describe("Gallery Backends", func() {
VRAM: 1000000000000,
Backend: system.Backend{BackendsPath: tempDir},
}
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
Expect(err).NotTo(HaveOccurred())
metaBackendPath := filepath.Join(tempDir, "meta-backend")
@@ -462,7 +463,7 @@ var _ = Describe("Gallery Backends", func() {
VRAM: 1000000000000,
Backend: system.Backend{BackendsPath: tempDir},
}
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
Expect(err).NotTo(HaveOccurred())
metaBackendPath := filepath.Join(tempDir, "meta-backend")
@@ -561,7 +562,7 @@ var _ = Describe("Gallery Backends", func() {
system.WithBackendPath(newPath),
)
Expect(err).NotTo(HaveOccurred())
err = InstallBackend(systemState, ml, &backend, nil)
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
Expect(err).To(HaveOccurred()) // Will fail due to invalid URI, but path should be created
Expect(newPath).To(BeADirectory())
})
@@ -593,7 +594,7 @@ var _ = Describe("Gallery Backends", func() {
system.WithBackendPath(tempDir),
)
Expect(err).NotTo(HaveOccurred())
err = InstallBackend(systemState, ml, &backend, nil)
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
Expect(err).ToNot(HaveOccurred())
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
dat, err := os.ReadFile(filepath.Join(tempDir, "test-backend", "metadata.json"))
@@ -626,7 +627,7 @@ var _ = Describe("Gallery Backends", func() {
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).ToNot(BeARegularFile())
err = InstallBackend(systemState, ml, &backend, nil)
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
Expect(err).ToNot(HaveOccurred())
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
})
@@ -647,7 +648,7 @@ var _ = Describe("Gallery Backends", func() {
system.WithBackendPath(tempDir),
)
Expect(err).NotTo(HaveOccurred())
err = InstallBackend(systemState, ml, &backend, nil)
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
Expect(err).ToNot(HaveOccurred())
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())

View File

@@ -1,6 +1,7 @@
package gallery
import (
"context"
"fmt"
"os"
"path/filepath"
@@ -28,6 +29,19 @@ func GetGalleryConfigFromURL[T any](url string, basePath string) (T, error) {
return config, nil
}
func GetGalleryConfigFromURLWithContext[T any](ctx context.Context, url string, basePath string) (T, error) {
var config T
uri := downloader.URI(url)
err := uri.DownloadWithAuthorizationAndCallback(ctx, basePath, "", func(url string, d []byte) error {
return yaml.Unmarshal(d, &config)
})
if err != nil {
log.Error().Err(err).Str("url", url).Msg("failed to get gallery config for url")
return config, err
}
return config, nil
}
func ReadConfigFile[T any](filePath string) (*T, error) {
// Read the YAML file
yamlFile, err := os.ReadFile(filePath)

View File

@@ -10,9 +10,11 @@ import (
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
)
var DefaultImporters = []Importer{
var defaultImporters = []Importer{
&LlamaCPPImporter{},
&MLXImporter{},
&VLLMImporter{},
&TransformersImporter{},
}
type Details struct {
@@ -52,7 +54,7 @@ func DiscoverModelConfig(uri string, preferences json.RawMessage) (gallery.Model
Preferences: preferences,
}
for _, importer := range DefaultImporters {
for _, importer := range defaultImporters {
if importer.Match(details) {
modelConfig, err = importer.Import(details)
if err != nil {

View File

@@ -80,10 +80,13 @@ func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error)
mmprojQuantsList = strings.Split(mmprojQuants, ",")
}
embeddings, _ := preferencesMap["embeddings"].(string)
modelConfig := config.ModelConfig{
Name: name,
Description: description,
KnownUsecaseStrings: []string{"chat"},
Options: []string{"use_jinja:true"},
Backend: "llama-cpp",
TemplateConfig: config.TemplateConfig{
UseTokenizerTemplate: true,
@@ -95,6 +98,11 @@ func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error)
},
}
if embeddings != "" && strings.ToLower(embeddings) == "true" || strings.ToLower(embeddings) == "yes" {
trueV := true
modelConfig.Embeddings = &trueV
}
cfg := gallery.ModelConfig{
Name: name,
Description: description,

View File

@@ -5,20 +5,21 @@ import (
"fmt"
"github.com/mudler/LocalAI/core/gallery/importers"
. "github.com/mudler/LocalAI/core/gallery/importers"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("LlamaCPPImporter", func() {
var importer *importers.LlamaCPPImporter
var importer *LlamaCPPImporter
BeforeEach(func() {
importer = &importers.LlamaCPPImporter{}
importer = &LlamaCPPImporter{}
})
Context("Match", func() {
It("should match when URI ends with .gguf", func() {
details := importers.Details{
details := Details{
URI: "https://example.com/model.gguf",
}
@@ -28,7 +29,7 @@ var _ = Describe("LlamaCPPImporter", func() {
It("should match when backend preference is llama-cpp", func() {
preferences := json.RawMessage(`{"backend": "llama-cpp"}`)
details := importers.Details{
details := Details{
URI: "https://example.com/model",
Preferences: preferences,
}
@@ -38,7 +39,7 @@ var _ = Describe("LlamaCPPImporter", func() {
})
It("should not match when URI does not end with .gguf and no backend preference", func() {
details := importers.Details{
details := Details{
URI: "https://example.com/model.bin",
}
@@ -48,7 +49,7 @@ var _ = Describe("LlamaCPPImporter", func() {
It("should not match when backend preference is different", func() {
preferences := json.RawMessage(`{"backend": "mlx"}`)
details := importers.Details{
details := Details{
URI: "https://example.com/model",
Preferences: preferences,
}
@@ -59,7 +60,7 @@ var _ = Describe("LlamaCPPImporter", func() {
It("should return false when JSON preferences are invalid", func() {
preferences := json.RawMessage(`invalid json`)
details := importers.Details{
details := Details{
URI: "https://example.com/model.gguf",
Preferences: preferences,
}
@@ -72,7 +73,7 @@ var _ = Describe("LlamaCPPImporter", func() {
Context("Import", func() {
It("should import model config with default name and description", func() {
details := importers.Details{
details := Details{
URI: "https://example.com/my-model.gguf",
}
@@ -89,7 +90,7 @@ var _ = Describe("LlamaCPPImporter", func() {
It("should import model config with custom name and description from preferences", func() {
preferences := json.RawMessage(`{"name": "custom-model", "description": "Custom description"}`)
details := importers.Details{
details := Details{
URI: "https://example.com/my-model.gguf",
Preferences: preferences,
}
@@ -106,7 +107,7 @@ var _ = Describe("LlamaCPPImporter", func() {
It("should handle invalid JSON preferences", func() {
preferences := json.RawMessage(`invalid json`)
details := importers.Details{
details := Details{
URI: "https://example.com/my-model.gguf",
Preferences: preferences,
}

View File

@@ -0,0 +1,110 @@
package importers
import (
"encoding/json"
"path/filepath"
"strings"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/schema"
"go.yaml.in/yaml/v2"
)
var _ Importer = &TransformersImporter{}
type TransformersImporter struct{}
func (i *TransformersImporter) Match(details Details) bool {
preferences, err := details.Preferences.MarshalJSON()
if err != nil {
return false
}
preferencesMap := make(map[string]any)
err = json.Unmarshal(preferences, &preferencesMap)
if err != nil {
return false
}
b, ok := preferencesMap["backend"].(string)
if ok && b == "transformers" {
return true
}
if details.HuggingFace != nil {
for _, file := range details.HuggingFace.Files {
if strings.Contains(file.Path, "tokenizer.json") ||
strings.Contains(file.Path, "tokenizer_config.json") {
return true
}
}
}
return false
}
func (i *TransformersImporter) Import(details Details) (gallery.ModelConfig, error) {
preferences, err := details.Preferences.MarshalJSON()
if err != nil {
return gallery.ModelConfig{}, err
}
preferencesMap := make(map[string]any)
err = json.Unmarshal(preferences, &preferencesMap)
if err != nil {
return gallery.ModelConfig{}, err
}
name, ok := preferencesMap["name"].(string)
if !ok {
name = filepath.Base(details.URI)
}
description, ok := preferencesMap["description"].(string)
if !ok {
description = "Imported from " + details.URI
}
backend := "transformers"
b, ok := preferencesMap["backend"].(string)
if ok {
backend = b
}
modelType, ok := preferencesMap["type"].(string)
if !ok {
modelType = "AutoModelForCausalLM"
}
quantization, ok := preferencesMap["quantization"].(string)
if !ok {
quantization = ""
}
modelConfig := config.ModelConfig{
Name: name,
Description: description,
KnownUsecaseStrings: []string{"chat"},
Backend: backend,
PredictionOptions: schema.PredictionOptions{
BasicModelRequest: schema.BasicModelRequest{
Model: details.URI,
},
},
TemplateConfig: config.TemplateConfig{
UseTokenizerTemplate: true,
},
}
modelConfig.ModelType = modelType
modelConfig.Quantization = quantization
data, err := yaml.Marshal(modelConfig)
if err != nil {
return gallery.ModelConfig{}, err
}
return gallery.ModelConfig{
Name: name,
Description: description,
ConfigFile: string(data),
}, nil
}

View File

@@ -0,0 +1,219 @@
package importers_test
import (
"encoding/json"
"github.com/mudler/LocalAI/core/gallery/importers"
. "github.com/mudler/LocalAI/core/gallery/importers"
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("TransformersImporter", func() {
var importer *TransformersImporter
BeforeEach(func() {
importer = &TransformersImporter{}
})
Context("Match", func() {
It("should match when backend preference is transformers", func() {
preferences := json.RawMessage(`{"backend": "transformers"}`)
details := Details{
URI: "https://example.com/model",
Preferences: preferences,
}
result := importer.Match(details)
Expect(result).To(BeTrue())
})
It("should match when HuggingFace details contain tokenizer.json", func() {
hfDetails := &hfapi.ModelDetails{
Files: []hfapi.ModelFile{
{Path: "tokenizer.json"},
},
}
details := Details{
URI: "https://huggingface.co/test/model",
HuggingFace: hfDetails,
}
result := importer.Match(details)
Expect(result).To(BeTrue())
})
It("should match when HuggingFace details contain tokenizer_config.json", func() {
hfDetails := &hfapi.ModelDetails{
Files: []hfapi.ModelFile{
{Path: "tokenizer_config.json"},
},
}
details := Details{
URI: "https://huggingface.co/test/model",
HuggingFace: hfDetails,
}
result := importer.Match(details)
Expect(result).To(BeTrue())
})
It("should not match when URI has no tokenizer files and no backend preference", func() {
details := Details{
URI: "https://example.com/model.bin",
}
result := importer.Match(details)
Expect(result).To(BeFalse())
})
It("should not match when backend preference is different", func() {
preferences := json.RawMessage(`{"backend": "llama-cpp"}`)
details := Details{
URI: "https://example.com/model",
Preferences: preferences,
}
result := importer.Match(details)
Expect(result).To(BeFalse())
})
It("should return false when JSON preferences are invalid", func() {
preferences := json.RawMessage(`invalid json`)
details := Details{
URI: "https://example.com/model",
Preferences: preferences,
}
result := importer.Match(details)
Expect(result).To(BeFalse())
})
})
Context("Import", func() {
It("should import model config with default name and description", func() {
details := Details{
URI: "https://huggingface.co/test/my-model",
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("my-model"))
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/test/my-model"))
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: transformers"))
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: https://huggingface.co/test/my-model"))
Expect(modelConfig.ConfigFile).To(ContainSubstring("type: AutoModelForCausalLM"))
})
It("should import model config with custom name and description from preferences", func() {
preferences := json.RawMessage(`{"name": "custom-model", "description": "Custom description"}`)
details := Details{
URI: "https://huggingface.co/test/my-model",
Preferences: preferences,
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("custom-model"))
Expect(modelConfig.Description).To(Equal("Custom description"))
})
It("should use custom model type from preferences", func() {
preferences := json.RawMessage(`{"type": "SentenceTransformer"}`)
details := Details{
URI: "https://huggingface.co/test/my-model",
Preferences: preferences,
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.ConfigFile).To(ContainSubstring("type: SentenceTransformer"))
})
It("should use default model type when not specified", func() {
details := Details{
URI: "https://huggingface.co/test/my-model",
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.ConfigFile).To(ContainSubstring("type: AutoModelForCausalLM"))
})
It("should use custom backend from preferences", func() {
preferences := json.RawMessage(`{"backend": "transformers"}`)
details := Details{
URI: "https://huggingface.co/test/my-model",
Preferences: preferences,
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: transformers"))
})
It("should use quantization from preferences", func() {
preferences := json.RawMessage(`{"quantization": "int8"}`)
details := Details{
URI: "https://huggingface.co/test/my-model",
Preferences: preferences,
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.ConfigFile).To(ContainSubstring("quantization: int8"))
})
It("should handle invalid JSON preferences", func() {
preferences := json.RawMessage(`invalid json`)
details := Details{
URI: "https://huggingface.co/test/my-model",
Preferences: preferences,
}
_, err := importer.Import(details)
Expect(err).To(HaveOccurred())
})
It("should extract filename correctly from URI with path", func() {
details := importers.Details{
URI: "https://huggingface.co/test/path/to/model",
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("model"))
})
It("should include use_tokenizer_template in config", func() {
details := Details{
URI: "https://huggingface.co/test/my-model",
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.ConfigFile).To(ContainSubstring("use_tokenizer_template: true"))
})
It("should include known_usecases in config", func() {
details := Details{
URI: "https://huggingface.co/test/my-model",
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.ConfigFile).To(ContainSubstring("known_usecases:"))
Expect(modelConfig.ConfigFile).To(ContainSubstring("- chat"))
})
})
})

View File

@@ -0,0 +1,98 @@
package importers
import (
"encoding/json"
"path/filepath"
"strings"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/schema"
"go.yaml.in/yaml/v2"
)
var _ Importer = &VLLMImporter{}
type VLLMImporter struct{}
func (i *VLLMImporter) Match(details Details) bool {
preferences, err := details.Preferences.MarshalJSON()
if err != nil {
return false
}
preferencesMap := make(map[string]any)
err = json.Unmarshal(preferences, &preferencesMap)
if err != nil {
return false
}
b, ok := preferencesMap["backend"].(string)
if ok && b == "vllm" {
return true
}
if details.HuggingFace != nil {
for _, file := range details.HuggingFace.Files {
if strings.Contains(file.Path, "tokenizer.json") ||
strings.Contains(file.Path, "tokenizer_config.json") {
return true
}
}
}
return false
}
func (i *VLLMImporter) Import(details Details) (gallery.ModelConfig, error) {
preferences, err := details.Preferences.MarshalJSON()
if err != nil {
return gallery.ModelConfig{}, err
}
preferencesMap := make(map[string]any)
err = json.Unmarshal(preferences, &preferencesMap)
if err != nil {
return gallery.ModelConfig{}, err
}
name, ok := preferencesMap["name"].(string)
if !ok {
name = filepath.Base(details.URI)
}
description, ok := preferencesMap["description"].(string)
if !ok {
description = "Imported from " + details.URI
}
backend := "vllm"
b, ok := preferencesMap["backend"].(string)
if ok {
backend = b
}
modelConfig := config.ModelConfig{
Name: name,
Description: description,
KnownUsecaseStrings: []string{"chat"},
Backend: backend,
PredictionOptions: schema.PredictionOptions{
BasicModelRequest: schema.BasicModelRequest{
Model: details.URI,
},
},
TemplateConfig: config.TemplateConfig{
UseTokenizerTemplate: true,
},
}
data, err := yaml.Marshal(modelConfig)
if err != nil {
return gallery.ModelConfig{}, err
}
return gallery.ModelConfig{
Name: name,
Description: description,
ConfigFile: string(data),
}, nil
}

View File

@@ -0,0 +1,181 @@
package importers_test
import (
"encoding/json"
"github.com/mudler/LocalAI/core/gallery/importers"
. "github.com/mudler/LocalAI/core/gallery/importers"
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("VLLMImporter", func() {
var importer *VLLMImporter
BeforeEach(func() {
importer = &VLLMImporter{}
})
Context("Match", func() {
It("should match when backend preference is vllm", func() {
preferences := json.RawMessage(`{"backend": "vllm"}`)
details := Details{
URI: "https://example.com/model",
Preferences: preferences,
}
result := importer.Match(details)
Expect(result).To(BeTrue())
})
It("should match when HuggingFace details contain tokenizer.json", func() {
hfDetails := &hfapi.ModelDetails{
Files: []hfapi.ModelFile{
{Path: "tokenizer.json"},
},
}
details := Details{
URI: "https://huggingface.co/test/model",
HuggingFace: hfDetails,
}
result := importer.Match(details)
Expect(result).To(BeTrue())
})
It("should match when HuggingFace details contain tokenizer_config.json", func() {
hfDetails := &hfapi.ModelDetails{
Files: []hfapi.ModelFile{
{Path: "tokenizer_config.json"},
},
}
details := Details{
URI: "https://huggingface.co/test/model",
HuggingFace: hfDetails,
}
result := importer.Match(details)
Expect(result).To(BeTrue())
})
It("should not match when URI has no tokenizer files and no backend preference", func() {
details := Details{
URI: "https://example.com/model.bin",
}
result := importer.Match(details)
Expect(result).To(BeFalse())
})
It("should not match when backend preference is different", func() {
preferences := json.RawMessage(`{"backend": "llama-cpp"}`)
details := Details{
URI: "https://example.com/model",
Preferences: preferences,
}
result := importer.Match(details)
Expect(result).To(BeFalse())
})
It("should return false when JSON preferences are invalid", func() {
preferences := json.RawMessage(`invalid json`)
details := Details{
URI: "https://example.com/model",
Preferences: preferences,
}
result := importer.Match(details)
Expect(result).To(BeFalse())
})
})
Context("Import", func() {
It("should import model config with default name and description", func() {
details := Details{
URI: "https://huggingface.co/test/my-model",
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("my-model"))
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/test/my-model"))
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: vllm"))
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: https://huggingface.co/test/my-model"))
})
It("should import model config with custom name and description from preferences", func() {
preferences := json.RawMessage(`{"name": "custom-model", "description": "Custom description"}`)
details := Details{
URI: "https://huggingface.co/test/my-model",
Preferences: preferences,
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("custom-model"))
Expect(modelConfig.Description).To(Equal("Custom description"))
})
It("should use custom backend from preferences", func() {
preferences := json.RawMessage(`{"backend": "vllm"}`)
details := Details{
URI: "https://huggingface.co/test/my-model",
Preferences: preferences,
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: vllm"))
})
It("should handle invalid JSON preferences", func() {
preferences := json.RawMessage(`invalid json`)
details := Details{
URI: "https://huggingface.co/test/my-model",
Preferences: preferences,
}
_, err := importer.Import(details)
Expect(err).To(HaveOccurred())
})
It("should extract filename correctly from URI with path", func() {
details := importers.Details{
URI: "https://huggingface.co/test/path/to/model",
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("model"))
})
It("should include use_tokenizer_template in config", func() {
details := Details{
URI: "https://huggingface.co/test/my-model",
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.ConfigFile).To(ContainSubstring("use_tokenizer_template: true"))
})
It("should include known_usecases in config", func() {
details := Details{
URI: "https://huggingface.co/test/my-model",
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.ConfigFile).To(ContainSubstring("known_usecases:"))
Expect(modelConfig.ConfigFile).To(ContainSubstring("- chat"))
})
})
})

View File

@@ -1,6 +1,7 @@
package gallery
import (
"context"
"errors"
"fmt"
"os"
@@ -72,6 +73,7 @@ type PromptTemplate struct {
// Installs a model from the gallery
func InstallModelFromGallery(
ctx context.Context,
modelGalleries, backendGalleries []config.Gallery,
systemState *system.SystemState,
modelLoader *model.ModelLoader,
@@ -84,7 +86,7 @@ func InstallModelFromGallery(
if len(model.URL) > 0 {
var err error
config, err = GetGalleryConfigFromURL[ModelConfig](model.URL, systemState.Model.ModelsPath)
config, err = GetGalleryConfigFromURLWithContext[ModelConfig](ctx, model.URL, systemState.Model.ModelsPath)
if err != nil {
return err
}
@@ -125,7 +127,7 @@ func InstallModelFromGallery(
return err
}
installedModel, err := InstallModel(systemState, installName, &config, model.Overrides, downloadStatus, enforceScan)
installedModel, err := InstallModel(ctx, systemState, installName, &config, model.Overrides, downloadStatus, enforceScan)
if err != nil {
return err
}
@@ -133,7 +135,7 @@ func InstallModelFromGallery(
if automaticallyInstallBackend && installedModel.Backend != "" {
log.Debug().Msgf("Installing backend %q", installedModel.Backend)
if err := InstallBackendFromGallery(backendGalleries, systemState, modelLoader, installedModel.Backend, downloadStatus, false); err != nil {
if err := InstallBackendFromGallery(ctx, backendGalleries, systemState, modelLoader, installedModel.Backend, downloadStatus, false); err != nil {
return err
}
}
@@ -154,7 +156,7 @@ func InstallModelFromGallery(
return applyModel(model)
}
func InstallModel(systemState *system.SystemState, nameOverride string, config *ModelConfig, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.ModelConfig, error) {
func InstallModel(ctx context.Context, systemState *system.SystemState, nameOverride string, config *ModelConfig, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.ModelConfig, error) {
basePath := systemState.Model.ModelsPath
// Create base path if it doesn't exist
err := os.MkdirAll(basePath, 0750)
@@ -168,6 +170,13 @@ func InstallModel(systemState *system.SystemState, nameOverride string, config *
// Download files and verify their SHA
for i, file := range config.Files {
// Check for cancellation before each file
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename)
if err := utils.VerifyPath(file.Filename, basePath); err != nil {
@@ -185,7 +194,7 @@ func InstallModel(systemState *system.SystemState, nameOverride string, config *
}
}
uri := downloader.URI(file.URI)
if err := uri.DownloadFile(filePath, file.SHA256, i, len(config.Files), downloadStatus); err != nil {
if err := uri.DownloadFileWithContext(ctx, filePath, file.SHA256, i, len(config.Files), downloadStatus); err != nil {
return nil, err
}
}

View File

@@ -1,6 +1,7 @@
package gallery_test
import (
"context"
"errors"
"os"
"path/filepath"
@@ -34,7 +35,7 @@ var _ = Describe("Model test", func() {
system.WithModelPath(tempdir),
)
Expect(err).ToNot(HaveOccurred())
_, err = InstallModel(systemState, "", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
_, err = InstallModel(context.TODO(), systemState, "", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
Expect(err).ToNot(HaveOccurred())
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} {
@@ -88,7 +89,7 @@ var _ = Describe("Model test", func() {
Expect(models[0].URL).To(Equal(bertEmbeddingsURL))
Expect(models[0].Installed).To(BeFalse())
err = InstallModelFromGallery(galleries, []config.Gallery{}, systemState, nil, "test@bert", GalleryModel{}, func(s1, s2, s3 string, f float64) {}, true, true)
err = InstallModelFromGallery(context.TODO(), galleries, []config.Gallery{}, systemState, nil, "test@bert", GalleryModel{}, func(s1, s2, s3 string, f float64) {}, true, true)
Expect(err).ToNot(HaveOccurred())
dat, err := os.ReadFile(filepath.Join(tempdir, "bert.yaml"))
@@ -129,7 +130,7 @@ var _ = Describe("Model test", func() {
system.WithModelPath(tempdir),
)
Expect(err).ToNot(HaveOccurred())
_, err = InstallModel(systemState, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
_, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
Expect(err).ToNot(HaveOccurred())
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
@@ -149,7 +150,7 @@ var _ = Describe("Model test", func() {
system.WithModelPath(tempdir),
)
Expect(err).ToNot(HaveOccurred())
_, err = InstallModel(systemState, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}, true)
_, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}, true)
Expect(err).ToNot(HaveOccurred())
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
@@ -179,7 +180,7 @@ var _ = Describe("Model test", func() {
system.WithModelPath(tempdir),
)
Expect(err).ToNot(HaveOccurred())
_, err = InstallModel(systemState, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
_, err = InstallModel(context.TODO(), systemState, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
Expect(err).To(HaveOccurred())
})
})

View File

@@ -4,30 +4,23 @@ import (
"embed"
"errors"
"fmt"
"io/fs"
"net/http"
"os"
"path/filepath"
"strings"
"github.com/dave-gray101/v2keyauth"
"github.com/gofiber/websocket/v2"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"github.com/mudler/LocalAI/core/http/endpoints/localai"
"github.com/mudler/LocalAI/core/http/middleware"
httpMiddleware "github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/http/routes"
"github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/gofiber/contrib/fiberzerolog"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors"
"github.com/gofiber/fiber/v2/middleware/csrf"
"github.com/gofiber/fiber/v2/middleware/favicon"
"github.com/gofiber/fiber/v2/middleware/filesystem"
"github.com/gofiber/fiber/v2/middleware/recover"
// swagger handler
"github.com/rs/zerolog/log"
)
@@ -49,85 +42,85 @@ var embedDirStatic embed.FS
// @in header
// @name Authorization
func API(application *application.Application) (*fiber.App, error) {
func API(application *application.Application) (*echo.Echo, error) {
e := echo.New()
fiberCfg := fiber.Config{
Views: renderEngine(),
BodyLimit: application.ApplicationConfig().UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
// We disable the Fiber startup message as it does not conform to structured logging.
// We register a startup log line with connection information in the OnListen hook to keep things user friendly though
DisableStartupMessage: true,
// Override default error handler
// Set body limit
if application.ApplicationConfig().UploadLimitMB > 0 {
e.Use(middleware.BodyLimit(fmt.Sprintf("%dM", application.ApplicationConfig().UploadLimitMB)))
}
// Set error handler
if !application.ApplicationConfig().OpaqueErrors {
// Normally, return errors as JSON responses
fiberCfg.ErrorHandler = func(ctx *fiber.Ctx, err error) error {
// Status code defaults to 500
code := fiber.StatusInternalServerError
e.HTTPErrorHandler = func(err error, c echo.Context) {
code := http.StatusInternalServerError
var he *echo.HTTPError
if errors.As(err, &he) {
code = he.Code
}
// Retrieve the custom status code if it's a *fiber.Error
var e *fiber.Error
if errors.As(err, &e) {
code = e.Code
// Handle 404 errors with HTML rendering when appropriate
if code == http.StatusNotFound {
notFoundHandler(c)
return
}
// Send custom error page
return ctx.Status(code).JSON(
schema.ErrorResponse{
Error: &schema.APIError{Message: err.Error(), Code: code},
},
)
c.JSON(code, schema.ErrorResponse{
Error: &schema.APIError{Message: err.Error(), Code: code},
})
}
} else {
// If OpaqueErrors are required, replace everything with a blank 500.
fiberCfg.ErrorHandler = func(ctx *fiber.Ctx, _ error) error {
return ctx.Status(500).SendString("")
e.HTTPErrorHandler = func(err error, c echo.Context) {
code := http.StatusInternalServerError
var he *echo.HTTPError
if errors.As(err, &he) {
code = he.Code
}
c.NoContent(code)
}
}
router := fiber.New(fiberCfg)
// Set renderer
e.Renderer = renderEngine()
router.Use(middleware.StripPathPrefix())
// Hide banner
e.HideBanner = true
// Middleware - StripPathPrefix must be registered early as it uses Rewrite which runs before routing
e.Pre(httpMiddleware.StripPathPrefix())
if application.ApplicationConfig().MachineTag != "" {
router.Use(func(c *fiber.Ctx) error {
c.Response().Header.Set("Machine-Tag", application.ApplicationConfig().MachineTag)
return c.Next()
e.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
c.Response().Header().Set("Machine-Tag", application.ApplicationConfig().MachineTag)
return next(c)
}
})
}
router.Use("/v1/realtime", func(c *fiber.Ctx) error {
if websocket.IsWebSocketUpgrade(c) {
// Returns true if the client requested upgrade to the WebSocket protocol
return c.Next()
// Custom logger middleware using zerolog
e.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
req := c.Request()
res := c.Response()
start := log.Logger.Info()
err := next(c)
start.
Str("method", req.Method).
Str("path", req.URL.Path).
Int("status", res.Status).
Msg("HTTP request")
return err
}
return nil
})
router.Hooks().OnListen(func(listenData fiber.ListenData) error {
scheme := "http"
if listenData.TLS {
scheme = "https"
}
log.Info().Str("endpoint", scheme+"://"+listenData.Host+":"+listenData.Port).Msg("LocalAI API is listening! Please connect to the endpoint for API documentation.")
return nil
})
// Have Fiber use zerolog like the rest of the application rather than it's built-in logger
logger := log.Logger
router.Use(fiberzerolog.New(fiberzerolog.Config{
Logger: &logger,
}))
// Default middleware config
// Recover middleware
if !application.ApplicationConfig().Debug {
router.Use(recover.New())
e.Use(middleware.Recover())
}
// Metrics middleware
if !application.ApplicationConfig().DisableMetrics {
metricsService, err := services.NewLocalAIMetricsService()
if err != nil {
@@ -135,34 +128,40 @@ func API(application *application.Application) (*fiber.App, error) {
}
if metricsService != nil {
router.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
router.Hooks().OnShutdown(func() error {
return metricsService.Shutdown()
e.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
e.Server.RegisterOnShutdown(func() {
metricsService.Shutdown()
})
}
}
// Health Checks should always be exempt from auth, so register these first
routes.HealthRoutes(router)
kaConfig, err := middleware.GetKeyAuthConfig(application.ApplicationConfig())
if err != nil || kaConfig == nil {
// Health Checks should always be exempt from auth, so register these first
routes.HealthRoutes(e)
// Get key auth middleware
keyAuthMiddleware, err := httpMiddleware.GetKeyAuthConfig(application.ApplicationConfig())
if err != nil {
return nil, fmt.Errorf("failed to create key auth config: %w", err)
}
httpFS := http.FS(embedDirStatic)
// Favicon handler
e.GET("/favicon.svg", func(c echo.Context) error {
data, err := embedDirStatic.ReadFile("static/favicon.svg")
if err != nil {
return c.NoContent(http.StatusNotFound)
}
c.Response().Header().Set("Content-Type", "image/svg+xml")
return c.Blob(http.StatusOK, "image/svg+xml", data)
})
router.Use(favicon.New(favicon.Config{
URL: "/favicon.svg",
FileSystem: httpFS,
File: "static/favicon.svg",
}))
router.Use("/static", filesystem.New(filesystem.Config{
Root: httpFS,
PathPrefix: "static",
Browse: true,
}))
// Static files - use fs.Sub to create a filesystem rooted at "static"
staticFS, err := fs.Sub(embedDirStatic, "static")
if err != nil {
return nil, fmt.Errorf("failed to create static filesystem: %w", err)
}
e.StaticFS("/static", staticFS)
// Generated content directories
if application.ApplicationConfig().GeneratedContentDir != "" {
os.MkdirAll(application.ApplicationConfig().GeneratedContentDir, 0750)
audioPath := filepath.Join(application.ApplicationConfig().GeneratedContentDir, "audio")
@@ -173,51 +172,53 @@ func API(application *application.Application) (*fiber.App, error) {
os.MkdirAll(imagePath, 0750)
os.MkdirAll(videoPath, 0750)
router.Static("/generated-audio", audioPath)
router.Static("/generated-images", imagePath)
router.Static("/generated-videos", videoPath)
e.Static("/generated-audio", audioPath)
e.Static("/generated-images", imagePath)
e.Static("/generated-videos", videoPath)
}
// Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Filter property of the KeyAuth Configuration
router.Use(v2keyauth.New(*kaConfig))
// Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Skipper property of the KeyAuth Configuration
e.Use(keyAuthMiddleware)
// CORS middleware
if application.ApplicationConfig().CORS {
var c func(ctx *fiber.Ctx) error
if application.ApplicationConfig().CORSAllowOrigins == "" {
c = cors.New()
} else {
c = cors.New(cors.Config{AllowOrigins: application.ApplicationConfig().CORSAllowOrigins})
corsConfig := middleware.CORSConfig{}
if application.ApplicationConfig().CORSAllowOrigins != "" {
corsConfig.AllowOrigins = strings.Split(application.ApplicationConfig().CORSAllowOrigins, ",")
}
router.Use(c)
e.Use(middleware.CORSWithConfig(corsConfig))
}
// CSRF middleware
if application.ApplicationConfig().CSRF {
log.Debug().Msg("Enabling CSRF middleware. Tokens are now required for state-modifying requests")
router.Use(csrf.New())
e.Use(middleware.CSRF())
}
requestExtractor := middleware.NewRequestExtractor(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
requestExtractor := httpMiddleware.NewRequestExtractor(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
routes.RegisterElevenLabsRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
routes.RegisterElevenLabsRoutes(router, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
// Create opcache for tracking UI operations (used by both UI and LocalAI routes)
var opcache *services.OpCache
if !application.ApplicationConfig().DisableWebUI {
opcache = services.NewOpCache(application.GalleryService())
}
routes.RegisterLocalAIRoutes(router, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache)
routes.RegisterOpenAIRoutes(router, requestExtractor, application)
routes.RegisterLocalAIRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application.TemplatesEvaluator())
routes.RegisterOpenAIRoutes(e, requestExtractor, application)
if !application.ApplicationConfig().DisableWebUI {
routes.RegisterUIAPIRoutes(router, application.ModelConfigLoader(), application.ApplicationConfig(), application.GalleryService(), opcache)
routes.RegisterUIRoutes(router, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService())
routes.RegisterUIAPIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache)
routes.RegisterUIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService())
}
routes.RegisterJINARoutes(router, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
routes.RegisterJINARoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
// Define a custom 404 handler
// Note: keep this at the bottom!
router.Use(notFoundHandler)
// Note: 404 handling is done via HTTPErrorHandler above, no need for catch-all route
return router, nil
// Log startup message
e.Server.RegisterOnShutdown(func() {
log.Info().Msg("LocalAI API server shutting down")
})
return e, nil
}

View File

@@ -10,13 +10,14 @@ import (
"os"
"path/filepath"
"runtime"
"time"
"github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/config"
. "github.com/mudler/LocalAI/core/http"
"github.com/mudler/LocalAI/core/schema"
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/pkg/downloader"
"github.com/mudler/LocalAI/pkg/system"
@@ -25,6 +26,7 @@ import (
"gopkg.in/yaml.v3"
openaigo "github.com/otiai10/openaigo"
"github.com/rs/zerolog/log"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/jsonschema"
)
@@ -85,7 +87,7 @@ func getModels(url string) ([]gallery.GalleryModel, error) {
response := []gallery.GalleryModel{}
uri := downloader.URI(url)
// TODO: No tests currently seem to exercise file:// urls. Fix?
err := uri.DownloadWithAuthorizationAndCallback("", bearerKey, func(url string, i []byte) error {
err := uri.DownloadWithAuthorizationAndCallback(context.TODO(), "", bearerKey, func(url string, i []byte) error {
// Unmarshal YAML data into a struct
return json.Unmarshal(i, &response)
})
@@ -266,7 +268,7 @@ const bertEmbeddingsURL = `https://gist.githubusercontent.com/mudler/0a080b166b8
var _ = Describe("API test", func() {
var app *fiber.App
var app *echo.Echo
var client *openai.Client
var client2 *openaigo.Client
var c context.Context
@@ -339,7 +341,11 @@ var _ = Describe("API test", func() {
app, err = API(application)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090")
go func() {
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
log.Error().Err(err).Msg("server error")
}
}()
defaultConfig := openai.DefaultConfig(apiKey)
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
@@ -358,7 +364,9 @@ var _ = Describe("API test", func() {
AfterEach(func(sc SpecContext) {
cancel()
if app != nil {
err := app.Shutdown()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := app.Shutdown(ctx)
Expect(err).ToNot(HaveOccurred())
}
err := os.RemoveAll(tmpdir)
@@ -547,7 +555,11 @@ var _ = Describe("API test", func() {
app, err = API(application)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090")
go func() {
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
log.Error().Err(err).Msg("server error")
}
}()
defaultConfig := openai.DefaultConfig("")
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
@@ -566,7 +578,9 @@ var _ = Describe("API test", func() {
AfterEach(func() {
cancel()
if app != nil {
err := app.Shutdown()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := app.Shutdown(ctx)
Expect(err).ToNot(HaveOccurred())
}
err := os.RemoveAll(tmpdir)
@@ -755,7 +769,11 @@ var _ = Describe("API test", func() {
Expect(err).ToNot(HaveOccurred())
app, err = API(application)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090")
go func() {
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
log.Error().Err(err).Msg("server error")
}
}()
defaultConfig := openai.DefaultConfig("")
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
@@ -773,7 +791,9 @@ var _ = Describe("API test", func() {
AfterEach(func() {
cancel()
if app != nil {
err := app.Shutdown()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := app.Shutdown(ctx)
Expect(err).ToNot(HaveOccurred())
}
})
@@ -796,6 +816,83 @@ var _ = Describe("API test", func() {
Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty())
})
It("returns logprobs in chat completions when requested", func() {
topLogprobsVal := 3
response, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{
Model: "testmodel.ggml",
LogProbs: true,
TopLogProbs: topLogprobsVal,
Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}})
Expect(err).ToNot(HaveOccurred())
Expect(len(response.Choices)).To(Equal(1))
Expect(response.Choices[0].Message).ToNot(BeNil())
Expect(response.Choices[0].Message.Content).ToNot(BeEmpty())
// Verify logprobs are present and have correct structure
Expect(response.Choices[0].LogProbs).ToNot(BeNil())
Expect(response.Choices[0].LogProbs.Content).ToNot(BeEmpty())
Expect(len(response.Choices[0].LogProbs.Content)).To(BeNumerically(">", 1))
foundatLeastToken := ""
foundAtLeastBytes := []byte{}
foundAtLeastTopLogprobBytes := []byte{}
foundatLeastTopLogprob := ""
// Verify logprobs content structure matches OpenAI format
for _, logprobContent := range response.Choices[0].LogProbs.Content {
// Bytes can be empty for certain tokens (special tokens, etc.), so we don't require it
if len(logprobContent.Bytes) > 0 {
foundAtLeastBytes = logprobContent.Bytes
}
if len(logprobContent.Token) > 0 {
foundatLeastToken = logprobContent.Token
}
Expect(logprobContent.LogProb).To(BeNumerically("<=", 0)) // Logprobs are always <= 0
Expect(len(logprobContent.TopLogProbs)).To(BeNumerically(">", 1))
// If top_logprobs is requested, verify top_logprobs array respects the limit
if len(logprobContent.TopLogProbs) > 0 {
// Should respect top_logprobs limit (3 in this test)
Expect(len(logprobContent.TopLogProbs)).To(BeNumerically("<=", topLogprobsVal))
for _, topLogprob := range logprobContent.TopLogProbs {
if len(topLogprob.Bytes) > 0 {
foundAtLeastTopLogprobBytes = topLogprob.Bytes
}
if len(topLogprob.Token) > 0 {
foundatLeastTopLogprob = topLogprob.Token
}
Expect(topLogprob.LogProb).To(BeNumerically("<=", 0))
}
}
}
Expect(foundAtLeastBytes).ToNot(BeEmpty())
Expect(foundAtLeastTopLogprobBytes).ToNot(BeEmpty())
Expect(foundatLeastToken).ToNot(BeEmpty())
Expect(foundatLeastTopLogprob).ToNot(BeEmpty())
})
It("applies logit_bias to chat completions when requested", func() {
// logit_bias is a map of token IDs (as strings) to bias values (-100 to 100)
// According to OpenAI API: modifies the likelihood of specified tokens appearing in the completion
logitBias := map[string]int{
"15043": 1, // Bias token ID 15043 (example token ID) with bias value 1
}
response, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{
Model: "testmodel.ggml",
Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}},
LogitBias: logitBias,
})
Expect(err).ToNot(HaveOccurred())
Expect(len(response.Choices)).To(Equal(1))
Expect(response.Choices[0].Message).ToNot(BeNil())
Expect(response.Choices[0].Message.Content).ToNot(BeEmpty())
// If logit_bias is applied, the response should be generated successfully
// We can't easily verify the bias effect without knowing the actual token IDs for the model,
// but the fact that the request succeeds confirms the API accepts and processes logit_bias
})
It("returns errors", func() {
_, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: testPrompt})
Expect(err).To(HaveOccurred())
@@ -1006,7 +1103,11 @@ var _ = Describe("API test", func() {
app, err = API(application)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090")
go func() {
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
log.Error().Err(err).Msg("server error")
}
}()
defaultConfig := openai.DefaultConfig("")
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
@@ -1022,7 +1123,9 @@ var _ = Describe("API test", func() {
AfterEach(func() {
cancel()
if app != nil {
err := app.Shutdown()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := app.Shutdown(ctx)
Expect(err).ToNot(HaveOccurred())
}
})

View File

@@ -1,7 +1,9 @@
package elevenlabs
import (
"github.com/gofiber/fiber/v2"
"path/filepath"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
@@ -15,17 +17,17 @@ import (
// @Param request body schema.ElevenLabsSoundGenerationRequest true "query params"
// @Success 200 {string} binary "Response"
// @Router /v1/sound-generation [post]
func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsSoundGenerationRequest)
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsSoundGenerationRequest)
if !ok || input.ModelID == "" {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
log.Debug().Str("modelFile", "modelFile").Str("backend", cfg.Backend).Msg("Sound Generation Request about to be sent to backend")
@@ -35,7 +37,7 @@ func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader
if err != nil {
return err
}
return c.Download(filePath)
return c.Attachment(filePath, filepath.Base(filePath))
}
}

View File

@@ -1,13 +1,14 @@
package elevenlabs
import (
"path/filepath"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
)
@@ -17,19 +18,19 @@ import (
// @Param request body schema.TTSRequest true "query params"
// @Success 200 {string} binary "Response"
// @Router /v1/text-to-speech/{voice-id} [post]
func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
voiceID := c.Params("voice-id")
voiceID := c.Param("voice-id")
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsTTSRequest)
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsTTSRequest)
if !ok || input.ModelID == "" {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
log.Debug().Str("modelName", input.ModelID).Msg("elevenlabs TTS request received")
@@ -38,6 +39,6 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
if err != nil {
return err
}
return c.Download(filePath)
return c.Attachment(filePath, filepath.Base(filePath))
}
}

View File

@@ -2,28 +2,32 @@ package explorer
import (
"encoding/base64"
"net/http"
"sort"
"strings"
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/explorer"
"github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/internal"
)
func Dashboard() func(*fiber.Ctx) error {
return func(c *fiber.Ctx) error {
summary := fiber.Map{
func Dashboard() echo.HandlerFunc {
return func(c echo.Context) error {
summary := map[string]interface{}{
"Title": "LocalAI API - " + internal.PrintableVersion(),
"Version": internal.PrintableVersion(),
"BaseURL": utils.BaseURL(c),
"BaseURL": middleware.BaseURL(c),
}
if string(c.Context().Request.Header.ContentType()) == "application/json" || len(c.Accepts("html")) == 0 {
contentType := c.Request().Header.Get("Content-Type")
accept := c.Request().Header.Get("Accept")
if strings.Contains(contentType, "application/json") || (accept != "" && !strings.Contains(accept, "html")) {
// The client expects a JSON response
return c.Status(fiber.StatusOK).JSON(summary)
return c.JSON(http.StatusOK, summary)
} else {
// Render index
return c.Render("views/explorer", summary)
return c.Render(http.StatusOK, "views/explorer", summary)
}
}
}
@@ -39,8 +43,8 @@ type Network struct {
Token string `json:"token"`
}
func ShowNetworks(db *explorer.Database) func(*fiber.Ctx) error {
return func(c *fiber.Ctx) error {
func ShowNetworks(db *explorer.Database) echo.HandlerFunc {
return func(c echo.Context) error {
results := []Network{}
for _, token := range db.TokenList() {
networkData, exists := db.Get(token) // get the token data
@@ -61,44 +65,44 @@ func ShowNetworks(db *explorer.Database) func(*fiber.Ctx) error {
return len(results[i].Clusters) > len(results[j].Clusters)
})
return c.JSON(results)
return c.JSON(http.StatusOK, results)
}
}
func AddNetwork(db *explorer.Database) func(*fiber.Ctx) error {
return func(c *fiber.Ctx) error {
func AddNetwork(db *explorer.Database) echo.HandlerFunc {
return func(c echo.Context) error {
request := new(AddNetworkRequest)
if err := c.BodyParser(request); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Cannot parse JSON"})
if err := c.Bind(request); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Cannot parse JSON"})
}
if request.Token == "" {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Token is required"})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Token is required"})
}
if request.Name == "" {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Name is required"})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Name is required"})
}
if request.Description == "" {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Description is required"})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Description is required"})
}
// TODO: check if token is valid, otherwise reject
// try to decode the token from base64
_, err := base64.StdEncoding.DecodeString(request.Token)
if err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid token"})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid token"})
}
if _, exists := db.Get(request.Token); exists {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Token already exists"})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Token already exists"})
}
err = db.Set(request.Token, explorer.TokenData{Name: request.Name, Description: request.Description})
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Cannot add token"})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Cannot add token"})
}
return c.Status(fiber.StatusOK).JSON(fiber.Map{"message": "Token added"})
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Token added"})
}
}

View File

@@ -1,11 +1,12 @@
package jina
import (
"net/http"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/model"
@@ -17,24 +18,36 @@ import (
// @Param request body schema.JINARerankRequest true "query params"
// @Success 200 {object} schema.JINARerankResponse "Response"
// @Router /v1/rerank [post]
func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.JINARerankRequest)
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.JINARerankRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
log.Debug().Str("model", input.Model).Msg("JINA Rerank Request received")
var requestTopN int32
docs := int32(len(input.Documents))
if input.TopN == nil { // omit top_n to get all
requestTopN = docs
} else {
requestTopN = int32(*input.TopN)
if requestTopN < 1 {
return c.JSON(http.StatusUnprocessableEntity, "top_n - should be greater than or equal to 1")
}
if requestTopN > docs { // make it more obvious for backends
requestTopN = docs
}
}
request := &proto.RerankRequest{
Query: input.Query,
TopN: int32(input.TopN),
TopN: requestTopN,
Documents: input.Documents,
}
@@ -58,6 +71,6 @@ func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
response.Usage.TotalTokens = int(results.Usage.TotalTokens)
response.Usage.PromptTokens = int(results.Usage.PromptTokens)
return c.Status(fiber.StatusOK).JSON(response)
return c.JSON(http.StatusOK, response)
}
}

View File

@@ -4,11 +4,11 @@ import (
"encoding/json"
"fmt"
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/pkg/system"
@@ -39,13 +39,13 @@ func CreateBackendEndpointService(galleries []config.Gallery, systemState *syste
// @Summary Returns the job status
// @Success 200 {object} services.GalleryOpStatus "Response"
// @Router /backends/jobs/{uuid} [get]
func (mgs *BackendEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
status := mgs.backendApplier.GetStatus(c.Params("uuid"))
func (mgs *BackendEndpointService) GetOpStatusEndpoint() echo.HandlerFunc {
return func(c echo.Context) error {
status := mgs.backendApplier.GetStatus(c.Param("uuid"))
if status == nil {
return fmt.Errorf("could not find any status for ID")
}
return c.JSON(status)
return c.JSON(200, status)
}
}
@@ -53,9 +53,9 @@ func (mgs *BackendEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) erro
// @Summary Returns all the jobs status progress
// @Success 200 {object} map[string]services.GalleryOpStatus "Response"
// @Router /backends/jobs [get]
func (mgs *BackendEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
return c.JSON(mgs.backendApplier.GetAllStatus())
func (mgs *BackendEndpointService) GetAllStatusEndpoint() echo.HandlerFunc {
return func(c echo.Context) error {
return c.JSON(200, mgs.backendApplier.GetAllStatus())
}
}
@@ -64,11 +64,11 @@ func (mgs *BackendEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) err
// @Param request body GalleryBackend true "query params"
// @Success 200 {object} schema.BackendResponse "Response"
// @Router /backends/apply [post]
func (mgs *BackendEndpointService) ApplyBackendEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
func (mgs *BackendEndpointService) ApplyBackendEndpoint() echo.HandlerFunc {
return func(c echo.Context) error {
input := new(GalleryBackend)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
if err := c.Bind(input); err != nil {
return err
}
@@ -82,7 +82,7 @@ func (mgs *BackendEndpointService) ApplyBackendEndpoint() func(c *fiber.Ctx) err
Galleries: mgs.galleries,
}
return c.JSON(schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", utils.BaseURL(c), uuid.String())})
return c.JSON(200, schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", middleware.BaseURL(c), uuid.String())})
}
}
@@ -91,9 +91,9 @@ func (mgs *BackendEndpointService) ApplyBackendEndpoint() func(c *fiber.Ctx) err
// @Param name path string true "Backend name"
// @Success 200 {object} schema.BackendResponse "Response"
// @Router /backends/delete/{name} [post]
func (mgs *BackendEndpointService) DeleteBackendEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
backendName := c.Params("name")
func (mgs *BackendEndpointService) DeleteBackendEndpoint() echo.HandlerFunc {
return func(c echo.Context) error {
backendName := c.Param("name")
mgs.backendApplier.BackendGalleryChannel <- services.GalleryOp[gallery.GalleryBackend, any]{
Delete: true,
@@ -106,7 +106,7 @@ func (mgs *BackendEndpointService) DeleteBackendEndpoint() func(c *fiber.Ctx) er
return err
}
return c.JSON(schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", utils.BaseURL(c), uuid.String())})
return c.JSON(200, schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", middleware.BaseURL(c), uuid.String())})
}
}
@@ -114,13 +114,13 @@ func (mgs *BackendEndpointService) DeleteBackendEndpoint() func(c *fiber.Ctx) er
// @Summary List all Backends
// @Success 200 {object} []gallery.GalleryBackend "Response"
// @Router /backends [get]
func (mgs *BackendEndpointService) ListBackendsEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
func (mgs *BackendEndpointService) ListBackendsEndpoint(systemState *system.SystemState) echo.HandlerFunc {
return func(c echo.Context) error {
backends, err := gallery.ListSystemBackends(systemState)
if err != nil {
return err
}
return c.JSON(backends.GetAll())
return c.JSON(200, backends.GetAll())
}
}
@@ -129,14 +129,14 @@ func (mgs *BackendEndpointService) ListBackendsEndpoint(systemState *system.Syst
// @Success 200 {object} []config.Gallery "Response"
// @Router /backends/galleries [get]
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
func (mgs *BackendEndpointService) ListBackendGalleriesEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
func (mgs *BackendEndpointService) ListBackendGalleriesEndpoint() echo.HandlerFunc {
return func(c echo.Context) error {
log.Debug().Msgf("Listing backend galleries %+v", mgs.galleries)
dat, err := json.Marshal(mgs.galleries)
if err != nil {
return err
}
return c.Send(dat)
return c.Blob(200, "application/json", dat)
}
}
@@ -144,12 +144,12 @@ func (mgs *BackendEndpointService) ListBackendGalleriesEndpoint() func(c *fiber.
// @Summary List all available Backends
// @Success 200 {object} []gallery.GalleryBackend "Response"
// @Router /backends/available [get]
func (mgs *BackendEndpointService) ListAvailableBackendsEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
func (mgs *BackendEndpointService) ListAvailableBackendsEndpoint(systemState *system.SystemState) echo.HandlerFunc {
return func(c echo.Context) error {
backends, err := gallery.AvailableBackends(mgs.galleries, systemState)
if err != nil {
return err
}
return c.JSON(backends)
return c.JSON(200, backends)
}
}

View File

@@ -1,45 +1,45 @@
package localai
import (
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
)
// BackendMonitorEndpoint returns the status of the specified backend
// @Summary Backend monitor endpoint
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
// @Success 200 {object} proto.StatusResponse "Response"
// @Router /backend/monitor [get]
func BackendMonitorEndpoint(bm *services.BackendMonitorService) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.BackendMonitorRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
resp, err := bm.CheckAndSample(input.Model)
if err != nil {
return err
}
return c.JSON(resp)
}
}
// BackendShutdownEndpoint shuts down the specified backend
// @Summary Backend monitor endpoint
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
// @Router /backend/shutdown [post]
func BackendShutdownEndpoint(bm *services.BackendMonitorService) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.BackendMonitorRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
return bm.ShutdownModel(input.Model)
}
}
package localai
import (
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
)
// BackendMonitorEndpoint returns the status of the specified backend
// @Summary Backend monitor endpoint
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
// @Success 200 {object} proto.StatusResponse "Response"
// @Router /backend/monitor [get]
func BackendMonitorEndpoint(bm *services.BackendMonitorService) echo.HandlerFunc {
return func(c echo.Context) error {
input := new(schema.BackendMonitorRequest)
// Get input data from the request body
if err := c.Bind(input); err != nil {
return err
}
resp, err := bm.CheckAndSample(input.Model)
if err != nil {
return err
}
return c.JSON(200, resp)
}
}
// BackendShutdownEndpoint shuts down the specified backend
// @Summary Backend monitor endpoint
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
// @Router /backend/shutdown [post]
func BackendShutdownEndpoint(bm *services.BackendMonitorService) echo.HandlerFunc {
return func(c echo.Context) error {
input := new(schema.BackendMonitorRequest)
// Get input data from the request body
if err := c.Bind(input); err != nil {
return err
}
return bm.ShutdownModel(input.Model)
}
}

View File

@@ -1,7 +1,7 @@
package localai
import (
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
@@ -16,17 +16,17 @@ import (
// @Param request body schema.DetectionRequest true "query params"
// @Success 200 {object} schema.DetectionResponse "Response"
// @Router /v1/detection [post]
func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.DetectionRequest)
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.DetectionRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
log.Debug().Str("image", input.Image).Str("modelFile", "modelFile").Str("backend", cfg.Backend).Msg("Detection")
@@ -54,6 +54,6 @@ func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appC
}
}
return c.JSON(response)
return c.JSON(200, response)
}
}

View File

@@ -2,11 +2,13 @@ package localai
import (
"fmt"
"io"
"net/http"
"os"
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
httpUtils "github.com/mudler/LocalAI/core/http/utils"
httpUtils "github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/internal"
"github.com/mudler/LocalAI/pkg/utils"
@@ -14,15 +16,15 @@ import (
)
// GetEditModelPage renders the edit model page with current configuration
func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler {
return func(c *fiber.Ctx) error {
modelName := c.Params("name")
func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
modelName := c.Param("name")
if modelName == "" {
response := ModelResponse{
Success: false,
Error: "Model name is required",
}
return c.Status(400).JSON(response)
return c.JSON(http.StatusBadRequest, response)
}
modelConfig, exists := cl.GetModelConfig(modelName)
@@ -31,7 +33,7 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio
Success: false,
Error: "Model configuration not found",
}
return c.Status(404).JSON(response)
return c.JSON(http.StatusNotFound, response)
}
modelConfigFile := modelConfig.GetModelConfigFile()
@@ -40,7 +42,7 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio
Success: false,
Error: "Model configuration file not found",
}
return c.Status(404).JSON(response)
return c.JSON(http.StatusNotFound, response)
}
configData, err := os.ReadFile(modelConfigFile)
if err != nil {
@@ -48,7 +50,7 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio
Success: false,
Error: "Failed to read configuration file: " + err.Error(),
}
return c.Status(500).JSON(response)
return c.JSON(http.StatusInternalServerError, response)
}
// Render the edit page with the current configuration
@@ -69,20 +71,20 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio
Version: internal.PrintableVersion(),
}
return c.Render("views/model-editor", templateData)
return c.Render(http.StatusOK, "views/model-editor", templateData)
}
}
// EditModelEndpoint handles updating existing model configurations
func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler {
return func(c *fiber.Ctx) error {
modelName := c.Params("name")
func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
modelName := c.Param("name")
if modelName == "" {
response := ModelResponse{
Success: false,
Error: "Model name is required",
}
return c.Status(400).JSON(response)
return c.JSON(http.StatusBadRequest, response)
}
modelConfig, exists := cl.GetModelConfig(modelName)
@@ -91,17 +93,24 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
Success: false,
Error: "Existing model configuration not found",
}
return c.Status(404).JSON(response)
return c.JSON(http.StatusNotFound, response)
}
// Get the raw body
body := c.Body()
body, err := io.ReadAll(c.Request().Body)
if err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to read request body: " + err.Error(),
}
return c.JSON(http.StatusBadRequest, response)
}
if len(body) == 0 {
response := ModelResponse{
Success: false,
Error: "Request body is empty",
}
return c.Status(400).JSON(response)
return c.JSON(http.StatusBadRequest, response)
}
// Check content to see if it's a valid model config
@@ -113,7 +122,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
Success: false,
Error: "Failed to parse YAML: " + err.Error(),
}
return c.Status(400).JSON(response)
return c.JSON(http.StatusBadRequest, response)
}
// Validate required fields
@@ -122,7 +131,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
Success: false,
Error: "Name is required",
}
return c.Status(400).JSON(response)
return c.JSON(http.StatusBadRequest, response)
}
// Validate the configuration
@@ -132,7 +141,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
Error: "Validation failed",
Details: []string{"Configuration validation failed. Please check your YAML syntax and required fields."},
}
return c.Status(400).JSON(response)
return c.JSON(http.StatusBadRequest, response)
}
// Load the existing configuration
@@ -142,7 +151,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
Success: false,
Error: "Model configuration not trusted: " + err.Error(),
}
return c.Status(404).JSON(response)
return c.JSON(http.StatusNotFound, response)
}
// Write new content to file
@@ -151,7 +160,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
Success: false,
Error: "Failed to write configuration file: " + err.Error(),
}
return c.Status(500).JSON(response)
return c.JSON(http.StatusInternalServerError, response)
}
// Reload configurations
@@ -160,7 +169,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
Success: false,
Error: "Failed to reload configurations: " + err.Error(),
}
return c.Status(500).JSON(response)
return c.JSON(http.StatusInternalServerError, response)
}
// Preload the model
@@ -169,7 +178,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
Success: false,
Error: "Failed to preload model: " + err.Error(),
}
return c.Status(500).JSON(response)
return c.JSON(http.StatusInternalServerError, response)
}
// Return success response
@@ -179,20 +188,20 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
Filename: configPath,
Config: req,
}
return c.JSON(response)
return c.JSON(200, response)
}
}
// ReloadModelsEndpoint handles reloading model configurations from disk
func ReloadModelsEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler {
return func(c *fiber.Ctx) error {
func ReloadModelsEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
// Reload configurations
if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath); err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to reload configurations: " + err.Error(),
}
return c.Status(500).JSON(response)
return c.JSON(http.StatusInternalServerError, response)
}
// Preload the models
@@ -201,7 +210,7 @@ func ReloadModelsEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applic
Success: false,
Error: "Failed to preload models: " + err.Error(),
}
return c.Status(500).JSON(response)
return c.JSON(http.StatusInternalServerError, response)
}
// Return success response
@@ -209,6 +218,6 @@ func ReloadModelsEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applic
Success: true,
Message: "Model configurations reloaded successfully",
}
return c.Status(fiber.StatusOK).JSON(response)
return c.JSON(http.StatusOK, response)
}
}

View File

@@ -2,12 +2,14 @@ package localai_test
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
. "github.com/mudler/LocalAI/core/http/endpoints/localai"
"github.com/mudler/LocalAI/pkg/system"
@@ -15,6 +17,14 @@ import (
. "github.com/onsi/gomega"
)
// testRenderer is a simple renderer for tests that returns JSON
type testRenderer struct{}
func (t *testRenderer) Render(w io.Writer, name string, data interface{}, c echo.Context) error {
// For tests, just return the data as JSON
return json.NewEncoder(w).Encode(data)
}
var _ = Describe("Edit Model test", func() {
var tempDir string
@@ -40,33 +50,35 @@ var _ = Describe("Edit Model test", func() {
//modelLoader := model.NewModelLoader(systemState, true)
modelConfigLoader := config.NewModelConfigLoader(systemState.Model.ModelsPath)
// Define Fiber app.
app := fiber.New()
app.Put("/import-model", ImportModelEndpoint(modelConfigLoader, applicationConfig))
// Define Echo app and register all routes upfront
app := echo.New()
// Set up a simple renderer for the test
app.Renderer = &testRenderer{}
app.POST("/import-model", ImportModelEndpoint(modelConfigLoader, applicationConfig))
app.GET("/edit-model/:name", GetEditModelPage(modelConfigLoader, applicationConfig))
requestBody := bytes.NewBufferString(`{"name": "foo", "backend": "foo", "model": "foo"}`)
req := httptest.NewRequest("PUT", "/import-model", requestBody)
resp, err := app.Test(req, 5000)
Expect(err).ToNot(HaveOccurred())
req := httptest.NewRequest("POST", "/import-model", requestBody)
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
body, err := io.ReadAll(resp.Body)
defer resp.Body.Close()
body, err := io.ReadAll(rec.Body)
Expect(err).ToNot(HaveOccurred())
Expect(string(body)).To(ContainSubstring("Model configuration created successfully"))
Expect(resp.StatusCode).To(Equal(fiber.StatusOK))
Expect(rec.Code).To(Equal(http.StatusOK))
app.Get("/edit-model/:name", EditModelEndpoint(modelConfigLoader, applicationConfig))
requestBody = bytes.NewBufferString(`{"name": "foo", "parameters": { "model": "foo"}}`)
req = httptest.NewRequest("GET", "/edit-model/foo", nil)
rec = httptest.NewRecorder()
app.ServeHTTP(rec, req)
req = httptest.NewRequest("GET", "/edit-model/foo", requestBody)
resp, _ = app.Test(req, 1)
body, err = io.ReadAll(resp.Body)
defer resp.Body.Close()
body, err = io.ReadAll(rec.Body)
Expect(err).ToNot(HaveOccurred())
Expect(string(body)).To(ContainSubstring(`"model":"foo"`))
Expect(resp.StatusCode).To(Equal(fiber.StatusOK))
// The response contains the model configuration with backend field
Expect(string(body)).To(ContainSubstring(`"backend":"foo"`))
Expect(string(body)).To(ContainSubstring(`"name":"foo"`))
Expect(rec.Code).To(Equal(http.StatusOK))
})
})
})

View File

@@ -1,160 +1,160 @@
package localai
import (
"encoding/json"
"fmt"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/pkg/system"
"github.com/rs/zerolog/log"
)
type ModelGalleryEndpointService struct {
galleries []config.Gallery
backendGalleries []config.Gallery
modelPath string
galleryApplier *services.GalleryService
}
type GalleryModel struct {
ID string `json:"id"`
gallery.GalleryModel
}
func CreateModelGalleryEndpointService(galleries []config.Gallery, backendGalleries []config.Gallery, systemState *system.SystemState, galleryApplier *services.GalleryService) ModelGalleryEndpointService {
return ModelGalleryEndpointService{
galleries: galleries,
backendGalleries: backendGalleries,
modelPath: systemState.Model.ModelsPath,
galleryApplier: galleryApplier,
}
}
// GetOpStatusEndpoint returns the job status
// @Summary Returns the job status
// @Success 200 {object} services.GalleryOpStatus "Response"
// @Router /models/jobs/{uuid} [get]
func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
status := mgs.galleryApplier.GetStatus(c.Params("uuid"))
if status == nil {
return fmt.Errorf("could not find any status for ID")
}
return c.JSON(status)
}
}
// GetAllStatusEndpoint returns all the jobs status progress
// @Summary Returns all the jobs status progress
// @Success 200 {object} map[string]services.GalleryOpStatus "Response"
// @Router /models/jobs [get]
func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
return c.JSON(mgs.galleryApplier.GetAllStatus())
}
}
// ApplyModelGalleryEndpoint installs a new model to a LocalAI instance from the model gallery
// @Summary Install models to LocalAI.
// @Param request body GalleryModel true "query params"
// @Success 200 {object} schema.GalleryResponse "Response"
// @Router /models/apply [post]
func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(GalleryModel)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
uuid, err := uuid.NewUUID()
if err != nil {
return err
}
mgs.galleryApplier.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
Req: input.GalleryModel,
ID: uuid.String(),
GalleryElementName: input.ID,
Galleries: mgs.galleries,
BackendGalleries: mgs.backendGalleries,
}
return c.JSON(schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", utils.BaseURL(c), uuid.String())})
}
}
// DeleteModelGalleryEndpoint lets delete models from a LocalAI instance
// @Summary delete models to LocalAI.
// @Param name path string true "Model name"
// @Success 200 {object} schema.GalleryResponse "Response"
// @Router /models/delete/{name} [post]
func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
modelName := c.Params("name")
mgs.galleryApplier.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
Delete: true,
GalleryElementName: modelName,
}
uuid, err := uuid.NewUUID()
if err != nil {
return err
}
return c.JSON(schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", utils.BaseURL(c), uuid.String())})
}
}
// ListModelFromGalleryEndpoint list the available models for installation from the active galleries
// @Summary List installable models.
// @Success 200 {object} []gallery.GalleryModel "Response"
// @Router /models/available [get]
func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
models, err := gallery.AvailableGalleryModels(mgs.galleries, systemState)
if err != nil {
log.Error().Err(err).Msg("could not list models from galleries")
return err
}
log.Debug().Msgf("Available %d models from %d galleries\n", len(models), len(mgs.galleries))
m := []gallery.Metadata{}
for _, mm := range models {
m = append(m, mm.Metadata)
}
log.Debug().Msgf("Models %#v", m)
dat, err := json.Marshal(m)
if err != nil {
return fmt.Errorf("could not marshal models: %w", err)
}
return c.Send(dat)
}
}
// ListModelGalleriesEndpoint list the available galleries configured in LocalAI
// @Summary List all Galleries
// @Success 200 {object} []config.Gallery "Response"
// @Router /models/galleries [get]
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
func (mgs *ModelGalleryEndpointService) ListModelGalleriesEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
log.Debug().Msgf("Listing model galleries %+v", mgs.galleries)
dat, err := json.Marshal(mgs.galleries)
if err != nil {
return err
}
return c.Send(dat)
}
}
package localai
import (
"encoding/json"
"fmt"
"github.com/labstack/echo/v4"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/pkg/system"
"github.com/rs/zerolog/log"
)
type ModelGalleryEndpointService struct {
galleries []config.Gallery
backendGalleries []config.Gallery
modelPath string
galleryApplier *services.GalleryService
}
type GalleryModel struct {
ID string `json:"id"`
gallery.GalleryModel
}
func CreateModelGalleryEndpointService(galleries []config.Gallery, backendGalleries []config.Gallery, systemState *system.SystemState, galleryApplier *services.GalleryService) ModelGalleryEndpointService {
return ModelGalleryEndpointService{
galleries: galleries,
backendGalleries: backendGalleries,
modelPath: systemState.Model.ModelsPath,
galleryApplier: galleryApplier,
}
}
// GetOpStatusEndpoint returns the job status
// @Summary Returns the job status
// @Success 200 {object} services.GalleryOpStatus "Response"
// @Router /models/jobs/{uuid} [get]
func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() echo.HandlerFunc {
return func(c echo.Context) error {
status := mgs.galleryApplier.GetStatus(c.Param("uuid"))
if status == nil {
return fmt.Errorf("could not find any status for ID")
}
return c.JSON(200, status)
}
}
// GetAllStatusEndpoint returns all the jobs status progress
// @Summary Returns all the jobs status progress
// @Success 200 {object} map[string]services.GalleryOpStatus "Response"
// @Router /models/jobs [get]
func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() echo.HandlerFunc {
return func(c echo.Context) error {
return c.JSON(200, mgs.galleryApplier.GetAllStatus())
}
}
// ApplyModelGalleryEndpoint installs a new model to a LocalAI instance from the model gallery
// @Summary Install models to LocalAI.
// @Param request body GalleryModel true "query params"
// @Success 200 {object} schema.GalleryResponse "Response"
// @Router /models/apply [post]
func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() echo.HandlerFunc {
return func(c echo.Context) error {
input := new(GalleryModel)
// Get input data from the request body
if err := c.Bind(input); err != nil {
return err
}
uuid, err := uuid.NewUUID()
if err != nil {
return err
}
mgs.galleryApplier.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
Req: input.GalleryModel,
ID: uuid.String(),
GalleryElementName: input.ID,
Galleries: mgs.galleries,
BackendGalleries: mgs.backendGalleries,
}
return c.JSON(200, schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", middleware.BaseURL(c), uuid.String())})
}
}
// DeleteModelGalleryEndpoint lets delete models from a LocalAI instance
// @Summary delete models to LocalAI.
// @Param name path string true "Model name"
// @Success 200 {object} schema.GalleryResponse "Response"
// @Router /models/delete/{name} [post]
func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() echo.HandlerFunc {
return func(c echo.Context) error {
modelName := c.Param("name")
mgs.galleryApplier.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
Delete: true,
GalleryElementName: modelName,
}
uuid, err := uuid.NewUUID()
if err != nil {
return err
}
return c.JSON(200, schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", middleware.BaseURL(c), uuid.String())})
}
}
// ListModelFromGalleryEndpoint list the available models for installation from the active galleries
// @Summary List installable models.
// @Success 200 {object} []gallery.GalleryModel "Response"
// @Router /models/available [get]
func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint(systemState *system.SystemState) echo.HandlerFunc {
return func(c echo.Context) error {
models, err := gallery.AvailableGalleryModels(mgs.galleries, systemState)
if err != nil {
log.Error().Err(err).Msg("could not list models from galleries")
return err
}
log.Debug().Msgf("Available %d models from %d galleries\n", len(models), len(mgs.galleries))
m := []gallery.Metadata{}
for _, mm := range models {
m = append(m, mm.Metadata)
}
log.Debug().Msgf("Models %#v", m)
dat, err := json.Marshal(m)
if err != nil {
return fmt.Errorf("could not marshal models: %w", err)
}
return c.Blob(200, "application/json", dat)
}
}
// ListModelGalleriesEndpoint list the available galleries configured in LocalAI
// @Summary List all Galleries
// @Success 200 {object} []config.Gallery "Response"
// @Router /models/galleries [get]
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
func (mgs *ModelGalleryEndpointService) ListModelGalleriesEndpoint() echo.HandlerFunc {
return func(c echo.Context) error {
log.Debug().Msgf("Listing model galleries %+v", mgs.galleries)
dat, err := json.Marshal(mgs.galleries)
if err != nil {
return err
}
return c.Blob(200, "application/json", dat)
}
}

View File

@@ -1,7 +1,7 @@
package localai
import (
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
@@ -21,17 +21,17 @@ import (
// @Success 200 {string} binary "generated audio/wav file"
// @Router /v1/tokenMetrics [get]
// @Router /tokenMetrics [get]
func TokenMetricsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
func TokenMetricsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input := new(schema.TokenMetricsRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
if err := c.Bind(input); err != nil {
return err
}
modelFile, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
modelFile, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if !ok || modelFile != "" {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
@@ -52,6 +52,6 @@ func TokenMetricsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, a
if err != nil {
return err
}
return c.JSON(response)
return c.JSON(200, response)
}
}

View File

@@ -3,16 +3,18 @@ package localai
import (
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/gallery/importers"
httpUtils "github.com/mudler/LocalAI/core/http/utils"
httpUtils "github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/pkg/utils"
@@ -21,12 +23,12 @@ import (
)
// ImportModelURIEndpoint handles creating new model configurations from a URI
func ImportModelURIEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) fiber.Handler {
return func(c *fiber.Ctx) error {
func ImportModelURIEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) echo.HandlerFunc {
return func(c echo.Context) error {
input := new(schema.ImportModelRequest)
if err := c.BodyParser(input); err != nil {
if err := c.Bind(input); err != nil {
return err
}
@@ -61,7 +63,7 @@ func ImportModelURIEndpoint(cl *config.ModelConfigLoader, appConfig *config.Appl
BackendGalleries: appConfig.BackendGalleries,
}
return c.JSON(schema.GalleryResponse{
return c.JSON(200, schema.GalleryResponse{
ID: uuid.String(),
StatusURL: fmt.Sprintf("%smodels/jobs/%s", httpUtils.BaseURL(c), uuid.String()),
})
@@ -69,22 +71,28 @@ func ImportModelURIEndpoint(cl *config.ModelConfigLoader, appConfig *config.Appl
}
// ImportModelEndpoint handles creating new model configurations
func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler {
return func(c *fiber.Ctx) error {
func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
// Get the raw body
body := c.Body()
body, err := io.ReadAll(c.Request().Body)
if err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to read request body: " + err.Error(),
}
return c.JSON(http.StatusBadRequest, response)
}
if len(body) == 0 {
response := ModelResponse{
Success: false,
Error: "Request body is empty",
}
return c.Status(400).JSON(response)
return c.JSON(http.StatusBadRequest, response)
}
// Check content type to determine how to parse
contentType := string(c.Context().Request.Header.ContentType())
contentType := c.Request().Header.Get("Content-Type")
var modelConfig config.ModelConfig
var err error
if strings.Contains(contentType, "application/json") {
// Parse JSON
@@ -93,7 +101,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false,
Error: "Failed to parse JSON: " + err.Error(),
}
return c.Status(400).JSON(response)
return c.JSON(http.StatusBadRequest, response)
}
} else if strings.Contains(contentType, "application/x-yaml") || strings.Contains(contentType, "text/yaml") {
// Parse YAML
@@ -102,18 +110,18 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false,
Error: "Failed to parse YAML: " + err.Error(),
}
return c.Status(400).JSON(response)
return c.JSON(http.StatusBadRequest, response)
}
} else {
// Try to auto-detect format
if strings.TrimSpace(string(body))[0] == '{' {
if len(body) > 0 && strings.TrimSpace(string(body))[0] == '{' {
// Looks like JSON
if err := json.Unmarshal(body, &modelConfig); err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to parse JSON: " + err.Error(),
}
return c.Status(400).JSON(response)
return c.JSON(http.StatusBadRequest, response)
}
} else {
// Assume YAML
@@ -122,7 +130,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false,
Error: "Failed to parse YAML: " + err.Error(),
}
return c.Status(400).JSON(response)
return c.JSON(http.StatusBadRequest, response)
}
}
}
@@ -133,7 +141,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false,
Error: "Name is required",
}
return c.Status(400).JSON(response)
return c.JSON(http.StatusBadRequest, response)
}
// Set defaults
@@ -145,7 +153,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false,
Error: "Invalid configuration",
}
return c.Status(400).JSON(response)
return c.JSON(http.StatusBadRequest, response)
}
// Create the configuration file
@@ -155,7 +163,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false,
Error: "Model path not trusted: " + err.Error(),
}
return c.Status(400).JSON(response)
return c.JSON(http.StatusBadRequest, response)
}
// Marshal to YAML for storage
@@ -165,7 +173,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false,
Error: "Failed to marshal configuration: " + err.Error(),
}
return c.Status(500).JSON(response)
return c.JSON(http.StatusInternalServerError, response)
}
// Write the file
@@ -174,7 +182,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false,
Error: "Failed to write configuration file: " + err.Error(),
}
return c.Status(500).JSON(response)
return c.JSON(http.StatusInternalServerError, response)
}
// Reload configurations
if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath); err != nil {
@@ -182,7 +190,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false,
Error: "Failed to reload configurations: " + err.Error(),
}
return c.Status(500).JSON(response)
return c.JSON(http.StatusInternalServerError, response)
}
// Preload the model
@@ -191,7 +199,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false,
Error: "Failed to preload model: " + err.Error(),
}
return c.Status(500).JSON(response)
return c.JSON(http.StatusInternalServerError, response)
}
// Return success response
response := ModelResponse{
@@ -199,6 +207,6 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Message: "Model configuration created successfully",
Filename: filepath.Base(configPath),
}
return c.JSON(response)
return c.JSON(200, response)
}
}

View File

@@ -0,0 +1,323 @@
package localai
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/templates"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/cogito"
"github.com/rs/zerolog/log"
)
// MCP SSE Event Types
type MCPReasoningEvent struct {
Type string `json:"type"`
Content string `json:"content"`
}
type MCPToolCallEvent struct {
Type string `json:"type"`
Name string `json:"name"`
Arguments map[string]interface{} `json:"arguments"`
Reasoning string `json:"reasoning"`
}
type MCPToolResultEvent struct {
Type string `json:"type"`
Name string `json:"name"`
Result string `json:"result"`
}
type MCPStatusEvent struct {
Type string `json:"type"`
Message string `json:"message"`
}
type MCPAssistantEvent struct {
Type string `json:"type"`
Content string `json:"content"`
}
type MCPErrorEvent struct {
Type string `json:"type"`
Message string `json:"message"`
}
// MCPStreamEndpoint is the SSE streaming endpoint for MCP chat completions
// @Summary Stream MCP chat completions with reasoning, tool calls, and results
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/mcp/chat/completions [post]
func MCPStreamEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
ctx := c.Request().Context()
created := int(time.Now().Unix())
// Handle Correlation
id := c.Request().Header.Get("X-Correlation-ID")
if id == "" {
id = fmt.Sprintf("mcp-%d", time.Now().UnixNano())
}
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return echo.ErrBadRequest
}
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil {
return echo.ErrBadRequest
}
if config.MCP.Servers == "" && config.MCP.Stdio == "" {
return fmt.Errorf("no MCP servers configured")
}
// Get MCP config from model config
remote, stdio, err := config.MCP.MCPConfigFromYAML()
if err != nil {
return fmt.Errorf("failed to get MCP config: %w", err)
}
// Check if we have tools in cache, or we have to have an initial connection
sessions, err := mcpTools.SessionsFromMCPConfig(config.Name, remote, stdio)
if err != nil {
return fmt.Errorf("failed to get MCP sessions: %w", err)
}
if len(sessions) == 0 {
return fmt.Errorf("no working MCP servers found")
}
// Build fragment from messages
fragment := cogito.NewEmptyFragment()
for _, message := range input.Messages {
fragment = fragment.AddMessage(message.Role, message.StringContent)
}
port := appConfig.APIAddress[strings.LastIndex(appConfig.APIAddress, ":")+1:]
apiKey := ""
if len(appConfig.ApiKeys) > 0 {
apiKey = appConfig.ApiKeys[0]
}
ctxWithCancellation, cancel := context.WithCancel(ctx)
defer cancel()
// TODO: instead of connecting to the API, we should just wire this internally
// and act like completion.go.
// We can do this as cogito expects an interface and we can create one that
// we satisfy to just call internally ComputeChoices
defaultLLM := cogito.NewOpenAILLM(config.Name, apiKey, "http://127.0.0.1:"+port)
// Build cogito options using the consolidated method
cogitoOpts := config.BuildCogitoOptions()
cogitoOpts = append(
cogitoOpts,
cogito.WithContext(ctxWithCancellation),
cogito.WithMCPs(sessions...),
)
// Check if streaming is requested
toStream := input.Stream
if !toStream {
// Non-streaming mode: execute synchronously and return JSON response
cogitoOpts = append(
cogitoOpts,
cogito.WithStatusCallback(func(s string) {
log.Debug().Msgf("[model agent] [model: %s] Status: %s", config.Name, s)
}),
cogito.WithReasoningCallback(func(s string) {
log.Debug().Msgf("[model agent] [model: %s] Reasoning: %s", config.Name, s)
}),
cogito.WithToolCallBack(func(t *cogito.ToolChoice) bool {
log.Debug().Str("model", config.Name).Str("tool", t.Name).Str("reasoning", t.Reasoning).Interface("arguments", t.Arguments).Msg("[model agent] Tool call")
return true
}),
cogito.WithToolCallResultCallback(func(t cogito.ToolStatus) {
log.Debug().Str("model", config.Name).Str("tool", t.Name).Str("result", t.Result).Interface("tool_arguments", t.ToolArguments).Msg("[model agent] Tool call result")
}),
)
f, err := cogito.ExecuteTools(
defaultLLM, fragment,
cogitoOpts...,
)
if err != nil && !errors.Is(err, cogito.ErrNoToolSelected) {
return err
}
f, err = defaultLLM.Ask(ctxWithCancellation, f)
if err != nil {
return err
}
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Message: &schema.Message{Role: "assistant", Content: &f.LastMessage().Content}}},
Object: "chat.completion",
}
jsonResult, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(200, resp)
}
// Streaming mode: use SSE
// Set up SSE headers
c.Response().Header().Set("Content-Type", "text/event-stream")
c.Response().Header().Set("Cache-Control", "no-cache")
c.Response().Header().Set("Connection", "keep-alive")
c.Response().Header().Set("X-Correlation-ID", id)
// Create channel for streaming events
events := make(chan interface{})
ended := make(chan error, 1)
// Set up callbacks for streaming
statusCallback := func(s string) {
events <- MCPStatusEvent{
Type: "status",
Message: s,
}
}
reasoningCallback := func(s string) {
events <- MCPReasoningEvent{
Type: "reasoning",
Content: s,
}
}
toolCallCallback := func(t *cogito.ToolChoice) bool {
events <- MCPToolCallEvent{
Type: "tool_call",
Name: t.Name,
Arguments: t.Arguments,
Reasoning: t.Reasoning,
}
return true
}
toolCallResultCallback := func(t cogito.ToolStatus) {
events <- MCPToolResultEvent{
Type: "tool_result",
Name: t.Name,
Result: t.Result,
}
}
cogitoOpts = append(cogitoOpts,
cogito.WithStatusCallback(statusCallback),
cogito.WithReasoningCallback(reasoningCallback),
cogito.WithToolCallBack(toolCallCallback),
cogito.WithToolCallResultCallback(toolCallResultCallback),
)
// Execute tools in a goroutine
go func() {
defer close(events)
f, err := cogito.ExecuteTools(
defaultLLM, fragment,
cogitoOpts...,
)
if err != nil && !errors.Is(err, cogito.ErrNoToolSelected) {
events <- MCPErrorEvent{
Type: "error",
Message: fmt.Sprintf("Failed to execute tools: %v", err),
}
ended <- err
return
}
// Get final response
f, err = defaultLLM.Ask(ctxWithCancellation, f)
if err != nil {
events <- MCPErrorEvent{
Type: "error",
Message: fmt.Sprintf("Failed to get response: %v", err),
}
ended <- err
return
}
// Stream final assistant response
content := f.LastMessage().Content
events <- MCPAssistantEvent{
Type: "assistant",
Content: content,
}
ended <- nil
}()
// Stream events to client
LOOP:
for {
select {
case <-ctx.Done():
// Context was cancelled (client disconnected or request cancelled)
log.Debug().Msgf("Request context cancelled, stopping stream")
cancel()
break LOOP
case event := <-events:
if event == nil {
// Channel closed
break LOOP
}
eventData, err := json.Marshal(event)
if err != nil {
log.Debug().Msgf("Failed to marshal event: %v", err)
continue
}
log.Debug().Msgf("Sending event: %s", string(eventData))
_, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(eventData))
if err != nil {
log.Debug().Msgf("Sending event failed: %v", err)
cancel()
return err
}
c.Response().Flush()
case err := <-ended:
if err == nil {
// Send done signal
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
c.Response().Flush()
break LOOP
}
log.Error().Msgf("Stream ended with error: %v", err)
errorEvent := MCPErrorEvent{
Type: "error",
Message: err.Error(),
}
errorData, marshalErr := json.Marshal(errorEvent)
if marshalErr != nil {
fmt.Fprintf(c.Response().Writer, "data: {\"type\":\"error\",\"message\":\"Internal error\"}\n\n")
} else {
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(errorData))
}
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
c.Response().Flush()
return nil
}
}
log.Debug().Msgf("Stream ended")
return nil
}
}

View File

@@ -1,46 +1,47 @@
package localai
import (
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/adaptor"
"github.com/mudler/LocalAI/core/services"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
// LocalAIMetricsEndpoint returns the metrics endpoint for LocalAI
// @Summary Prometheus metrics endpoint
// @Param request body config.Gallery true "Gallery details"
// @Router /metrics [get]
func LocalAIMetricsEndpoint() fiber.Handler {
return adaptor.HTTPHandler(promhttp.Handler())
}
type apiMiddlewareConfig struct {
Filter func(c *fiber.Ctx) bool
metricsService *services.LocalAIMetricsService
}
func LocalAIMetricsAPIMiddleware(metrics *services.LocalAIMetricsService) fiber.Handler {
cfg := apiMiddlewareConfig{
metricsService: metrics,
Filter: func(c *fiber.Ctx) bool {
return c.Path() == "/metrics"
},
}
return func(c *fiber.Ctx) error {
if cfg.Filter != nil && cfg.Filter(c) {
return c.Next()
}
path := c.Path()
method := c.Method()
start := time.Now()
err := c.Next()
elapsed := float64(time.Since(start)) / float64(time.Second)
cfg.metricsService.ObserveAPICall(method, path, elapsed)
return err
}
}
package localai
import (
"time"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/services"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
// LocalAIMetricsEndpoint returns the metrics endpoint for LocalAI
// @Summary Prometheus metrics endpoint
// @Param request body config.Gallery true "Gallery details"
// @Router /metrics [get]
func LocalAIMetricsEndpoint() echo.HandlerFunc {
return echo.WrapHandler(promhttp.Handler())
}
type apiMiddlewareConfig struct {
Filter func(c echo.Context) bool
metricsService *services.LocalAIMetricsService
}
func LocalAIMetricsAPIMiddleware(metrics *services.LocalAIMetricsService) echo.MiddlewareFunc {
cfg := apiMiddlewareConfig{
metricsService: metrics,
Filter: func(c echo.Context) bool {
return c.Path() == "/metrics"
},
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if cfg.Filter != nil && cfg.Filter(c) {
return next(c)
}
path := c.Path()
method := c.Request().Method
start := time.Now()
err := next(c)
elapsed := float64(time.Since(start)) / float64(time.Second)
cfg.metricsService.ObserveAPICall(method, path, elapsed)
return err
}
}
}

View File

@@ -1,7 +1,7 @@
package localai
import (
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/p2p"
"github.com/mudler/LocalAI/core/schema"
@@ -11,10 +11,10 @@ import (
// @Summary Returns available P2P nodes
// @Success 200 {object} []schema.P2PNodesResponse "Response"
// @Router /api/p2p [get]
func ShowP2PNodes(appConfig *config.ApplicationConfig) func(*fiber.Ctx) error {
func ShowP2PNodes(appConfig *config.ApplicationConfig) echo.HandlerFunc {
// Render index
return func(c *fiber.Ctx) error {
return c.JSON(schema.P2PNodesResponse{
return func(c echo.Context) error {
return c.JSON(200, schema.P2PNodesResponse{
Nodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID)),
FederatedNodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID)),
})
@@ -25,6 +25,6 @@ func ShowP2PNodes(appConfig *config.ApplicationConfig) func(*fiber.Ctx) error {
// @Summary Show the P2P token
// @Success 200 {string} string "Response"
// @Router /api/p2p/token [get]
func ShowP2PToken(appConfig *config.ApplicationConfig) func(*fiber.Ctx) error {
return func(c *fiber.Ctx) error { return c.Send([]byte(appConfig.P2PToken)) }
func ShowP2PToken(appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error { return c.String(200, appConfig.P2PToken) }
}

View File

@@ -1,7 +1,7 @@
package localai
import (
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
@@ -9,11 +9,11 @@ import (
"github.com/mudler/LocalAI/pkg/store"
)
func StoresSetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
func StoresSetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input := new(schema.StoresSet)
if err := c.BodyParser(input); err != nil {
if err := c.Bind(input); err != nil {
return err
}
@@ -28,20 +28,20 @@ func StoresSetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfi
vals[i] = []byte(v)
}
err = store.SetCols(c.Context(), sb, input.Keys, vals)
err = store.SetCols(c.Request().Context(), sb, input.Keys, vals)
if err != nil {
return err
}
return c.Send(nil)
return c.NoContent(200)
}
}
func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input := new(schema.StoresDelete)
if err := c.BodyParser(input); err != nil {
if err := c.Bind(input); err != nil {
return err
}
@@ -51,19 +51,19 @@ func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationCo
}
defer sl.Close()
if err := store.DeleteCols(c.Context(), sb, input.Keys); err != nil {
if err := store.DeleteCols(c.Request().Context(), sb, input.Keys); err != nil {
return err
}
return c.Send(nil)
return c.NoContent(200)
}
}
func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input := new(schema.StoresGet)
if err := c.BodyParser(input); err != nil {
if err := c.Bind(input); err != nil {
return err
}
@@ -73,7 +73,7 @@ func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfi
}
defer sl.Close()
keys, vals, err := store.GetCols(c.Context(), sb, input.Keys)
keys, vals, err := store.GetCols(c.Request().Context(), sb, input.Keys)
if err != nil {
return err
}
@@ -87,15 +87,15 @@ func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfi
res.Values[i] = string(v)
}
return c.JSON(res)
return c.JSON(200, res)
}
}
func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input := new(schema.StoresFind)
if err := c.BodyParser(input); err != nil {
if err := c.Bind(input); err != nil {
return err
}
@@ -105,7 +105,7 @@ func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConf
}
defer sl.Close()
keys, vals, similarities, err := store.Find(c.Context(), sb, input.Key, input.Topk)
keys, vals, similarities, err := store.Find(c.Request().Context(), sb, input.Key, input.Topk)
if err != nil {
return err
}
@@ -120,6 +120,6 @@ func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConf
res.Values[i] = string(v)
}
return c.JSON(res)
return c.JSON(200, res)
}
}

View File

@@ -1,7 +1,7 @@
package localai
import (
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/model"
@@ -11,8 +11,8 @@ import (
// @Summary Show the LocalAI instance information
// @Success 200 {object} schema.SystemInformationResponse "Response"
// @Router /system [get]
func SystemInformations(ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(*fiber.Ctx) error {
return func(c *fiber.Ctx) error {
func SystemInformations(ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
availableBackends := []string{}
loadedModels := ml.ListLoadedModels()
for b := range appConfig.ExternalGRPCBackends {
@@ -26,7 +26,7 @@ func SystemInformations(ml *model.ModelLoader, appConfig *config.ApplicationConf
for _, m := range loadedModels {
sysmodels = append(sysmodels, schema.SysInfoModel{ID: m.ID})
}
return c.JSON(
return c.JSON(200,
schema.SystemInformationResponse{
Backends: availableBackends,
Models: sysmodels,

View File

@@ -1,7 +1,7 @@
package localai
import (
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
@@ -14,22 +14,22 @@ import (
// @Param request body schema.TokenizeRequest true "Request"
// @Success 200 {object} schema.TokenizeResponse "Response"
// @Router /v1/tokenize [post]
func TokenizeEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(ctx *fiber.Ctx) error {
input, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TokenizeRequest)
func TokenizeEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TokenizeRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
cfg, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
tokenResponse, err := backend.ModelTokenize(input.Content, ml, *cfg, appConfig)
if err != nil {
return err
}
return ctx.JSON(tokenResponse)
return c.JSON(200, tokenResponse)
}
}

View File

@@ -1,12 +1,14 @@
package localai
import (
"path/filepath"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/schema"
"github.com/rs/zerolog/log"
@@ -22,16 +24,16 @@ import (
// @Success 200 {string} binary "generated audio/wav file"
// @Router /v1/audio/speech [post]
// @Router /tts [post]
func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TTSRequest)
func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TTSRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
log.Debug().Str("model", input.Model).Msg("LocalAI TTS Request received")
@@ -59,6 +61,6 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
return err
}
return c.Download(filePath)
return c.Attachment(filePath, filepath.Base(filePath))
}
}

View File

@@ -1,7 +1,7 @@
package localai
import (
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
@@ -16,26 +16,26 @@ import (
// @Param request body schema.VADRequest true "query params"
// @Success 200 {object} proto.VADResponse "Response"
// @Router /vad [post]
func VADEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VADRequest)
func VADEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VADRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
log.Debug().Str("model", input.Model).Msg("LocalAI VAD Request received")
resp, err := backend.VAD(input, c.Context(), ml, appConfig, *cfg)
resp, err := backend.VAD(input, c.Request().Context(), ml, appConfig, *cfg)
if err != nil {
return err
}
return c.JSON(resp)
return c.JSON(200, resp)
}
}

View File

@@ -7,19 +7,20 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"time"
"github.com/google/uuid"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/backend"
"github.com/gofiber/fiber/v2"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
)
@@ -64,18 +65,18 @@ func downloadFile(url string) (string, error) {
// @Param request body schema.VideoRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /video [post]
func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VideoRequest)
func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VideoRequest)
if !ok || input.Model == "" {
log.Error().Msg("Video Endpoint - Invalid Input")
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil {
log.Error().Msg("Video Endpoint - Invalid Config")
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
src := ""
@@ -164,7 +165,7 @@ func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
return err
}
baseURL := c.BaseURL()
baseURL := middleware.BaseURL(c)
fn, err := backend.VideoGeneration(
height,
@@ -201,7 +202,10 @@ func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
item.B64JSON = base64.StdEncoding.EncodeToString(data)
} else {
base := filepath.Base(output)
item.URL = baseURL + "/generated-videos/" + base
item.URL, err = url.JoinPath(baseURL, "generated-videos", base)
if err != nil {
return err
}
}
id := uuid.New().String()
@@ -216,6 +220,6 @@ func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
return c.JSON(200, resp)
}
}

View File

@@ -1,18 +1,20 @@
package localai
import (
"github.com/gofiber/fiber/v2"
"strings"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/internal"
"github.com/mudler/LocalAI/pkg/model"
)
func WelcomeEndpoint(appConfig *config.ApplicationConfig,
cl *config.ModelConfigLoader, ml *model.ModelLoader, opcache *services.OpCache) func(*fiber.Ctx) error {
return func(c *fiber.Ctx) error {
cl *config.ModelConfigLoader, ml *model.ModelLoader, opcache *services.OpCache) echo.HandlerFunc {
return func(c echo.Context) error {
modelConfigs := cl.GetAllModelsConfigs()
galleryConfigs := map[string]*gallery.ModelConfig{}
@@ -40,10 +42,10 @@ func WelcomeEndpoint(appConfig *config.ApplicationConfig,
// Get model statuses to display in the UI the operation in progress
processingModels, taskTypes := opcache.GetStatus()
summary := fiber.Map{
summary := map[string]interface{}{
"Title": "LocalAI API - " + internal.PrintableVersion(),
"Version": internal.PrintableVersion(),
"BaseURL": utils.BaseURL(c),
"BaseURL": middleware.BaseURL(c),
"Models": modelsWithoutConfig,
"ModelsConfig": modelConfigs,
"GalleryConfig": galleryConfigs,
@@ -54,12 +56,21 @@ func WelcomeEndpoint(appConfig *config.ApplicationConfig,
"InstalledBackends": installedBackends,
}
if string(c.Context().Request.Header.ContentType()) == "application/json" || len(c.Accepts("html")) == 0 {
contentType := c.Request().Header.Get("Content-Type")
accept := c.Request().Header.Get("Accept")
// Default to HTML if Accept header is empty (browser behavior)
// Only return JSON if explicitly requested or Content-Type is application/json
if strings.Contains(contentType, "application/json") || (accept != "" && !strings.Contains(accept, "text/html")) {
// The client expects a JSON response
return c.Status(fiber.StatusOK).JSON(summary)
return c.JSON(200, summary)
} else {
// Render index
return c.Render("views/index", summary)
// Check if this is the manage route
templateName := "views/index"
if strings.HasSuffix(c.Request().URL.Path, "/manage") || c.Request().URL.Path == "/manage" {
templateName = "views/manage"
}
// Render appropriate template
return c.Render(200, templateName, summary)
}
}
}

View File

@@ -1,15 +1,12 @@
package openai
import (
"bufio"
"context"
"encoding/json"
"fmt"
"net"
"time"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
@@ -20,68 +17,14 @@ import (
"github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
"github.com/valyala/fasthttp"
)
// NOTE: this is a bad WORKAROUND! We should find a better way to handle this.
// Fasthttp doesn't support context cancellation from the caller
// for non-streaming requests, so we need to monitor the connection directly.
// Monitor connection for client disconnection during non-streaming requests
// We access the connection directly via c.Context().Conn() to monitor it
// during ComputeChoices execution, not after the response is sent
// see: https://github.com/mudler/LocalAI/pull/7187#issuecomment-3506720906
func handleConnectionCancellation(c *fiber.Ctx, cancelFunc func(), requestCtx context.Context) {
var conn net.Conn = c.Context().Conn()
if conn == nil {
return
}
go func() {
defer func() {
// Clear read deadline when goroutine exits
conn.SetReadDeadline(time.Time{})
}()
buf := make([]byte, 1)
// Use a short read deadline to periodically check if connection is closed
// Without a deadline, Read() would block indefinitely waiting for data
// that will never come (client is waiting for response, not sending more data)
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-requestCtx.Done():
// Request completed or was cancelled - exit goroutine
return
case <-ticker.C:
// Set a short deadline - if connection is closed, read will fail immediately
// If connection is open but no data, it will timeout and we check again
conn.SetReadDeadline(time.Now().Add(50 * time.Millisecond))
_, err := conn.Read(buf)
if err != nil {
// Check if it's a timeout (connection still open, just no data)
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
// Timeout is expected - connection is still open, just no data to read
// Continue the loop to check again
continue
}
// Connection closed or other error - cancel the context to stop gRPC call
log.Debug().Msgf("Calling cancellation function")
cancelFunc()
return
}
}
}
}()
}
// ChatEndpoint is the OpenAI Completion API endpoint https://platform.openai.com/docs/api-reference/chat/create
// @Summary Generate a chat completions for a given prompt and model.
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/chat/completions [post]
func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error {
func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig) echo.HandlerFunc {
var id, textContentToReturn string
var created int
@@ -235,21 +178,21 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
return err
}
return func(c *fiber.Ctx) error {
return func(c echo.Context) error {
textContentToReturn = ""
id = uuid.New().String()
created = int(time.Now().Unix())
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
extraUsage := c.Get("Extra-Usage", "") != ""
extraUsage := c.Request().Header.Get("Extra-Usage") != ""
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
log.Debug().Msgf("Chat endpoint configuration read: %+v", config)
@@ -392,13 +335,10 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
case toStream:
log.Debug().Msgf("Stream request received")
c.Context().SetContentType("text/event-stream")
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
// c.Set("Content-Type", "text/event-stream")
c.Set("Cache-Control", "no-cache")
c.Set("Connection", "keep-alive")
c.Set("Transfer-Encoding", "chunked")
c.Set("X-Correlation-ID", id)
c.Response().Header().Set("Content-Type", "text/event-stream")
c.Response().Header().Set("Cache-Control", "no-cache")
c.Response().Header().Set("Connection", "keep-alive")
c.Response().Header().Set("X-Correlation-ID", id)
responses := make(chan schema.OpenAIResponse)
ended := make(chan error, 1)
@@ -411,103 +351,101 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
}
}()
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
usage := &schema.OpenAIUsage{}
toolsCalled := false
usage := &schema.OpenAIUsage{}
toolsCalled := false
LOOP:
for {
select {
case <-input.Context.Done():
// Context was cancelled (client disconnected or request cancelled)
log.Debug().Msgf("Request context cancelled, stopping stream")
input.Cancel()
break LOOP
case ev := <-responses:
if len(ev.Choices) == 0 {
log.Debug().Msgf("No choices in the response, skipping")
continue
}
usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
if len(ev.Choices[0].Delta.ToolCalls) > 0 {
toolsCalled = true
}
respData, err := json.Marshal(ev)
if err != nil {
log.Debug().Msgf("Failed to marshal response: %v", err)
input.Cancel()
continue
}
log.Debug().Msgf("Sending chunk: %s", string(respData))
_, err = fmt.Fprintf(w, "data: %s\n\n", string(respData))
if err != nil {
log.Debug().Msgf("Sending chunk failed: %v", err)
input.Cancel()
}
w.Flush()
case err := <-ended:
if err == nil {
break LOOP
}
log.Error().Msgf("Stream ended with error: %v", err)
stopReason := FinishReasonStop
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
FinishReason: &stopReason,
Index: 0,
Delta: &schema.Message{Content: "Internal error: " + err.Error()},
}},
Object: "chat.completion.chunk",
Usage: *usage,
}
respData, marshalErr := json.Marshal(resp)
if marshalErr != nil {
log.Error().Msgf("Failed to marshal error response: %v", marshalErr)
// Send a simple error message as fallback
w.WriteString("data: {\"error\":\"Internal error\"}\n\n")
} else {
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
}
w.WriteString("data: [DONE]\n\n")
w.Flush()
return
LOOP:
for {
select {
case <-input.Context.Done():
// Context was cancelled (client disconnected or request cancelled)
log.Debug().Msgf("Request context cancelled, stopping stream")
input.Cancel()
break LOOP
case ev := <-responses:
if len(ev.Choices) == 0 {
log.Debug().Msgf("No choices in the response, skipping")
continue
}
usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
if len(ev.Choices[0].Delta.ToolCalls) > 0 {
toolsCalled = true
}
respData, err := json.Marshal(ev)
if err != nil {
log.Debug().Msgf("Failed to marshal response: %v", err)
input.Cancel()
continue
}
log.Debug().Msgf("Sending chunk: %s", string(respData))
_, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(respData))
if err != nil {
log.Debug().Msgf("Sending chunk failed: %v", err)
input.Cancel()
return err
}
c.Response().Flush()
case err := <-ended:
if err == nil {
break LOOP
}
log.Error().Msgf("Stream ended with error: %v", err)
stopReason := FinishReasonStop
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
FinishReason: &stopReason,
Index: 0,
Delta: &schema.Message{Content: "Internal error: " + err.Error()},
}},
Object: "chat.completion.chunk",
Usage: *usage,
}
respData, marshalErr := json.Marshal(resp)
if marshalErr != nil {
log.Error().Msgf("Failed to marshal error response: %v", marshalErr)
// Send a simple error message as fallback
fmt.Fprintf(c.Response().Writer, "data: {\"error\":\"Internal error\"}\n\n")
} else {
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
}
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
c.Response().Flush()
return nil
}
}
finishReason := FinishReasonStop
if toolsCalled && len(input.Tools) > 0 {
finishReason = FinishReasonToolCalls
} else if toolsCalled {
finishReason = FinishReasonFunctionCall
}
finishReason := FinishReasonStop
if toolsCalled && len(input.Tools) > 0 {
finishReason = FinishReasonToolCalls
} else if toolsCalled {
finishReason = FinishReasonFunctionCall
}
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
FinishReason: &finishReason,
Index: 0,
Delta: &schema.Message{},
}},
Object: "chat.completion.chunk",
Usage: *usage,
}
respData, _ := json.Marshal(resp)
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
w.WriteString("data: [DONE]\n\n")
w.Flush()
log.Debug().Msgf("Stream ended")
}))
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
FinishReason: &finishReason,
Index: 0,
Delta: &schema.Message{},
}},
Object: "chat.completion.chunk",
Usage: *usage,
}
respData, _ := json.Marshal(resp)
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
c.Response().Flush()
log.Debug().Msgf("Stream ended")
return nil
// no streaming mode
@@ -589,9 +527,8 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
}
// NOTE: this is a workaround as fasthttp
// context cancellation does not fire in non-streaming requests
handleConnectionCancellation(c, input.Cancel, input.Context)
// Echo properly supports context cancellation via c.Request().Context()
// No workaround needed!
result, tokenUsage, err := ComputeChoices(
input,
@@ -628,7 +565,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
log.Debug().Msgf("Response: %s", respData)
// Return the prediction in the response body
return c.JSON(resp)
return c.JSON(200, resp)
}
}
}
@@ -698,7 +635,32 @@ func handleQuestion(config *config.ModelConfig, cl *config.ModelConfigLoader, in
}
}
predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, config, cl, o, nil, toolsJSON, toolChoiceJSON)
// Extract logprobs from request
// According to OpenAI API: logprobs is boolean, top_logprobs (0-20) controls how many top tokens per position
var logprobs *int
var topLogprobs *int
if input.Logprobs.IsEnabled() {
// If logprobs is enabled, use top_logprobs if provided, otherwise default to 1
if input.TopLogprobs != nil {
topLogprobs = input.TopLogprobs
// For backend compatibility, set logprobs to the top_logprobs value
logprobs = input.TopLogprobs
} else {
// Default to 1 if logprobs is true but top_logprobs not specified
val := 1
logprobs = &val
topLogprobs = &val
}
}
// Extract logit_bias from request
// According to OpenAI API: logit_bias is a map of token IDs (as strings) to bias values (-100 to 100)
var logitBias map[string]float64
if len(input.LogitBias) > 0 {
logitBias = input.LogitBias
}
predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, config, cl, o, nil, toolsJSON, toolChoiceJSON, logprobs, topLogprobs, logitBias)
if err != nil {
log.Error().Err(err).Msg("model inference failed")
return "", err

View File

@@ -1,24 +1,22 @@
package openai
import (
"bufio"
"encoding/json"
"errors"
"fmt"
"time"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/templates"
"github.com/mudler/LocalAI/pkg/functions"
"github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
"github.com/valyala/fasthttp"
)
// CompletionEndpoint is the OpenAI Completion API endpoint https://platform.openai.com/docs/api-reference/completions
@@ -26,7 +24,7 @@ import (
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/completions [post]
func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
process := func(id string, s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error {
tokenCallback := func(s string, tokenUsage backend.TokenUsage) bool {
created := int(time.Now().Unix())
@@ -64,22 +62,25 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
return err
}
return func(c *fiber.Ctx) error {
return func(c echo.Context) error {
created := int(time.Now().Unix())
// Handle Correlation
id := c.Get("X-Correlation-ID", uuid.New().String())
extraUsage := c.Get("Extra-Usage", "") != ""
id := c.Request().Header.Get("X-Correlation-ID")
if id == "" {
id = uuid.New().String()
}
extraUsage := c.Request().Header.Get("Extra-Usage") != ""
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
if config.ResponseFormatMap != nil {
@@ -97,15 +98,10 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
if input.Stream {
log.Debug().Msgf("Stream request received")
c.Context().SetContentType("text/event-stream")
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
//c.Set("Content-Type", "text/event-stream")
c.Set("Cache-Control", "no-cache")
c.Set("Connection", "keep-alive")
c.Set("Transfer-Encoding", "chunked")
}
c.Response().Header().Set("Content-Type", "text/event-stream")
c.Response().Header().Set("Cache-Control", "no-cache")
c.Response().Header().Set("Connection", "keep-alive")
if input.Stream {
if len(config.PromptStrings) > 1 {
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming")
}
@@ -130,78 +126,78 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
ended <- process(id, predInput, input, config, ml, responses, extraUsage)
}()
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
LOOP:
for {
select {
case ev := <-responses:
if len(ev.Choices) == 0 {
log.Debug().Msgf("No choices in the response, skipping")
continue
}
respData, err := json.Marshal(ev)
if err != nil {
log.Debug().Msgf("Failed to marshal response: %v", err)
continue
}
LOOP:
for {
select {
case ev := <-responses:
if len(ev.Choices) == 0 {
log.Debug().Msgf("No choices in the response, skipping")
continue
}
respData, err := json.Marshal(ev)
if err != nil {
log.Debug().Msgf("Failed to marshal response: %v", err)
continue
}
log.Debug().Msgf("Sending chunk: %s", string(respData))
fmt.Fprintf(w, "data: %s\n\n", string(respData))
w.Flush()
case err := <-ended:
if err == nil {
break LOOP
}
log.Error().Msgf("Stream ended with error: %v", err)
stopReason := FinishReasonStop
errorResp := schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model,
Choices: []schema.Choice{
{
Index: 0,
FinishReason: &stopReason,
Text: "Internal error: " + err.Error(),
},
},
Object: "text_completion",
}
errorData, marshalErr := json.Marshal(errorResp)
if marshalErr != nil {
log.Error().Msgf("Failed to marshal error response: %v", marshalErr)
// Send a simple error message as fallback
fmt.Fprintf(w, "data: {\"error\":\"Internal error\"}\n\n")
} else {
fmt.Fprintf(w, "data: %s\n\n", string(errorData))
}
w.Flush()
log.Debug().Msgf("Sending chunk: %s", string(respData))
_, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(respData))
if err != nil {
return err
}
c.Response().Flush()
case err := <-ended:
if err == nil {
break LOOP
}
}
log.Error().Msgf("Stream ended with error: %v", err)
stopReason := FinishReasonStop
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
Index: 0,
FinishReason: &stopReason,
stopReason := FinishReasonStop
errorResp := schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model,
Choices: []schema.Choice{
{
Index: 0,
FinishReason: &stopReason,
Text: "Internal error: " + err.Error(),
},
},
},
Object: "text_completion",
Object: "text_completion",
}
errorData, marshalErr := json.Marshal(errorResp)
if marshalErr != nil {
log.Error().Msgf("Failed to marshal error response: %v", marshalErr)
// Send a simple error message as fallback
fmt.Fprintf(c.Response().Writer, "data: {\"error\":\"Internal error\"}\n\n")
} else {
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(errorData))
}
c.Response().Flush()
return nil
}
respData, _ := json.Marshal(resp)
}
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
w.WriteString("data: [DONE]\n\n")
w.Flush()
}))
return <-ended
stopReason := FinishReasonStop
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
Index: 0,
FinishReason: &stopReason,
},
},
Object: "text_completion",
}
respData, _ := json.Marshal(resp)
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
c.Response().Flush()
return nil
}
var result []schema.Choice
@@ -257,6 +253,6 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
return c.JSON(200, resp)
}
}

View File

@@ -4,11 +4,11 @@ import (
"encoding/json"
"time"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/schema"
@@ -23,20 +23,20 @@ import (
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/edits [post]
func EditEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
func EditEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c *fiber.Ctx) error {
return func(c echo.Context) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
// Opt-in extra usage flag
extraUsage := c.Get("Extra-Usage", "") != ""
extraUsage := c.Request().Header.Get("Extra-Usage") != ""
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
log.Debug().Msgf("Edit Endpoint Input : %+v", input)
@@ -98,6 +98,6 @@ func EditEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
return c.JSON(200, resp)
}
}

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"time"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
@@ -12,7 +13,6 @@ import (
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/schema"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
@@ -21,16 +21,16 @@ import (
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/embeddings [post]
func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
log.Debug().Msgf("Parameter Config: %+v", config)
@@ -78,6 +78,6 @@ func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
return c.JSON(200, resp)
}
}

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"strconv"
@@ -14,13 +15,13 @@ import (
"time"
"github.com/google/uuid"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/backend"
"github.com/gofiber/fiber/v2"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
)
@@ -65,18 +66,18 @@ func downloadFile(url string) (string, error) {
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/images/generations [post]
func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
log.Error().Msg("Image Endpoint - Invalid Input")
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil {
log.Error().Msg("Image Endpoint - Invalid Config")
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
// Process input images (for img2img/inpainting)
@@ -188,7 +189,7 @@ func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
return err
}
baseURL := c.BaseURL()
baseURL := middleware.BaseURL(c)
// Use the first input image as src if available, otherwise use the original src
inputSrc := src
@@ -215,7 +216,10 @@ func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
item.B64JSON = base64.StdEncoding.EncodeToString(data)
} else {
base := filepath.Base(output)
item.URL = baseURL + "/generated-images/" + base
item.URL, err = url.JoinPath(baseURL, "generated-images", base)
if err != nil {
return err
}
}
result = append(result, *item)
@@ -234,7 +238,7 @@ func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
return c.JSON(200, resp)
}
}

View File

@@ -55,9 +55,34 @@ func ComputeChoices(
}
}
// Extract logprobs from request
// According to OpenAI API: logprobs is boolean, top_logprobs (0-20) controls how many top tokens per position
var logprobs *int
var topLogprobs *int
if req.Logprobs.IsEnabled() {
// If logprobs is enabled, use top_logprobs if provided, otherwise default to 1
if req.TopLogprobs != nil {
topLogprobs = req.TopLogprobs
// For backend compatibility, set logprobs to the top_logprobs value
logprobs = req.TopLogprobs
} else {
// Default to 1 if logprobs is true but top_logprobs not specified
val := 1
logprobs = &val
topLogprobs = &val
}
}
// Extract logit_bias from request
// According to OpenAI API: logit_bias is a map of token IDs (as strings) to bias values (-100 to 100)
var logitBias map[string]float64
if len(req.LogitBias) > 0 {
logitBias = req.LogitBias
}
// get the model function to call for the result
predFunc, err := backend.ModelInference(
req.Context, predInput, req.Messages, images, videos, audios, loader, config, bcl, o, tokenCallback, toolsJSON, toolChoiceJSON)
req.Context, predInput, req.Messages, images, videos, audios, loader, config, bcl, o, tokenCallback, toolsJSON, toolChoiceJSON, logprobs, topLogprobs, logitBias)
if err != nil {
return result, backend.TokenUsage{}, err
}
@@ -78,6 +103,11 @@ func ComputeChoices(
finetunedResponse := backend.Finetune(*config, predInput, prediction.Response)
cb(finetunedResponse, &result)
// Add logprobs to the last choice if present
if prediction.Logprobs != nil && len(result) > 0 {
result[len(result)-1].Logprobs = prediction.Logprobs
}
//result = append(result, Choice{Text: prediction})
}

View File

@@ -1,7 +1,7 @@
package openai
import (
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
@@ -12,14 +12,15 @@ import (
// @Summary List and describe the various models available in the API.
// @Success 200 {object} schema.ModelsDataResponse "Response"
// @Router /v1/models [get]
func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(ctx *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
// If blank, no filter is applied.
filter := c.Query("filter")
filter := c.QueryParam("filter")
// By default, exclude any loose files that are already referenced by a configuration file.
var policy services.LooseFilePolicy
if c.QueryBool("excludeConfigured", true) {
excludeConfigured := c.QueryParam("excludeConfigured")
if excludeConfigured == "" || excludeConfigured == "true" {
policy = services.SKIP_IF_CONFIGURED
} else {
policy = services.ALWAYS_INCLUDE // This replicates current behavior. TODO: give more options to the user?
@@ -41,7 +42,7 @@ func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, ap
dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"})
}
return c.JSON(schema.ModelsDataResponse{
return c.JSON(200, schema.ModelsDataResponse{
Object: "list",
Data: dataModels,
})

View File

@@ -8,11 +8,11 @@ import (
"strings"
"time"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/templates"
@@ -26,24 +26,27 @@ import (
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /mcp/v1/completions [post]
func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
// We do not support streaming mode (Yet?)
return func(c *fiber.Ctx) error {
return func(c echo.Context) error {
created := int(time.Now().Unix())
ctx := c.Context()
ctx := c.Request().Context()
// Handle Correlation
id := c.Get("X-Correlation-ID", uuid.New().String())
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
id := c.Request().Header.Get("X-Correlation-ID")
if id == "" {
id = uuid.New().String()
}
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return echo.ErrBadRequest
}
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
if config.MCP.Servers == "" && config.MCP.Stdio == "" {
@@ -80,47 +83,34 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
ctxWithCancellation, cancel := context.WithCancel(ctx)
defer cancel()
handleConnectionCancellation(c, cancel, ctxWithCancellation)
// TODO: instead of connecting to the API, we should just wire this internally
// and act like completion.go.
// We can do this as cogito expects an interface and we can create one that
// we satisfy to just call internally ComputeChoices
defaultLLM := cogito.NewOpenAILLM(config.Name, apiKey, "http://127.0.0.1:"+port)
cogitoOpts := []cogito.Option{
// Build cogito options using the consolidated method
cogitoOpts := config.BuildCogitoOptions()
cogitoOpts = append(
cogitoOpts,
cogito.WithContext(ctxWithCancellation),
cogito.WithMCPs(sessions...),
cogito.WithStatusCallback(func(s string) {
log.Debug().Msgf("[model agent] [model: %s] Status: %s", config.Name, s)
}),
cogito.WithContext(ctxWithCancellation),
cogito.WithMCPs(sessions...),
cogito.WithIterations(3), // default to 3 iterations
cogito.WithMaxAttempts(3), // default to 3 attempts
cogito.WithForceReasoning(),
}
if config.Agent.EnableReasoning {
cogitoOpts = append(cogitoOpts, cogito.EnableToolReasoner)
}
if config.Agent.EnablePlanning {
cogitoOpts = append(cogitoOpts, cogito.EnableAutoPlan)
}
if config.Agent.EnableMCPPrompts {
cogitoOpts = append(cogitoOpts, cogito.EnableMCPPrompts)
}
if config.Agent.EnablePlanReEvaluator {
cogitoOpts = append(cogitoOpts, cogito.EnableAutoPlanReEvaluator)
}
if config.Agent.MaxIterations != 0 {
cogitoOpts = append(cogitoOpts, cogito.WithIterations(config.Agent.MaxIterations))
}
if config.Agent.MaxAttempts != 0 {
cogitoOpts = append(cogitoOpts, cogito.WithMaxAttempts(config.Agent.MaxAttempts))
}
cogito.WithReasoningCallback(func(s string) {
log.Debug().Msgf("[model agent] [model: %s] Reasoning: %s", config.Name, s)
}),
cogito.WithToolCallBack(func(t *cogito.ToolChoice) bool {
log.Debug().Msgf("[model agent] [model: %s] Tool call: %s, reasoning: %s, arguments: %+v", t.Name, t.Reasoning, t.Arguments)
return true
}),
cogito.WithToolCallResultCallback(func(t cogito.ToolStatus) {
log.Debug().Msgf("[model agent] [model: %s] Tool call result: %s, tool arguments: %+v", t.Name, t.Result, t.ToolArguments)
}),
)
f, err := cogito.ExecuteTools(
defaultLLM, fragment,
@@ -147,6 +137,6 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
return c.JSON(200, resp)
}
}

View File

@@ -10,9 +10,11 @@ import (
"sync"
"time"
"net/http"
"github.com/go-audio/audio"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/websocket/v2"
"github.com/gorilla/websocket"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/openai/types"
@@ -167,32 +169,50 @@ type Model interface {
PredictStream(ctx context.Context, in *proto.PredictOptions, f func(*proto.Reply), opts ...grpc.CallOption) error
}
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true // Allow all origins
},
}
// TODO: Implement ephemeral keys to allow these endpoints to be used
func RealtimeSessions(application *application.Application) fiber.Handler {
return func(ctx *fiber.Ctx) error {
return ctx.SendStatus(501)
func RealtimeSessions(application *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
return c.NoContent(501)
}
}
func RealtimeTranscriptionSession(application *application.Application) fiber.Handler {
return func(ctx *fiber.Ctx) error {
return ctx.SendStatus(501)
func RealtimeTranscriptionSession(application *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
return c.NoContent(501)
}
}
func Realtime(application *application.Application) fiber.Handler {
return websocket.New(registerRealtime(application))
func Realtime(application *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
ws, err := upgrader.Upgrade(c.Response(), c.Request(), nil)
if err != nil {
return err
}
defer ws.Close()
// Extract query parameters from Echo context before passing to websocket handler
model := c.QueryParam("model")
if model == "" {
model = "gpt-4o"
}
intent := c.QueryParam("intent")
registerRealtime(application, model, intent)(ws)
return nil
}
}
func registerRealtime(application *application.Application) func(c *websocket.Conn) {
func registerRealtime(application *application.Application, model, intent string) func(c *websocket.Conn) {
return func(c *websocket.Conn) {
evaluator := application.TemplatesEvaluator()
log.Debug().Msgf("WebSocket connection established with '%s'", c.RemoteAddr().String())
model := c.Query("model", "gpt-4o")
intent := c.Query("intent")
if intent != "transcription" {
sendNotImplemented(c, "Only transcription mode is supported which requires the intent=transcription parameter")
}
@@ -1067,7 +1087,7 @@ func processTextResponse(config *config.ModelConfig, session *Session, prompt st
// For example, the model might return a special token or JSON indicating a function call
/*
predFunc, err := backend.ModelInference(context.Background(), prompt, input.Messages, images, videos, audios, ml, *config, o, nil)
predFunc, err := backend.ModelInference(context.Background(), prompt, input.Messages, images, videos, audios, ml, *config, o, nil, "", "", nil, nil, nil)
result, tokenUsage, err := ComputeChoices(input, prompt, config, startupOptions, ml, func(s string, c *[]schema.Choice) {
if !shouldUseFn {

View File

@@ -7,13 +7,13 @@ import (
"path"
"path/filepath"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
@@ -24,19 +24,19 @@ import (
// @Param file formData file true "file"
// @Success 200 {object} map[string]string "Response"
// @Router /v1/audio/transcriptions [post]
func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
diarize := c.FormValue("diarize", "false") != "false"
diarize := c.FormValue("diarize") != "false"
// retrieve the file data from the request
file, err := c.FormFile("file")
@@ -76,6 +76,6 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
log.Debug().Msgf("Trascribed: %+v", tr)
// TODO: handle different outputs here
return c.Status(http.StatusOK).JSON(tr)
return c.JSON(http.StatusOK, tr)
}
}

View File

@@ -6,7 +6,7 @@ import (
"strconv"
"strings"
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/localai"
"github.com/mudler/LocalAI/core/http/middleware"
@@ -14,20 +14,24 @@ import (
model "github.com/mudler/LocalAI/pkg/model"
)
func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input == nil {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
var raw map[string]interface{}
if body := c.Body(); len(body) > 0 {
body := make([]byte, 0)
if c.Request().Body != nil {
c.Request().Body.Read(body)
}
if len(body) > 0 {
_ = json.Unmarshal(body, &raw)
}
// Build VideoRequest using shared mapper
vr := MapOpenAIToVideo(input, raw)
// Place VideoRequest into locals so localai.VideoEndpoint can consume it
c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, vr)
// Place VideoRequest into context so localai.VideoEndpoint can consume it
c.Set(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, vr)
// Delegate to existing localai handler
return localai.VideoEndpoint(cl, ml, appConfig)(c)
}

View File

@@ -1,48 +1,50 @@
package http
import (
"io/fs"
"net/http"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/favicon"
"github.com/gofiber/fiber/v2/middleware/filesystem"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/explorer"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/http/routes"
"github.com/rs/zerolog/log"
)
func Explorer(db *explorer.Database) *fiber.App {
func Explorer(db *explorer.Database) *echo.Echo {
e := echo.New()
fiberCfg := fiber.Config{
Views: renderEngine(),
// We disable the Fiber startup message as it does not conform to structured logging.
// We register a startup log line with connection information in the OnListen hook to keep things user friendly though
DisableStartupMessage: false,
// Override default error handler
// Set renderer
e.Renderer = renderEngine()
// Hide banner
e.HideBanner = true
e.Pre(middleware.StripPathPrefix())
routes.RegisterExplorerRoutes(e, db)
// Favicon handler
e.GET("/favicon.svg", func(c echo.Context) error {
data, err := embedDirStatic.ReadFile("static/favicon.svg")
if err != nil {
return c.NoContent(http.StatusNotFound)
}
c.Response().Header().Set("Content-Type", "image/svg+xml")
return c.Blob(http.StatusOK, "image/svg+xml", data)
})
// Static files - use fs.Sub to create a filesystem rooted at "static"
staticFS, err := fs.Sub(embedDirStatic, "static")
if err != nil {
// Log error but continue - static files might not work
log.Error().Err(err).Msg("failed to create static filesystem")
} else {
e.StaticFS("/static", staticFS)
}
app := fiber.New(fiberCfg)
app.Use(middleware.StripPathPrefix())
routes.RegisterExplorerRoutes(app, db)
httpFS := http.FS(embedDirStatic)
app.Use(favicon.New(favicon.Config{
URL: "/favicon.svg",
FileSystem: httpFS,
File: "static/favicon.svg",
}))
app.Use("/static", filesystem.New(filesystem.Config{
Root: httpFS,
PathPrefix: "static",
Browse: true,
}))
// Define a custom 404 handler
// Note: keep this at the bottom!
app.Use(notFoundHandler)
e.GET("/*", notFoundHandler)
return app
return e
}

View File

@@ -3,50 +3,108 @@ package middleware
import (
"crypto/subtle"
"errors"
"net/http"
"strings"
"github.com/dave-gray101/v2keyauth"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/keyauth"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/core/schema"
)
// This file contains the configuration generators and handler functions that are used along with the fiber/keyauth middleware
// Currently this requires an upstream patch - and feature patches are no longer accepted to v2
// Therefore `dave-gray101/v2keyauth` contains the v2 backport of the middleware until v3 stabilizes and we migrate.
var ErrMissingOrMalformedAPIKey = errors.New("missing or malformed API Key")
func GetKeyAuthConfig(applicationConfig *config.ApplicationConfig) (*v2keyauth.Config, error) {
customLookup, err := v2keyauth.MultipleKeySourceLookup([]string{"header:Authorization", "header:x-api-key", "header:xi-api-key", "cookie:token"}, keyauth.ConfigDefault.AuthScheme)
if err != nil {
return nil, err
}
// GetKeyAuthConfig returns Echo's KeyAuth middleware configuration
func GetKeyAuthConfig(applicationConfig *config.ApplicationConfig) (echo.MiddlewareFunc, error) {
// Create validator function
validator := getApiKeyValidationFunction(applicationConfig)
return &v2keyauth.Config{
CustomKeyLookup: customLookup,
Next: getApiKeyRequiredFilterFunction(applicationConfig),
Validator: getApiKeyValidationFunction(applicationConfig),
ErrorHandler: getApiKeyErrorHandler(applicationConfig),
AuthScheme: "Bearer",
// Create error handler
errorHandler := getApiKeyErrorHandler(applicationConfig)
// Create Next function (skip middleware for certain requests)
skipper := getApiKeyRequiredFilterFunction(applicationConfig)
// Wrap it with our custom key lookup that checks multiple sources
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if len(applicationConfig.ApiKeys) == 0 {
return next(c)
}
// Skip if skipper says so
if skipper != nil && skipper(c) {
return next(c)
}
// Try to extract key from multiple sources
key, err := extractKeyFromMultipleSources(c)
if err != nil {
return errorHandler(err, c)
}
// Validate the key
valid, err := validator(key, c)
if err != nil || !valid {
return errorHandler(ErrMissingOrMalformedAPIKey, c)
}
// Store key in context for later use
c.Set("api_key", key)
return next(c)
}
}, nil
}
func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) fiber.ErrorHandler {
return func(ctx *fiber.Ctx, err error) error {
if errors.Is(err, v2keyauth.ErrMissingOrMalformedAPIKey) {
// extractKeyFromMultipleSources checks multiple sources for the API key
// in order: Authorization header, x-api-key header, xi-api-key header, token cookie
func extractKeyFromMultipleSources(c echo.Context) (string, error) {
// Check Authorization header first
auth := c.Request().Header.Get("Authorization")
if auth != "" {
// Check for Bearer scheme
if strings.HasPrefix(auth, "Bearer ") {
return strings.TrimPrefix(auth, "Bearer "), nil
}
// If no Bearer prefix, return as-is (for backward compatibility)
return auth, nil
}
// Check x-api-key header
if key := c.Request().Header.Get("x-api-key"); key != "" {
return key, nil
}
// Check xi-api-key header
if key := c.Request().Header.Get("xi-api-key"); key != "" {
return key, nil
}
// Check token cookie
cookie, err := c.Cookie("token")
if err == nil && cookie != nil && cookie.Value != "" {
return cookie.Value, nil
}
return "", ErrMissingOrMalformedAPIKey
}
func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) func(error, echo.Context) error {
return func(err error, c echo.Context) error {
if errors.Is(err, ErrMissingOrMalformedAPIKey) {
if len(applicationConfig.ApiKeys) == 0 {
return ctx.Next() // if no keys are set up, any error we get here is not an error.
return nil // if no keys are set up, any error we get here is not an error.
}
ctx.Set("WWW-Authenticate", "Bearer")
c.Response().Header().Set("WWW-Authenticate", "Bearer")
if applicationConfig.OpaqueErrors {
return ctx.SendStatus(401)
return c.NoContent(http.StatusUnauthorized)
}
// Check if the request content type is JSON
contentType := string(ctx.Context().Request.Header.ContentType())
contentType := c.Request().Header.Get("Content-Type")
if strings.Contains(contentType, "application/json") {
return ctx.Status(401).JSON(schema.ErrorResponse{
return c.JSON(http.StatusUnauthorized, schema.ErrorResponse{
Error: &schema.APIError{
Message: "An authentication key is required",
Code: 401,
@@ -55,50 +113,69 @@ func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) fiber.Er
})
}
return ctx.Status(401).Render("views/login", fiber.Map{
"BaseURL": utils.BaseURL(ctx),
return c.Render(http.StatusUnauthorized, "views/login", map[string]interface{}{
"BaseURL": BaseURL(c),
})
}
if applicationConfig.OpaqueErrors {
return ctx.SendStatus(500)
return c.NoContent(http.StatusInternalServerError)
}
return err
}
}
func getApiKeyValidationFunction(applicationConfig *config.ApplicationConfig) func(*fiber.Ctx, string) (bool, error) {
func getApiKeyValidationFunction(applicationConfig *config.ApplicationConfig) func(string, echo.Context) (bool, error) {
if applicationConfig.UseSubtleKeyComparison {
return func(ctx *fiber.Ctx, apiKey string) (bool, error) {
return func(key string, c echo.Context) (bool, error) {
if len(applicationConfig.ApiKeys) == 0 {
return true, nil // If no keys are setup, accept everything
}
for _, validKey := range applicationConfig.ApiKeys {
if subtle.ConstantTimeCompare([]byte(apiKey), []byte(validKey)) == 1 {
if subtle.ConstantTimeCompare([]byte(key), []byte(validKey)) == 1 {
return true, nil
}
}
return false, v2keyauth.ErrMissingOrMalformedAPIKey
return false, ErrMissingOrMalformedAPIKey
}
}
return func(ctx *fiber.Ctx, apiKey string) (bool, error) {
return func(key string, c echo.Context) (bool, error) {
if len(applicationConfig.ApiKeys) == 0 {
return true, nil // If no keys are setup, accept everything
}
for _, validKey := range applicationConfig.ApiKeys {
if apiKey == validKey {
if key == validKey {
return true, nil
}
}
return false, v2keyauth.ErrMissingOrMalformedAPIKey
return false, ErrMissingOrMalformedAPIKey
}
}
func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig) func(*fiber.Ctx) bool {
if applicationConfig.DisableApiKeyRequirementForHttpGet {
return func(c *fiber.Ctx) bool {
if c.Method() != "GET" {
func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig) middleware.Skipper {
return func(c echo.Context) bool {
path := c.Request().URL.Path
// Always skip authentication for static files
if strings.HasPrefix(path, "/static/") {
return true
}
// Always skip authentication for generated content
if strings.HasPrefix(path, "/generated-audio/") ||
strings.HasPrefix(path, "/generated-images/") ||
strings.HasPrefix(path, "/generated-videos/") {
return true
}
// Skip authentication for favicon
if path == "/favicon.svg" {
return true
}
// Handle GET request exemptions if enabled
if applicationConfig.DisableApiKeyRequirementForHttpGet {
if c.Request().Method != http.MethodGet {
return false
}
for _, rx := range applicationConfig.HttpGetExemptedEndpoints {
@@ -106,8 +183,8 @@ func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig
return true
}
}
return false
}
return false
}
return func(c *fiber.Ctx) bool { return false }
}

View File

@@ -0,0 +1,48 @@
package middleware
import (
"strings"
"github.com/labstack/echo/v4"
)
// BaseURL returns the base URL for the given HTTP request context.
// It takes into account that the app may be exposed by a reverse-proxy under a different protocol, host and path.
// The returned URL is guaranteed to end with `/`.
// The method should be used in conjunction with the StripPathPrefix middleware.
func BaseURL(c echo.Context) string {
path := c.Path()
origPath := c.Request().URL.Path
// Check if StripPathPrefix middleware stored the original path
if storedPath, ok := c.Get("_original_path").(string); ok && storedPath != "" {
origPath = storedPath
}
// Check X-Forwarded-Proto for scheme
scheme := "http"
if c.Request().Header.Get("X-Forwarded-Proto") == "https" {
scheme = "https"
} else if c.Request().TLS != nil {
scheme = "https"
}
// Check X-Forwarded-Host for host
host := c.Request().Host
if forwardedHost := c.Request().Header.Get("X-Forwarded-Host"); forwardedHost != "" {
host = forwardedHost
}
if path != origPath && strings.HasSuffix(origPath, path) && len(path) > 0 {
prefixLen := len(origPath) - len(path)
if prefixLen > 0 && prefixLen <= len(origPath) {
pathPrefix := origPath[:prefixLen]
if !strings.HasSuffix(pathPrefix, "/") {
pathPrefix += "/"
}
return scheme + "://" + host + pathPrefix
}
}
return scheme + "://" + host + "/"
}

View File

@@ -0,0 +1,58 @@
package middleware
import (
"net/http/httptest"
"github.com/labstack/echo/v4"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("BaseURL", func() {
Context("without prefix", func() {
It("should return base URL without prefix", func() {
app := echo.New()
actualURL := ""
// Register route - use the actual request path so routing works
routePath := "/hello/world"
app.GET(routePath, func(c echo.Context) error {
actualURL = BaseURL(c)
return nil
})
req := httptest.NewRequest("GET", "/hello/world", nil)
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(200), "response status code")
Expect(actualURL).To(Equal("http://example.com/"), "base URL")
})
})
Context("with prefix", func() {
It("should return base URL with prefix", func() {
app := echo.New()
actualURL := ""
// Register route with the stripped path (after middleware removes prefix)
routePath := "/hello/world"
app.GET(routePath, func(c echo.Context) error {
// Simulate what StripPathPrefix middleware does - store original path
c.Set("_original_path", "/myprefix/hello/world")
// Modify the request path to simulate prefix stripping
c.Request().URL.Path = "/hello/world"
actualURL = BaseURL(c)
return nil
})
// Make request with stripped path (middleware would have already processed it)
req := httptest.NewRequest("GET", "/hello/world", nil)
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(200), "response status code")
Expect(actualURL).To(Equal("http://example.com/myprefix/"), "base URL")
})
})
})

View File

@@ -0,0 +1,13 @@
package middleware_test
import (
"testing"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestMiddleware(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Middleware test suite")
}

View File

@@ -1,470 +1,482 @@
package middleware
import (
"context"
"encoding/json"
"fmt"
"strconv"
"strings"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/core/templates"
"github.com/mudler/LocalAI/pkg/functions"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/utils"
"github.com/valyala/fasthttp"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
type correlationIDKeyType string
// CorrelationIDKey to track request across process boundary
const CorrelationIDKey correlationIDKeyType = "correlationID"
type RequestExtractor struct {
modelConfigLoader *config.ModelConfigLoader
modelLoader *model.ModelLoader
applicationConfig *config.ApplicationConfig
}
func NewRequestExtractor(modelConfigLoader *config.ModelConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor {
return &RequestExtractor{
modelConfigLoader: modelConfigLoader,
modelLoader: modelLoader,
applicationConfig: applicationConfig,
}
}
const CONTEXT_LOCALS_KEY_MODEL_NAME = "MODEL_NAME"
const CONTEXT_LOCALS_KEY_LOCALAI_REQUEST = "LOCALAI_REQUEST"
const CONTEXT_LOCALS_KEY_MODEL_CONFIG = "MODEL_CONFIG"
// TODO: Refactor to not return error if unchanged
func (re *RequestExtractor) setModelNameFromRequest(ctx *fiber.Ctx) {
model, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if ok && model != "" {
return
}
model = ctx.Params("model")
if (model == "") && ctx.Query("model") != "" {
model = ctx.Query("model")
}
if model == "" {
// Set model from bearer token, if available
bearer := strings.TrimLeft(ctx.Get("authorization"), "Bear ") // "Bearer " => "Bear" to please go-staticcheck. It looks dumb but we might as well take free performance on something called for nearly every request.
if bearer != "" {
exists, err := services.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, bearer, services.ALWAYS_INCLUDE)
if err == nil && exists {
model = bearer
}
}
}
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, model)
}
func (re *RequestExtractor) BuildConstantDefaultModelNameMiddleware(defaultModelName string) fiber.Handler {
return func(ctx *fiber.Ctx) error {
re.setModelNameFromRequest(ctx)
localModelName, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if !ok || localModelName == "" {
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, defaultModelName)
log.Debug().Str("defaultModelName", defaultModelName).Msg("context local model name not found, setting to default")
}
return ctx.Next()
}
}
func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.ModelConfigFilterFn) fiber.Handler {
return func(ctx *fiber.Ctx) error {
re.setModelNameFromRequest(ctx)
localModelName := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if localModelName != "" { // Don't overwrite existing values
return ctx.Next()
}
modelNames, err := services.ListModels(re.modelConfigLoader, re.modelLoader, filterFn, services.SKIP_IF_CONFIGURED)
if err != nil {
log.Error().Err(err).Msg("non-fatal error calling ListModels during SetDefaultModelNameToFirstAvailable()")
return ctx.Next()
}
if len(modelNames) == 0 {
log.Warn().Msg("SetDefaultModelNameToFirstAvailable used with no matching models installed")
// This is non-fatal - making it so was breaking the case of direct installation of raw models
// return errors.New("this endpoint requires at least one model to be installed")
return ctx.Next()
}
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, modelNames[0])
log.Debug().Str("first model name", modelNames[0]).Msg("context local model name not found, setting to the first model")
return ctx.Next()
}
}
// TODO: If context and cancel above belong on all methods, move that part of above into here!
// Otherwise, it's in its own method below for now
func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIRequest) fiber.Handler {
return func(ctx *fiber.Ctx) error {
input := initializer()
if input == nil {
return fmt.Errorf("unable to initialize body")
}
if err := ctx.BodyParser(input); err != nil {
return fmt.Errorf("failed parsing request body: %w", err)
}
// If this request doesn't have an associated model name, fetch it from earlier in the middleware chain
if input.ModelName(nil) == "" {
localModelName, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if ok && localModelName != "" {
log.Debug().Str("context localModelName", localModelName).Msg("overriding empty model name in request body with value found earlier in middleware chain")
input.ModelName(&localModelName)
}
}
cfg, err := re.modelConfigLoader.LoadModelConfigFileByNameDefaultOptions(input.ModelName(nil), re.applicationConfig)
if err != nil {
log.Err(err)
log.Warn().Msgf("Model Configuration File not found for %q", input.ModelName(nil))
} else if cfg.Model == "" && input.ModelName(nil) != "" {
log.Debug().Str("input.ModelName", input.ModelName(nil)).Msg("config does not include model, using input")
cfg.Model = input.ModelName(nil)
}
ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
return ctx.Next()
}
}
func (re *RequestExtractor) SetOpenAIRequest(ctx *fiber.Ctx) error {
input, ok := ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
}
cfg, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
}
// Extract or generate the correlation ID
correlationID := ctx.Get("X-Correlation-ID", uuid.New().String())
ctx.Set("X-Correlation-ID", correlationID)
//c1, cancel := context.WithCancel(re.applicationConfig.Context)
// Use the application context as parent to ensure cancellation on app shutdown
// We'll monitor the Fiber context separately and cancel our context when the request is canceled
c1, cancel := context.WithCancel(re.applicationConfig.Context)
// Monitor the Fiber context and cancel our context when it's canceled
// This ensures we respect request cancellation without causing panics
go func(fiberCtx *fasthttp.RequestCtx) {
if fiberCtx != nil {
<-fiberCtx.Done()
cancel()
}
}(ctx.Context())
// Add the correlation ID to the new context
ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID)
input.Context = ctxWithCorrelationID
input.Cancel = cancel
err := mergeOpenAIRequestAndModelConfig(cfg, input)
if err != nil {
return err
}
if cfg.Model == "" {
log.Debug().Str("input.Model", input.Model).Msg("replacing empty cfg.Model with input value")
cfg.Model = input.Model
}
ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
return ctx.Next()
}
func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.OpenAIRequest) error {
if input.Echo {
config.Echo = input.Echo
}
if input.TopK != nil {
config.TopK = input.TopK
}
if input.TopP != nil {
config.TopP = input.TopP
}
if input.Backend != "" {
config.Backend = input.Backend
}
if input.ClipSkip != 0 {
config.Diffusers.ClipSkip = input.ClipSkip
}
if input.NegativePromptScale != 0 {
config.NegativePromptScale = input.NegativePromptScale
}
if input.NegativePrompt != "" {
config.NegativePrompt = input.NegativePrompt
}
if input.RopeFreqBase != 0 {
config.RopeFreqBase = input.RopeFreqBase
}
if input.RopeFreqScale != 0 {
config.RopeFreqScale = input.RopeFreqScale
}
if input.Grammar != "" {
config.Grammar = input.Grammar
}
if input.Temperature != nil {
config.Temperature = input.Temperature
}
if input.Maxtokens != nil {
config.Maxtokens = input.Maxtokens
}
if input.ResponseFormat != nil {
switch responseFormat := input.ResponseFormat.(type) {
case string:
config.ResponseFormat = responseFormat
case map[string]interface{}:
config.ResponseFormatMap = responseFormat
}
}
switch stop := input.Stop.(type) {
case string:
if stop != "" {
config.StopWords = append(config.StopWords, stop)
}
case []interface{}:
for _, pp := range stop {
if s, ok := pp.(string); ok {
config.StopWords = append(config.StopWords, s)
}
}
}
if len(input.Tools) > 0 {
for _, tool := range input.Tools {
input.Functions = append(input.Functions, tool.Function)
}
}
if input.ToolsChoice != nil {
var toolChoice functions.Tool
switch content := input.ToolsChoice.(type) {
case string:
_ = json.Unmarshal([]byte(content), &toolChoice)
case map[string]interface{}:
dat, _ := json.Marshal(content)
_ = json.Unmarshal(dat, &toolChoice)
}
input.FunctionCall = map[string]interface{}{
"name": toolChoice.Function.Name,
}
}
// Decode each request's message content
imgIndex, vidIndex, audioIndex := 0, 0, 0
for i, m := range input.Messages {
nrOfImgsInMessage := 0
nrOfVideosInMessage := 0
nrOfAudiosInMessage := 0
switch content := m.Content.(type) {
case string:
input.Messages[i].StringContent = content
case []interface{}:
dat, _ := json.Marshal(content)
c := []schema.Content{}
json.Unmarshal(dat, &c)
textContent := ""
// we will template this at the end
CONTENT:
for _, pp := range c {
switch pp.Type {
case "text":
textContent += pp.Text
//input.Messages[i].StringContent = pp.Text
case "video", "video_url":
// Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL)
if err != nil {
log.Error().Msgf("Failed encoding video: %s", err)
continue CONTENT
}
input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff
vidIndex++
nrOfVideosInMessage++
case "audio_url", "audio":
// Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL)
if err != nil {
log.Error().Msgf("Failed encoding audio: %s", err)
continue CONTENT
}
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff
audioIndex++
nrOfAudiosInMessage++
case "input_audio":
// TODO: make sure that we only return base64 stuff
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, pp.InputAudio.Data)
audioIndex++
nrOfAudiosInMessage++
case "image_url", "image":
// Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL)
if err != nil {
log.Error().Msgf("Failed encoding image: %s", err)
continue CONTENT
}
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
imgIndex++
nrOfImgsInMessage++
}
}
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{
TotalImages: imgIndex,
TotalVideos: vidIndex,
TotalAudios: audioIndex,
ImagesInMessage: nrOfImgsInMessage,
VideosInMessage: nrOfVideosInMessage,
AudiosInMessage: nrOfAudiosInMessage,
}, textContent)
}
}
if input.RepeatPenalty != 0 {
config.RepeatPenalty = input.RepeatPenalty
}
if input.FrequencyPenalty != 0 {
config.FrequencyPenalty = input.FrequencyPenalty
}
if input.PresencePenalty != 0 {
config.PresencePenalty = input.PresencePenalty
}
if input.Keep != 0 {
config.Keep = input.Keep
}
if input.Batch != 0 {
config.Batch = input.Batch
}
if input.IgnoreEOS {
config.IgnoreEOS = input.IgnoreEOS
}
if input.Seed != nil {
config.Seed = input.Seed
}
if input.TypicalP != nil {
config.TypicalP = input.TypicalP
}
log.Debug().Str("input.Input", fmt.Sprintf("%+v", input.Input))
switch inputs := input.Input.(type) {
case string:
if inputs != "" {
config.InputStrings = append(config.InputStrings, inputs)
}
case []any:
for _, pp := range inputs {
switch i := pp.(type) {
case string:
config.InputStrings = append(config.InputStrings, i)
case []any:
tokens := []int{}
inputStrings := []string{}
for _, ii := range i {
switch ii := ii.(type) {
case int:
tokens = append(tokens, ii)
case float64:
tokens = append(tokens, int(ii))
case string:
inputStrings = append(inputStrings, ii)
default:
log.Error().Msgf("Unknown input type: %T", ii)
}
}
config.InputToken = append(config.InputToken, tokens)
config.InputStrings = append(config.InputStrings, inputStrings...)
}
}
}
// Can be either a string or an object
switch fnc := input.FunctionCall.(type) {
case string:
if fnc != "" {
config.SetFunctionCallString(fnc)
}
case map[string]interface{}:
var name string
n, exists := fnc["name"]
if exists {
nn, e := n.(string)
if e {
name = nn
}
}
config.SetFunctionCallNameString(name)
}
switch p := input.Prompt.(type) {
case string:
config.PromptStrings = append(config.PromptStrings, p)
case []interface{}:
for _, pp := range p {
if s, ok := pp.(string); ok {
config.PromptStrings = append(config.PromptStrings, s)
}
}
}
// If a quality was defined as number, convert it to step
if input.Quality != "" {
q, err := strconv.Atoi(input.Quality)
if err == nil {
config.Step = q
}
}
if config.Validate() {
return nil
}
return fmt.Errorf("unable to validate configuration after merging")
}
package middleware
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strconv"
"strings"
"github.com/google/uuid"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/core/templates"
"github.com/mudler/LocalAI/pkg/functions"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/utils"
"github.com/rs/zerolog/log"
)
type correlationIDKeyType string
// CorrelationIDKey to track request across process boundary
const CorrelationIDKey correlationIDKeyType = "correlationID"
type RequestExtractor struct {
modelConfigLoader *config.ModelConfigLoader
modelLoader *model.ModelLoader
applicationConfig *config.ApplicationConfig
}
func NewRequestExtractor(modelConfigLoader *config.ModelConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor {
return &RequestExtractor{
modelConfigLoader: modelConfigLoader,
modelLoader: modelLoader,
applicationConfig: applicationConfig,
}
}
const CONTEXT_LOCALS_KEY_MODEL_NAME = "MODEL_NAME"
const CONTEXT_LOCALS_KEY_LOCALAI_REQUEST = "LOCALAI_REQUEST"
const CONTEXT_LOCALS_KEY_MODEL_CONFIG = "MODEL_CONFIG"
// TODO: Refactor to not return error if unchanged
func (re *RequestExtractor) setModelNameFromRequest(c echo.Context) {
model, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if ok && model != "" {
return
}
model = c.Param("model")
if model == "" {
model = c.QueryParam("model")
}
if model == "" {
// Set model from bearer token, if available
auth := c.Request().Header.Get("Authorization")
bearer := strings.TrimPrefix(auth, "Bearer ")
if bearer != "" && bearer != auth {
exists, err := services.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, bearer, services.ALWAYS_INCLUDE)
if err == nil && exists {
model = bearer
}
}
}
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, model)
}
func (re *RequestExtractor) BuildConstantDefaultModelNameMiddleware(defaultModelName string) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
re.setModelNameFromRequest(c)
localModelName, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if !ok || localModelName == "" {
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, defaultModelName)
log.Debug().Str("defaultModelName", defaultModelName).Msg("context local model name not found, setting to default")
}
return next(c)
}
}
}
func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.ModelConfigFilterFn) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
re.setModelNameFromRequest(c)
localModelName := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if localModelName != "" { // Don't overwrite existing values
return next(c)
}
modelNames, err := services.ListModels(re.modelConfigLoader, re.modelLoader, filterFn, services.SKIP_IF_CONFIGURED)
if err != nil {
log.Error().Err(err).Msg("non-fatal error calling ListModels during SetDefaultModelNameToFirstAvailable()")
return next(c)
}
if len(modelNames) == 0 {
log.Warn().Msg("SetDefaultModelNameToFirstAvailable used with no matching models installed")
// This is non-fatal - making it so was breaking the case of direct installation of raw models
// return errors.New("this endpoint requires at least one model to be installed")
return next(c)
}
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelNames[0])
log.Debug().Str("first model name", modelNames[0]).Msg("context local model name not found, setting to the first model")
return next(c)
}
}
}
// TODO: If context and cancel above belong on all methods, move that part of above into here!
// Otherwise, it's in its own method below for now
func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIRequest) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
input := initializer()
if input == nil {
return echo.NewHTTPError(http.StatusBadRequest, "unable to initialize body")
}
if err := c.Bind(input); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed parsing request body: %v", err))
}
// If this request doesn't have an associated model name, fetch it from earlier in the middleware chain
if input.ModelName(nil) == "" {
localModelName, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if ok && localModelName != "" {
log.Debug().Str("context localModelName", localModelName).Msg("overriding empty model name in request body with value found earlier in middleware chain")
input.ModelName(&localModelName)
}
}
cfg, err := re.modelConfigLoader.LoadModelConfigFileByNameDefaultOptions(input.ModelName(nil), re.applicationConfig)
if err != nil {
log.Err(err)
log.Warn().Msgf("Model Configuration File not found for %q", input.ModelName(nil))
} else if cfg.Model == "" && input.ModelName(nil) != "" {
log.Debug().Str("input.ModelName", input.ModelName(nil)).Msg("config does not include model, using input")
cfg.Model = input.ModelName(nil)
}
c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
return next(c)
}
}
}
func (re *RequestExtractor) SetOpenAIRequest(c echo.Context) error {
input, ok := c.Get(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return echo.ErrBadRequest
}
cfg, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil {
return echo.ErrBadRequest
}
// Extract or generate the correlation ID
correlationID := c.Request().Header.Get("X-Correlation-ID")
if correlationID == "" {
correlationID = uuid.New().String()
}
c.Response().Header().Set("X-Correlation-ID", correlationID)
// Use the request context directly - Echo properly supports context cancellation!
// No need for workarounds like handleConnectionCancellation
reqCtx := c.Request().Context()
c1, cancel := context.WithCancel(re.applicationConfig.Context)
// Cancel when request context is cancelled (client disconnects)
go func() {
select {
case <-reqCtx.Done():
cancel()
case <-c1.Done():
// Already cancelled
}
}()
// Add the correlation ID to the new context
ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID)
input.Context = ctxWithCorrelationID
input.Cancel = cancel
err := mergeOpenAIRequestAndModelConfig(cfg, input)
if err != nil {
return err
}
if cfg.Model == "" {
log.Debug().Str("input.Model", input.Model).Msg("replacing empty cfg.Model with input value")
cfg.Model = input.Model
}
c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
return nil
}
func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.OpenAIRequest) error {
if input.Echo {
config.Echo = input.Echo
}
if input.TopK != nil {
config.TopK = input.TopK
}
if input.TopP != nil {
config.TopP = input.TopP
}
if input.Backend != "" {
config.Backend = input.Backend
}
if input.ClipSkip != 0 {
config.Diffusers.ClipSkip = input.ClipSkip
}
if input.NegativePromptScale != 0 {
config.NegativePromptScale = input.NegativePromptScale
}
if input.NegativePrompt != "" {
config.NegativePrompt = input.NegativePrompt
}
if input.RopeFreqBase != 0 {
config.RopeFreqBase = input.RopeFreqBase
}
if input.RopeFreqScale != 0 {
config.RopeFreqScale = input.RopeFreqScale
}
if input.Grammar != "" {
config.Grammar = input.Grammar
}
if input.Temperature != nil {
config.Temperature = input.Temperature
}
if input.Maxtokens != nil {
config.Maxtokens = input.Maxtokens
}
if input.ResponseFormat != nil {
switch responseFormat := input.ResponseFormat.(type) {
case string:
config.ResponseFormat = responseFormat
case map[string]interface{}:
config.ResponseFormatMap = responseFormat
}
}
switch stop := input.Stop.(type) {
case string:
if stop != "" {
config.StopWords = append(config.StopWords, stop)
}
case []interface{}:
for _, pp := range stop {
if s, ok := pp.(string); ok {
config.StopWords = append(config.StopWords, s)
}
}
}
if len(input.Tools) > 0 {
for _, tool := range input.Tools {
input.Functions = append(input.Functions, tool.Function)
}
}
if input.ToolsChoice != nil {
var toolChoice functions.Tool
switch content := input.ToolsChoice.(type) {
case string:
_ = json.Unmarshal([]byte(content), &toolChoice)
case map[string]interface{}:
dat, _ := json.Marshal(content)
_ = json.Unmarshal(dat, &toolChoice)
}
input.FunctionCall = map[string]interface{}{
"name": toolChoice.Function.Name,
}
}
// Decode each request's message content
imgIndex, vidIndex, audioIndex := 0, 0, 0
for i, m := range input.Messages {
nrOfImgsInMessage := 0
nrOfVideosInMessage := 0
nrOfAudiosInMessage := 0
switch content := m.Content.(type) {
case string:
input.Messages[i].StringContent = content
case []interface{}:
dat, _ := json.Marshal(content)
c := []schema.Content{}
json.Unmarshal(dat, &c)
textContent := ""
// we will template this at the end
CONTENT:
for _, pp := range c {
switch pp.Type {
case "text":
textContent += pp.Text
//input.Messages[i].StringContent = pp.Text
case "video", "video_url":
// Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL)
if err != nil {
log.Error().Msgf("Failed encoding video: %s", err)
continue CONTENT
}
input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff
vidIndex++
nrOfVideosInMessage++
case "audio_url", "audio":
// Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL)
if err != nil {
log.Error().Msgf("Failed encoding audio: %s", err)
continue CONTENT
}
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff
audioIndex++
nrOfAudiosInMessage++
case "input_audio":
// TODO: make sure that we only return base64 stuff
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, pp.InputAudio.Data)
audioIndex++
nrOfAudiosInMessage++
case "image_url", "image":
// Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL)
if err != nil {
log.Error().Msgf("Failed encoding image: %s", err)
continue CONTENT
}
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
imgIndex++
nrOfImgsInMessage++
}
}
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{
TotalImages: imgIndex,
TotalVideos: vidIndex,
TotalAudios: audioIndex,
ImagesInMessage: nrOfImgsInMessage,
VideosInMessage: nrOfVideosInMessage,
AudiosInMessage: nrOfAudiosInMessage,
}, textContent)
}
}
if input.RepeatPenalty != 0 {
config.RepeatPenalty = input.RepeatPenalty
}
if input.FrequencyPenalty != 0 {
config.FrequencyPenalty = input.FrequencyPenalty
}
if input.PresencePenalty != 0 {
config.PresencePenalty = input.PresencePenalty
}
if input.Keep != 0 {
config.Keep = input.Keep
}
if input.Batch != 0 {
config.Batch = input.Batch
}
if input.IgnoreEOS {
config.IgnoreEOS = input.IgnoreEOS
}
if input.Seed != nil {
config.Seed = input.Seed
}
if input.TypicalP != nil {
config.TypicalP = input.TypicalP
}
log.Debug().Str("input.Input", fmt.Sprintf("%+v", input.Input))
switch inputs := input.Input.(type) {
case string:
if inputs != "" {
config.InputStrings = append(config.InputStrings, inputs)
}
case []any:
for _, pp := range inputs {
switch i := pp.(type) {
case string:
config.InputStrings = append(config.InputStrings, i)
case []any:
tokens := []int{}
inputStrings := []string{}
for _, ii := range i {
switch ii := ii.(type) {
case int:
tokens = append(tokens, ii)
case float64:
tokens = append(tokens, int(ii))
case string:
inputStrings = append(inputStrings, ii)
default:
log.Error().Msgf("Unknown input type: %T", ii)
}
}
config.InputToken = append(config.InputToken, tokens)
config.InputStrings = append(config.InputStrings, inputStrings...)
}
}
}
// Can be either a string or an object
switch fnc := input.FunctionCall.(type) {
case string:
if fnc != "" {
config.SetFunctionCallString(fnc)
}
case map[string]interface{}:
var name string
n, exists := fnc["name"]
if exists {
nn, e := n.(string)
if e {
name = nn
}
}
config.SetFunctionCallNameString(name)
}
switch p := input.Prompt.(type) {
case string:
config.PromptStrings = append(config.PromptStrings, p)
case []interface{}:
for _, pp := range p {
if s, ok := pp.(string); ok {
config.PromptStrings = append(config.PromptStrings, s)
}
}
}
// If a quality was defined as number, convert it to step
if input.Quality != "" {
q, err := strconv.Atoi(input.Quality)
if err == nil {
config.Step = q
}
}
if config.Validate() {
return nil
}
return fmt.Errorf("unable to validate configuration after merging")
}

View File

@@ -3,34 +3,55 @@ package middleware
import (
"strings"
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
)
// StripPathPrefix returns a middleware that strips a path prefix from the request path.
// StripPathPrefix returns middleware that strips a path prefix from the request path.
// The path prefix is obtained from the X-Forwarded-Prefix HTTP request header.
func StripPathPrefix() fiber.Handler {
return func(c *fiber.Ctx) error {
for _, prefix := range c.GetReqHeaders()["X-Forwarded-Prefix"] {
if prefix != "" {
path := c.Path()
pos := len(prefix)
// This must be registered as Pre middleware (using e.Pre()) to modify the path before routing.
func StripPathPrefix() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
prefixes := c.Request().Header.Values("X-Forwarded-Prefix")
originalPath := c.Request().URL.Path
if prefix[pos-1] == '/' {
pos--
} else {
prefix += "/"
}
for _, prefix := range prefixes {
if prefix != "" {
normalizedPrefix := prefix
if !strings.HasSuffix(prefix, "/") {
normalizedPrefix = prefix + "/"
}
if strings.HasPrefix(path, prefix) {
c.Path(path[pos:])
break
} else if prefix[:pos] == path {
c.Redirect(prefix)
return nil
if strings.HasPrefix(originalPath, normalizedPrefix) {
// Update the request path by stripping the normalized prefix
newPath := originalPath[len(normalizedPrefix):]
if newPath == "" {
newPath = "/"
}
// Ensure path starts with / for proper routing
if !strings.HasPrefix(newPath, "/") {
newPath = "/" + newPath
}
// Update the URL path - Echo's router uses URL.Path for routing
c.Request().URL.Path = newPath
c.Request().URL.RawPath = ""
// Update RequestURI to match the new path (needed for proper routing)
if c.Request().URL.RawQuery != "" {
c.Request().RequestURI = newPath + "?" + c.Request().URL.RawQuery
} else {
c.Request().RequestURI = newPath
}
// Store original path for BaseURL utility
c.Set("_original_path", originalPath)
break
} else if originalPath == prefix || originalPath == prefix+"/" {
// Redirect to prefix with trailing slash (use 302 to match test expectations)
return c.Redirect(302, normalizedPrefix)
}
}
}
}
return c.Next()
return next(c)
}
}
}

View File

@@ -2,120 +2,133 @@ package middleware
import (
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/require"
"github.com/labstack/echo/v4"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestStripPathPrefix(t *testing.T) {
var _ = Describe("StripPathPrefix", func() {
var app *echo.Echo
var actualPath string
var appInitialized bool
app := fiber.New()
BeforeEach(func() {
actualPath = ""
if !appInitialized {
app = echo.New()
app.Pre(StripPathPrefix())
app.Use(StripPathPrefix())
app.GET("/hello/world", func(c echo.Context) error {
actualPath = c.Request().URL.Path
return nil
})
app.Get("/hello/world", func(c *fiber.Ctx) error {
actualPath = c.Path()
return nil
app.GET("/", func(c echo.Context) error {
actualPath = c.Request().URL.Path
return nil
})
appInitialized = true
}
})
app.Get("/", func(c *fiber.Ctx) error {
actualPath = c.Path()
return nil
})
Context("without prefix", func() {
It("should not modify path when no header is present", func() {
req := httptest.NewRequest("GET", "/hello/world", nil)
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
for _, tc := range []struct {
name string
path string
prefixHeader []string
expectStatus int
expectPath string
}{
{
name: "without prefix and header",
path: "/hello/world",
expectStatus: 200,
expectPath: "/hello/world",
},
{
name: "without prefix and headers on root path",
path: "/",
expectStatus: 200,
expectPath: "/",
},
{
name: "without prefix but header",
path: "/hello/world",
prefixHeader: []string{"/otherprefix/"},
expectStatus: 200,
expectPath: "/hello/world",
},
{
name: "with prefix but non-matching header",
path: "/prefix/hello/world",
prefixHeader: []string{"/otherprefix/"},
expectStatus: 404,
},
{
name: "with prefix and matching header",
path: "/myprefix/hello/world",
prefixHeader: []string{"/myprefix/"},
expectStatus: 200,
expectPath: "/hello/world",
},
{
name: "with prefix and 1st header matching",
path: "/myprefix/hello/world",
prefixHeader: []string{"/myprefix/", "/otherprefix/"},
expectStatus: 200,
expectPath: "/hello/world",
},
{
name: "with prefix and 2nd header matching",
path: "/myprefix/hello/world",
prefixHeader: []string{"/otherprefix/", "/myprefix/"},
expectStatus: 200,
expectPath: "/hello/world",
},
{
name: "with prefix and header not ending with slash",
path: "/myprefix/hello/world",
prefixHeader: []string{"/myprefix"},
expectStatus: 200,
expectPath: "/hello/world",
},
{
name: "with prefix and non-matching header not ending with slash",
path: "/myprefix-suffix/hello/world",
prefixHeader: []string{"/myprefix"},
expectStatus: 404,
},
{
name: "redirect when prefix does not end with a slash",
path: "/myprefix",
prefixHeader: []string{"/myprefix"},
expectStatus: 302,
expectPath: "/myprefix/",
},
} {
t.Run(tc.name, func(t *testing.T) {
actualPath = ""
req := httptest.NewRequest("GET", tc.path, nil)
if tc.prefixHeader != nil {
req.Header["X-Forwarded-Prefix"] = tc.prefixHeader
}
resp, err := app.Test(req, -1)
require.NoError(t, err)
require.Equal(t, tc.expectStatus, resp.StatusCode, "response status code")
if tc.expectStatus == 200 {
require.Equal(t, tc.expectPath, actualPath, "rewritten path")
} else if tc.expectStatus == 302 {
require.Equal(t, tc.expectPath, resp.Header.Get("Location"), "redirect location")
}
Expect(rec.Code).To(Equal(200), "response status code")
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
})
}
}
It("should not modify root path when no header is present", func() {
req := httptest.NewRequest("GET", "/", nil)
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(200), "response status code")
Expect(actualPath).To(Equal("/"), "rewritten path")
})
It("should not modify path when header does not match", func() {
req := httptest.NewRequest("GET", "/hello/world", nil)
req.Header["X-Forwarded-Prefix"] = []string{"/otherprefix/"}
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(200), "response status code")
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
})
})
Context("with prefix", func() {
It("should return 404 when prefix does not match header", func() {
req := httptest.NewRequest("GET", "/prefix/hello/world", nil)
req.Header["X-Forwarded-Prefix"] = []string{"/otherprefix/"}
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(404), "response status code")
})
It("should strip matching prefix from path", func() {
req := httptest.NewRequest("GET", "/myprefix/hello/world", nil)
req.Header["X-Forwarded-Prefix"] = []string{"/myprefix/"}
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(200), "response status code")
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
})
It("should strip prefix when it matches the first header value", func() {
req := httptest.NewRequest("GET", "/myprefix/hello/world", nil)
req.Header["X-Forwarded-Prefix"] = []string{"/myprefix/", "/otherprefix/"}
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(200), "response status code")
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
})
It("should strip prefix when it matches the second header value", func() {
req := httptest.NewRequest("GET", "/myprefix/hello/world", nil)
req.Header["X-Forwarded-Prefix"] = []string{"/otherprefix/", "/myprefix/"}
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(200), "response status code")
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
})
It("should strip prefix when header does not end with slash", func() {
req := httptest.NewRequest("GET", "/myprefix/hello/world", nil)
req.Header["X-Forwarded-Prefix"] = []string{"/myprefix"}
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(200), "response status code")
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
})
It("should return 404 when prefix does not match header without trailing slash", func() {
req := httptest.NewRequest("GET", "/myprefix-suffix/hello/world", nil)
req.Header["X-Forwarded-Prefix"] = []string{"/myprefix"}
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(404), "response status code")
})
It("should redirect when prefix does not end with a slash", func() {
req := httptest.NewRequest("GET", "/myprefix", nil)
req.Header["X-Forwarded-Prefix"] = []string{"/myprefix"}
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(302), "response status code")
Expect(rec.Header().Get("Location")).To(Equal("/myprefix/"), "redirect location")
})
})
})

View File

@@ -17,7 +17,7 @@ import (
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"fmt"
. "github.com/mudler/LocalAI/core/http"
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
@@ -62,7 +62,7 @@ func (f *fakeAI) VAD(*pb.VADRequest) (pb.VADResponse, error) { return pb.VADResp
var _ = Describe("OpenAI /v1/videos (embedded backend)", func() {
var tmpdir string
var appServer *application.Application
var app *fiber.App
var app *echo.Echo
var ctx context.Context
var cancel context.CancelFunc
@@ -97,7 +97,9 @@ var _ = Describe("OpenAI /v1/videos (embedded backend)", func() {
AfterEach(func() {
cancel()
if app != nil {
_ = app.Shutdown()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = app.Shutdown(ctx)
}
_ = os.RemoveAll(tmpdir)
})
@@ -106,7 +108,11 @@ var _ = Describe("OpenAI /v1/videos (embedded backend)", func() {
var err error
app, err = API(appServer)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9091")
go func() {
if err := app.Start("127.0.0.1:9091"); err != nil && err != http.ErrServerClosed {
// Log error if needed
}
}()
// wait for server
client := &http.Client{Timeout: 5 * time.Second}

View File

@@ -4,13 +4,15 @@ import (
"embed"
"fmt"
"html/template"
"io"
"io/fs"
"net/http"
"strings"
"github.com/Masterminds/sprig/v3"
"github.com/gofiber/fiber/v2"
fiberhtml "github.com/gofiber/template/html/v2"
"github.com/labstack/echo/v4"
"github.com/microcosm-cc/bluemonday"
"github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/russross/blackfriday"
)
@@ -18,26 +20,67 @@ import (
//go:embed views/*
var viewsfs embed.FS
func notFoundHandler(c *fiber.Ctx) error {
// TemplateRenderer is a custom template renderer for Echo
type TemplateRenderer struct {
templates *template.Template
}
// Render renders a template document
func (t *TemplateRenderer) Render(w io.Writer, name string, data interface{}, c echo.Context) error {
return t.templates.ExecuteTemplate(w, name, data)
}
func notFoundHandler(c echo.Context) error {
// Check if the request accepts JSON
if string(c.Context().Request.Header.ContentType()) == "application/json" || len(c.Accepts("html")) == 0 {
contentType := c.Request().Header.Get("Content-Type")
accept := c.Request().Header.Get("Accept")
if strings.Contains(contentType, "application/json") || !strings.Contains(accept, "text/html") {
// The client expects a JSON response
return c.Status(fiber.StatusNotFound).JSON(schema.ErrorResponse{
Error: &schema.APIError{Message: "Resource not found", Code: fiber.StatusNotFound},
return c.JSON(http.StatusNotFound, schema.ErrorResponse{
Error: &schema.APIError{Message: "Resource not found", Code: http.StatusNotFound},
})
} else {
// The client expects an HTML response
return c.Status(fiber.StatusNotFound).Render("views/404", fiber.Map{
"BaseURL": utils.BaseURL(c),
return c.Render(http.StatusNotFound, "views/404", map[string]interface{}{
"BaseURL": middleware.BaseURL(c),
})
}
}
func renderEngine() *fiberhtml.Engine {
engine := fiberhtml.NewFileSystem(http.FS(viewsfs), ".html")
engine.AddFuncMap(sprig.FuncMap())
engine.AddFunc("MDToHTML", markDowner)
return engine
func renderEngine() *TemplateRenderer {
// Parse all templates from embedded filesystem
tmpl := template.New("").Funcs(sprig.FuncMap())
tmpl = tmpl.Funcs(template.FuncMap{
"MDToHTML": markDowner,
})
// Recursively walk through embedded filesystem and parse all HTML templates
err := fs.WalkDir(viewsfs, "views", func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if !d.IsDir() && strings.HasSuffix(path, ".html") {
data, err := viewsfs.ReadFile(path)
if err == nil {
// Remove .html extension to get template name (e.g., "views/index.html" -> "views/index")
templateName := strings.TrimSuffix(path, ".html")
_, err := tmpl.New(templateName).Parse(string(data))
if err != nil {
// If parsing fails, try parsing without explicit name (for templates with {{define}})
tmpl.Parse(string(data))
}
}
}
return nil
})
if err != nil {
// Log error but continue - templates might still work
fmt.Printf("Error walking views directory: %v\n", err)
}
return &TemplateRenderer{
templates: tmpl,
}
}
func markDowner(args ...interface{}) template.HTML {

View File

@@ -1,7 +1,7 @@
package routes
import (
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/elevenlabs"
"github.com/mudler/LocalAI/core/http/middleware"
@@ -9,21 +9,23 @@ import (
"github.com/mudler/LocalAI/pkg/model"
)
func RegisterElevenLabsRoutes(app *fiber.App,
func RegisterElevenLabsRoutes(app *echo.Echo,
re *middleware.RequestExtractor,
cl *config.ModelConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig) {
// Elevenlabs
app.Post("/v1/text-to-speech/:voice-id",
ttsHandler := elevenlabs.TTSEndpoint(cl, ml, appConfig)
app.POST("/v1/text-to-speech/:voice-id",
ttsHandler,
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.ElevenLabsTTSRequest) }),
elevenlabs.TTSEndpoint(cl, ml, appConfig))
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.ElevenLabsTTSRequest) }))
app.Post("/v1/sound-generation",
soundGenHandler := elevenlabs.SoundGenerationEndpoint(cl, ml, appConfig)
app.POST("/v1/sound-generation",
soundGenHandler,
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_SOUND_GENERATION)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.ElevenLabsSoundGenerationRequest) }),
elevenlabs.SoundGenerationEndpoint(cl, ml, appConfig))
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.ElevenLabsSoundGenerationRequest) }))
}

View File

@@ -1,13 +1,13 @@
package routes
import (
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
coreExplorer "github.com/mudler/LocalAI/core/explorer"
"github.com/mudler/LocalAI/core/http/endpoints/explorer"
)
func RegisterExplorerRoutes(app *fiber.App, db *coreExplorer.Database) {
app.Get("/", explorer.Dashboard())
app.Post("/network/add", explorer.AddNetwork(db))
app.Get("/networks", explorer.ShowNetworks(db))
func RegisterExplorerRoutes(app *echo.Echo, db *coreExplorer.Database) {
app.GET("/", explorer.Dashboard())
app.POST("/network/add", explorer.AddNetwork(db))
app.GET("/networks", explorer.ShowNetworks(db))
}

View File

@@ -1,13 +1,15 @@
package routes
import "github.com/gofiber/fiber/v2"
import (
"github.com/labstack/echo/v4"
)
func HealthRoutes(app *fiber.App) {
func HealthRoutes(app *echo.Echo) {
// Service health checks
ok := func(c *fiber.Ctx) error {
return c.SendStatus(200)
ok := func(c echo.Context) error {
return c.NoContent(200)
}
app.Get("/healthz", ok)
app.Get("/readyz", ok)
app.GET("/healthz", ok)
app.GET("/readyz", ok)
}

View File

@@ -1,24 +1,25 @@
package routes
import (
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/jina"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/pkg/model"
)
func RegisterJINARoutes(app *fiber.App,
func RegisterJINARoutes(app *echo.Echo,
re *middleware.RequestExtractor,
cl *config.ModelConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig) {
// POST endpoint to mimic the reranking
app.Post("/v1/rerank",
rerankHandler := jina.JINARerankEndpoint(cl, ml, appConfig)
app.POST("/v1/rerank",
rerankHandler,
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_RERANK)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.JINARerankRequest) }),
jina.JINARerankEndpoint(cl, ml, appConfig))
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.JINARerankRequest) }))
}

View File

@@ -1,131 +1,157 @@
package routes
import (
"github.com/gofiber/fiber/v2"
"github.com/gofiber/swagger"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/localai"
"github.com/mudler/LocalAI/core/http/middleware"
httpUtils "github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/core/templates"
"github.com/mudler/LocalAI/internal"
"github.com/mudler/LocalAI/pkg/model"
echoswagger "github.com/swaggo/echo-swagger"
)
func RegisterLocalAIRoutes(router *fiber.App,
func RegisterLocalAIRoutes(router *echo.Echo,
requestExtractor *middleware.RequestExtractor,
cl *config.ModelConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig,
galleryService *services.GalleryService,
opcache *services.OpCache) {
opcache *services.OpCache,
evaluator *templates.Evaluator) {
router.Get("/swagger/*", swagger.HandlerDefault) // default
router.GET("/swagger/*", echoswagger.WrapHandler) // default
// LocalAI API endpoints
if !appConfig.DisableGalleryEndpoint {
// Import model page
router.Get("/import-model", func(c *fiber.Ctx) error {
return c.Render("views/model-editor", fiber.Map{
router.GET("/import-model", func(c echo.Context) error {
return c.Render(200, "views/model-editor", map[string]interface{}{
"Title": "LocalAI - Import Model",
"BaseURL": httpUtils.BaseURL(c),
"BaseURL": middleware.BaseURL(c),
"Version": internal.PrintableVersion(),
})
})
// Edit model page
router.Get("/models/edit/:name", localai.GetEditModelPage(cl, appConfig))
router.GET("/models/edit/:name", localai.GetEditModelPage(cl, appConfig))
modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.BackendGalleries, appConfig.SystemState, galleryService)
router.Post("/models/apply", modelGalleryEndpointService.ApplyModelGalleryEndpoint())
router.Post("/models/delete/:name", modelGalleryEndpointService.DeleteModelGalleryEndpoint())
router.POST("/models/apply", modelGalleryEndpointService.ApplyModelGalleryEndpoint())
router.POST("/models/delete/:name", modelGalleryEndpointService.DeleteModelGalleryEndpoint())
router.Get("/models/available", modelGalleryEndpointService.ListModelFromGalleryEndpoint(appConfig.SystemState))
router.Get("/models/galleries", modelGalleryEndpointService.ListModelGalleriesEndpoint())
router.Get("/models/jobs/:uuid", modelGalleryEndpointService.GetOpStatusEndpoint())
router.Get("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint())
router.GET("/models/available", modelGalleryEndpointService.ListModelFromGalleryEndpoint(appConfig.SystemState))
router.GET("/models/galleries", modelGalleryEndpointService.ListModelGalleriesEndpoint())
router.GET("/models/jobs/:uuid", modelGalleryEndpointService.GetOpStatusEndpoint())
router.GET("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint())
backendGalleryEndpointService := localai.CreateBackendEndpointService(
appConfig.BackendGalleries,
appConfig.SystemState,
galleryService)
router.Post("/backends/apply", backendGalleryEndpointService.ApplyBackendEndpoint())
router.Post("/backends/delete/:name", backendGalleryEndpointService.DeleteBackendEndpoint())
router.Get("/backends", backendGalleryEndpointService.ListBackendsEndpoint(appConfig.SystemState))
router.Get("/backends/available", backendGalleryEndpointService.ListAvailableBackendsEndpoint(appConfig.SystemState))
router.Get("/backends/galleries", backendGalleryEndpointService.ListBackendGalleriesEndpoint())
router.Get("/backends/jobs/:uuid", backendGalleryEndpointService.GetOpStatusEndpoint())
router.POST("/backends/apply", backendGalleryEndpointService.ApplyBackendEndpoint())
router.POST("/backends/delete/:name", backendGalleryEndpointService.DeleteBackendEndpoint())
router.GET("/backends", backendGalleryEndpointService.ListBackendsEndpoint(appConfig.SystemState))
router.GET("/backends/available", backendGalleryEndpointService.ListAvailableBackendsEndpoint(appConfig.SystemState))
router.GET("/backends/galleries", backendGalleryEndpointService.ListBackendGalleriesEndpoint())
router.GET("/backends/jobs/:uuid", backendGalleryEndpointService.GetOpStatusEndpoint())
// Custom model import endpoint
router.Post("/models/import", localai.ImportModelEndpoint(cl, appConfig))
router.POST("/models/import", localai.ImportModelEndpoint(cl, appConfig))
// URI model import endpoint
router.Post("/models/import-uri", localai.ImportModelURIEndpoint(cl, appConfig, galleryService, opcache))
router.POST("/models/import-uri", localai.ImportModelURIEndpoint(cl, appConfig, galleryService, opcache))
// Custom model edit endpoint
router.Post("/models/edit/:name", localai.EditModelEndpoint(cl, appConfig))
router.POST("/models/edit/:name", localai.EditModelEndpoint(cl, appConfig))
// Reload models endpoint
router.Post("/models/reload", localai.ReloadModelsEndpoint(cl, appConfig))
router.POST("/models/reload", localai.ReloadModelsEndpoint(cl, appConfig))
}
router.Post("/v1/detection",
detectionHandler := localai.DetectionEndpoint(cl, ml, appConfig)
router.POST("/v1/detection",
detectionHandler,
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_DETECTION)),
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.DetectionRequest) }),
localai.DetectionEndpoint(cl, ml, appConfig))
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.DetectionRequest) }))
router.Post("/tts",
ttsHandler := localai.TTSEndpoint(cl, ml, appConfig)
router.POST("/tts",
ttsHandler,
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)),
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TTSRequest) }),
localai.TTSEndpoint(cl, ml, appConfig))
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TTSRequest) }))
vadChain := []fiber.Handler{
vadHandler := localai.VADEndpoint(cl, ml, appConfig)
router.POST("/vad",
vadHandler,
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_VAD)),
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.VADRequest) }),
localai.VADEndpoint(cl, ml, appConfig),
}
router.Post("/vad", vadChain...)
router.Post("/v1/vad", vadChain...)
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.VADRequest) }))
router.POST("/v1/vad",
vadHandler,
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_VAD)),
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.VADRequest) }))
// Stores
router.Post("/stores/set", localai.StoresSetEndpoint(ml, appConfig))
router.Post("/stores/delete", localai.StoresDeleteEndpoint(ml, appConfig))
router.Post("/stores/get", localai.StoresGetEndpoint(ml, appConfig))
router.Post("/stores/find", localai.StoresFindEndpoint(ml, appConfig))
router.POST("/stores/set", localai.StoresSetEndpoint(ml, appConfig))
router.POST("/stores/delete", localai.StoresDeleteEndpoint(ml, appConfig))
router.POST("/stores/get", localai.StoresGetEndpoint(ml, appConfig))
router.POST("/stores/find", localai.StoresFindEndpoint(ml, appConfig))
if !appConfig.DisableMetrics {
router.Get("/metrics", localai.LocalAIMetricsEndpoint())
router.GET("/metrics", localai.LocalAIMetricsEndpoint())
}
router.Post("/video",
videoHandler := localai.VideoEndpoint(cl, ml, appConfig)
router.POST("/video",
videoHandler,
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_VIDEO)),
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.VideoRequest) }),
localai.VideoEndpoint(cl, ml, appConfig))
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.VideoRequest) }))
// Backend Statistics Module
// TODO: Should these use standard middlewares? Refactor later, they are extremely simple.
backendMonitorService := services.NewBackendMonitorService(ml, cl, appConfig) // Split out for now
router.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
router.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
router.GET("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
router.POST("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
// The v1/* urls are exactly the same as above - makes local e2e testing easier if they are registered.
router.Get("/v1/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
router.Post("/v1/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
router.GET("/v1/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
router.POST("/v1/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
// p2p
router.Get("/api/p2p", localai.ShowP2PNodes(appConfig))
router.Get("/api/p2p/token", localai.ShowP2PToken(appConfig))
router.GET("/api/p2p", localai.ShowP2PNodes(appConfig))
router.GET("/api/p2p/token", localai.ShowP2PToken(appConfig))
router.Get("/version", func(c *fiber.Ctx) error {
return c.JSON(struct {
router.GET("/version", func(c echo.Context) error {
return c.JSON(200, struct {
Version string `json:"version"`
}{Version: internal.PrintableVersion()})
})
router.Get("/system", localai.SystemInformations(ml, appConfig))
router.GET("/system", localai.SystemInformations(ml, appConfig))
// misc
router.Post("/v1/tokenize",
tokenizeHandler := localai.TokenizeEndpoint(cl, ml, appConfig)
router.POST("/v1/tokenize",
tokenizeHandler,
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TOKENIZE)),
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TokenizeRequest) }),
localai.TokenizeEndpoint(cl, ml, appConfig))
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TokenizeRequest) }))
// MCP Stream endpoint
if evaluator != nil {
mcpStreamHandler := localai.MCPStreamEndpoint(cl, ml, evaluator, appConfig)
mcpStreamMiddleware := []echo.MiddlewareFunc{
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if err := requestExtractor.SetOpenAIRequest(c); err != nil {
return err
}
return next(c)
}
},
}
router.POST("/v1/mcp/chat/completions", mcpStreamHandler, mcpStreamMiddleware...)
router.POST("/mcp/v1/chat/completions", mcpStreamHandler, mcpStreamMiddleware...)
}
}

View File

@@ -1,7 +1,7 @@
package routes
import (
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/localai"
@@ -10,118 +10,172 @@ import (
"github.com/mudler/LocalAI/core/schema"
)
func RegisterOpenAIRoutes(app *fiber.App,
func RegisterOpenAIRoutes(app *echo.Echo,
re *middleware.RequestExtractor,
application *application.Application) {
// openAI compatible API endpoint
// realtime
// TODO: Modify/disable the API key middleware for this endpoint to allow ephemeral keys created by sessions
app.Get("/v1/realtime", openai.Realtime(application))
app.Post("/v1/realtime/sessions", openai.RealtimeTranscriptionSession(application))
app.Post("/v1/realtime/transcription_session", openai.RealtimeTranscriptionSession(application))
app.GET("/v1/realtime", openai.Realtime(application))
app.POST("/v1/realtime/sessions", openai.RealtimeTranscriptionSession(application))
app.POST("/v1/realtime/transcription_session", openai.RealtimeTranscriptionSession(application))
// chat
chatChain := []fiber.Handler{
chatHandler := openai.ChatEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig())
chatMiddleware := []echo.MiddlewareFunc{
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.ChatEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()),
func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if err := re.SetOpenAIRequest(c); err != nil {
return err
}
return next(c)
}
},
}
app.Post("/v1/chat/completions", chatChain...)
app.Post("/chat/completions", chatChain...)
app.POST("/v1/chat/completions", chatHandler, chatMiddleware...)
app.POST("/chat/completions", chatHandler, chatMiddleware...)
// edit
editChain := []fiber.Handler{
editHandler := openai.EditEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig())
editMiddleware := []echo.MiddlewareFunc{
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EDIT)),
re.BuildConstantDefaultModelNameMiddleware("gpt-4o"),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.EditEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()),
func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if err := re.SetOpenAIRequest(c); err != nil {
return err
}
return next(c)
}
},
}
app.Post("/v1/edits", editChain...)
app.Post("/edits", editChain...)
app.POST("/v1/edits", editHandler, editMiddleware...)
app.POST("/edits", editHandler, editMiddleware...)
// completion
completionChain := []fiber.Handler{
completionHandler := openai.CompletionEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig())
completionMiddleware := []echo.MiddlewareFunc{
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_COMPLETION)),
re.BuildConstantDefaultModelNameMiddleware("gpt-4o"),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.CompletionEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()),
func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if err := re.SetOpenAIRequest(c); err != nil {
return err
}
return next(c)
}
},
}
app.Post("/v1/completions", completionChain...)
app.Post("/completions", completionChain...)
app.Post("/v1/engines/:model/completions", completionChain...)
app.POST("/v1/completions", completionHandler, completionMiddleware...)
app.POST("/completions", completionHandler, completionMiddleware...)
app.POST("/v1/engines/:model/completions", completionHandler, completionMiddleware...)
// MCPcompletion
mcpCompletionChain := []fiber.Handler{
mcpCompletionHandler := openai.MCPCompletionEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig())
mcpCompletionMiddleware := []echo.MiddlewareFunc{
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.MCPCompletionEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()),
func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if err := re.SetOpenAIRequest(c); err != nil {
return err
}
return next(c)
}
},
}
app.Post("/mcp/v1/chat/completions", mcpCompletionChain...)
app.Post("/mcp/chat/completions", mcpCompletionChain...)
app.POST("/mcp/v1/chat/completions", mcpCompletionHandler, mcpCompletionMiddleware...)
app.POST("/mcp/chat/completions", mcpCompletionHandler, mcpCompletionMiddleware...)
// embeddings
embeddingChain := []fiber.Handler{
embeddingHandler := openai.EmbeddingsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
embeddingMiddleware := []echo.MiddlewareFunc{
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EMBEDDINGS)),
re.BuildConstantDefaultModelNameMiddleware("gpt-4o"),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.EmbeddingsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()),
func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if err := re.SetOpenAIRequest(c); err != nil {
return err
}
return next(c)
}
},
}
app.Post("/v1/embeddings", embeddingChain...)
app.Post("/embeddings", embeddingChain...)
app.Post("/v1/engines/:model/embeddings", embeddingChain...)
app.POST("/v1/embeddings", embeddingHandler, embeddingMiddleware...)
app.POST("/embeddings", embeddingHandler, embeddingMiddleware...)
app.POST("/v1/engines/:model/embeddings", embeddingHandler, embeddingMiddleware...)
audioChain := []fiber.Handler{
audioHandler := openai.TranscriptEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
audioMiddleware := []echo.MiddlewareFunc{
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TRANSCRIPT)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.TranscriptEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()),
func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if err := re.SetOpenAIRequest(c); err != nil {
return err
}
return next(c)
}
},
}
// audio
app.Post("/v1/audio/transcriptions", audioChain...)
app.Post("/audio/transcriptions", audioChain...)
app.POST("/v1/audio/transcriptions", audioHandler, audioMiddleware...)
app.POST("/audio/transcriptions", audioHandler, audioMiddleware...)
audioSpeechChain := []fiber.Handler{
audioSpeechHandler := localai.TTSEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
audioSpeechMiddleware := []echo.MiddlewareFunc{
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TTSRequest) }),
localai.TTSEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()),
}
app.Post("/v1/audio/speech",
audioSpeechChain...)
app.Post("/audio/speech", audioSpeechChain...)
app.POST("/v1/audio/speech", audioSpeechHandler, audioSpeechMiddleware...)
app.POST("/audio/speech", audioSpeechHandler, audioSpeechMiddleware...)
// images
imageChain := []fiber.Handler{
imageHandler := openai.ImageEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
imageMiddleware := []echo.MiddlewareFunc{
re.BuildConstantDefaultModelNameMiddleware("stablediffusion"),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.ImageEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()),
func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if err := re.SetOpenAIRequest(c); err != nil {
return err
}
return next(c)
}
},
}
app.Post("/v1/images/generations",
imageChain...)
app.Post("/images/generations", imageChain...)
app.POST("/v1/images/generations", imageHandler, imageMiddleware...)
app.POST("/images/generations", imageHandler, imageMiddleware...)
// videos (OpenAI-compatible endpoints mapped to LocalAI video handler)
videoChain := []fiber.Handler{
videoHandler := openai.VideoEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
videoMiddleware := []echo.MiddlewareFunc{
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_VIDEO)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.VideoEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()),
func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if err := re.SetOpenAIRequest(c); err != nil {
return err
}
return next(c)
}
},
}
// OpenAI-style create video endpoint
app.Post("/v1/videos", videoChain...)
app.Post("/v1/videos/generations", videoChain...)
app.Post("/videos", videoChain...)
app.POST("/v1/videos", videoHandler, videoMiddleware...)
app.POST("/v1/videos/generations", videoHandler, videoMiddleware...)
app.POST("/videos", videoHandler, videoMiddleware...)
// List models
app.Get("/v1/models", openai.ListModelsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.Get("/models", openai.ListModelsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.GET("/v1/models", openai.ListModelsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.GET("/models", openai.ListModelsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()))
}

View File

@@ -1,18 +1,17 @@
package routes
import (
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/http/endpoints/localai"
"github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/internal"
"github.com/mudler/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
)
func RegisterUIRoutes(app *fiber.App,
func RegisterUIRoutes(app *echo.Echo,
cl *config.ModelConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig,
@@ -21,13 +20,14 @@ func RegisterUIRoutes(app *fiber.App,
// keeps the state of ops that are started from the UI
var processingOps = services.NewOpCache(galleryService)
app.Get("/", localai.WelcomeEndpoint(appConfig, cl, ml, processingOps))
app.GET("/", localai.WelcomeEndpoint(appConfig, cl, ml, processingOps))
app.GET("/manage", localai.WelcomeEndpoint(appConfig, cl, ml, processingOps))
// P2P
app.Get("/p2p", func(c *fiber.Ctx) error {
summary := fiber.Map{
app.GET("/p2p/", func(c echo.Context) error {
summary := map[string]interface{}{
"Title": "LocalAI - P2P dashboard",
"BaseURL": utils.BaseURL(c),
"BaseURL": middleware.BaseURL(c),
"Version": internal.PrintableVersion(),
//"Nodes": p2p.GetAvailableNodes(""),
//"FederatedNodes": p2p.GetAvailableNodes(p2p.FederatedID),
@@ -37,7 +37,7 @@ func RegisterUIRoutes(app *fiber.App,
}
// Render index
return c.Render("views/p2p", summary)
return c.Render(200, "views/p2p", summary)
})
// Note: P2P UI fragment routes (/p2p/ui/*) were removed
@@ -50,17 +50,17 @@ func RegisterUIRoutes(app *fiber.App,
registerBackendGalleryRoutes(app, appConfig, galleryService, processingOps)
}
app.Get("/talk/", func(c *fiber.Ctx) error {
app.GET("/talk/", func(c echo.Context) error {
modelConfigs, _ := services.ListModels(cl, ml, config.NoFilterFn, services.SKIP_IF_CONFIGURED)
if len(modelConfigs) == 0 {
// If no model is available redirect to the index which suggests how to install models
return c.Redirect(utils.BaseURL(c))
return c.Redirect(302, middleware.BaseURL(c))
}
summary := fiber.Map{
summary := map[string]interface{}{
"Title": "LocalAI - Talk",
"BaseURL": utils.BaseURL(c),
"BaseURL": middleware.BaseURL(c),
"ModelsConfig": modelConfigs,
"Model": modelConfigs[0],
@@ -68,16 +68,16 @@ func RegisterUIRoutes(app *fiber.App,
}
// Render index
return c.Render("views/talk", summary)
return c.Render(200, "views/talk", summary)
})
app.Get("/chat/", func(c *fiber.Ctx) error {
app.GET("/chat/", func(c echo.Context) error {
modelConfigs := cl.GetAllModelsConfigs()
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
if len(modelConfigs)+len(modelsWithoutConfig) == 0 {
// If no model is available redirect to the index which suggests how to install models
return c.Redirect(utils.BaseURL(c))
return c.Redirect(302, middleware.BaseURL(c))
}
modelThatCanBeUsed := ""
galleryConfigs := map[string]*gallery.ModelConfig{}
@@ -104,9 +104,9 @@ func RegisterUIRoutes(app *fiber.App,
}
}
summary := fiber.Map{
summary := map[string]interface{}{
"Title": title,
"BaseURL": utils.BaseURL(c),
"BaseURL": middleware.BaseURL(c),
"ModelsWithoutConfig": modelsWithoutConfig,
"GalleryConfig": galleryConfigs,
"ModelsConfig": modelConfigs,
@@ -116,16 +116,16 @@ func RegisterUIRoutes(app *fiber.App,
}
// Render index
return c.Render("views/chat", summary)
return c.Render(200, "views/chat", summary)
})
// Show the Chat page
app.Get("/chat/:model", func(c *fiber.Ctx) error {
app.GET("/chat/:model", func(c echo.Context) error {
modelConfigs := cl.GetAllModelsConfigs()
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
galleryConfigs := map[string]*gallery.ModelConfig{}
modelName := c.Params("model")
modelName := c.Param("model")
var modelContextSize *int
for _, m := range modelConfigs {
@@ -139,9 +139,9 @@ func RegisterUIRoutes(app *fiber.App,
}
}
summary := fiber.Map{
summary := map[string]interface{}{
"Title": "LocalAI - Chat with " + modelName,
"BaseURL": utils.BaseURL(c),
"BaseURL": middleware.BaseURL(c),
"ModelsConfig": modelConfigs,
"GalleryConfig": galleryConfigs,
"ModelsWithoutConfig": modelsWithoutConfig,
@@ -151,33 +151,33 @@ func RegisterUIRoutes(app *fiber.App,
}
// Render index
return c.Render("views/chat", summary)
return c.Render(200, "views/chat", summary)
})
app.Get("/text2image/:model", func(c *fiber.Ctx) error {
app.GET("/text2image/:model", func(c echo.Context) error {
modelConfigs := cl.GetAllModelsConfigs()
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
summary := fiber.Map{
"Title": "LocalAI - Generate images with " + c.Params("model"),
"BaseURL": utils.BaseURL(c),
summary := map[string]interface{}{
"Title": "LocalAI - Generate images with " + c.Param("model"),
"BaseURL": middleware.BaseURL(c),
"ModelsConfig": modelConfigs,
"ModelsWithoutConfig": modelsWithoutConfig,
"Model": c.Params("model"),
"Model": c.Param("model"),
"Version": internal.PrintableVersion(),
}
// Render index
return c.Render("views/text2image", summary)
return c.Render(200, "views/text2image", summary)
})
app.Get("/text2image/", func(c *fiber.Ctx) error {
app.GET("/text2image/", func(c echo.Context) error {
modelConfigs := cl.GetAllModelsConfigs()
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
if len(modelConfigs)+len(modelsWithoutConfig) == 0 {
// If no model is available redirect to the index which suggests how to install models
return c.Redirect(utils.BaseURL(c))
return c.Redirect(302, middleware.BaseURL(c))
}
modelThatCanBeUsed := ""
@@ -191,9 +191,9 @@ func RegisterUIRoutes(app *fiber.App,
}
}
summary := fiber.Map{
summary := map[string]interface{}{
"Title": title,
"BaseURL": utils.BaseURL(c),
"BaseURL": middleware.BaseURL(c),
"ModelsConfig": modelConfigs,
"ModelsWithoutConfig": modelsWithoutConfig,
"Model": modelThatCanBeUsed,
@@ -201,33 +201,33 @@ func RegisterUIRoutes(app *fiber.App,
}
// Render index
return c.Render("views/text2image", summary)
return c.Render(200, "views/text2image", summary)
})
app.Get("/tts/:model", func(c *fiber.Ctx) error {
app.GET("/tts/:model", func(c echo.Context) error {
modelConfigs := cl.GetAllModelsConfigs()
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
summary := fiber.Map{
"Title": "LocalAI - Generate images with " + c.Params("model"),
"BaseURL": utils.BaseURL(c),
summary := map[string]interface{}{
"Title": "LocalAI - Generate images with " + c.Param("model"),
"BaseURL": middleware.BaseURL(c),
"ModelsConfig": modelConfigs,
"ModelsWithoutConfig": modelsWithoutConfig,
"Model": c.Params("model"),
"Model": c.Param("model"),
"Version": internal.PrintableVersion(),
}
// Render index
return c.Render("views/tts", summary)
return c.Render(200, "views/tts", summary)
})
app.Get("/tts/", func(c *fiber.Ctx) error {
app.GET("/tts/", func(c echo.Context) error {
modelConfigs := cl.GetAllModelsConfigs()
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
if len(modelConfigs)+len(modelsWithoutConfig) == 0 {
// If no model is available redirect to the index which suggests how to install models
return c.Redirect(utils.BaseURL(c))
return c.Redirect(302, middleware.BaseURL(c))
}
modelThatCanBeUsed := ""
@@ -240,9 +240,9 @@ func RegisterUIRoutes(app *fiber.App,
break
}
}
summary := fiber.Map{
summary := map[string]interface{}{
"Title": title,
"BaseURL": utils.BaseURL(c),
"BaseURL": middleware.BaseURL(c),
"ModelsConfig": modelConfigs,
"ModelsWithoutConfig": modelsWithoutConfig,
"Model": modelThatCanBeUsed,
@@ -250,6 +250,6 @@ func RegisterUIRoutes(app *fiber.App,
}
// Render index
return c.Render("views/tts", summary)
return c.Render(200, "views/tts", summary)
})
}

View File

@@ -1,30 +1,33 @@
package routes
import (
"context"
"fmt"
"math"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/p2p"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
)
// RegisterUIAPIRoutes registers JSON API routes for the web UI
func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) {
func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) {
// Operations API - Get all current operations (models + backends)
app.Get("/api/operations", func(c *fiber.Ctx) error {
app.GET("/api/operations", func(c echo.Context) error {
processingData, taskTypes := opcache.GetStatus()
operations := []fiber.Map{}
operations := []map[string]interface{}{}
for galleryID, jobID := range processingData {
taskType := "installation"
if tt, ok := taskTypes[galleryID]; ok {
@@ -35,23 +38,35 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
progress := 0
isDeletion := false
isQueued := false
isCancelled := false
isCancellable := false
message := ""
if status != nil {
// Skip completed operations
if status.Processed {
// Skip completed operations (unless cancelled and not yet cleaned up)
if status.Processed && !status.Cancelled {
continue
}
// Skip cancelled operations that are processed (they're done, no need to show)
if status.Processed && status.Cancelled {
continue
}
progress = int(status.Progress)
isDeletion = status.Deletion
isCancelled = status.Cancelled
isCancellable = status.Cancellable
message = status.Message
if isDeletion {
taskType = "deletion"
}
if isCancelled {
taskType = "cancelled"
}
} else {
// Job is queued but hasn't started
isQueued = true
isCancellable = true
message = "Operation queued"
}
@@ -75,17 +90,19 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
}
}
operations = append(operations, fiber.Map{
"id": galleryID,
"name": displayName,
"fullName": galleryID,
"jobID": jobID,
"progress": progress,
"taskType": taskType,
"isDeletion": isDeletion,
"isBackend": isBackend,
"isQueued": isQueued,
"message": message,
operations = append(operations, map[string]interface{}{
"id": galleryID,
"name": displayName,
"fullName": galleryID,
"jobID": jobID,
"progress": progress,
"taskType": taskType,
"isDeletion": isDeletion,
"isBackend": isBackend,
"isQueued": isQueued,
"isCancelled": isCancelled,
"cancellable": isCancellable,
"message": message,
})
}
@@ -103,21 +120,49 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
return operations[i]["id"].(string) < operations[j]["id"].(string)
})
return c.JSON(fiber.Map{
return c.JSON(200, map[string]interface{}{
"operations": operations,
})
})
// Cancel operation endpoint
app.POST("/api/operations/:jobID/cancel", func(c echo.Context) error {
jobID := c.Param("jobID")
log.Debug().Msgf("API request to cancel operation: %s", jobID)
err := galleryService.CancelOperation(jobID)
if err != nil {
log.Error().Err(err).Msgf("Failed to cancel operation: %s", jobID)
return c.JSON(http.StatusBadRequest, map[string]interface{}{
"error": err.Error(),
})
}
// Clean up opcache for cancelled operation
opcache.DeleteUUID(jobID)
return c.JSON(200, map[string]interface{}{
"success": true,
"message": "Operation cancelled",
})
})
// Model Gallery APIs
app.Get("/api/models", func(c *fiber.Ctx) error {
term := c.Query("term")
page := c.Query("page", "1")
items := c.Query("items", "21")
app.GET("/api/models", func(c echo.Context) error {
term := c.QueryParam("term")
page := c.QueryParam("page")
if page == "" {
page = "1"
}
items := c.QueryParam("items")
if items == "" {
items = "21"
}
models, err := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.SystemState)
if err != nil {
log.Error().Err(err).Msg("could not list models from galleries")
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
"error": err.Error(),
})
}
@@ -160,7 +205,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
}
// Convert models to JSON-friendly format and deduplicate by ID
modelsJSON := make([]fiber.Map, 0, len(models))
modelsJSON := make([]map[string]interface{}, 0, len(models))
seenIDs := make(map[string]bool)
for _, m := range models {
@@ -186,7 +231,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
_, trustRemoteCodeExists := m.Overrides["trust_remote_code"]
modelsJSON = append(modelsJSON, fiber.Map{
modelsJSON = append(modelsJSON, map[string]interface{}{
"id": modelID,
"name": m.Name,
"description": m.Description,
@@ -213,26 +258,32 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
nextPage = totalPages
}
return c.JSON(fiber.Map{
"models": modelsJSON,
"repositories": appConfig.Galleries,
"allTags": tags,
"processingModels": processingModelsData,
"taskTypes": taskTypes,
"availableModels": totalModels,
"currentPage": pageNum,
"totalPages": totalPages,
"prevPage": prevPage,
"nextPage": nextPage,
// Calculate installed models count (models with configs + models without configs)
modelConfigs := cl.GetAllModelsConfigs()
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
installedModelsCount := len(modelConfigs) + len(modelsWithoutConfig)
return c.JSON(200, map[string]interface{}{
"models": modelsJSON,
"repositories": appConfig.Galleries,
"allTags": tags,
"processingModels": processingModelsData,
"taskTypes": taskTypes,
"availableModels": totalModels,
"installedModels": installedModelsCount,
"currentPage": pageNum,
"totalPages": totalPages,
"prevPage": prevPage,
"nextPage": nextPage,
})
})
app.Post("/api/models/install/:id", func(c *fiber.Ctx) error {
galleryID := strings.Clone(c.Params("id"))
app.POST("/api/models/install/:id", func(c echo.Context) error {
galleryID := c.Param("id")
// URL decode the gallery ID (e.g., "localai%40model" -> "localai@model")
galleryID, err := url.QueryUnescape(galleryID)
if err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
return c.JSON(http.StatusBadRequest, map[string]interface{}{
"error": "invalid model ID",
})
}
@@ -240,7 +291,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
id, err := uuid.NewUUID()
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
"error": err.Error(),
})
}
@@ -248,28 +299,33 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
uid := id.String()
opcache.Set(galleryID, uid)
ctx, cancelFunc := context.WithCancel(context.Background())
op := services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
ID: uid,
GalleryElementName: galleryID,
Galleries: appConfig.Galleries,
BackendGalleries: appConfig.BackendGalleries,
Context: ctx,
CancelFunc: cancelFunc,
}
// Store cancellation function immediately so queued operations can be cancelled
galleryService.StoreCancellation(uid, cancelFunc)
go func() {
galleryService.ModelGalleryChannel <- op
}()
return c.JSON(fiber.Map{
return c.JSON(200, map[string]interface{}{
"jobID": uid,
"message": "Installation started",
})
})
app.Post("/api/models/delete/:id", func(c *fiber.Ctx) error {
galleryID := strings.Clone(c.Params("id"))
app.POST("/api/models/delete/:id", func(c echo.Context) error {
galleryID := c.Param("id")
// URL decode the gallery ID
galleryID, err := url.QueryUnescape(galleryID)
if err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
return c.JSON(http.StatusBadRequest, map[string]interface{}{
"error": "invalid model ID",
})
}
@@ -282,7 +338,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
id, err := uuid.NewUUID()
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
"error": err.Error(),
})
}
@@ -291,30 +347,35 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
opcache.Set(galleryID, uid)
ctx, cancelFunc := context.WithCancel(context.Background())
op := services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
ID: uid,
Delete: true,
GalleryElementName: galleryName,
Galleries: appConfig.Galleries,
BackendGalleries: appConfig.BackendGalleries,
Context: ctx,
CancelFunc: cancelFunc,
}
// Store cancellation function immediately so queued operations can be cancelled
galleryService.StoreCancellation(uid, cancelFunc)
go func() {
galleryService.ModelGalleryChannel <- op
cl.RemoveModelConfig(galleryName)
}()
return c.JSON(fiber.Map{
return c.JSON(200, map[string]interface{}{
"jobID": uid,
"message": "Deletion started",
})
})
app.Post("/api/models/config/:id", func(c *fiber.Ctx) error {
galleryID := strings.Clone(c.Params("id"))
app.POST("/api/models/config/:id", func(c echo.Context) error {
galleryID := c.Param("id")
// URL decode the gallery ID
galleryID, err := url.QueryUnescape(galleryID)
if err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
return c.JSON(http.StatusBadRequest, map[string]interface{}{
"error": "invalid model ID",
})
}
@@ -322,44 +383,44 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
models, err := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.SystemState)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
"error": err.Error(),
})
}
model := gallery.FindGalleryElement(models, galleryID)
if model == nil {
return c.Status(fiber.StatusNotFound).JSON(fiber.Map{
return c.JSON(http.StatusNotFound, map[string]interface{}{
"error": "model not found",
})
}
config, err := gallery.GetGalleryConfigFromURL[gallery.ModelConfig](model.URL, appConfig.SystemState.Model.ModelsPath)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
"error": err.Error(),
})
}
_, err = gallery.InstallModel(appConfig.SystemState, model.Name, &config, model.Overrides, nil, false)
_, err = gallery.InstallModel(context.Background(), appConfig.SystemState, model.Name, &config, model.Overrides, nil, false)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
"error": err.Error(),
})
}
return c.JSON(fiber.Map{
return c.JSON(200, map[string]interface{}{
"message": "Configuration file saved",
})
})
app.Get("/api/models/job/:uid", func(c *fiber.Ctx) error {
jobUID := strings.Clone(c.Params("uid"))
app.GET("/api/models/job/:uid", func(c echo.Context) error {
jobUID := c.Param("uid")
status := galleryService.GetStatus(jobUID)
if status == nil {
// Job is queued but hasn't started processing yet
return c.JSON(fiber.Map{
return c.JSON(200, map[string]interface{}{
"progress": 0,
"message": "Operation queued",
"galleryElementName": "",
@@ -369,7 +430,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
})
}
response := fiber.Map{
response := map[string]interface{}{
"progress": status.Progress,
"message": status.Message,
"galleryElementName": status.GalleryElementName,
@@ -387,19 +448,25 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
response["completed"] = true
}
return c.JSON(response)
return c.JSON(200, response)
})
// Backend Gallery APIs
app.Get("/api/backends", func(c *fiber.Ctx) error {
term := c.Query("term")
page := c.Query("page", "1")
items := c.Query("items", "21")
app.GET("/api/backends", func(c echo.Context) error {
term := c.QueryParam("term")
page := c.QueryParam("page")
if page == "" {
page = "1"
}
items := c.QueryParam("items")
if items == "" {
items = "21"
}
backends, err := gallery.AvailableBackends(appConfig.BackendGalleries, appConfig.SystemState)
if err != nil {
log.Error().Err(err).Msg("could not list backends from galleries")
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
"error": err.Error(),
})
}
@@ -442,7 +509,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
}
// Convert backends to JSON-friendly format and deduplicate by ID
backendsJSON := make([]fiber.Map, 0, len(backends))
backendsJSON := make([]map[string]interface{}, 0, len(backends))
seenBackendIDs := make(map[string]bool)
for _, b := range backends {
@@ -466,7 +533,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
}
}
backendsJSON = append(backendsJSON, fiber.Map{
backendsJSON = append(backendsJSON, map[string]interface{}{
"id": backendID,
"name": b.Name,
"description": b.Description,
@@ -491,13 +558,21 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
nextPage = totalPages
}
return c.JSON(fiber.Map{
// Calculate installed backends count
installedBackends, err := gallery.ListSystemBackends(appConfig.SystemState)
installedBackendsCount := 0
if err == nil {
installedBackendsCount = len(installedBackends)
}
return c.JSON(200, map[string]interface{}{
"backends": backendsJSON,
"repositories": appConfig.BackendGalleries,
"allTags": tags,
"processingBackends": processingBackendsData,
"taskTypes": taskTypes,
"availableBackends": totalBackends,
"installedBackends": installedBackendsCount,
"currentPage": pageNum,
"totalPages": totalPages,
"prevPage": prevPage,
@@ -505,12 +580,12 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
})
})
app.Post("/api/backends/install/:id", func(c *fiber.Ctx) error {
backendID := strings.Clone(c.Params("id"))
app.POST("/api/backends/install/:id", func(c echo.Context) error {
backendID := c.Param("id")
// URL decode the backend ID
backendID, err := url.QueryUnescape(backendID)
if err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
return c.JSON(http.StatusBadRequest, map[string]interface{}{
"error": "invalid backend ID",
})
}
@@ -518,7 +593,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
id, err := uuid.NewUUID()
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
"error": err.Error(),
})
}
@@ -526,27 +601,32 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
uid := id.String()
opcache.Set(backendID, uid)
ctx, cancelFunc := context.WithCancel(context.Background())
op := services.GalleryOp[gallery.GalleryBackend, any]{
ID: uid,
GalleryElementName: backendID,
Galleries: appConfig.BackendGalleries,
Context: ctx,
CancelFunc: cancelFunc,
}
// Store cancellation function immediately so queued operations can be cancelled
galleryService.StoreCancellation(uid, cancelFunc)
go func() {
galleryService.BackendGalleryChannel <- op
}()
return c.JSON(fiber.Map{
return c.JSON(200, map[string]interface{}{
"jobID": uid,
"message": "Backend installation started",
})
})
app.Post("/api/backends/delete/:id", func(c *fiber.Ctx) error {
backendID := strings.Clone(c.Params("id"))
app.POST("/api/backends/delete/:id", func(c echo.Context) error {
backendID := c.Param("id")
// URL decode the backend ID
backendID, err := url.QueryUnescape(backendID)
if err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
return c.JSON(http.StatusBadRequest, map[string]interface{}{
"error": "invalid backend ID",
})
}
@@ -559,7 +639,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
id, err := uuid.NewUUID()
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
"error": err.Error(),
})
}
@@ -568,29 +648,34 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
opcache.Set(backendID, uid)
ctx, cancelFunc := context.WithCancel(context.Background())
op := services.GalleryOp[gallery.GalleryBackend, any]{
ID: uid,
Delete: true,
GalleryElementName: backendName,
Galleries: appConfig.BackendGalleries,
Context: ctx,
CancelFunc: cancelFunc,
}
// Store cancellation function immediately so queued operations can be cancelled
galleryService.StoreCancellation(uid, cancelFunc)
go func() {
galleryService.BackendGalleryChannel <- op
}()
return c.JSON(fiber.Map{
return c.JSON(200, map[string]interface{}{
"jobID": uid,
"message": "Backend deletion started",
})
})
app.Get("/api/backends/job/:uid", func(c *fiber.Ctx) error {
jobUID := strings.Clone(c.Params("uid"))
app.GET("/api/backends/job/:uid", func(c echo.Context) error {
jobUID := c.Param("uid")
status := galleryService.GetStatus(jobUID)
if status == nil {
// Job is queued but hasn't started processing yet
return c.JSON(fiber.Map{
return c.JSON(200, map[string]interface{}{
"progress": 0,
"message": "Operation queued",
"galleryElementName": "",
@@ -600,7 +685,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
})
}
response := fiber.Map{
response := map[string]interface{}{
"progress": status.Progress,
"message": status.Message,
"galleryElementName": status.GalleryElementName,
@@ -618,16 +703,16 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
response["completed"] = true
}
return c.JSON(response)
return c.JSON(200, response)
})
// System Backend Deletion API (for installed backends on index page)
app.Post("/api/backends/system/delete/:name", func(c *fiber.Ctx) error {
backendName := strings.Clone(c.Params("name"))
app.POST("/api/backends/system/delete/:name", func(c echo.Context) error {
backendName := c.Param("name")
// URL decode the backend name
backendName, err := url.QueryUnescape(backendName)
if err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
return c.JSON(http.StatusBadRequest, map[string]interface{}{
"error": "invalid backend name",
})
}
@@ -636,24 +721,24 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
// Use the gallery package to delete the backend
if err := gallery.DeleteBackendFromSystem(appConfig.SystemState, backendName); err != nil {
log.Error().Err(err).Msgf("Failed to delete backend: %s", backendName)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
"error": err.Error(),
})
}
return c.JSON(fiber.Map{
return c.JSON(200, map[string]interface{}{
"success": true,
"message": "Backend deleted successfully",
})
})
// P2P APIs
app.Get("/api/p2p/workers", func(c *fiber.Ctx) error {
app.GET("/api/p2p/workers", func(c echo.Context) error {
nodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID))
nodesJSON := make([]fiber.Map, 0, len(nodes))
nodesJSON := make([]map[string]interface{}, 0, len(nodes))
for _, n := range nodes {
nodesJSON = append(nodesJSON, fiber.Map{
nodesJSON = append(nodesJSON, map[string]interface{}{
"name": n.Name,
"id": n.ID,
"tunnelAddress": n.TunnelAddress,
@@ -663,17 +748,17 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
})
}
return c.JSON(fiber.Map{
return c.JSON(200, map[string]interface{}{
"nodes": nodesJSON,
})
})
app.Get("/api/p2p/federation", func(c *fiber.Ctx) error {
app.GET("/api/p2p/federation", func(c echo.Context) error {
nodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID))
nodesJSON := make([]fiber.Map, 0, len(nodes))
nodesJSON := make([]map[string]interface{}, 0, len(nodes))
for _, n := range nodes {
nodesJSON = append(nodesJSON, fiber.Map{
nodesJSON = append(nodesJSON, map[string]interface{}{
"name": n.Name,
"id": n.ID,
"tunnelAddress": n.TunnelAddress,
@@ -683,12 +768,12 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
})
}
return c.JSON(fiber.Map{
return c.JSON(200, map[string]interface{}{
"nodes": nodesJSON,
})
})
app.Get("/api/p2p/stats", func(c *fiber.Ctx) error {
app.GET("/api/p2p/stats", func(c echo.Context) error {
workerNodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID))
federatedNodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID))
@@ -706,12 +791,12 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
}
}
return c.JSON(fiber.Map{
"workers": fiber.Map{
return c.JSON(200, map[string]interface{}{
"workers": map[string]interface{}{
"online": workersOnline,
"total": len(workerNodes),
},
"federated": fiber.Map{
"federated": map[string]interface{}{
"online": federatedOnline,
"total": len(federatedNodes),
},

View File

@@ -1,24 +1,24 @@
package routes
import (
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/internal"
)
func registerBackendGalleryRoutes(app *fiber.App, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) {
func registerBackendGalleryRoutes(app *echo.Echo, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) {
// Show the Backends page (all backends are loaded client-side via Alpine.js)
app.Get("/browse/backends", func(c *fiber.Ctx) error {
summary := fiber.Map{
app.GET("/browse/backends", func(c echo.Context) error {
summary := map[string]interface{}{
"Title": "LocalAI - Backends",
"BaseURL": utils.BaseURL(c),
"BaseURL": middleware.BaseURL(c),
"Version": internal.PrintableVersion(),
"Repositories": appConfig.BackendGalleries,
}
// Render index - backends are now loaded via Alpine.js from /api/backends
return c.Render("views/backends", summary)
return c.Render(200, "views/backends", summary)
})
}

View File

@@ -1,24 +1,24 @@
package routes
import (
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/internal"
)
func registerGalleryRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) {
func registerGalleryRoutes(app *echo.Echo, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) {
app.Get("/browse", func(c *fiber.Ctx) error {
summary := fiber.Map{
app.GET("/browse/", func(c echo.Context) error {
summary := map[string]interface{}{
"Title": "LocalAI - Models",
"BaseURL": utils.BaseURL(c),
"BaseURL": middleware.BaseURL(c),
"Version": internal.PrintableVersion(),
"Repositories": appConfig.Galleries,
}
// Render index - models are now loaded via Alpine.js from /api/models
return c.Render("views/models", summary)
return c.Render(200, "views/models", summary)
})
}

View File

@@ -0,0 +1 @@
!function(e){"object"==typeof exports&&"object"==typeof module?e(require("../../lib/codemirror")):"function"==typeof define&&define.amd?define(["../../lib/codemirror"],e):e(CodeMirror)}(function(u){"use strict";function f(e,t){clearTimeout(t.timeout),u.off(window,"mouseup",t.hurry),u.off(window,"keyup",t.hurry)}u.defineOption("autoRefresh",!1,function(e,t){function o(){i.display.wrapper.offsetHeight?(f(0,r),i.display.lastWrapHeight!=i.display.wrapper.clientHeight&&i.refresh()):r.timeout=setTimeout(o,r.delay)}var i,r;e.state.autoRefresh&&(f(0,e.state.autoRefresh),e.state.autoRefresh=null),t&&0==e.display.wrapper.offsetHeight&&((r=(i=e).state.autoRefresh={delay:t.delay||250}).timeout=setTimeout(o,r.delay),r.hurry=function(){clearTimeout(r.timeout),r.timeout=setTimeout(o,50)},u.on(window,"mouseup",r.hurry),u.on(window,"keyup",r.hurry))})});

View File

File diff suppressed because one or more lines are too long

View File

File diff suppressed because one or more lines are too long

View File

File diff suppressed because one or more lines are too long

22
core/http/static/assets/pdf.min.js vendored Normal file
View File

File diff suppressed because one or more lines are too long

View File

File diff suppressed because one or more lines are too long

1
core/http/static/assets/yaml.min.js vendored Normal file
View File

@@ -0,0 +1 @@
!function(e){"object"==typeof exports&&"object"==typeof module?e(require("../../lib/codemirror")):"function"==typeof define&&define.amd?define(["../../lib/codemirror"],e):e(CodeMirror)}(function(e){"use strict";e.defineMode("yaml",function(){var n=new RegExp("\\b(("+["true","false","on","off","yes","no"].join(")|(")+"))$","i");return{token:function(e,i){var t=e.peek(),r=i.escaped;if(i.escaped=!1,"#"==t&&(0==e.pos||/\s/.test(e.string.charAt(e.pos-1))))return e.skipToEnd(),"comment";if(e.match(/^('([^']|\\.)*'?|"([^"]|\\.)*"?)/))return"string";if(i.literal&&e.indentation()>i.keyCol)return e.skipToEnd(),"string";if(i.literal&&(i.literal=!1),e.sol()){if(i.keyCol=0,i.pair=!1,i.pairStart=!1,e.match("---"))return"def";if(e.match("..."))return"def";if(e.match(/\s*-\s+/))return"meta"}if(e.match(/^(\{|\}|\[|\])/))return"{"==t?i.inlinePairs++:"}"==t?i.inlinePairs--:"["==t?i.inlineList++:i.inlineList--,"meta";if(0<i.inlineList&&!r&&","==t)return e.next(),"meta";if(0<i.inlinePairs&&!r&&","==t)return i.keyCol=0,i.pair=!1,i.pairStart=!1,e.next(),"meta";if(i.pairStart){if(e.match(/^\s*(\||\>)\s*/))return i.literal=!0,"meta";if(e.match(/^\s*(\&|\*)[a-z0-9\._-]+\b/i))return"variable-2";if(0==i.inlinePairs&&e.match(/^\s*-?[0-9\.\,]+\s?$/))return"number";if(0<i.inlinePairs&&e.match(/^\s*-?[0-9\.\,]+\s?(?=(,|}))/))return"number";if(e.match(n))return"keyword"}return!i.pair&&e.match(/^\s*(?:[,\[\]{}&*!|>'"%@`][^\s'":]|[^,\[\]{}#&*!|>'"%@`])[^#]*?(?=\s*:($|\s))/)?(i.pair=!0,i.keyCol=e.indentation(),"atom"):i.pair&&e.match(/^:\s*/)?(i.pairStart=!0,"meta"):(i.pairStart=!1,i.escaped="\\"==t,e.next(),null)},startState:function(){return{pair:!1,pairStart:!1,keyCol:0,inlinePairs:0,inlineList:0,literal:!1,escaped:!1}},lineComment:"#",fold:"indent"}}),e.defineMIME("text/x-yaml","yaml"),e.defineMIME("text/yaml","yaml")});

View File

@@ -90,6 +90,23 @@ function updateTokensPerSecond() {
}
}
function scrollThinkingBoxToBottom() {
// Find all thinking/reasoning message containers that are expanded
const thinkingBoxes = document.querySelectorAll('[data-thinking-box]');
thinkingBoxes.forEach(box => {
// Only scroll if the box is visible (expanded) and has overflow
if (box.offsetParent !== null && box.scrollHeight > box.clientHeight) {
box.scrollTo({
top: box.scrollHeight,
behavior: 'smooth'
});
}
});
}
// Make function available globally
window.scrollThinkingBoxToBottom = scrollThinkingBoxToBottom;
function stopRequest() {
if (currentAbortController) {
currentAbortController.abort();
@@ -160,6 +177,9 @@ var images = [];
var audios = [];
var fileContents = [];
var currentFileNames = [];
// Track file names to data URLs for proper removal
var imageFileMap = new Map(); // fileName -> dataURL
var audioFileMap = new Map(); // fileName -> dataURL
async function extractTextFromPDF(pdfData) {
try {
@@ -180,35 +200,119 @@ async function extractTextFromPDF(pdfData) {
}
}
// Global function to handle file selection and update Alpine.js state
window.handleFileSelection = function(event, fileType) {
if (!event.target.files || !event.target.files.length) return;
// Get the Alpine.js component - find the parent div with x-data containing attachedFiles
let inputContainer = event.target.closest('[x-data*="attachedFiles"]');
if (!inputContainer && window.Alpine) {
// Fallback: find any element with attachedFiles in x-data
inputContainer = document.querySelector('[x-data*="attachedFiles"]');
}
if (!inputContainer || !window.Alpine) return;
const alpineData = Alpine.$data(inputContainer);
if (!alpineData || !alpineData.attachedFiles) return;
Array.from(event.target.files).forEach(file => {
// Check if file already exists
const exists = alpineData.attachedFiles.some(f => f.name === file.name && f.type === fileType);
if (!exists) {
alpineData.attachedFiles.push({ name: file.name, type: fileType });
// Process the file based on type
if (fileType === 'image') {
readInputImageFile(file);
} else if (fileType === 'audio') {
readInputAudioFile(file);
} else if (fileType === 'file') {
readInputFileFile(file);
}
}
});
};
// Global function to remove file from input
window.removeFileFromInput = function(fileType, fileName) {
// Remove from arrays
if (fileType === 'image') {
// Remove from images array using the mapping
const dataURL = imageFileMap.get(fileName);
if (dataURL) {
const imageIndex = images.indexOf(dataURL);
if (imageIndex !== -1) {
images.splice(imageIndex, 1);
}
imageFileMap.delete(fileName);
}
} else if (fileType === 'audio') {
// Remove from audios array using the mapping
const dataURL = audioFileMap.get(fileName);
if (dataURL) {
const audioIndex = audios.indexOf(dataURL);
if (audioIndex !== -1) {
audios.splice(audioIndex, 1);
}
audioFileMap.delete(fileName);
}
} else if (fileType === 'file') {
// Remove from fileContents and currentFileNames
const fileIndex = currentFileNames.indexOf(fileName);
if (fileIndex !== -1) {
currentFileNames.splice(fileIndex, 1);
fileContents.splice(fileIndex, 1);
}
}
// Also remove from the actual input element
const inputId = fileType === 'image' ? 'input_image' :
fileType === 'audio' ? 'input_audio' : 'input_file';
const input = document.getElementById(inputId);
if (input && input.files) {
const dt = new DataTransfer();
Array.from(input.files).forEach(file => {
if (file.name !== fileName) {
dt.items.add(file);
}
});
input.files = dt.files;
}
};
function readInputFile() {
if (!this.files || !this.files.length) return;
Array.from(this.files).forEach(file => {
const FR = new FileReader();
currentFileNames.push(file.name);
const fileExtension = file.name.split('.').pop().toLowerCase();
FR.addEventListener("load", async function(evt) {
if (fileExtension === 'pdf') {
try {
const content = await extractTextFromPDF(evt.target.result);
fileContents.push({ name: file.name, content: content });
} catch (error) {
console.error('Error processing PDF:', error);
fileContents.push({ name: file.name, content: "Error processing PDF file" });
}
} else {
// For text and markdown files
fileContents.push({ name: file.name, content: evt.target.result });
}
});
readInputFileFile(file);
});
}
function readInputFileFile(file) {
const FR = new FileReader();
currentFileNames.push(file.name);
const fileExtension = file.name.split('.').pop().toLowerCase();
FR.addEventListener("load", async function(evt) {
if (fileExtension === 'pdf') {
FR.readAsArrayBuffer(file);
try {
const content = await extractTextFromPDF(evt.target.result);
fileContents.push({ name: file.name, content: content });
} catch (error) {
console.error('Error processing PDF:', error);
fileContents.push({ name: file.name, content: "Error processing PDF file" });
}
} else {
FR.readAsText(file);
// For text and markdown files
fileContents.push({ name: file.name, content: evt.target.result });
}
});
if (fileExtension === 'pdf') {
FR.readAsArrayBuffer(file);
} else {
FR.readAsText(file);
}
}
function submitPrompt(event) {
@@ -267,7 +371,15 @@ function processAndSendMessage(inputValue) {
const input = document.getElementById("input");
if (input) input.value = "";
const systemPrompt = localStorage.getItem("system_prompt");
Alpine.nextTick(() => { document.getElementById('messages').scrollIntoView(false); });
Alpine.nextTick(() => {
const chatContainer = document.getElementById('chat');
if (chatContainer) {
chatContainer.scrollTo({
top: chatContainer.scrollHeight,
behavior: 'smooth'
});
}
});
// Reset token tracking before starting new request
requestStartTime = Date.now();
@@ -278,36 +390,66 @@ function processAndSendMessage(inputValue) {
// Reset file contents and names after sending
fileContents = [];
currentFileNames = [];
images = [];
audios = [];
imageFileMap.clear();
audioFileMap.clear();
// Clear Alpine.js attachedFiles array
const inputContainer = document.querySelector('[x-data*="attachedFiles"]');
if (inputContainer && window.Alpine) {
const alpineData = Alpine.$data(inputContainer);
if (alpineData && alpineData.attachedFiles) {
alpineData.attachedFiles = [];
}
}
// Clear file inputs
document.getElementById("input_image").value = null;
document.getElementById("input_audio").value = null;
document.getElementById("input_file").value = null;
}
function readInputImage() {
if (!this.files || !this.files.length) return;
Array.from(this.files).forEach(file => {
const FR = new FileReader();
FR.addEventListener("load", function(evt) {
images.push(evt.target.result);
});
FR.readAsDataURL(file);
readInputImageFile(file);
});
}
function readInputImageFile(file) {
const FR = new FileReader();
FR.addEventListener("load", function(evt) {
const dataURL = evt.target.result;
images.push(dataURL);
imageFileMap.set(file.name, dataURL);
});
FR.readAsDataURL(file);
}
function readInputAudio() {
if (!this.files || !this.files.length) return;
Array.from(this.files).forEach(file => {
const FR = new FileReader();
FR.addEventListener("load", function(evt) {
audios.push(evt.target.result);
});
FR.readAsDataURL(file);
readInputAudioFile(file);
});
}
function readInputAudioFile(file) {
const FR = new FileReader();
FR.addEventListener("load", function(evt) {
const dataURL = evt.target.result;
audios.push(dataURL);
audioFileMap.set(file.name, dataURL);
});
FR.readAsDataURL(file);
}
async function promptGPT(systemPrompt, input) {
const model = document.getElementById("chat-model").value;
const mcpMode = Alpine.store("chat").mcpMode;
@@ -370,25 +512,18 @@ async function promptGPT(systemPrompt, input) {
}
});
// reset the form and the files
images = [];
audios = [];
document.getElementById("input_image").value = null;
document.getElementById("input_audio").value = null;
document.getElementById("input_file").value = null;
document.getElementById("fileName").innerHTML = "";
// reset the form and the files (already done in processAndSendMessage)
// images, audios, and file inputs are cleared after sending
// Choose endpoint based on MCP mode
const endpoint = mcpMode ? "mcp/v1/chat/completions" : "v1/chat/completions";
const endpoint = mcpMode ? "v1/mcp/chat/completions" : "v1/chat/completions";
const requestBody = {
model: model,
messages: messages,
};
// Only add stream parameter for regular chat (MCP doesn't support streaming)
if (!mcpMode) {
requestBody.stream = true;
}
// Add stream parameter for both regular chat and MCP (MCP now supports SSE streaming)
requestBody.stream = true;
let response;
try {
@@ -444,64 +579,441 @@ async function promptGPT(systemPrompt, input) {
return;
}
// Handle streaming response (both regular and MCP mode now use SSE)
if (mcpMode) {
// Handle MCP non-streaming response
// Handle MCP SSE streaming with new event types
const reader = response.body
?.pipeThrough(new TextDecoderStream())
.getReader();
if (!reader) {
Alpine.store("chat").add(
"assistant",
`<span class='error'>Error: Failed to decode MCP API response</span>`,
);
toggleLoader(false);
return;
}
// Store reader globally so stop button can cancel it
currentReader = reader;
let buffer = "";
let assistantContent = "";
let assistantContentBuffer = [];
let thinkingContent = "";
let isThinking = false;
let lastAssistantMessageIndex = -1;
let lastThinkingMessageIndex = -1;
let lastThinkingScrollTime = 0;
const THINKING_SCROLL_THROTTLE = 200; // Throttle scrolling to every 200ms
try {
const data = await response.json();
// Update token usage if present
if (data.usage) {
Alpine.store("chat").updateTokenUsage(data.usage);
}
// MCP endpoint returns content in choices[0].message.content (chat completion format)
// Fallback to choices[0].text for backward compatibility (completion format)
const content = data.choices[0]?.message?.content || data.choices[0]?.text || "";
if (!content && (!data.choices || data.choices.length === 0)) {
Alpine.store("chat").add(
"assistant",
`<span class='error'>Error: Empty response from MCP endpoint</span>`,
);
toggleLoader(false);
return;
}
if (content) {
// Count tokens for rate calculation (MCP mode - full content at once)
// Prefer actual token count from API if available
if (data.usage && data.usage.completion_tokens) {
tokensReceived = data.usage.completion_tokens;
} else {
tokensReceived += Math.ceil(content.length / 4);
while (true) {
const { value, done } = await reader.read();
if (done) break;
buffer += value;
let lines = buffer.split("\n");
buffer = lines.pop(); // Retain any incomplete line in the buffer
lines.forEach((line) => {
if (line.length === 0 || line.startsWith(":")) return;
if (line === "data: [DONE]") {
return;
}
if (line.startsWith("data: ")) {
try {
const eventData = JSON.parse(line.substring(6));
// Handle different event types
switch (eventData.type) {
case "reasoning":
if (eventData.content) {
const chatStore = Alpine.store("chat");
// Insert reasoning before assistant message if it exists
if (lastAssistantMessageIndex >= 0 && chatStore.history[lastAssistantMessageIndex]?.role === "assistant") {
chatStore.history.splice(lastAssistantMessageIndex, 0, {
role: "reasoning",
content: eventData.content,
html: DOMPurify.sanitize(marked.parse(eventData.content)),
image: [],
audio: [],
expanded: false // Reasoning is always collapsed
});
lastAssistantMessageIndex++; // Adjust index since we inserted
// Scroll smoothly after adding reasoning
setTimeout(() => {
const chatContainer = document.getElementById('chat');
if (chatContainer) {
chatContainer.scrollTo({
top: chatContainer.scrollHeight,
behavior: 'smooth'
});
}
}, 100);
} else {
// No assistant message yet, just add normally
chatStore.add("reasoning", eventData.content);
}
}
break;
case "tool_call":
if (eventData.name) {
// Store as JSON for better formatting
const toolCallData = {
name: eventData.name,
arguments: eventData.arguments || {},
reasoning: eventData.reasoning || ""
};
Alpine.store("chat").add("tool_call", JSON.stringify(toolCallData, null, 2));
// Scroll smoothly after adding tool call
setTimeout(() => {
const chatContainer = document.getElementById('chat');
if (chatContainer) {
chatContainer.scrollTo({
top: chatContainer.scrollHeight,
behavior: 'smooth'
});
}
}, 100);
}
break;
case "tool_result":
if (eventData.name) {
// Store as JSON for better formatting
const toolResultData = {
name: eventData.name,
result: eventData.result || ""
};
Alpine.store("chat").add("tool_result", JSON.stringify(toolResultData, null, 2));
// Scroll smoothly after adding tool result
setTimeout(() => {
const chatContainer = document.getElementById('chat');
if (chatContainer) {
chatContainer.scrollTo({
top: chatContainer.scrollHeight,
behavior: 'smooth'
});
}
}, 100);
}
break;
case "status":
// Status messages can be logged but not necessarily displayed
console.log("[MCP Status]", eventData.message);
break;
case "assistant":
if (eventData.content) {
assistantContent += eventData.content;
const contentChunk = eventData.content;
// Count tokens for rate calculation
tokensReceived += Math.ceil(contentChunk.length / 4);
updateTokensPerSecond();
// Check for thinking tags in the chunk (incremental detection)
if (contentChunk.includes("<thinking>") || contentChunk.includes("<think>")) {
isThinking = true;
thinkingContent = "";
lastThinkingMessageIndex = -1;
}
if (contentChunk.includes("</thinking>") || contentChunk.includes("</think>")) {
isThinking = false;
// When closing tag is detected, process the accumulated thinking content
if (thinkingContent.trim()) {
// Extract just the thinking part from the accumulated content
const thinkingMatch = thinkingContent.match(/<(?:thinking|redacted_reasoning)>(.*?)<\/(?:thinking|redacted_reasoning)>/s);
if (thinkingMatch && thinkingMatch[1]) {
const extractedThinking = thinkingMatch[1];
const chatStore = Alpine.store("chat");
const isMCPMode = chatStore.mcpMode || false;
const shouldExpand = !isMCPMode; // Expanded in non-MCP mode, collapsed in MCP mode
if (lastThinkingMessageIndex === -1) {
// Insert thinking before the last assistant message if it exists
if (lastAssistantMessageIndex >= 0 && chatStore.history[lastAssistantMessageIndex]?.role === "assistant") {
// Insert before assistant message
chatStore.history.splice(lastAssistantMessageIndex, 0, {
role: "thinking",
content: extractedThinking,
html: DOMPurify.sanitize(marked.parse(extractedThinking)),
image: [],
audio: [],
expanded: shouldExpand
});
lastThinkingMessageIndex = lastAssistantMessageIndex;
lastAssistantMessageIndex++; // Adjust index since we inserted
} else {
// No assistant message yet, just add normally
chatStore.add("thinking", extractedThinking);
lastThinkingMessageIndex = chatStore.history.length - 1;
}
} else {
// Update existing thinking message
const lastMessage = chatStore.history[lastThinkingMessageIndex];
if (lastMessage && lastMessage.role === "thinking") {
lastMessage.content = extractedThinking;
lastMessage.html = DOMPurify.sanitize(marked.parse(extractedThinking));
}
}
// Scroll when thinking is finalized in non-MCP mode
if (!isMCPMode) {
setTimeout(() => {
const chatContainer = document.getElementById('chat');
if (chatContainer) {
chatContainer.scrollTo({
top: chatContainer.scrollHeight,
behavior: 'smooth'
});
}
}, 50);
}
}
thinkingContent = "";
}
}
// Handle content based on thinking state
if (isThinking) {
thinkingContent += contentChunk;
const chatStore = Alpine.store("chat");
const isMCPMode = chatStore.mcpMode || false;
const shouldExpand = !isMCPMode; // Expanded in non-MCP mode, collapsed in MCP mode
// Update the last thinking message or create a new one (incremental)
if (lastThinkingMessageIndex === -1) {
// Insert thinking before the last assistant message if it exists
if (lastAssistantMessageIndex >= 0 && chatStore.history[lastAssistantMessageIndex]?.role === "assistant") {
// Insert before assistant message
chatStore.history.splice(lastAssistantMessageIndex, 0, {
role: "thinking",
content: thinkingContent,
html: DOMPurify.sanitize(marked.parse(thinkingContent)),
image: [],
audio: [],
expanded: shouldExpand
});
lastThinkingMessageIndex = lastAssistantMessageIndex;
lastAssistantMessageIndex++; // Adjust index since we inserted
} else {
// No assistant message yet, just add normally
chatStore.add("thinking", thinkingContent);
lastThinkingMessageIndex = chatStore.history.length - 1;
}
} else {
// Update existing thinking message
const lastMessage = chatStore.history[lastThinkingMessageIndex];
if (lastMessage && lastMessage.role === "thinking") {
lastMessage.content = thinkingContent;
lastMessage.html = DOMPurify.sanitize(marked.parse(thinkingContent));
}
}
// Scroll when thinking is updated in non-MCP mode (throttled)
if (!isMCPMode) {
const now = Date.now();
if (now - lastThinkingScrollTime > THINKING_SCROLL_THROTTLE) {
lastThinkingScrollTime = now;
setTimeout(() => {
const chatContainer = document.getElementById('chat');
if (chatContainer) {
chatContainer.scrollTo({
top: chatContainer.scrollHeight,
behavior: 'smooth'
});
}
}, 100);
}
}
} else {
// Regular assistant content - buffer it for batch processing
assistantContentBuffer.push(contentChunk);
}
}
break;
case "error":
Alpine.store("chat").add(
"assistant",
`<span class='error'>MCP Error: ${eventData.message}</span>`,
);
break;
}
} catch (error) {
console.error("Failed to parse MCP event:", line, error);
}
}
});
// Efficiently update assistant message in batch
if (assistantContentBuffer.length > 0) {
const regularContent = assistantContentBuffer.join("");
// Process any thinking tags that might be in the accumulated content
// This handles cases where tags are split across chunks
const { regularContent: processedRegular, thinkingContent: processedThinking } = processThinkingTags(regularContent);
// Update or create assistant message with processed regular content
if (lastAssistantMessageIndex === -1) {
if (processedRegular && processedRegular.trim()) {
Alpine.store("chat").add("assistant", processedRegular);
lastAssistantMessageIndex = Alpine.store("chat").history.length - 1;
}
} else {
const chatStore = Alpine.store("chat");
const lastMessage = chatStore.history[lastAssistantMessageIndex];
if (lastMessage && lastMessage.role === "assistant") {
lastMessage.content = (lastMessage.content || "") + (processedRegular || "");
lastMessage.html = DOMPurify.sanitize(marked.parse(lastMessage.content));
}
}
// Add any extracted thinking content from the processed buffer BEFORE assistant message
if (processedThinking && processedThinking.trim()) {
const chatStore = Alpine.store("chat");
const isMCPMode = chatStore.mcpMode || false;
const shouldExpand = !isMCPMode; // Expanded in non-MCP mode, collapsed in MCP mode
// Insert thinking before assistant message if it exists
if (lastAssistantMessageIndex >= 0 && chatStore.history[lastAssistantMessageIndex]?.role === "assistant") {
chatStore.history.splice(lastAssistantMessageIndex, 0, {
role: "thinking",
content: processedThinking,
html: DOMPurify.sanitize(marked.parse(processedThinking)),
image: [],
audio: [],
expanded: shouldExpand
});
lastAssistantMessageIndex++; // Adjust index since we inserted
} else {
// No assistant message yet, just add normally
chatStore.add("thinking", processedThinking);
}
}
assistantContentBuffer = [];
}
updateTokensPerSecond();
}
// Final assistant content flush if any data remains
if (assistantContentBuffer.length > 0) {
const regularContent = assistantContentBuffer.join("");
// Process any remaining thinking tags that might be in the buffer
const { regularContent: processedRegular, thinkingContent: processedThinking } = processThinkingTags(regularContent);
// Process thinking tags using shared function
const { regularContent, thinkingContent } = processThinkingTags(content);
const chatStore = Alpine.store("chat");
// Add thinking content if present
if (thinkingContent) {
// First, add any extracted thinking content BEFORE assistant message
if (processedThinking && processedThinking.trim()) {
const isMCPMode = chatStore.mcpMode || false;
const shouldExpand = !isMCPMode; // Expanded in non-MCP mode, collapsed in MCP mode
// Insert thinking before assistant message if it exists
if (lastAssistantMessageIndex >= 0 && chatStore.history[lastAssistantMessageIndex]?.role === "assistant") {
chatStore.history.splice(lastAssistantMessageIndex, 0, {
role: "thinking",
content: processedThinking,
html: DOMPurify.sanitize(marked.parse(processedThinking)),
image: [],
audio: [],
expanded: shouldExpand
});
lastAssistantMessageIndex++; // Adjust index since we inserted
} else {
// No assistant message yet, just add normally
chatStore.add("thinking", processedThinking);
}
}
// Then update or create assistant message
if (lastAssistantMessageIndex !== -1) {
const lastMessage = chatStore.history[lastAssistantMessageIndex];
if (lastMessage && lastMessage.role === "assistant") {
lastMessage.content = (lastMessage.content || "") + (processedRegular || "");
lastMessage.html = DOMPurify.sanitize(marked.parse(lastMessage.content));
}
} else if (processedRegular && processedRegular.trim()) {
chatStore.add("assistant", processedRegular);
lastAssistantMessageIndex = chatStore.history.length - 1;
}
}
// Final thinking content flush if any data remains (from incremental detection)
if (thinkingContent.trim() && lastThinkingMessageIndex === -1) {
// Extract thinking content if tags are present
const thinkingMatch = thinkingContent.match(/<(?:thinking|redacted_reasoning)>(.*?)<\/(?:thinking|redacted_reasoning)>/s);
if (thinkingMatch && thinkingMatch[1]) {
const chatStore = Alpine.store("chat");
const isMCPMode = chatStore.mcpMode || false;
const shouldExpand = !isMCPMode; // Expanded in non-MCP mode, collapsed in MCP mode
// Insert thinking before assistant message if it exists
if (lastAssistantMessageIndex >= 0 && chatStore.history[lastAssistantMessageIndex]?.role === "assistant") {
chatStore.history.splice(lastAssistantMessageIndex, 0, {
role: "thinking",
content: thinkingMatch[1],
html: DOMPurify.sanitize(marked.parse(thinkingMatch[1])),
image: [],
audio: [],
expanded: shouldExpand
});
} else {
// No assistant message yet, just add normally
chatStore.add("thinking", thinkingMatch[1]);
}
} else {
Alpine.store("chat").add("thinking", thinkingContent);
}
// Add regular content if present
if (regularContent) {
Alpine.store("chat").add("assistant", regularContent);
}
}
// Highlight all code blocks
// Final pass: process the entire assistantContent to catch any missed thinking tags
// This ensures we don't miss tags that were split across chunks
if (assistantContent.trim()) {
const { regularContent: finalRegular, thinkingContent: finalThinking } = processThinkingTags(assistantContent);
// Update assistant message with final processed content (without thinking tags)
if (finalRegular && finalRegular.trim()) {
if (lastAssistantMessageIndex !== -1) {
const chatStore = Alpine.store("chat");
const lastMessage = chatStore.history[lastAssistantMessageIndex];
if (lastMessage && lastMessage.role === "assistant") {
lastMessage.content = finalRegular;
lastMessage.html = DOMPurify.sanitize(marked.parse(lastMessage.content));
}
} else {
Alpine.store("chat").add("assistant", finalRegular);
}
}
// Add any extracted thinking content (only if not already added)
if (finalThinking && finalThinking.trim()) {
const hasThinking = Alpine.store("chat").history.some(msg =>
msg.role === "thinking" && msg.content.trim() === finalThinking.trim()
);
if (!hasThinking) {
Alpine.store("chat").add("thinking", finalThinking);
}
}
}
// Highlight all code blocks once at the end
hljs.highlightAll();
} catch (error) {
// Don't show error if request was aborted by user
if (error.name !== 'AbortError' || currentAbortController) {
if (error.name !== 'AbortError' || !currentAbortController) {
Alpine.store("chat").add(
"assistant",
`<span class='error'>Error: Failed to parse MCP response</span>`,
`<span class='error'>Error: Failed to process MCP stream</span>`,
);
}
} finally {
// Perform any cleanup if necessary
if (reader) {
reader.releaseLock();
}
currentReader = null;
currentAbortController = null;
}
} else {
@@ -539,6 +1051,8 @@ async function promptGPT(systemPrompt, input) {
let thinkingContent = "";
let isThinking = false;
let lastThinkingMessageIndex = -1;
let lastThinkingScrollTime = 0;
const THINKING_SCROLL_THROTTLE = 200; // Throttle scrolling to every 200ms
try {
while (true) {
@@ -606,6 +1120,23 @@ async function promptGPT(systemPrompt, input) {
lastMessage.html = DOMPurify.sanitize(marked.parse(thinkingContent));
}
}
// Scroll when thinking is updated (throttled)
const now = Date.now();
if (now - lastThinkingScrollTime > THINKING_SCROLL_THROTTLE) {
lastThinkingScrollTime = now;
setTimeout(() => {
// Scroll main chat container
const chatContainer = document.getElementById('chat');
if (chatContainer) {
chatContainer.scrollTo({
top: chatContainer.scrollHeight,
behavior: 'smooth'
});
}
// Scroll thinking box to bottom if it's expanded and scrollable
scrollThinkingBoxToBottom();
}, 100);
}
} else {
contentBuffer.push(token);
}
@@ -620,6 +1151,16 @@ async function promptGPT(systemPrompt, input) {
if (contentBuffer.length > 0) {
addToChat(contentBuffer.join(""));
contentBuffer = [];
// Scroll when assistant content is updated (this will also show thinking messages above)
setTimeout(() => {
const chatContainer = document.getElementById('chat');
if (chatContainer) {
chatContainer.scrollTo({
top: chatContainer.scrollHeight,
behavior: 'smooth'
});
}
}, 50);
}
}
@@ -654,8 +1195,17 @@ async function promptGPT(systemPrompt, input) {
// Remove class "loader" from the element with "loader" id
toggleLoader(false);
// scroll to the bottom of the chat
document.getElementById('messages').scrollIntoView(false)
// scroll to the bottom of the chat consistently
setTimeout(() => {
const chatContainer = document.getElementById('chat');
if (chatContainer) {
chatContainer.scrollTo({
top: chatContainer.scrollHeight,
behavior: 'smooth'
});
}
}, 100);
// set focus to the input
document.getElementById("input").focus();
}
@@ -748,8 +1298,8 @@ document.addEventListener("alpine:init", () => {
},
add(role, content, image, audio) {
const N = this.history.length - 1;
// For thinking messages, always create a new message
if (role === "thinking") {
// For thinking and reasoning messages, always create a new message
if (role === "thinking" || role === "reasoning") {
let c = "";
const lines = content.split("\n");
lines.forEach((line) => {
@@ -784,7 +1334,21 @@ document.addEventListener("alpine:init", () => {
audio: audio || []
});
}
document.getElementById('messages').scrollIntoView(false);
const chatContainer = document.getElementById('chat');
if (chatContainer) {
chatContainer.scrollTo({
top: chatContainer.scrollHeight,
behavior: 'smooth'
});
}
// Also scroll thinking box if it's a thinking/reasoning message
if (role === "thinking" || role === "reasoning") {
setTimeout(() => {
if (typeof window.scrollThinkingBoxToBottom === 'function') {
window.scrollThinkingBoxToBottom();
}
}, 100);
}
const parser = new DOMParser();
const html = parser.parseFromString(
this.history[this.history.length - 1].html,
@@ -812,3 +1376,56 @@ document.addEventListener("alpine:init", () => {
});
}
});
// Check for message from index page on load
document.addEventListener('DOMContentLoaded', function() {
// Wait for Alpine to be ready
setTimeout(() => {
const chatData = localStorage.getItem('localai_index_chat_data');
if (chatData) {
try {
const data = JSON.parse(chatData);
const input = document.getElementById('input');
if (input && data.message) {
// Set the message in the input
input.value = data.message;
// Process files if any
if (data.imageFiles && data.imageFiles.length > 0) {
data.imageFiles.forEach(file => {
images.push(file.data);
});
}
if (data.audioFiles && data.audioFiles.length > 0) {
data.audioFiles.forEach(file => {
audios.push(file.data);
});
}
if (data.textFiles && data.textFiles.length > 0) {
data.textFiles.forEach(file => {
fileContents.push({ name: file.name, content: file.data });
currentFileNames.push(file.name);
});
}
// Clear localStorage
localStorage.removeItem('localai_index_chat_data');
// Auto-submit after a short delay to ensure everything is ready
setTimeout(() => {
if (input.value.trim()) {
processAndSendMessage(input.value);
}
}, 500);
}
} catch (error) {
console.error('Error processing chat data from index:', error);
localStorage.removeItem('localai_index_chat_data');
}
}
}, 300);
});

View File

@@ -1,24 +0,0 @@
package utils
import (
"strings"
"github.com/gofiber/fiber/v2"
)
// BaseURL returns the base URL for the given HTTP request context.
// It takes into account that the app may be exposed by a reverse-proxy under a different protocol, host and path.
// The returned URL is guaranteed to end with `/`.
// The method should be used in conjunction with the StripPathPrefix middleware.
func BaseURL(c *fiber.Ctx) string {
path := c.Path()
origPath := c.OriginalURL()
if path != origPath && strings.HasSuffix(origPath, path) {
pathPrefix := origPath[:len(origPath)-len(path)+1]
return c.BaseURL() + pathPrefix
}
return c.BaseURL() + "/"
}

View File

@@ -1,48 +0,0 @@
package utils
import (
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/require"
)
func TestBaseURL(t *testing.T) {
for _, tc := range []struct {
name string
prefix string
expectURL string
}{
{
name: "without prefix",
prefix: "/",
expectURL: "http://example.com/",
},
{
name: "with prefix",
prefix: "/myprefix/",
expectURL: "http://example.com/myprefix/",
},
} {
t.Run(tc.name, func(t *testing.T) {
app := fiber.New()
actualURL := ""
app.Get(tc.prefix+"hello/world", func(c *fiber.Ctx) error {
if tc.prefix != "/" {
c.Path("/hello/world")
}
actualURL = BaseURL(c)
return nil
})
req := httptest.NewRequest("GET", tc.prefix+"hello/world", nil)
resp, err := app.Test(req, -1)
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode, "response status code")
require.Equal(t, tc.expectURL, actualURL, "base URL")
})
}
}

View File

@@ -9,9 +9,9 @@
<div class="container mx-auto px-4 py-8 flex-grow">
<!-- Error Section -->
<div class="bg-[#1E293B] border border-[#38BDF8]/20 rounded-2xl shadow-2xl shadow-[#38BDF8]/10 p-8 mb-10">
<div class="bg-[#1E293B] border border-[#38BDF8]/20 rounded-xl p-8 mb-10">
<div class="max-w-4xl mx-auto text-center">
<div class="mb-6 text-6xl text-[#38BDF8] animate-pulse">
<div class="mb-6 text-6xl text-[#38BDF8]">
<i class="fas fa-exclamation-circle"></i>
</div>
<h1 class="text-4xl md:text-5xl font-bold text-[#E5E7EB] mb-4">
@@ -22,23 +22,21 @@
<p class="text-xl text-[#94A3B8] mb-6">The page you're looking for doesn't exist or has been moved</p>
<div class="flex flex-wrap justify-center gap-4">
<a href="./"
class="group flex items-center bg-[#38BDF8] hover:bg-[#38BDF8]/90 text-[#101827] font-semibold py-3 px-6 rounded-lg transition duration-300 ease-in-out transform hover:scale-105 hover:shadow-[0_0_20px_rgba(56,189,248,0.4)]">
class="inline-flex items-center bg-[#38BDF8] hover:bg-[#38BDF8]/90 text-[#101827] font-semibold py-3 px-6 rounded-lg transition-colors">
<i class="fas fa-home mr-2"></i>
<span>Return Home</span>
<i class="fas fa-arrow-right opacity-0 group-hover:opacity-100 group-hover:translate-x-2 ml-2 transition-all duration-300"></i>
</a>
<a href="browse"
class="group flex items-center bg-[#8B5CF6] hover:bg-[#8B5CF6]/90 text-white font-semibold py-3 px-6 rounded-lg transition duration-300 ease-in-out transform hover:scale-105 hover:shadow-[0_0_20px_rgba(139,92,246,0.4)]">
class="inline-flex items-center bg-[#8B5CF6] hover:bg-[#8B5CF6]/90 text-white font-semibold py-3 px-6 rounded-lg transition-colors">
<i class="fas fa-images mr-2"></i>
<span>Browse Gallery</span>
<i class="fas fa-arrow-right opacity-0 group-hover:opacity-100 group-hover:translate-x-2 ml-2 transition-all duration-300"></i>
</a>
</div>
</div>
</div>
<!-- Additional Information -->
<div class="bg-[#1E293B]/80 border border-[#1E293B] rounded-xl p-8 shadow-lg backdrop-blur-sm">
<div class="bg-[#1E293B] border border-[#1E293B] rounded-xl p-8">
<div class="text-center max-w-3xl mx-auto">
<div class="inline-flex items-center justify-center w-16 h-16 rounded-full bg-yellow-500/10 border border-yellow-500/20 mb-4">
<i class="text-yellow-400 text-2xl fa-solid fa-triangle-exclamation"></i>

View File

@@ -11,21 +11,21 @@
<div class="fixed top-20 right-4 z-50 space-y-2" style="max-width: 400px;">
<template x-for="notification in notifications" :key="notification.id">
<div x-show="true"
x-transition:enter="transform ease-out duration-300 transition"
x-transition:enter-start="translate-x-full opacity-0"
x-transition:enter-end="translate-x-0 opacity-100"
x-transition:leave="transform ease-in duration-200 transition"
x-transition:leave-start="translate-x-0 opacity-100"
x-transition:leave-end="translate-x-full opacity-0"
x-transition:enter="transition ease-out duration-200"
x-transition:enter-start="opacity-0"
x-transition:enter-end="opacity-100"
x-transition:leave="transition ease-in duration-150"
x-transition:leave-start="opacity-100"
x-transition:leave-end="opacity-0"
:class="notification.type === 'error' ? 'bg-red-500' : 'bg-green-500'"
class="rounded-lg shadow-xl p-4 text-white flex items-start space-x-3">
class="rounded-lg p-4 text-white flex items-start space-x-3">
<div class="flex-shrink-0">
<i :class="notification.type === 'error' ? 'fas fa-exclamation-circle' : 'fas fa-check-circle'" class="text-xl"></i>
</div>
<div class="flex-1 min-w-0">
<p class="text-sm font-medium break-words" x-text="notification.message"></p>
</div>
<button @click="dismissNotification(notification.id)" class="flex-shrink-0 text-white hover:text-gray-200">
<button @click="dismissNotification(notification.id)" class="flex-shrink-0 text-white hover:opacity-80 transition-opacity">
<i class="fas fa-times"></i>
</button>
</div>
@@ -35,14 +35,8 @@
<div class="container mx-auto px-4 py-8 flex-grow">
<!-- Hero Header -->
<div class="relative bg-[#1E293B] border border-[#8B5CF6]/20 rounded-3xl shadow-2xl shadow-[#8B5CF6]/10 p-8 mb-12 overflow-hidden">
<!-- Background Pattern -->
<div class="absolute inset-0 opacity-10">
<div class="absolute inset-0 bg-gradient-to-r from-[#8B5CF6]/20 to-[#38BDF8]/20"></div>
<div class="absolute top-0 left-0 w-full h-full" style="background-image: radial-gradient(circle at 1px 1px, rgba(139,92,246,0.15) 1px, transparent 0); background-size: 20px 20px;"></div>
</div>
<div class="relative max-w-5xl mx-auto text-center">
<div class="bg-[#1E293B] border border-[#8B5CF6]/20 rounded-xl p-8 mb-12">
<div class="max-w-5xl mx-auto text-center">
<h1 class="text-4xl md:text-5xl font-bold text-[#E5E7EB] mb-4">
<span class="bg-clip-text text-transparent bg-gradient-to-r from-[#8B5CF6] via-[#38BDF8] to-[#8B5CF6]">
Backend Management
@@ -52,13 +46,18 @@
Discover and install AI backends to power your models
</p>
<div class="flex flex-wrap justify-center items-center gap-6 text-sm md:text-base">
<div class="flex items-center bg-white/10 rounded-full px-4 py-2">
<div class="w-2 h-2 bg-emerald-400 rounded-full mr-2 animate-pulse"></div>
<div class="flex items-center bg-[#101827] rounded-lg px-4 py-2">
<div class="w-2 h-2 bg-emerald-400 rounded-full mr-2"></div>
<span class="font-semibold text-emerald-300" x-text="availableBackends"></span>
<span class="text-gray-300 ml-1">backends available</span>
<span class="text-[#94A3B8] ml-1">backends available</span>
</div>
<a href="/manage" class="flex items-center bg-[#101827] hover:bg-[#1E293B] rounded-lg px-4 py-2 transition-colors border border-[#8B5CF6]/30 hover:border-[#8B5CF6]/50">
<div class="w-2 h-2 bg-cyan-400 rounded-full mr-2"></div>
<span class="font-semibold text-cyan-300" x-text="installedBackends"></span>
<span class="text-[#94A3B8] ml-1">installed</span>
</a>
<a href="https://localai.io/backends/" target="_blank"
class="flex items-center bg-cyan-600/80 hover:bg-cyan-600 text-white px-4 py-2 rounded-full transition-all duration-300 hover:scale-105">
class="inline-flex items-center bg-cyan-600 hover:bg-cyan-700 text-white px-4 py-2 rounded-lg transition-colors">
<i class="fas fa-info-circle mr-2"></i>
<span>Documentation</span>
<i class="fas fa-external-link-alt ml-2 text-xs"></i>
@@ -70,28 +69,26 @@
{{template "views/partials/inprogress" .}}
<!-- Search and Filter Section -->
<div class="relative bg-gradient-to-br from-gray-800/80 to-gray-900/80 rounded-2xl p-8 mb-8 shadow-xl border border-gray-700/50 backdrop-blur-sm">
<div class="absolute inset-0 rounded-2xl bg-gradient-to-br from-emerald-500/5 to-cyan-500/5"></div>
<div class="relative">
<div class="bg-[#1E293B] border border-[#8B5CF6]/20 rounded-xl p-8 mb-8">
<div>
<!-- Search Input -->
<div class="mb-8">
<h3 class="text-xl font-semibold text-white mb-4 flex items-center">
<i class="fas fa-search mr-3 text-emerald-400"></i>
<h3 class="text-xl font-semibold text-[#E5E7EB] mb-4 flex items-center">
<i class="fas fa-search mr-3 text-[#8B5CF6]"></i>
Find Backend Components
</h3>
<div class="relative">
<div class="absolute inset-y-0 start-0 flex items-center ps-4 pointer-events-none">
<i class="fas fa-search text-gray-400"></i>
<i class="fas fa-search text-[#94A3B8]"></i>
</div>
<input
x-model="searchTerm"
@input.debounce.500ms="fetchBackends()"
class="w-full pl-12 pr-16 py-4 text-base font-normal text-gray-300 bg-gray-900/90 border border-gray-700/70 rounded-xl transition-all duration-300 focus:text-gray-200 focus:bg-gray-900 focus:border-emerald-500 focus:ring-2 focus:ring-emerald-500/50 focus:outline-none"
class="w-full pl-12 pr-16 py-4 text-base font-normal text-[#E5E7EB] bg-[#101827] border border-[#1E293B] rounded-lg transition-colors focus:text-[#E5E7EB] focus:bg-[#101827] focus:border-[#8B5CF6] focus:ring-2 focus:ring-[#8B5CF6]/50 focus:outline-none"
type="search"
placeholder="Search backends by name, description or type...">
<span class="absolute right-4 top-4" x-show="loading">
<svg class="animate-spin h-6 w-6 text-emerald-500" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24">
<svg class="animate-spin h-6 w-6 text-[#8B5CF6]" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24">
<circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle>
<path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path>
</svg>
@@ -107,28 +104,28 @@
</h3>
<div class="grid grid-cols-2 sm:grid-cols-3 lg:grid-cols-5 gap-3">
<button @click="filterByTerm('llm')"
class="group flex items-center justify-center rounded-xl px-4 py-3 text-sm font-semibold bg-gradient-to-r from-indigo-600/80 to-indigo-700/80 hover:from-indigo-600 hover:to-indigo-700 text-indigo-100 border border-indigo-500/30 hover:border-indigo-400/50 transition-all duration-300 transform hover:scale-105 hover:shadow-lg hover:shadow-indigo-500/25">
<i class="fas fa-brain mr-2 group-hover:animate-pulse"></i>
class="flex items-center justify-center rounded-lg px-4 py-3 text-sm font-semibold bg-indigo-600/20 hover:bg-indigo-600/30 text-indigo-300 border border-indigo-500/30 transition-colors">
<i class="fas fa-brain mr-2"></i>
<span>LLM</span>
</button>
<button @click="filterByTerm('diffusion')"
class="group flex items-center justify-center rounded-xl px-4 py-3 text-sm font-semibold bg-gradient-to-r from-purple-600/80 to-purple-700/80 hover:from-purple-600 hover:to-purple-700 text-purple-100 border border-purple-500/30 hover:border-purple-400/50 transition-all duration-300 transform hover:scale-105 hover:shadow-lg hover:shadow-purple-500/25">
<i class="fas fa-image mr-2 group-hover:animate-pulse"></i>
class="flex items-center justify-center rounded-lg px-4 py-3 text-sm font-semibold bg-purple-600/20 hover:bg-purple-600/30 text-purple-300 border border-purple-500/30 transition-colors">
<i class="fas fa-image mr-2"></i>
<span>Diffusion</span>
</button>
<button @click="filterByTerm('tts')"
class="group flex items-center justify-center rounded-xl px-4 py-3 text-sm font-semibold bg-gradient-to-r from-blue-600/80 to-blue-700/80 hover:from-blue-600 hover:to-blue-700 text-blue-100 border border-blue-500/30 hover:border-blue-400/50 transition-all duration-300 transform hover:scale-105 hover:shadow-lg hover:shadow-blue-500/25">
<i class="fas fa-microphone mr-2 group-hover:animate-pulse"></i>
class="flex items-center justify-center rounded-lg px-4 py-3 text-sm font-semibold bg-blue-600/20 hover:bg-blue-600/30 text-blue-300 border border-blue-500/30 transition-colors">
<i class="fas fa-microphone mr-2"></i>
<span>TTS</span>
</button>
<button @click="filterByTerm('whisper')"
class="group flex items-center justify-center rounded-xl px-4 py-3 text-sm font-semibold bg-gradient-to-r from-green-600/80 to-green-700/80 hover:from-green-600 hover:to-green-700 text-green-100 border border-green-500/30 hover:border-green-400/50 transition-all duration-300 transform hover:scale-105 hover:shadow-lg hover:shadow-green-500/25">
<i class="fas fa-headphones mr-2 group-hover:animate-pulse"></i>
class="flex items-center justify-center rounded-lg px-4 py-3 text-sm font-semibold bg-green-600/20 hover:bg-green-600/30 text-green-300 border border-green-500/30 transition-colors">
<i class="fas fa-headphones mr-2"></i>
<span>Whisper</span>
</button>
<button @click="filterByTerm('object-detection')"
class="group flex items-center justify-center rounded-xl px-4 py-3 text-sm font-semibold bg-gradient-to-r from-red-600/80 to-red-700/80 hover:from-red-600 hover:to-red-700 text-red-100 border border-red-500/30 hover:border-red-400/50 transition-all duration-300 transform hover:scale-105 hover:shadow-lg hover:shadow-red-500/25">
<i class="fas fa-eye mr-2 group-hover:animate-pulse"></i>
class="flex items-center justify-center rounded-lg px-4 py-3 text-sm font-semibold bg-red-600/20 hover:bg-red-600/30 text-red-300 border border-red-500/30 transition-colors">
<i class="fas fa-eye mr-2"></i>
<span>Vision</span>
</button>
</div>
@@ -171,10 +168,14 @@
<tr class="hover:bg-[#38BDF8]/10 transition-colors duration-200">
<!-- Icon -->
<td class="px-6 py-4">
<img :src="backend.icon || 'https://upload.wikimedia.org/wikipedia/commons/6/65/No-Image-Placeholder.svg'"
class="w-12 h-12 object-cover rounded-lg border border-[#38BDF8]/30"
loading="lazy"
:alt="backend.name">
<div class="w-12 h-12 rounded-lg border border-[#38BDF8]/30 flex items-center justify-center bg-[#101827]">
<img x-show="backend.icon"
:src="backend.icon"
class="w-full h-full object-cover rounded-lg"
loading="lazy"
:alt="backend.name">
<i x-show="!backend.icon" class="fas fa-cog text-xl text-[#8B5CF6]"></i>
</div>
</td>
<!-- Backend Name -->
@@ -301,9 +302,13 @@
<!-- Modal Body -->
<div class="p-4 md:p-5 space-y-4 overflow-y-auto flex-1 min-h-0">
<div class="flex justify-center items-center">
<img :src="selectedBackend?.icon || 'https://upload.wikimedia.org/wikipedia/commons/6/65/No-Image-Placeholder.svg'"
class="rounded-t-lg max-h-48 max-w-96 object-cover mt-3"
loading="lazy">
<div class="w-48 h-48 rounded-lg border border-gray-300 dark:border-gray-600 flex items-center justify-center bg-gray-100 dark:bg-gray-800 mt-3">
<img x-show="selectedBackend?.icon"
:src="selectedBackend?.icon"
class="rounded-lg max-h-48 max-w-96 object-cover"
loading="lazy">
<i x-show="!selectedBackend?.icon" class="fas fa-cog text-6xl text-gray-400 dark:text-gray-500"></i>
</div>
</div>
<div class="text-base leading-relaxed text-gray-500 dark:text-gray-400 break-words max-w-full markdown-content" x-html="renderMarkdown(selectedBackend?.description)"></div>
<template x-if="selectedBackend?.tags && selectedBackend.tags.length > 0">
@@ -353,8 +358,8 @@
<button @click="goToPage(currentPage - 1)"
:disabled="currentPage <= 1"
:class="currentPage <= 1 ? 'opacity-50 cursor-not-allowed' : ''"
class="group flex items-center justify-center h-12 w-12 bg-gray-700/80 hover:bg-emerald-600 text-gray-300 hover:text-white rounded-xl shadow-lg transition-all duration-300 ease-in-out transform hover:scale-110">
<i class="fas fa-chevron-left group-hover:animate-pulse"></i>
class="flex items-center justify-center h-12 w-12 bg-[#1E293B] hover:bg-emerald-600 text-[#94A3B8] hover:text-white rounded-lg transition-colors">
<i class="fas fa-chevron-left"></i>
</button>
<div class="text-gray-300 text-sm font-medium px-4">
<span class="text-gray-400">Page</span>
@@ -488,6 +493,7 @@ function backendsGallery() {
currentPage: 1,
totalPages: 1,
availableBackends: 0,
installedBackends: 0,
selectedBackend: null,
jobProgress: {},
notifications: [],
@@ -526,6 +532,7 @@ function backendsGallery() {
this.currentPage = data.currentPage || 1;
this.totalPages = data.totalPages || 1;
this.availableBackends = data.availableBackends || 0;
this.installedBackends = data.installedBackends || 0;
} catch (error) {
console.error('Error fetching backends:', error);
} finally {

View File

@@ -28,10 +28,10 @@ SOFTWARE.
<!doctype html>
<html lang="en">
{{template "views/partials/head" .}}
<script src="https://cdnjs.cloudflare.com/ajax/libs/pdf.js/3.11.174/pdf.min.js"></script>
<script src="static/assets/pdf.min.js"></script>
<script>
// Initialize PDF.js worker
pdfjsLib.GlobalWorkerOptions.workerSrc = 'https://cdnjs.cloudflare.com/ajax/libs/pdf.js/3.11.174/pdf.worker.min.js';
pdfjsLib.GlobalWorkerOptions.workerSrc = 'static/assets/pdf.worker.min.js';
</script>
<script>
// Initialize Alpine store - must run before Alpine processes DOM
@@ -111,14 +111,36 @@ SOFTWARE.
},
add(role, content, image, audio) {
const N = this.history.length - 1;
// For thinking messages, always create a new message
if (role === "thinking") {
// For thinking, reasoning, tool_call, and tool_result messages, always create a new message
if (role === "thinking" || role === "reasoning" || role === "tool_call" || role === "tool_result") {
let c = "";
const lines = content.split("\n");
lines.forEach((line) => {
c += DOMPurify.sanitize(marked.parse(line));
});
this.history.push({ role, content, html: c, image, audio });
if (role === "tool_call" || role === "tool_result") {
// For tool calls and results, try to parse as JSON and format nicely
try {
const parsed = typeof content === 'string' ? JSON.parse(content) : content;
// Format JSON with proper indentation
const formatted = JSON.stringify(parsed, null, 2);
c = DOMPurify.sanitize('<pre><code class="language-json">' + formatted + '</code></pre>');
} catch (e) {
// If not JSON, treat as markdown
const lines = content.split("\n");
lines.forEach((line) => {
c += DOMPurify.sanitize(marked.parse(line));
});
}
} else {
// For thinking and reasoning, format as markdown
const lines = content.split("\n");
lines.forEach((line) => {
c += DOMPurify.sanitize(marked.parse(line));
});
}
// Set expanded state: thinking is expanded by default in non-MCP mode, collapsed in MCP mode
// Reasoning, tool_call, and tool_result are always collapsed by default
const isMCPMode = this.mcpMode || false;
const shouldExpand = (role === "thinking" && !isMCPMode) || false;
this.history.push({ role, content, html: c, image, audio, expanded: shouldExpand });
}
// For other messages, merge if same role
else if (this.history.length && this.history[N].role === role) {
@@ -147,7 +169,22 @@ SOFTWARE.
audio: audio || []
});
}
document.getElementById('messages').scrollIntoView(false);
// Scroll to bottom consistently for all messages (use #chat as it's the scrollable container)
setTimeout(() => {
const chatContainer = document.getElementById('chat');
if (chatContainer) {
chatContainer.scrollTo({
top: chatContainer.scrollHeight,
behavior: 'smooth'
});
}
// Also scroll thinking box if it's a thinking/reasoning message
if (role === "thinking" || role === "reasoning") {
if (typeof window.scrollThinkingBoxToBottom === 'function') {
window.scrollThinkingBoxToBottom();
}
}
}, 100);
const parser = new DOMParser();
const html = parser.parseFromString(
this.history[this.history.length - 1].html,
@@ -160,9 +197,33 @@ SOFTWARE.
if (this.languages.includes(language)) return;
const script = document.createElement("script");
script.src = `https://cdn.jsdelivr.net/gh/highlightjs/cdn-release@11.8.0/build/languages/${language}.min.js`;
script.onload = () => {
// Re-highlight after language script loads
if (window.hljs) {
const container = document.getElementById('messages');
if (container) {
container.querySelectorAll('pre code.language-json').forEach(block => {
window.hljs.highlightElement(block);
});
}
}
};
document.head.appendChild(script);
this.languages.push(language);
});
// Highlight code blocks immediately if hljs is available
if (window.hljs) {
setTimeout(() => {
const container = document.getElementById('messages');
if (container) {
container.querySelectorAll('pre code.language-json').forEach(block => {
if (!block.classList.contains('hljs')) {
window.hljs.highlightElement(block);
}
});
}
}, 100);
}
},
messages() {
return this.history.map((message) => ({
@@ -484,9 +545,113 @@ SOFTWARE.
<li>To send a text, markdown or PDF file, click the <i class="fa-solid fa-file text-[#38BDF8]"></i> icon.</li>
</ul>
</p>
<div id="messages" class="max-w-3xl mx-auto">
<template x-for="message in history">
<div :class="message.role === 'user' ? 'flex items-start space-x-2 my-2 justify-end' : 'flex items-start space-x-2 my-2'">
<div id="messages" class="max-w-3xl mx-auto space-y-2">
<template x-for="(message, index) in history" :key="index">
<div>
<!-- Reasoning/Thinking messages appear first (before assistant) - collapsible in MCP mode -->
<template x-if="message.role === 'reasoning' || message.role === 'thinking'">
<div class="flex items-start space-x-2 mb-1">
<div class="flex flex-col flex-1">
<div class="p-2 flex-1 rounded-lg bg-[#38BDF8]/10 text-[#94A3B8] border border-[#38BDF8]/30">
<button
@click="message.expanded = !message.expanded"
class="w-full flex items-center justify-between text-left hover:bg-[#38BDF8]/20 rounded p-2 transition-colors"
>
<div class="flex items-center space-x-2">
<i :class="message.role === 'thinking' ? 'fa-solid fa-brain' : 'fa-solid fa-lightbulb'" class="text-[#38BDF8]"></i>
<span class="text-xs font-semibold text-[#38BDF8]" x-text="message.role === 'thinking' ? 'Thinking' : 'Reasoning'"></span>
<span class="text-xs text-[#94A3B8]" x-show="message.content && message.content.length > 0" x-text="'(' + Math.ceil(message.content.length / 100) + ' lines)'"></span>
</div>
<i
class="fa-solid text-[#38BDF8] transition-transform text-xs"
:class="message.expanded ? 'fa-chevron-up' : 'fa-chevron-down'"
></i>
</button>
<div
x-show="message.expanded"
x-transition
class="mt-2 pt-2 border-t border-[#38BDF8]/20"
>
<div
class="text-[#E5E7EB] text-sm max-h-96 overflow-auto"
x-html="message.html"
data-thinking-box
x-effect="if (message.expanded && message.html) { setTimeout(() => { if ($el.scrollHeight > $el.clientHeight) { $el.scrollTo({ top: $el.scrollHeight, behavior: 'smooth' }); } }, 50); }"
></div>
</div>
</div>
</div>
</div>
</template>
<!-- Tool calls (collapsible) -->
<template x-if="message.role === 'tool_call'">
<div class="flex items-start space-x-2 mb-1">
<div class="flex flex-col flex-1">
<div class="p-2 flex-1 rounded-lg bg-[#8B5CF6]/10 text-[#94A3B8] border border-[#8B5CF6]/30">
<button
@click="message.expanded = !message.expanded"
class="w-full flex items-center justify-between text-left hover:bg-[#8B5CF6]/20 rounded p-2 transition-colors"
>
<div class="flex items-center space-x-2">
<i class="fa-solid fa-wrench text-[#8B5CF6]"></i>
<span class="text-xs font-semibold text-[#8B5CF6]">Tool Call</span>
<span class="text-xs text-[#94A3B8]" x-text="getToolName(message.content)"></span>
</div>
<i
class="fa-solid text-[#8B5CF6] transition-transform text-xs"
:class="message.expanded ? 'fa-chevron-up' : 'fa-chevron-down'"
></i>
</button>
<div
x-show="message.expanded"
x-transition
class="mt-2 pt-2 border-t border-[#8B5CF6]/20"
>
<div class="text-[#E5E7EB] text-xs max-h-96 overflow-auto overflow-x-auto tool-call-content"
x-html="message.html"
x-effect="if (message.expanded && window.hljs) { setTimeout(() => { $el.querySelectorAll('pre code.language-json').forEach(block => { if (!block.classList.contains('hljs')) window.hljs.highlightElement(block); }); }, 50); }"></div>
</div>
</div>
</div>
</div>
</template>
<!-- Tool results (collapsible) -->
<template x-if="message.role === 'tool_result'">
<div class="flex items-start space-x-2 mb-1">
<div class="flex flex-col flex-1">
<div class="p-2 flex-1 rounded-lg bg-[#10B981]/10 text-[#94A3B8] border border-[#10B981]/30">
<button
@click="message.expanded = !message.expanded"
class="w-full flex items-center justify-between text-left hover:bg-[#10B981]/20 rounded p-2 transition-colors"
>
<div class="flex items-center space-x-2">
<i class="fa-solid fa-check-circle text-[#10B981]"></i>
<span class="text-xs font-semibold text-[#10B981]">Tool Result</span>
<span class="text-xs text-[#94A3B8]" x-text="getToolName(message.content) || 'Success'"></span>
</div>
<i
class="fa-solid text-[#10B981] transition-transform text-xs"
:class="message.expanded ? 'fa-chevron-up' : 'fa-chevron-down'"
></i>
</button>
<div
x-show="message.expanded"
x-transition
class="mt-2 pt-2 border-t border-[#10B981]/20"
>
<div class="text-[#E5E7EB] text-xs max-h-96 overflow-auto overflow-x-auto tool-result-content"
x-html="formatToolResult(message.content)"
x-effect="if (message.expanded && window.hljs) { setTimeout(() => { $el.querySelectorAll('pre code.language-json').forEach(block => { if (!block.classList.contains('hljs')) window.hljs.highlightElement(block); }); }, 50); }"></div>
</div>
</div>
</div>
</div>
</template>
<!-- User and Assistant messages -->
<div :class="message.role === 'user' ? 'flex items-start space-x-2 justify-end' : 'flex items-start space-x-2'">
{{ if .Model }}
{{ $galleryConfig:= index $allGalleryConfigs .Model}}
<template x-if="message.role === 'user'">
@@ -514,20 +679,7 @@ SOFTWARE.
</div>
</div>
</template>
<template x-if="message.role === 'thinking'">
<div class="flex items-center space-x-2 w-full">
<div class="flex flex-col flex-1">
<div class="p-3 flex-1 rounded-lg bg-[#38BDF8]/10 text-[#94A3B8] border border-[#38BDF8]/30">
<div class="flex items-center space-x-2 mb-2">
<i class="fa-solid fa-brain text-[#38BDF8]"></i>
<span class="text-xs font-semibold text-[#38BDF8]">Thinking</span>
</div>
<div class="mt-1 text-[#E5E7EB]" x-html="message.html"></div>
</div>
</div>
</div>
</template>
<template x-if="message.role != 'user' && message.role != 'thinking'">
<template x-if="message.role != 'user' && message.role != 'thinking' && message.role != 'reasoning' && message.role != 'tool_call' && message.role != 'tool_result'">
<div class="flex items-center space-x-2">
{{ if $galleryConfig }}
{{ if $galleryConfig.Icon }}<img src="{{$galleryConfig.Icon}}" class="rounded-lg mt-2 max-w-8 max-h-8 border border-[#38BDF8]/20">{{end}}
@@ -566,6 +718,7 @@ SOFTWARE.
:class="message.role === 'user' ? 'fa-user text-[#38BDF8]' : 'fa-robot text-[#8B5CF6]'"
></i>
{{ end }}
</div>
</div>
</template>
</div>
@@ -573,8 +726,26 @@ SOFTWARE.
<!-- Chat Input -->
<div class="p-4 border-t border-[#1E293B]" x-data="{ inputValue: '', shiftPressed: false, fileName: '' }">
<div class="p-4 border-t border-[#1E293B]" x-data="{ inputValue: '', shiftPressed: false, attachedFiles: [] }">
<form id="prompt" action="chat/{{.Model}}" method="get" @submit.prevent="submitPrompt" class="max-w-3xl mx-auto">
<!-- Attachment Tags - Show above input when files are attached -->
<div x-show="attachedFiles.length > 0" class="mb-3 flex flex-wrap gap-2 items-center">
<template x-for="(file, index) in attachedFiles" :key="index">
<div class="inline-flex items-center gap-2 px-3 py-1.5 rounded-lg text-sm bg-[#38BDF8]/20 border border-[#38BDF8]/40 text-[#E5E7EB]">
<i :class="file.type === 'image' ? 'fa-solid fa-image' : file.type === 'audio' ? 'fa-solid fa-microphone' : 'fa-solid fa-file'" class="text-[#38BDF8]"></i>
<span x-text="file.name" class="max-w-[200px] truncate"></span>
<button
type="button"
@click="attachedFiles.splice(index, 1); removeFileFromInput(file.type, file.name)"
class="ml-1 text-[#94A3B8] hover:text-[#E5E7EB] transition-colors"
title="Remove attachment"
>
<i class="fa-solid fa-times text-xs"></i>
</button>
</div>
</template>
</div>
<!-- Token Usage and Context Window - Compact above input -->
<div class="mb-3 flex items-center justify-between gap-4 text-xs">
<!-- Token Usage -->
@@ -626,20 +797,19 @@ SOFTWARE.
</template>
</div>
<div class="relative w-full bg-[#1E293B] border border-[#38BDF8]/20 rounded-xl shadow-lg">
<div class="relative w-full bg-[#1E293B] border border-[#38BDF8]/20 rounded-xl shadow-lg focus-within:ring-2 focus-within:ring-[#38BDF8]/50 focus-within:border-[#38BDF8] transition-all duration-200">
<textarea
id="input"
name="input"
x-model="inputValue"
placeholder="Send a message..."
class="p-3 pr-16 w-full bg-[#1E293B] text-[#E5E7EB] placeholder-[#94A3B8] focus:outline-none resize-none border-0 rounded-xl transition-colors duration-200 focus:ring-2 focus:ring-[#38BDF8]/50"
class="p-3 pr-16 w-full bg-[#1E293B] text-[#E5E7EB] placeholder-[#94A3B8] focus:outline-none resize-none border-0 rounded-xl transition-colors duration-200"
required
@keydown.shift="shiftPressed = true"
@keyup.shift="shiftPressed = false"
@keydown.enter.prevent="if (!shiftPressed) { submitPrompt($event); }"
rows="2"
></textarea>
<span x-text="fileName" id="fileName" class="absolute right-16 top-3 text-[#94A3B8] text-xs mr-2"></span>
<button
type="button"
onclick="document.getElementById('input_image').click()"
@@ -692,7 +862,7 @@ SOFTWARE.
multiple
accept="image/*"
style="display: none;"
@change="fileName = $event.target.files.length + ' image(s) selected'"
@change="handleFileSelection($event, 'image')"
/>
<input
id="input_audio"
@@ -700,7 +870,7 @@ SOFTWARE.
multiple
accept="audio/*"
style="display: none;"
@change="fileName = $event.target.files.length + ' audio file(s) selected'"
@change="handleFileSelection($event, 'audio')"
/>
<input
id="input_file"
@@ -708,7 +878,7 @@ SOFTWARE.
multiple
accept=".txt,.md,.pdf"
style="display: none;"
@change="fileName = $event.target.files.length + ' file(s) selected'"
@change="handleFileSelection($event, 'file')"
/>
</div>
</form>
@@ -775,6 +945,83 @@ SOFTWARE.
console.error('Failed to copy: ', err);
});
};
// Format tool result for better display
window.formatToolResult = (content) => {
if (!content) return '';
try {
// Try to parse as JSON
const parsed = JSON.parse(content);
// If it has a 'result' field, try to parse that too
if (parsed.result && typeof parsed.result === 'string') {
try {
const resultParsed = JSON.parse(parsed.result);
parsed.result = resultParsed;
} catch (e) {
// Keep as string if not JSON
}
}
// Format the JSON nicely
const formatted = JSON.stringify(parsed, null, 2);
return DOMPurify.sanitize('<pre class="bg-[#101827] p-3 rounded border border-[#10B981]/20 overflow-x-auto"><code class="language-json">' + formatted + '</code></pre>');
} catch (e) {
// If not JSON, try to format as markdown or plain text
try {
// Check if it's a markdown code block
if (content.includes('```')) {
return DOMPurify.sanitize(marked.parse(content));
}
// Otherwise, try to parse as JSON one more time with error handling
const lines = content.split('\n');
let jsonStart = -1;
let jsonEnd = -1;
for (let i = 0; i < lines.length; i++) {
if (lines[i].trim().startsWith('{') || lines[i].trim().startsWith('[')) {
jsonStart = i;
break;
}
}
if (jsonStart >= 0) {
for (let i = lines.length - 1; i >= jsonStart; i--) {
if (lines[i].trim().endsWith('}') || lines[i].trim().endsWith(']')) {
jsonEnd = i;
break;
}
}
if (jsonEnd >= jsonStart) {
const jsonStr = lines.slice(jsonStart, jsonEnd + 1).join('\n');
try {
const parsed = JSON.parse(jsonStr);
const formatted = JSON.stringify(parsed, null, 2);
return DOMPurify.sanitize('<pre class="bg-[#101827] p-3 rounded border border-[#10B981]/20 overflow-x-auto"><code class="language-json">' + formatted + '</code></pre>');
} catch (e2) {
// Fall through to markdown
}
}
}
// Fall back to markdown
return DOMPurify.sanitize(marked.parse(content));
} catch (e2) {
// Last resort: plain text
return DOMPurify.sanitize('<pre class="bg-[#101827] p-3 rounded border border-[#10B981]/20 overflow-x-auto text-xs">' + content.replace(/</g, '&lt;').replace(/>/g, '&gt;') + '</pre>');
}
}
};
// Get tool name from content
window.getToolName = (content) => {
if (!content || typeof content !== 'string') return '';
try {
const parsed = JSON.parse(content);
return parsed.name || '';
} catch (e) {
// Try to extract name from string
const nameMatch = content.match(/"name"\s*:\s*"([^"]+)"/);
return nameMatch ? nameMatch[1] : '';
}
};
});
// Context size is now initialized in the Alpine store initialization above
@@ -904,6 +1151,76 @@ SOFTWARE.
max-width: 100%;
height: auto;
}
/* Prevent JSON overflow in tool calls and results */
.tool-call-content pre,
.tool-result-content pre {
overflow-x: auto;
overflow-y: auto;
max-width: 100%;
word-wrap: break-word;
white-space: pre-wrap;
background: #101827 !important;
border: 1px solid #1E293B;
border-radius: 6px;
padding: 12px;
margin: 0;
}
.tool-call-content code,
.tool-result-content code {
word-wrap: break-word;
white-space: pre-wrap;
overflow-wrap: break-word;
background: transparent !important;
color: #E5E7EB;
font-family: 'ui-monospace', 'Monaco', 'Consolas', monospace;
font-size: 0.875rem;
line-height: 1.5;
}
/* Dark theme syntax highlighting for JSON */
.tool-call-content .hljs,
.tool-result-content .hljs {
background: #101827 !important;
color: #E5E7EB !important;
}
.tool-call-content .hljs-keyword,
.tool-result-content .hljs-keyword {
color: #8B5CF6 !important;
font-weight: 600;
}
.tool-call-content .hljs-string,
.tool-result-content .hljs-string {
color: #10B981 !important;
}
.tool-call-content .hljs-number,
.tool-result-content .hljs-number {
color: #38BDF8 !important;
}
.tool-call-content .hljs-literal,
.tool-result-content .hljs-literal {
color: #F59E0B !important;
}
.tool-call-content .hljs-punctuation,
.tool-result-content .hljs-punctuation {
color: #94A3B8 !important;
}
.tool-call-content .hljs-property,
.tool-result-content .hljs-property {
color: #38BDF8 !important;
}
.tool-call-content .hljs-attr,
.tool-result-content .hljs-attr {
color: #8B5CF6 !important;
}
</style>
</body>
</html>

View File

@@ -9,9 +9,9 @@
<div class="container mx-auto px-4 py-8 flex-grow">
<!-- Error Section -->
<div class="bg-[#1E293B] border border-red-500/20 rounded-2xl shadow-2xl shadow-red-500/10 p-8 mb-10">
<div class="bg-[#1E293B] border border-red-500/20 rounded-xl p-8 mb-10">
<div class="max-w-4xl mx-auto text-center">
<div class="mb-6 text-6xl text-red-400 animate-pulse">
<div class="mb-6 text-6xl text-red-400">
<i class="fas fa-exclamation-circle"></i>
</div>
<h1 class="text-4xl md:text-5xl font-bold text-[#E5E7EB] mb-4">
@@ -22,23 +22,21 @@
<p class="text-xl text-[#94A3B8] mb-6">{{if .ErrorMessage}}{{.ErrorMessage}}{{else}}An unexpected error occurred{{end}}</p>
<div class="flex flex-wrap justify-center gap-4">
<a href="./"
class="group flex items-center bg-[#38BDF8] hover:bg-[#38BDF8]/90 text-[#101827] font-semibold py-3 px-6 rounded-lg transition duration-300 ease-in-out transform hover:scale-105 hover:shadow-[0_0_20px_rgba(56,189,248,0.4)]">
class="inline-flex items-center bg-[#38BDF8] hover:bg-[#38BDF8]/90 text-[#101827] font-semibold py-3 px-6 rounded-lg transition-colors">
<i class="fas fa-home mr-2"></i>
<span>Return Home</span>
<i class="fas fa-arrow-right opacity-0 group-hover:opacity-100 group-hover:translate-x-2 ml-2 transition-all duration-300"></i>
</a>
<a href="browse"
class="group flex items-center bg-[#8B5CF6] hover:bg-[#8B5CF6]/90 text-white font-semibold py-3 px-6 rounded-lg transition duration-300 ease-in-out transform hover:scale-105 hover:shadow-[0_0_20px_rgba(139,92,246,0.4)]">
class="inline-flex items-center bg-[#8B5CF6] hover:bg-[#8B5CF6]/90 text-white font-semibold py-3 px-6 rounded-lg transition-colors">
<i class="fas fa-images mr-2"></i>
<span>Browse Gallery</span>
<i class="fas fa-arrow-right opacity-0 group-hover:opacity-100 group-hover:translate-x-2 ml-2 transition-all duration-300"></i>
</a>
</div>
</div>
</div>
<!-- Additional Information -->
<div class="bg-[#1E293B]/80 border border-[#1E293B] rounded-xl p-8 shadow-lg backdrop-blur-sm">
<div class="bg-[#1E293B] border border-[#1E293B] rounded-xl p-8">
<div class="text-center max-w-3xl mx-auto">
<div class="inline-flex items-center justify-center w-16 h-16 rounded-full bg-yellow-500/10 border border-yellow-500/20 mb-4">
<i class="text-yellow-400 text-2xl fa-solid fa-triangle-exclamation"></i>

View File

@@ -23,12 +23,10 @@
padding: 20px;
border-radius: 8px;
margin-bottom: 20px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
transition: transform 0.3s ease, box-shadow 0.3s ease;
transition: background-color 0.2s ease;
}
.network-card:hover {
transform: translateY(-5px);
box-shadow: 0 6px 10px rgba(0, 0, 0, 0.15);
background-color: #374151;
}
.network-title {
font-size: 24px;

View File

File diff suppressed because it is too large Load Diff

View File

@@ -9,7 +9,7 @@
<div class="container mx-auto px-4 py-8 flex-grow flex items-center justify-center">
<!-- Auth Card -->
<div class="max-w-md w-full bg-[#1E293B] border border-[#38BDF8]/20 rounded-xl overflow-hidden shadow-2xl shadow-[#38BDF8]/10">
<div class="max-w-md w-full bg-[#1E293B] border border-[#38BDF8]/20 rounded-xl overflow-hidden">
<div class="animation-container">
<div class="text-overlay">
<img src="static/logo.png" alt="LocalAI Logo" class="h-32 drop-shadow-[0_0_15px_rgba(56,189,248,0.3)]">
@@ -47,11 +47,10 @@
<div>
<button
type="submit"
class="group w-full flex items-center justify-center bg-[#38BDF8] hover:bg-[#38BDF8]/90 text-[#101827] font-semibold py-3 px-6 rounded-lg transition duration-300 ease-in-out transform hover:scale-[1.02] hover:shadow-[0_0_20px_rgba(56,189,248,0.4)]"
class="w-full flex items-center justify-center bg-[#38BDF8] hover:bg-[#38BDF8]/90 text-[#101827] font-semibold py-3 px-6 rounded-lg transition-colors"
>
<i class="fas fa-sign-in-alt mr-2"></i>
<span>Login</span>
<i class="fas fa-arrow-right opacity-0 group-hover:opacity-100 group-hover:translate-x-2 ml-2 transition-all duration-300"></i>
</button>
</div>
</form>

569
core/http/views/manage.html Normal file
View File

@@ -0,0 +1,569 @@
<!DOCTYPE html>
<html lang="en">
{{template "views/partials/head" .}}
<body class="bg-[#101827] text-[#E5E7EB]">
<div class="flex flex-col min-h-screen" x-data="indexDashboard()">
{{template "views/partials/navbar" .}}
<!-- Notifications -->
<div class="fixed top-20 right-4 z-50 space-y-2" style="max-width: 400px;">
<template x-for="notification in notifications" :key="notification.id">
<div x-show="true"
x-transition:enter="transition ease-out duration-200"
x-transition:enter-start="opacity-0"
x-transition:enter-end="opacity-100"
x-transition:leave="transition ease-in duration-150"
x-transition:leave-start="opacity-100"
x-transition:leave-end="opacity-0"
:class="notification.type === 'error' ? 'bg-red-500' : 'bg-green-500'"
class="rounded-lg p-4 text-white flex items-start space-x-3">
<div class="flex-shrink-0">
<i :class="notification.type === 'error' ? 'fas fa-exclamation-circle' : 'fas fa-check-circle'" class="text-xl"></i>
</div>
<div class="flex-1 min-w-0">
<p class="text-sm font-medium break-words" x-text="notification.message"></p>
</div>
<button @click="dismissNotification(notification.id)" class="flex-shrink-0 text-white hover:opacity-80 transition-opacity">
<i class="fas fa-times"></i>
</button>
</div>
</template>
</div>
<div class="container mx-auto px-4 py-6 flex-grow">
<!-- Header -->
<div class="mb-6">
<h1 class="text-2xl font-semibold text-[#E5E7EB] mb-1">
Model & Backend Management
</h1>
<p class="text-sm text-[#94A3B8]">Manage your installed models and backends</p>
</div>
<!-- Quick Actions -->
<div class="flex flex-wrap gap-2 mb-6">
<a href="browse"
class="inline-flex items-center bg-[#8B5CF6] hover:bg-[#8B5CF6]/90 text-white py-1.5 px-3 rounded text-xs font-medium transition-colors">
<i class="fas fa-images mr-1.5 text-[10px]"></i>
<span>Model Gallery</span>
</a>
<a href="/import-model"
class="inline-flex items-center bg-green-600 hover:bg-green-700 text-white py-1.5 px-3 rounded text-xs font-medium transition-colors">
<i class="fas fa-plus mr-1.5 text-[10px]"></i>
<span>Import Model</span>
</a>
<button id="reload-models-btn"
class="inline-flex items-center bg-orange-600 hover:bg-orange-700 text-white py-1.5 px-3 rounded text-xs font-medium transition-colors">
<i class="fas fa-sync-alt mr-1.5 text-[10px]"></i>
<span>Update Models</span>
</button>
<a href="/browse/backends"
class="inline-flex items-center bg-[#1E293B] hover:bg-[#1E293B]/80 border border-[#8B5CF6]/20 text-[#E5E7EB] py-1.5 px-3 rounded text-xs font-medium transition-colors">
<i class="fas fa-cogs mr-1.5 text-[10px]"></i>
<span>Backend Gallery</span>
</a>
</div>
<!-- Models Section -->
<div class="models mt-8">
{{template "views/partials/inprogress" .}}
{{ if eq (len .ModelsConfig) 0 }}
<!-- No Models State -->
<div class="bg-[#1E293B] border border-[#38BDF8]/20 rounded-lg p-8">
<div class="text-center max-w-4xl mx-auto">
<div class="inline-flex items-center justify-center w-12 h-12 rounded-full bg-yellow-500/10 border border-yellow-500/20 mb-4">
<i class="text-yellow-400 text-xl fas fa-robot"></i>
</div>
<h2 class="text-2xl font-bold text-[#E5E7EB] mb-2">No models installed yet</h2>
<p class="text-sm text-[#94A3B8] mb-6">Get started by installing a model from the gallery or importing it</p>
<div class="flex flex-wrap justify-center gap-2 mb-6">
<a href="browse" class="inline-flex items-center bg-[#38BDF8] hover:bg-[#38BDF8]/90 text-[#101827] py-1.5 px-3 rounded text-xs font-medium transition-colors">
<i class="fas fa-images mr-1.5 text-[10px]"></i>
Browse Model Gallery
</a>
<a href="/import-model" class="inline-flex items-center bg-green-600 hover:bg-green-700 text-white py-1.5 px-3 rounded text-xs font-medium transition-colors">
<i class="fas fa-upload mr-1.5 text-[10px]"></i>
Import Model
</a>
<a href="https://localai.io/basics/getting_started/" target="_blank" class="inline-flex items-center bg-[#1E293B] hover:bg-[#1E293B]/80 border border-[#38BDF8]/20 text-[#E5E7EB] py-1.5 px-3 rounded text-xs font-medium transition-colors">
<i class="fas fa-book mr-1.5 text-[10px]"></i>
Documentation
</a>
</div>
{{ if ne (len .Models) 0 }}
<div class="mt-8 pt-6 border-t border-[#38BDF8]/20">
<h3 class="text-lg font-semibold text-[#E5E7EB] mb-2 flex items-center">
<i class="fas fa-file-alt mr-2 text-[#38BDF8] text-sm"></i>
Detected Model Files
</h3>
<p class="text-xs text-[#94A3B8] mb-4">These models were found but don't have configuration files yet</p>
<div class="flex flex-wrap gap-2 justify-center">
{{ range .Models }}
<div class="bg-[#101827] border border-[#38BDF8]/20 rounded px-2 py-1 flex items-center gap-2">
<i class="fas fa-brain text-xs text-[#38BDF8]"></i>
<span class="text-xs text-[#E5E7EB] font-medium">{{.}}</span>
</div>
{{end}}
</div>
</div>
{{end}}
</div>
</div>
{{ else }}
<!-- Models Table -->
{{ $modelsN := len .ModelsConfig}}
{{ $modelsN = add $modelsN (len .Models)}}
<div class="mb-6">
<h2 class="text-2xl font-semibold text-[#E5E7EB] mb-1 flex items-center">
<i class="fas fa-brain mr-2 text-[#38BDF8] text-sm"></i>
Installed Models
</h2>
<p class="text-sm text-[#94A3B8] mb-4">
<span class="text-[#38BDF8] font-medium">{{$modelsN}}</span> model{{if gt $modelsN 1}}s{{end}} ready to use
</p>
</div>
<div class="overflow-x-auto mb-8">
<table class="w-full border-collapse">
<thead>
<tr class="border-b border-[#1E293B]">
<th class="text-left p-2 text-xs font-semibold text-[#94A3B8]">Name</th>
<th class="text-left p-2 text-xs font-semibold text-[#94A3B8]">Status</th>
<th class="text-left p-2 text-xs font-semibold text-[#94A3B8]">Backend</th>
<th class="text-left p-2 text-xs font-semibold text-[#94A3B8]">Use Cases</th>
<th class="text-right p-2 text-xs font-semibold text-[#94A3B8]">Actions</th>
</tr>
</thead>
<tbody>
{{$galleryConfig:=.GalleryConfig}}
{{ $loadedModels := .LoadedModels }}
{{ range .ModelsConfig }}
{{ $backendCfg := . }}
{{ $cfg:= index $galleryConfig .Name}}
<tr class="hover:bg-[#1E293B]/50 border-b border-[#1E293B] transition-colors">
<!-- Name Column -->
<td class="p-2">
<div class="flex items-center gap-2">
<div class="relative flex-shrink-0">
{{ if and $cfg $cfg.Icon }}
<img src="{{$cfg.Icon}}" class="w-4 h-4 object-contain" alt="{{.Name}} icon">
{{ else }}
<i class="fas fa-brain text-xs text-[#38BDF8]"></i>
{{ end }}
{{ if index $loadedModels .Name }}
<div class="absolute -top-0.5 -right-0.5 w-2 h-2 bg-green-500 rounded-full border border-[#1E293B]"></div>
{{ end }}
</div>
<span class="text-xs text-[#E5E7EB] font-medium truncate">{{.Name}}</span>
<a href="/models/edit/{{.Name}}"
class="text-[#38BDF8]/60 hover:text-[#38BDF8] hover:bg-[#38BDF8]/10 rounded p-0.5 transition-colors ml-1 flex-shrink-0"
title="Edit {{.Name}}">
<i class="fas fa-edit text-[10px]"></i>
</a>
</div>
</td>
<!-- Status Column -->
<td class="p-2">
<div class="flex flex-wrap gap-1">
{{ if index $loadedModels .Name }}
<span class="inline-flex items-center px-1.5 py-0.5 rounded text-[10px] font-medium bg-green-500/10 text-green-300">
<i class="fas fa-circle text-[8px] mr-1"></i>Running
</span>
{{ end }}
{{ if and $backendCfg (or (ne $backendCfg.MCP.Servers "") (ne $backendCfg.MCP.Stdio "")) }}
<span class="inline-flex items-center px-1.5 py-0.5 rounded text-[10px] font-medium bg-[#8B5CF6]/10 text-[#8B5CF6]">
<i class="fas fa-plug text-[8px] mr-1"></i>MCP
</span>
{{ end }}
</div>
</td>
<!-- Backend Column -->
<td class="p-2">
{{ if .Backend }}
<span class="inline-flex items-center px-1.5 py-0.5 rounded text-[10px] font-medium bg-[#38BDF8]/10 text-[#38BDF8]">
<i class="fas fa-cog text-[8px] mr-1"></i>{{.Backend}}
</span>
{{ else }}
<span class="inline-flex items-center px-1.5 py-0.5 rounded text-[10px] font-medium bg-yellow-500/10 text-yellow-300">
<i class="fas fa-magic text-[8px] mr-1"></i>Auto
</span>
{{ end }}
</td>
<!-- Use Cases Column -->
<td class="p-2">
<div class="flex flex-wrap gap-1">
{{ range .KnownUsecaseStrings }}
{{ if eq . "FLAG_CHAT" }}
<a href="chat/{{$backendCfg.Name}}" class="inline-flex items-center px-1.5 py-0.5 rounded text-[10px] font-medium bg-[#38BDF8]/10 text-[#38BDF8] hover:bg-[#38BDF8]/20 transition-colors" title="Chat">
<i class="fas fa-comment-alt text-[8px] mr-1"></i>Chat
</a>
{{ end }}
{{ if eq . "FLAG_IMAGE" }}
<a href="text2image/{{$backendCfg.Name}}" class="inline-flex items-center px-1.5 py-0.5 rounded text-[10px] font-medium bg-green-500/10 text-green-300 hover:bg-green-500/20 transition-colors" title="Image">
<i class="fas fa-image text-[8px] mr-1"></i>Image
</a>
{{ end }}
{{ if eq . "FLAG_TTS" }}
<a href="tts/{{$backendCfg.Name}}" class="inline-flex items-center px-1.5 py-0.5 rounded text-[10px] font-medium bg-[#8B5CF6]/10 text-[#8B5CF6] hover:bg-[#8B5CF6]/20 transition-colors" title="TTS">
<i class="fas fa-microphone text-[8px] mr-1"></i>TTS
</a>
{{ end }}
{{ end }}
</div>
</td>
<!-- Actions Column -->
<td class="p-2">
<div class="flex items-center justify-end gap-1">
{{ if index $loadedModels .Name }}
<button class="text-red-400/60 hover:text-red-400 hover:bg-red-500/10 rounded p-1 transition-colors"
onclick="handleStopModel('{{.Name}}')"
title="Stop {{.Name}}">
<i class="fas fa-stop text-xs"></i>
</button>
{{ end }}
<button class="text-red-400/60 hover:text-red-400 hover:bg-red-500/10 rounded p-1 transition-colors"
onclick="handleDeleteModel('{{.Name}}')"
title="Delete {{.Name}}">
<i class="fas fa-trash-alt text-xs"></i>
</button>
</div>
</td>
</tr>
{{ end }}
<!-- Models without config -->
{{ range .Models }}
<tr class="hover:bg-[#1E293B]/50 border-b border-[#1E293B] transition-colors">
<td class="p-2">
<div class="flex items-center gap-2">
<i class="fas fa-brain text-xs text-[#94A3B8]"></i>
<span class="text-xs text-[#E5E7EB] font-medium truncate">{{.}}</span>
</div>
</td>
<td class="p-2">
<span class="inline-flex items-center px-1.5 py-0.5 rounded text-[10px] font-medium bg-orange-500/10 text-orange-300">
<i class="fas fa-exclamation-triangle text-[8px] mr-1"></i>No Config
</span>
</td>
<td class="p-2">
<span class="inline-flex items-center px-1.5 py-0.5 rounded text-[10px] font-medium bg-yellow-500/10 text-yellow-300">
<i class="fas fa-magic text-[8px] mr-1"></i>Auto
</span>
</td>
<td class="p-2">
<span class="text-xs text-[#94A3B8]"></span>
</td>
<td class="p-2">
<span class="text-xs text-[#94A3B8]"></span>
</td>
</tr>
{{end}}
</tbody>
</table>
</div>
{{ end }}
</div>
<!-- Backends Section -->
<div class="mt-8">
<div class="mb-6">
<h2 class="text-2xl font-semibold text-[#E5E7EB] mb-1 flex items-center">
<i class="fas fa-cogs mr-2 text-[#8B5CF6] text-sm"></i>
Installed Backends
</h2>
<p class="text-sm text-[#94A3B8] mb-4">
<span class="text-[#8B5CF6] font-medium">{{len .InstalledBackends}}</span> backend{{if gt (len .InstalledBackends) 1}}s{{end}} ready to use
</p>
</div>
{{ if eq (len .InstalledBackends) 0 }}
<!-- No backends state -->
<div class="bg-[#1E293B] border border-[#8B5CF6]/20 rounded-lg p-8">
<div class="text-center max-w-4xl mx-auto">
<div class="inline-flex items-center justify-center w-12 h-12 rounded-full bg-[#8B5CF6]/10 border border-[#8B5CF6]/20 mb-4">
<i class="text-[#8B5CF6] text-xl fas fa-cogs"></i>
</div>
<h2 class="text-2xl font-bold text-[#E5E7EB] mb-2">No backends installed yet</h2>
<p class="text-sm text-[#94A3B8] mb-6">Backends power your AI models. Install them from the backend gallery to get started</p>
<div class="flex flex-wrap justify-center gap-3">
<a href="/browse/backends" class="inline-flex items-center bg-[#8B5CF6] hover:bg-[#8B5CF6]/90 text-white py-2 px-4 rounded-lg text-sm font-medium transition-colors">
<i class="fas fa-cogs mr-2 text-xs"></i>
Browse Backend Gallery
</a>
<a href="https://localai.io/backends/" target="_blank" class="inline-flex items-center bg-[#1E293B] hover:bg-[#1E293B]/80 border border-[#8B5CF6]/20 text-[#E5E7EB] py-2 px-4 rounded-lg text-sm font-medium transition-colors">
<i class="fas fa-book mr-2 text-xs"></i>
Documentation
</a>
</div>
</div>
</div>
{{ else }}
<!-- Backends Table -->
<div class="overflow-x-auto mb-8">
<table class="w-full border-collapse">
<thead>
<tr class="border-b border-[#1E293B]">
<th class="text-left p-2 text-xs font-semibold text-[#94A3B8]">Name</th>
<th class="text-left p-2 text-xs font-semibold text-[#94A3B8]">Type</th>
<th class="text-left p-2 text-xs font-semibold text-[#94A3B8]">Metadata</th>
<th class="text-right p-2 text-xs font-semibold text-[#94A3B8]">Actions</th>
</tr>
</thead>
<tbody>
{{ range .InstalledBackends }}
<tr class="hover:bg-[#1E293B]/50 border-b border-[#1E293B] transition-colors">
<!-- Name Column -->
<td class="p-2">
<div class="flex items-center gap-2">
<i class="fas fa-cog text-xs text-[#8B5CF6]"></i>
<span class="text-xs text-[#E5E7EB] font-medium truncate">{{.Name}}</span>
</div>
</td>
<!-- Type Column -->
<td class="p-2">
<div class="flex flex-wrap gap-1">
{{ if .IsSystem }}
<span class="inline-flex items-center px-1.5 py-0.5 rounded text-[10px] font-medium bg-blue-500/10 text-blue-300">
<i class="fas fa-shield-alt text-[8px] mr-1"></i>System
</span>
{{ else }}
<span class="inline-flex items-center px-1.5 py-0.5 rounded text-[10px] font-medium bg-green-500/10 text-green-300">
<i class="fas fa-download text-[8px] mr-1"></i>User
</span>
{{ end }}
{{ if .IsMeta }}
<span class="inline-flex items-center px-1.5 py-0.5 rounded text-[10px] font-medium bg-[#8B5CF6]/10 text-[#8B5CF6]">
<i class="fas fa-layer-group text-[8px] mr-1"></i>Meta
</span>
{{ end }}
</div>
</td>
<!-- Metadata Column -->
<td class="p-2">
<div class="flex flex-col gap-1">
{{ if and .Metadata .Metadata.Alias }}
<span class="text-xs text-[#94A3B8]">
<i class="fas fa-tag text-[8px] mr-1"></i>Alias: <span class="text-[#E5E7EB]">{{.Metadata.Alias}}</span>
</span>
{{ end }}
{{ if and .Metadata .Metadata.MetaBackendFor }}
<span class="text-xs text-[#94A3B8]">
<i class="fas fa-link text-[8px] mr-1"></i>For: <span class="text-[#8B5CF6]">{{.Metadata.MetaBackendFor}}</span>
</span>
{{ end }}
{{ if and .Metadata .Metadata.InstalledAt }}
<span class="text-xs text-[#94A3B8]">
<i class="fas fa-calendar text-[8px] mr-1"></i>{{.Metadata.InstalledAt}}
</span>
{{ end }}
</div>
</td>
<!-- Actions Column -->
<td class="p-2">
<div class="flex items-center justify-end gap-1">
{{ if not .IsSystem }}
<button
@click="deleteBackend('{{.Name}}')"
class="text-red-400/60 hover:text-red-400 hover:bg-red-500/10 rounded p-1 transition-colors"
title="Delete {{.Name}}">
<i class="fas fa-trash-alt text-xs"></i>
</button>
{{ else }}
<span class="text-xs text-[#94A3B8]"></span>
{{ end }}
</div>
</td>
</tr>
{{end}}
</tbody>
</table>
</div>
{{ end }}
</div>
</div>
{{template "views/partials/footer" .}}
</div>
<script>
// Alpine.js component for index dashboard
function indexDashboard() {
return {
notifications: [],
init() {
// Initialize component
},
addNotification(message, type = 'success') {
const id = Date.now();
this.notifications.push({ id, message, type });
// Auto-dismiss after 5 seconds
setTimeout(() => this.dismissNotification(id), 5000);
},
dismissNotification(id) {
this.notifications = this.notifications.filter(n => n.id !== id);
},
async deleteBackend(backendName) {
if (!confirm(`Are you sure you want to delete the backend "${backendName}"?`)) {
return;
}
try {
const response = await fetch(`/api/backends/system/delete/${encodeURIComponent(backendName)}`, {
method: 'POST'
});
const data = await response.json();
if (response.ok && data.success) {
this.addNotification(`Backend "${backendName}" deleted successfully!`, 'success');
// Reload page after short delay
setTimeout(() => {
window.location.reload();
}, 1500);
} else {
this.addNotification(`Failed to delete backend: ${data.error || 'Unknown error'}`, 'error');
}
} catch (error) {
console.error('Error deleting backend:', error);
this.addNotification(`Failed to delete backend: ${error.message}`, 'error');
}
}
}
}
async function handleStopModel(modelName) {
if (!confirm('Are you sure you wish to stop this model?')) {
return;
}
try {
const response = await fetch('/backend/shutdown', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ model: modelName })
});
if (response.ok) {
window.location.reload();
} else {
alert('Failed to stop model');
}
} catch (error) {
console.error('Error stopping model:', error);
alert('Failed to stop model');
}
}
async function handleDeleteModel(modelName) {
if (!confirm('Are you sure you wish to delete this model?')) {
return;
}
try {
const response = await fetch(`/api/models/delete/${encodeURIComponent(modelName)}`, {
method: 'POST'
});
if (response.ok) {
window.location.reload();
} else {
alert('Failed to delete model');
}
} catch (error) {
console.error('Error deleting model:', error);
alert('Failed to delete model');
}
}
// Handle reload models button
document.addEventListener('DOMContentLoaded', function() {
const reloadBtn = document.getElementById('reload-models-btn');
if (reloadBtn) {
reloadBtn.addEventListener('click', function() {
const button = this;
const originalText = button.querySelector('span').textContent;
const icon = button.querySelector('i');
// Show loading state
button.disabled = true;
button.querySelector('span').textContent = 'Updating...';
icon.classList.add('fa-spin');
// Make the API call
fetch('/models/reload', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
}
})
.then(response => response.json())
.then(data => {
if (data.success) {
// Show success state briefly
button.querySelector('span').textContent = 'Updated!';
icon.classList.remove('fa-spin', 'fa-sync-alt');
icon.classList.add('fa-check');
// Reload the page after a short delay
setTimeout(() => {
window.location.reload();
}, 1000);
} else {
// Show error state
button.querySelector('span').textContent = 'Error!';
icon.classList.remove('fa-spin');
console.error('Failed to reload models:', data.error);
// Reset button after delay
setTimeout(() => {
button.disabled = false;
button.querySelector('span').textContent = originalText;
icon.classList.remove('fa-check');
icon.classList.add('fa-sync-alt');
}, 3000);
}
})
.catch(error => {
// Show error state
button.querySelector('span').textContent = 'Error!';
icon.classList.remove('fa-spin');
console.error('Error reloading models:', error);
// Reset button after delay
setTimeout(() => {
button.disabled = false;
button.querySelector('span').textContent = originalText;
icon.classList.remove('fa-check');
icon.classList.add('fa-sync-alt');
}, 3000);
});
});
}
});
</script>
</body>
</html>

View File

@@ -10,45 +10,36 @@
<div class="container mx-auto px-4 py-8 flex-grow">
<!-- Hero Header -->
<div class="relative bg-[#1E293B] border border-[#8B5CF6]/20 rounded-3xl shadow-2xl shadow-[#8B5CF6]/10 p-8 mb-8 overflow-hidden">
<!-- Background Pattern -->
<div class="absolute inset-0 opacity-10">
<div class="absolute inset-0 bg-gradient-to-r from-[#8B5CF6]/20 to-[#38BDF8]/20"></div>
<div class="absolute top-0 left-0 w-full h-full" style="background-image: radial-gradient(circle at 1px 1px, rgba(139,92,246,0.15) 1px, transparent 0); background-size: 20px 20px;"></div>
</div>
<div class="relative max-w-5xl mx-auto">
<div class="bg-[#1E293B] border border-[#8B5CF6]/20 rounded-xl p-8 mb-8">
<div class="max-w-5xl mx-auto">
<div class="flex flex-col md:flex-row md:items-center md:justify-between">
<div class="mb-4 md:mb-0">
<h1 class="text-3xl md:text-4xl font-bold text-white mb-2">
<h1 class="text-3xl md:text-4xl font-bold text-[#E5E7EB] mb-2">
<span class="bg-clip-text text-transparent bg-gradient-to-r from-violet-400 via-purple-400 to-fuchsia-400">
{{if .ModelName}}Edit Model: {{.ModelName}}{{else}}Import New Model{{end}}
</span>
</h1>
<p class="text-lg text-gray-300 font-light" x-text="isAdvancedMode ? 'Configure your model settings using YAML' : 'Import a model from URI with preferences'"></p>
<p class="text-lg text-[#94A3B8] font-light" x-text="isAdvancedMode ? 'Configure your model settings using YAML' : 'Import a model from URI with preferences'"></p>
</div>
<div class="flex gap-3">
<!-- Mode Toggle (only show when not in edit mode) -->
<template x-if="!isEditMode">
<button @click="toggleMode()"
class="group relative inline-flex items-center bg-gradient-to-r from-gray-600 to-gray-700 hover:from-gray-700 hover:to-gray-800 text-white py-3 px-6 rounded-xl font-semibold transition-all duration-300 ease-in-out transform hover:scale-105 hover:shadow-xl">
<i class="fas group-hover:animate-pulse" :class="isAdvancedMode ? 'fa-magic mr-2' : 'fa-code mr-2'"></i>
class="inline-flex items-center bg-[#1E293B] hover:bg-[#1E293B]/80 border border-[#8B5CF6]/20 text-[#E5E7EB] py-3 px-6 rounded-lg font-semibold transition-colors">
<i class="fas" :class="isAdvancedMode ? 'fa-magic mr-2' : 'fa-code mr-2'"></i>
<span x-text="isAdvancedMode ? 'Simple Mode' : 'Advanced Mode'"></span>
<div class="absolute inset-0 rounded-xl bg-white/10 opacity-0 group-hover:opacity-100 transition-opacity"></div>
</button>
</template>
<!-- Advanced Mode Buttons -->
<template x-if="isAdvancedMode">
<div class="flex gap-3">
<button id="validateBtn" class="group relative inline-flex items-center bg-gradient-to-r from-blue-600 to-blue-700 hover:from-blue-700 hover:to-blue-800 text-white py-3 px-6 rounded-xl font-semibold transition-all duration-300 ease-in-out transform hover:scale-105 hover:shadow-xl hover:shadow-blue-500/25">
<i class="fas fa-check mr-2 group-hover:animate-pulse"></i>
<button id="validateBtn" class="inline-flex items-center bg-blue-600 hover:bg-blue-700 text-white py-3 px-6 rounded-lg font-semibold transition-colors">
<i class="fas fa-check mr-2"></i>
<span>Validate</span>
<div class="absolute inset-0 rounded-xl bg-white/10 opacity-0 group-hover:opacity-100 transition-opacity"></div>
</button>
<button id="saveBtn" class="group relative inline-flex items-center bg-gradient-to-r from-green-600 to-emerald-600 hover:from-green-700 hover:to-emerald-700 text-white py-3 px-6 rounded-xl font-semibold transition-all duration-300 ease-in-out transform hover:scale-105 hover:shadow-xl hover:shadow-green-500/25">
<i class="fas fa-save mr-2 group-hover:animate-pulse"></i>
<button id="saveBtn" class="inline-flex items-center bg-green-600 hover:bg-green-700 text-white py-3 px-6 rounded-lg font-semibold transition-colors">
<i class="fas fa-save mr-2"></i>
<span>{{if .ModelName}}Update{{else}}Create{{end}}</span>
<div class="absolute inset-0 rounded-xl bg-white/10 opacity-0 group-hover:opacity-100 transition-opacity"></div>
</button>
</div>
</template>
@@ -57,10 +48,9 @@
<button @click="submitImport()"
:disabled="isSubmitting || !importUri.trim()"
:class="(isSubmitting || !importUri.trim()) ? 'opacity-50 cursor-not-allowed' : ''"
class="group relative inline-flex items-center bg-gradient-to-r from-green-600 to-emerald-600 hover:from-green-700 hover:to-emerald-700 text-white py-3 px-6 rounded-xl font-semibold transition-all duration-300 ease-in-out transform hover:scale-105 hover:shadow-xl hover:shadow-green-500/25">
<i class="fas group-hover:animate-pulse" :class="isSubmitting ? 'fa-spinner fa-spin mr-2' : 'fa-upload mr-2'"></i>
class="inline-flex items-center bg-green-600 hover:bg-green-700 text-white py-3 px-6 rounded-lg font-semibold transition-colors">
<i class="fas" :class="isSubmitting ? 'fa-spinner fa-spin mr-2' : 'fa-upload mr-2'"></i>
<span x-text="isSubmitting ? 'Importing...' : 'Import Model'"></span>
<div class="absolute inset-0 rounded-xl bg-white/10 opacity-0 group-hover:opacity-100 transition-opacity"></div>
</button>
</template>
</div>
@@ -73,15 +63,13 @@
<!-- Simple Import Mode -->
<div x-show="!isAdvancedMode && !isEditMode"
x-transition:enter="transition ease-out duration-300"
x-transition:enter-start="opacity-0 transform translate-y-4"
x-transition:enter-end="opacity-100 transform translate-y-0"
class="relative bg-gradient-to-br from-gray-800/90 to-gray-900/90 border border-gray-700/50 rounded-2xl overflow-hidden shadow-xl backdrop-blur-sm p-8">
<div class="absolute inset-0 rounded-2xl bg-gradient-to-br from-green-500/5 to-emerald-500/5"></div>
<div class="relative space-y-6">
<h2 class="text-2xl font-semibold text-white flex items-center gap-3 mb-6">
<div class="w-10 h-10 rounded-lg bg-green-500/20 flex items-center justify-center">
x-transition:enter="transition ease-out duration-200"
x-transition:enter-start="opacity-0"
x-transition:enter-end="opacity-100"
class="bg-[#1E293B] border border-[#8B5CF6]/20 rounded-xl p-8">
<div class="space-y-6">
<h2 class="text-2xl font-semibold text-[#E5E7EB] flex items-center gap-3 mb-6">
<div class="w-10 h-10 rounded-lg bg-green-500/10 flex items-center justify-center">
<i class="fas fa-link text-green-400"></i>
</div>
Import from URI
@@ -89,16 +77,16 @@
<!-- URI Input -->
<div>
<label class="block text-sm font-medium text-gray-300 mb-2">
<label class="block text-sm font-medium text-[#94A3B8] mb-2">
<i class="fas fa-link mr-2"></i>Model URI
</label>
<input
x-model="importUri"
type="text"
placeholder="https://example.com/model.gguf or file:///path/to/model.gguf"
class="w-full px-4 py-3 bg-gray-900/90 border border-gray-700/70 rounded-xl text-gray-200 focus:border-green-500 focus:ring-2 focus:ring-green-500/50 focus:outline-none transition-all"
class="w-full px-4 py-3 bg-[#101827] border border-[#1E293B] rounded-lg text-[#E5E7EB] focus:border-green-500 focus:ring-2 focus:ring-green-500/50 focus:outline-none transition-colors"
:disabled="isSubmitting">
<p class="mt-2 text-xs text-gray-400">
<p class="mt-2 text-xs text-[#94A3B8]">
Enter the URI or path to the model file you want to import
</p>
</div>
@@ -130,6 +118,8 @@
<option value="llama-cpp">llama-cpp</option>
<option value="mlx">mlx</option>
<option value="mlx-vlm">mlx-vlm</option>
<option value="transformers">transformers</option>
<option value="vllm">vllm</option>
</select>
<p class="mt-1 text-xs text-gray-400">
Force a specific backend. Leave empty to auto-detect from URI.
@@ -199,6 +189,39 @@
Preferred MMProj quantizations (comma-separated). Examples: fp16, fp32. Leave empty to use default (fp16).
</p>
</div>
<!-- Embeddings -->
<div>
<label class="flex items-center cursor-pointer">
<input
x-model="commonPreferences.embeddings"
type="checkbox"
class="w-5 h-5 rounded bg-gray-900/90 border-gray-700/70 text-green-500 focus:ring-2 focus:ring-green-500/50 focus:outline-none transition-all cursor-pointer"
:disabled="isSubmitting">
<span class="ml-3 text-sm font-medium text-gray-300">
<i class="fas fa-vector-square mr-2"></i>Embeddings
</span>
</label>
<p class="mt-1 ml-8 text-xs text-gray-400">
Enable embeddings support for this model.
</p>
</div>
<!-- Model Type -->
<div>
<label class="block text-sm font-medium text-gray-300 mb-2">
<i class="fas fa-tag mr-2"></i>Model Type
</label>
<input
x-model="commonPreferences.type"
type="text"
placeholder="AutoModelForCausalLM (for transformers backend)"
class="w-full px-4 py-2 bg-gray-900/90 border border-gray-700/70 rounded-lg text-gray-200 focus:border-green-500 focus:ring-2 focus:ring-green-500/50 focus:outline-none transition-all"
:disabled="isSubmitting">
<p class="mt-1 text-xs text-gray-400">
Model type for transformers backend. Examples: AutoModelForCausalLM, SentenceTransformer, Mamba, MusicgenForConditionalGeneration. Leave empty to use default (AutoModelForCausalLM).
</p>
</div>
</div>
<!-- Custom Preferences -->
@@ -248,25 +271,23 @@
<!-- Advanced YAML Editor Panel -->
<div x-show="isAdvancedMode || isEditMode"
x-transition:enter="transition ease-out duration-300"
x-transition:enter-start="opacity-0 transform translate-y-4"
x-transition:enter-end="opacity-100 transform translate-y-0"
class="relative bg-gradient-to-br from-gray-800/90 to-gray-900/90 border border-gray-700/50 rounded-2xl overflow-hidden shadow-xl backdrop-blur-sm h-[calc(100vh-250px)]">
<div class="absolute inset-0 rounded-2xl bg-gradient-to-br from-fuchsia-500/5 to-purple-500/5"></div>
<div class="relative sticky top-0 bg-gray-800/95 border-b border-gray-700/50 p-6 flex items-center justify-between z-10 backdrop-blur-sm">
<h2 class="text-xl font-semibold text-white flex items-center gap-3">
<div class="w-8 h-8 rounded-lg bg-fuchsia-500/20 flex items-center justify-center">
x-transition:enter="transition ease-out duration-200"
x-transition:enter-start="opacity-0"
x-transition:enter-end="opacity-100"
class="bg-[#1E293B] border border-[#8B5CF6]/20 rounded-xl overflow-hidden h-[calc(100vh-250px)]">
<div class="sticky top-0 bg-[#1E293B] border-b border-[#101827] p-6 flex items-center justify-between z-10">
<h2 class="text-xl font-semibold text-[#E5E7EB] flex items-center gap-3">
<div class="w-8 h-8 rounded-lg bg-fuchsia-500/10 flex items-center justify-center">
<i class="fas fa-code text-fuchsia-400"></i>
</div>
YAML Configuration Editor
</h2>
<div class="flex items-center gap-3">
<button id="formatYamlBtn" class="group text-gray-400 hover:text-gray-200 text-sm px-3 py-1.5 rounded-lg hover:bg-gray-700/50 transition-all duration-200">
<i class="fas fa-indent mr-1.5 group-hover:animate-pulse"></i> Format
<button id="formatYamlBtn" class="text-[#94A3B8] hover:text-[#E5E7EB] text-sm px-3 py-1.5 rounded-lg hover:bg-[#101827] transition-colors">
<i class="fas fa-indent mr-1.5"></i> Format
</button>
<button id="copyYamlBtn" class="group text-gray-400 hover:text-gray-200 text-sm px-3 py-1.5 rounded-lg hover:bg-gray-700/50 transition-all duration-200">
<i class="fas fa-copy mr-1.5 group-hover:animate-bounce"></i> Copy
<button id="copyYamlBtn" class="text-[#94A3B8] hover:text-[#E5E7EB] text-sm px-3 py-1.5 rounded-lg hover:bg-[#101827] transition-colors">
<i class="fas fa-copy mr-1.5"></i> Copy
</button>
</div>
</div>
@@ -280,13 +301,13 @@
</div>
<!-- Include JS-YAML library -->
<script src="https://cdnjs.cloudflare.com/ajax/libs/js-yaml/4.1.0/js-yaml.min.js"></script>
<script src="static/assets/js-yaml.min.js"></script>
<!-- Include CodeMirror for syntax highlighting -->
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/codemirror/6.65.7/codemirror.min.css">
<script src="https://cdnjs.cloudflare.com/ajax/libs/codemirror/6.65.7/codemirror.min.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/codemirror/6.65.7/mode/yaml/yaml.min.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/codemirror/6.65.7/addon/display/autorefresh.min.js"></script>
<link rel="stylesheet" href="static/assets/codemirror.min.css">
<script src="static/assets/codemirror.min.js"></script>
<script src="static/assets/yaml.min.js"></script>
<script src="static/assets/autorefresh.min.js"></script>
<style>
/* Enhanced CodeMirror styling */
@@ -412,11 +433,9 @@
@keyframes slideInFromTop {
from {
transform: translateY(-20px);
opacity: 0;
}
to {
transform: translateY(0);
opacity: 1;
}
}
@@ -458,7 +477,9 @@ function importModel() {
name: '',
description: '',
quantizations: '',
mmproj_quantizations: ''
mmproj_quantizations: '',
embeddings: false,
type: ''
},
isSubmitting: false,
currentJobId: null,
@@ -527,6 +548,12 @@ function importModel() {
if (this.commonPreferences.mmproj_quantizations && this.commonPreferences.mmproj_quantizations.trim()) {
prefsObj.mmproj_quantizations = this.commonPreferences.mmproj_quantizations.trim();
}
if (this.commonPreferences.embeddings) {
prefsObj.embeddings = 'true';
}
if (this.commonPreferences.type && this.commonPreferences.type.trim()) {
prefsObj.type = this.commonPreferences.type.trim();
}
// Add custom preferences (can override common ones)
this.preferences.forEach(pref => {

View File

@@ -10,22 +10,22 @@
<!-- Notifications -->
<div class="fixed top-20 right-4 z-50 space-y-2" style="max-width: 400px;">
<template x-for="notification in notifications" :key="notification.id">
<div x-show="true"
x-transition:enter="transform ease-out duration-300 transition"
x-transition:enter-start="translate-x-full opacity-0"
x-transition:enter-end="translate-x-0 opacity-100"
x-transition:leave="transform ease-in duration-200 transition"
x-transition:leave-start="translate-x-0 opacity-100"
x-transition:leave-end="translate-x-full opacity-0"
<div x-show="true"
x-transition:enter="transition ease-out duration-200"
x-transition:enter-start="opacity-0"
x-transition:enter-end="opacity-100"
x-transition:leave="transition ease-in duration-150"
x-transition:leave-start="opacity-100"
x-transition:leave-end="opacity-0"
:class="notification.type === 'error' ? 'bg-red-500' : 'bg-green-500'"
class="rounded-lg shadow-xl p-4 text-white flex items-start space-x-3">
class="rounded-lg p-4 text-white flex items-start space-x-3">
<div class="flex-shrink-0">
<i :class="notification.type === 'error' ? 'fas fa-exclamation-circle' : 'fas fa-check-circle'" class="text-xl"></i>
</div>
<div class="flex-1 min-w-0">
<p class="text-sm font-medium break-words" x-text="notification.message"></p>
</div>
<button @click="dismissNotification(notification.id)" class="flex-shrink-0 text-white hover:text-gray-200">
<button @click="dismissNotification(notification.id)" class="flex-shrink-0 text-white hover:opacity-80 transition-opacity">
<i class="fas fa-times"></i>
</button>
</div>
@@ -35,35 +35,34 @@
<div class="container mx-auto px-4 py-8 flex-grow">
<!-- Hero Header -->
<div class="relative bg-[#1E293B] border border-[#38BDF8]/20 rounded-3xl shadow-2xl shadow-[#38BDF8]/10 p-8 mb-12 overflow-hidden">
<!-- Background Pattern -->
<div class="absolute inset-0 opacity-10">
<div class="absolute inset-0 bg-gradient-to-r from-[#38BDF8]/20 to-[#8B5CF6]/20"></div>
<div class="absolute top-0 left-0 w-full h-full" style="background-image: radial-gradient(circle at 1px 1px, rgba(56,189,248,0.15) 1px, transparent 0); background-size: 20px 20px;"></div>
</div>
<div class="relative max-w-5xl mx-auto text-center">
<div class="bg-[#1E293B] border border-[#38BDF8]/20 rounded-xl p-8 mb-12">
<div class="max-w-5xl mx-auto text-center">
<h1 class="text-4xl md:text-5xl font-bold text-[#E5E7EB] mb-4">
<span class="bg-clip-text text-transparent bg-gradient-to-r from-[#38BDF8] via-[#8B5CF6] to-[#38BDF8]">
Model Gallery
</span>
</h1>
<p class="text-lg md:text-xl text-gray-300 mb-6 font-light">
<p class="text-lg md:text-xl text-[#94A3B8] mb-6 font-light">
Discover and install AI models from our curated collection
</p>
<div class="flex flex-wrap justify-center items-center gap-6 text-sm md:text-base">
<div class="flex items-center bg-white/10 rounded-full px-4 py-2">
<div class="w-2 h-2 bg-indigo-400 rounded-full mr-2 animate-pulse"></div>
<div class="flex items-center bg-[#101827] rounded-lg px-4 py-2">
<div class="w-2 h-2 bg-indigo-400 rounded-full mr-2"></div>
<span class="font-semibold text-indigo-300" x-text="availableModels"></span>
<span class="text-gray-300 ml-1">models available</span>
<span class="text-[#94A3B8] ml-1">models available</span>
</div>
<div class="flex items-center bg-white/10 rounded-full px-4 py-2">
<div class="w-2 h-2 bg-purple-400 rounded-full mr-2 animate-pulse"></div>
<a href="/manage" class="flex items-center bg-[#101827] hover:bg-[#1E293B] rounded-lg px-4 py-2 transition-colors border border-[#38BDF8]/30 hover:border-[#38BDF8]/50">
<div class="w-2 h-2 bg-emerald-400 rounded-full mr-2"></div>
<span class="font-semibold text-emerald-300" x-text="installedModels"></span>
<span class="text-[#94A3B8] ml-1">installed</span>
</a>
<div class="flex items-center bg-[#101827] rounded-lg px-4 py-2">
<div class="w-2 h-2 bg-purple-400 rounded-full mr-2"></div>
<span class="font-semibold text-purple-300" x-text="repositories.length"></span>
<span class="text-gray-300 ml-1">repositories</span>
<span class="text-[#94A3B8] ml-1">repositories</span>
</div>
<a href="https://localai.io/models/" target="_blank"
class="flex items-center bg-blue-600/80 hover:bg-blue-600 text-white px-4 py-2 rounded-full transition-all duration-300 hover:scale-105">
class="inline-flex items-center bg-blue-600 hover:bg-blue-700 text-white px-4 py-2 rounded-lg transition-colors">
<i class="fas fa-info-circle mr-2"></i>
<span>Documentation</span>
<i class="fas fa-external-link-alt ml-2 text-xs"></i>
@@ -75,28 +74,26 @@
{{template "views/partials/inprogress" .}}
<!-- Search and Filter Section -->
<div class="relative bg-gradient-to-br from-gray-800/80 to-gray-900/80 rounded-2xl p-8 mb-8 shadow-xl border border-gray-700/50 backdrop-blur-sm">
<div class="absolute inset-0 rounded-2xl bg-gradient-to-br from-blue-500/5 to-purple-500/5"></div>
<div class="relative">
<div class="bg-[#1E293B] border border-[#38BDF8]/20 rounded-xl p-8 mb-8">
<div>
<!-- Search Input -->
<div class="mb-8">
<h3 class="text-xl font-semibold text-white mb-4 flex items-center">
<i class="fas fa-search mr-3 text-blue-400"></i>
<h3 class="text-xl font-semibold text-[#E5E7EB] mb-4 flex items-center">
<i class="fas fa-search mr-3 text-[#38BDF8]"></i>
Find Your Perfect Model
</h3>
<div class="relative">
<div class="absolute inset-y-0 start-0 flex items-center ps-4 pointer-events-none">
<i class="fas fa-search text-gray-400"></i>
<i class="fas fa-search text-[#94A3B8]"></i>
</div>
<input
x-model="searchTerm"
@input.debounce.500ms="fetchModels()"
class="w-full pl-12 pr-16 py-4 text-base font-normal text-gray-300 bg-gray-900/90 border border-gray-700/70 rounded-xl transition-all duration-300 focus:text-gray-200 focus:bg-gray-900 focus:border-blue-500 focus:ring-2 focus:ring-blue-500/50 focus:outline-none"
class="w-full pl-12 pr-16 py-4 text-base font-normal text-[#E5E7EB] bg-[#101827] border border-[#1E293B] rounded-lg transition-colors focus:text-[#E5E7EB] focus:bg-[#101827] focus:border-[#38BDF8] focus:ring-2 focus:ring-[#38BDF8]/50 focus:outline-none"
type="search"
placeholder="Search models by name, tag, or description...">
<span class="absolute right-4 top-4" x-show="loading">
<svg class="animate-spin h-6 w-6 text-blue-500" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24">
<svg class="animate-spin h-6 w-6 text-[#38BDF8]" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24">
<circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle>
<path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path>
</svg>
@@ -106,49 +103,49 @@
<!-- Filter by Type -->
<div class="mb-8">
<h3 class="text-lg font-semibold text-white mb-4 flex items-center">
<i class="fas fa-filter mr-3 text-purple-400"></i>
<h3 class="text-lg font-semibold text-[#E5E7EB] mb-4 flex items-center">
<i class="fas fa-filter mr-3 text-[#8B5CF6]"></i>
Filter by Model Type
</h3>
<div class="grid grid-cols-2 sm:grid-cols-3 lg:grid-cols-4 xl:grid-cols-8 gap-3">
<button @click="filterByTerm('tts')"
class="group flex items-center justify-center rounded-xl px-4 py-3 text-sm font-semibold bg-gradient-to-r from-indigo-600/80 to-indigo-700/80 hover:from-indigo-600 hover:to-indigo-700 text-indigo-100 border border-indigo-500/30 hover:border-indigo-400/50 transition-all duration-300 transform hover:scale-105 hover:shadow-lg hover:shadow-indigo-500/25">
<i class="fas fa-microphone mr-2 group-hover:animate-pulse"></i>
class="flex items-center justify-center rounded-lg px-4 py-3 text-sm font-semibold bg-indigo-600/20 hover:bg-indigo-600/30 text-indigo-300 border border-indigo-500/30 transition-colors">
<i class="fas fa-microphone mr-2"></i>
<span>TTS</span>
</button>
<button @click="filterByTerm('stablediffusion')"
class="group flex items-center justify-center rounded-xl px-4 py-3 text-sm font-semibold bg-gradient-to-r from-purple-600/80 to-purple-700/80 hover:from-purple-600 hover:to-purple-700 text-purple-100 border border-purple-500/30 hover:border-purple-400/50 transition-all duration-300 transform hover:scale-105 hover:shadow-lg hover:shadow-purple-500/25">
<i class="fas fa-image mr-2 group-hover:animate-pulse"></i>
class="flex items-center justify-center rounded-lg px-4 py-3 text-sm font-semibold bg-purple-600/20 hover:bg-purple-600/30 text-purple-300 border border-purple-500/30 transition-colors">
<i class="fas fa-image mr-2"></i>
<span>Image</span>
</button>
<button @click="filterByTerm('llm')"
class="group flex items-center justify-center rounded-xl px-4 py-3 text-sm font-semibold bg-gradient-to-r from-blue-600/80 to-blue-700/80 hover:from-blue-600 hover:to-blue-700 text-blue-100 border border-blue-500/30 hover:border-blue-400/50 transition-all duration-300 transform hover:scale-105 hover:shadow-lg hover:shadow-blue-500/25">
<i class="fas fa-comment-alt mr-2 group-hover:animate-bounce"></i>
class="flex items-center justify-center rounded-lg px-4 py-3 text-sm font-semibold bg-blue-600/20 hover:bg-blue-600/30 text-blue-300 border border-blue-500/30 transition-colors">
<i class="fas fa-comment-alt mr-2"></i>
<span>LLM</span>
</button>
<button @click="filterByTerm('multimodal')"
class="group flex items-center justify-center rounded-xl px-4 py-3 text-sm font-semibold bg-gradient-to-r from-green-600/80 to-green-700/80 hover:from-green-600 hover:to-green-700 text-green-100 border border-green-500/30 hover:border-green-400/50 transition-all duration-300 transform hover:scale-105 hover:shadow-lg hover:shadow-green-500/25">
<i class="fas fa-object-group mr-2 group-hover:animate-pulse"></i>
class="flex items-center justify-center rounded-lg px-4 py-3 text-sm font-semibold bg-green-600/20 hover:bg-green-600/30 text-green-300 border border-green-500/30 transition-colors">
<i class="fas fa-object-group mr-2"></i>
<span>Multimodal</span>
</button>
<button @click="filterByTerm('embedding')"
class="group flex items-center justify-center rounded-xl px-4 py-3 text-sm font-semibold bg-gradient-to-r from-cyan-600/80 to-cyan-700/80 hover:from-cyan-600 hover:to-cyan-700 text-cyan-100 border border-cyan-500/30 hover:border-cyan-400/50 transition-all duration-300 transform hover:scale-105 hover:shadow-lg hover:shadow-cyan-500/25">
<i class="fas fa-vector-square mr-2 group-hover:animate-pulse"></i>
class="flex items-center justify-center rounded-lg px-4 py-3 text-sm font-semibold bg-cyan-600/20 hover:bg-cyan-600/30 text-cyan-300 border border-cyan-500/30 transition-colors">
<i class="fas fa-vector-square mr-2"></i>
<span>Embedding</span>
</button>
<button @click="filterByTerm('rerank')"
class="group flex items-center justify-center rounded-xl px-4 py-3 text-sm font-semibold bg-gradient-to-r from-amber-600/80 to-amber-700/80 hover:from-amber-600 hover:to-amber-700 text-amber-100 border border-amber-500/30 hover:border-amber-400/50 transition-all duration-300 transform hover:scale-105 hover:shadow-lg hover:shadow-amber-500/25">
<i class="fas fa-sort-amount-up mr-2 group-hover:animate-pulse"></i>
class="flex items-center justify-center rounded-lg px-4 py-3 text-sm font-semibold bg-amber-600/20 hover:bg-amber-600/30 text-amber-300 border border-amber-500/30 transition-colors">
<i class="fas fa-sort-amount-up mr-2"></i>
<span>Rerank</span>
</button>
<button @click="filterByTerm('whisper')"
class="group flex items-center justify-center rounded-xl px-4 py-3 text-sm font-semibold bg-gradient-to-r from-teal-600/80 to-teal-700/80 hover:from-teal-600 hover:to-teal-700 text-teal-100 border border-teal-500/30 hover:border-teal-400/50 transition-all duration-300 transform hover:scale-105 hover:shadow-lg hover:shadow-teal-500/25">
<i class="fas fa-headphones mr-2 group-hover:animate-pulse"></i>
class="flex items-center justify-center rounded-lg px-4 py-3 text-sm font-semibold bg-teal-600/20 hover:bg-teal-600/30 text-teal-300 border border-teal-500/30 transition-colors">
<i class="fas fa-headphones mr-2"></i>
<span>Whisper</span>
</button>
<button @click="filterByTerm('object-detection')"
class="group flex items-center justify-center rounded-xl px-4 py-3 text-sm font-semibold bg-gradient-to-r from-red-600/80 to-red-700/80 hover:from-red-600 hover:to-red-700 text-red-100 border border-red-500/30 hover:border-red-400/50 transition-all duration-300 transform hover:scale-105 hover:shadow-lg hover:shadow-red-500/25">
<i class="fas fa-eye mr-2 group-hover:animate-pulse"></i>
class="flex items-center justify-center rounded-lg px-4 py-3 text-sm font-semibold bg-red-600/20 hover:bg-red-600/30 text-red-300 border border-red-500/30 transition-colors">
<i class="fas fa-eye mr-2"></i>
<span>Vision</span>
</button>
</div>
@@ -156,16 +153,16 @@
<!-- Filter by Tags -->
<div x-show="allTags.length > 0">
<h3 class="text-lg font-semibold text-white mb-4 flex items-center">
<i class="fas fa-tags mr-3 text-pink-400"></i>
<h3 class="text-lg font-semibold text-[#E5E7EB] mb-4 flex items-center">
<i class="fas fa-tags mr-3 text-[#8B5CF6]"></i>
Browse by Tags
</h3>
<div class="max-h-32 overflow-y-auto scrollbar-thin scrollbar-thumb-gray-600/50 scrollbar-track-gray-800/50 pr-2">
<div class="max-h-32 overflow-y-auto pr-2">
<div class="flex flex-wrap gap-2">
<template x-for="tag in allTags" :key="tag">
<button @click="filterByTerm(tag)"
class="group inline-flex items-center text-xs px-3 py-2 rounded-full bg-gray-700/60 hover:bg-gray-600/80 text-gray-300 hover:text-white border border-gray-600/50 hover:border-gray-500/70 transition-all duration-200 ease-in-out transform hover:scale-105">
<i class="fas fa-tag text-xs mr-2 group-hover:animate-pulse"></i>
class="inline-flex items-center text-xs px-3 py-2 rounded bg-[#101827] hover:bg-[#101827]/80 text-[#94A3B8] hover:text-[#E5E7EB] border border-[#1E293B] transition-colors">
<i class="fas fa-tag text-xs mr-2"></i>
<span x-text="tag"></span>
</button>
</template>
@@ -210,10 +207,14 @@
<tr class="hover:bg-[#38BDF8]/10 transition-colors duration-200">
<!-- Icon -->
<td class="px-6 py-4">
<img :src="model.icon || 'https://upload.wikimedia.org/wikipedia/commons/6/65/No-Image-Placeholder.svg'"
class="w-12 h-12 object-cover rounded-lg border border-[#38BDF8]/30"
loading="lazy"
:alt="model.name">
<div class="w-12 h-12 rounded-lg border border-[#38BDF8]/30 flex items-center justify-center bg-[#101827]">
<img x-show="model.icon"
:src="model.icon"
class="w-full h-full object-cover rounded-lg"
loading="lazy"
:alt="model.name">
<i x-show="!model.icon" class="fas fa-brain text-xl text-[#38BDF8]"></i>
</div>
</td>
<!-- Model Name -->
@@ -355,9 +356,13 @@
<!-- Modal Body -->
<div class="p-4 md:p-5 space-y-4 overflow-y-auto flex-1 min-h-0">
<div class="flex justify-center items-center">
<img :src="selectedModel?.icon || 'https://upload.wikimedia.org/wikipedia/commons/6/65/No-Image-Placeholder.svg'"
class="lazy rounded-t-lg max-h-48 max-w-96 object-cover mt-3"
loading="lazy">
<div class="w-48 h-48 rounded-lg border border-gray-300 dark:border-gray-600 flex items-center justify-center bg-gray-100 dark:bg-gray-800 mt-3">
<img x-show="selectedModel?.icon"
:src="selectedModel?.icon"
class="rounded-lg max-h-48 max-w-96 object-cover"
loading="lazy">
<i x-show="!selectedModel?.icon" class="fas fa-brain text-6xl text-gray-400 dark:text-gray-500"></i>
</div>
</div>
<div class="text-base leading-relaxed text-gray-500 dark:text-gray-400 break-words max-w-full markdown-content" x-html="renderMarkdown(selectedModel?.description)"></div>
<hr>
@@ -424,8 +429,8 @@
<button @click="goToPage(currentPage - 1)"
:disabled="currentPage <= 1"
:class="currentPage <= 1 ? 'opacity-50 cursor-not-allowed' : ''"
class="group flex items-center justify-center h-12 w-12 bg-gray-700/80 hover:bg-indigo-600 text-gray-300 hover:text-white rounded-xl shadow-lg transition-all duration-300 ease-in-out transform hover:scale-110">
<i class="fas fa-chevron-left group-hover:animate-pulse"></i>
class="flex items-center justify-center h-12 w-12 bg-[#1E293B] hover:bg-indigo-600 text-[#94A3B8] hover:text-white rounded-lg transition-colors">
<i class="fas fa-chevron-left"></i>
</button>
<div class="text-gray-300 text-sm font-medium px-4">
<span class="text-gray-400">Page</span>
@@ -559,6 +564,7 @@ function modelsGallery() {
currentPage: 1,
totalPages: 1,
availableModels: 0,
installedModels: 0,
selectedModel: null,
jobProgress: {},
notifications: [],
@@ -597,6 +603,7 @@ function modelsGallery() {
this.currentPage = data.currentPage || 1;
this.totalPages = data.totalPages || 1;
this.availableModels = data.availableModels || 0;
this.installedModels = data.installedModels || 0;
} catch (error) {
console.error('Error fetching models:', error);
} finally {

Some files were not shown because too many files have changed in this diff Show More