mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-24 16:51:44 -04:00
Compare commits
25 Commits
deps/llama
...
docs/impro
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3e8a54f4b6 | ||
|
|
18d11396cd | ||
|
|
93cd688f40 | ||
|
|
721c3f962b | ||
|
|
fb834805db | ||
|
|
839aa7b42b | ||
|
|
e963a45d66 | ||
|
|
c313b2c671 | ||
|
|
137f16336e | ||
|
|
d7f9f3ac93 | ||
|
|
cd7d384500 | ||
|
|
d1a0dd10e6 | ||
|
|
be8cf838c2 | ||
|
|
3276d1cdaf | ||
|
|
5e5f01badd | ||
|
|
6d0f646c37 | ||
|
|
99d31667f8 | ||
|
|
47b546afdc | ||
|
|
a09d49da43 | ||
|
|
1cdcaf0152 | ||
|
|
03e9f4b140 | ||
|
|
7129409bf6 | ||
|
|
d9e9ec6825 | ||
|
|
b82645d28d | ||
|
|
735ca757fa |
@@ -156,6 +156,8 @@ message PredictOptions {
|
||||
string CorrelationId = 47;
|
||||
string Tools = 48; // JSON array of available tools/functions for tool calling
|
||||
string ToolChoice = 49; // JSON string or object specifying tool choice behavior
|
||||
int32 Logprobs = 50; // Number of top logprobs to return (maps to OpenAI logprobs parameter)
|
||||
int32 TopLogprobs = 51; // Number of top logprobs to return per token (maps to OpenAI top_logprobs parameter)
|
||||
}
|
||||
|
||||
// The response message containing the result
|
||||
@@ -166,6 +168,7 @@ message Reply {
|
||||
double timing_prompt_processing = 4;
|
||||
double timing_token_generation = 5;
|
||||
bytes audio = 6;
|
||||
bytes logprobs = 7; // JSON-encoded logprobs data matching OpenAI format
|
||||
}
|
||||
|
||||
message GrammarTrigger {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
LLAMA_VERSION?=7d019cff744b73084b15ca81ba9916f3efab1223
|
||||
LLAMA_VERSION?=80deff3648b93727422461c41c7279ef1dac7452
|
||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -166,6 +166,34 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, const
|
||||
SRV_INF("Extracted tool_choice as string: %s\n", predict->toolchoice().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
// Extract logprobs and top_logprobs from proto and add to JSON data
|
||||
// Following server.cpp pattern: logprobs maps to n_probs when provided
|
||||
if (predict->logprobs() > 0) {
|
||||
data["logprobs"] = predict->logprobs();
|
||||
// Map logprobs to n_probs (following server.cpp line 369 pattern)
|
||||
// n_probs will be set by params_from_json_cmpl if logprobs is provided
|
||||
data["n_probs"] = predict->logprobs();
|
||||
SRV_INF("Using logprobs: %d\n", predict->logprobs());
|
||||
}
|
||||
if (predict->toplogprobs() > 0) {
|
||||
data["top_logprobs"] = predict->toplogprobs();
|
||||
SRV_INF("Using top_logprobs: %d\n", predict->toplogprobs());
|
||||
}
|
||||
|
||||
// Extract logit_bias from proto and add to JSON data
|
||||
if (!predict->logitbias().empty()) {
|
||||
try {
|
||||
// Parse logit_bias JSON string from proto
|
||||
json logit_bias_json = json::parse(predict->logitbias());
|
||||
// Add to data - llama.cpp server expects it as an object (map)
|
||||
data["logit_bias"] = logit_bias_json;
|
||||
SRV_INF("Using logit_bias: %s\n", predict->logitbias().c_str());
|
||||
} catch (const json::parse_error& e) {
|
||||
SRV_ERR("Failed to parse logit_bias JSON from proto: %s\n", e.what());
|
||||
}
|
||||
}
|
||||
|
||||
data["ignore_eos"] = predict->ignoreeos();
|
||||
data["embeddings"] = predict->embeddings();
|
||||
|
||||
@@ -568,6 +596,28 @@ public:
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
// Helper function to extract logprobs from JSON response
|
||||
static json extract_logprobs_from_json(const json& res_json) {
|
||||
json logprobs_json = json::object();
|
||||
|
||||
// Check for OAI-compatible format: choices[0].logprobs
|
||||
if (res_json.contains("choices") && res_json["choices"].is_array() &&
|
||||
res_json["choices"].size() > 0 && res_json["choices"][0].contains("logprobs")) {
|
||||
logprobs_json = res_json["choices"][0]["logprobs"];
|
||||
}
|
||||
// Check for non-OAI format: completion_probabilities
|
||||
else if (res_json.contains("completion_probabilities")) {
|
||||
// Convert completion_probabilities to OAI format
|
||||
logprobs_json["content"] = res_json["completion_probabilities"];
|
||||
}
|
||||
// Check for direct logprobs field
|
||||
else if (res_json.contains("logprobs")) {
|
||||
logprobs_json = res_json["logprobs"];
|
||||
}
|
||||
|
||||
return logprobs_json;
|
||||
}
|
||||
|
||||
grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter<backend::Reply>* writer) override {
|
||||
json data = parse_options(true, request, ctx_server);
|
||||
|
||||
@@ -579,7 +629,8 @@ public:
|
||||
|
||||
|
||||
auto completion_id = gen_chatcmplid();
|
||||
std::unordered_set<int> task_ids;
|
||||
// need to store the reader as a pointer, so that it won't be destroyed when the handle returns
|
||||
const auto rd = std::make_shared<server_response_reader>(ctx_server);
|
||||
try {
|
||||
std::vector<server_task> tasks;
|
||||
|
||||
@@ -620,6 +671,11 @@ public:
|
||||
content_val = msg.content();
|
||||
}
|
||||
|
||||
// If content is an object (e.g., from tool call failures), convert to string
|
||||
if (content_val.is_object()) {
|
||||
content_val = content_val.dump();
|
||||
}
|
||||
|
||||
// If content is a string and this is the last user message with images/audio, combine them
|
||||
if (content_val.is_string() && is_last_user_msg && has_images_or_audio) {
|
||||
json content_array = json::array();
|
||||
@@ -871,18 +927,91 @@ public:
|
||||
tasks.push_back(std::move(task));
|
||||
}
|
||||
|
||||
task_ids = server_task::get_list_id(tasks);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(std::move(tasks));
|
||||
rd->post_tasks(std::move(tasks));
|
||||
} catch (const std::exception & e) {
|
||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, e.what());
|
||||
}
|
||||
|
||||
ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool {
|
||||
// Get first result for error checking (following server.cpp pattern)
|
||||
server_task_result_ptr first_result = rd->next([&context]() { return context->IsCancelled(); });
|
||||
if (first_result == nullptr) {
|
||||
// connection is closed
|
||||
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
|
||||
} else if (first_result->is_error()) {
|
||||
json error_json = first_result->to_json();
|
||||
backend::Reply reply;
|
||||
reply.set_message(error_json.value("message", ""));
|
||||
writer->Write(reply);
|
||||
return grpc::Status(grpc::StatusCode::INTERNAL, error_json.value("message", "Error occurred"));
|
||||
}
|
||||
|
||||
// Process first result
|
||||
json first_res_json = first_result->to_json();
|
||||
if (first_res_json.is_array()) {
|
||||
for (const auto & res : first_res_json) {
|
||||
std::string completion_text = res.value("content", "");
|
||||
|
||||
backend::Reply reply;
|
||||
reply.set_message(completion_text);
|
||||
int32_t tokens_predicted = res.value("tokens_predicted", 0);
|
||||
reply.set_tokens(tokens_predicted);
|
||||
int32_t tokens_evaluated = res.value("tokens_evaluated", 0);
|
||||
reply.set_prompt_tokens(tokens_evaluated);
|
||||
|
||||
if (res.contains("timings")) {
|
||||
double timing_prompt_processing = res.at("timings").value("prompt_ms", 0.0);
|
||||
reply.set_timing_prompt_processing(timing_prompt_processing);
|
||||
double timing_token_generation = res.at("timings").value("predicted_ms", 0.0);
|
||||
reply.set_timing_token_generation(timing_token_generation);
|
||||
}
|
||||
|
||||
// Extract and set logprobs if present
|
||||
json logprobs_json = extract_logprobs_from_json(res);
|
||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||
std::string logprobs_str = logprobs_json.dump();
|
||||
reply.set_logprobs(logprobs_str);
|
||||
}
|
||||
|
||||
writer->Write(reply);
|
||||
}
|
||||
} else {
|
||||
std::string completion_text = first_res_json.value("content", "");
|
||||
|
||||
backend::Reply reply;
|
||||
reply.set_message(completion_text);
|
||||
int32_t tokens_predicted = first_res_json.value("tokens_predicted", 0);
|
||||
reply.set_tokens(tokens_predicted);
|
||||
int32_t tokens_evaluated = first_res_json.value("tokens_evaluated", 0);
|
||||
reply.set_prompt_tokens(tokens_evaluated);
|
||||
|
||||
if (first_res_json.contains("timings")) {
|
||||
double timing_prompt_processing = first_res_json.at("timings").value("prompt_ms", 0.0);
|
||||
reply.set_timing_prompt_processing(timing_prompt_processing);
|
||||
double timing_token_generation = first_res_json.at("timings").value("predicted_ms", 0.0);
|
||||
reply.set_timing_token_generation(timing_token_generation);
|
||||
}
|
||||
|
||||
// Extract and set logprobs if present
|
||||
json logprobs_json = extract_logprobs_from_json(first_res_json);
|
||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||
std::string logprobs_str = logprobs_json.dump();
|
||||
reply.set_logprobs(logprobs_str);
|
||||
}
|
||||
|
||||
writer->Write(reply);
|
||||
}
|
||||
|
||||
// Process subsequent results
|
||||
while (rd->has_next()) {
|
||||
// Check if context is cancelled before processing result
|
||||
if (context->IsCancelled()) {
|
||||
ctx_server.cancel_tasks(task_ids);
|
||||
return false;
|
||||
break;
|
||||
}
|
||||
|
||||
auto result = rd->next([&context]() { return context->IsCancelled(); });
|
||||
if (result == nullptr) {
|
||||
// connection is closed
|
||||
break;
|
||||
}
|
||||
|
||||
json res_json = result->to_json();
|
||||
@@ -904,9 +1033,13 @@ public:
|
||||
reply.set_timing_token_generation(timing_token_generation);
|
||||
}
|
||||
|
||||
// Log Request Correlation Id
|
||||
|
||||
// Send the reply
|
||||
// Extract and set logprobs if present
|
||||
json logprobs_json = extract_logprobs_from_json(res);
|
||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||
std::string logprobs_str = logprobs_json.dump();
|
||||
reply.set_logprobs(logprobs_str);
|
||||
}
|
||||
|
||||
writer->Write(reply);
|
||||
}
|
||||
} else {
|
||||
@@ -926,24 +1059,16 @@ public:
|
||||
reply.set_timing_token_generation(timing_token_generation);
|
||||
}
|
||||
|
||||
|
||||
// Extract and set logprobs if present
|
||||
json logprobs_json = extract_logprobs_from_json(res_json);
|
||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||
std::string logprobs_str = logprobs_json.dump();
|
||||
reply.set_logprobs(logprobs_str);
|
||||
}
|
||||
|
||||
// Send the reply
|
||||
writer->Write(reply);
|
||||
|
||||
writer->Write(reply);
|
||||
}
|
||||
return true;
|
||||
}, [&](const json & error_data) {
|
||||
backend::Reply reply;
|
||||
reply.set_message(error_data.value("content", ""));
|
||||
writer->Write(reply);
|
||||
return true;
|
||||
}, [&context]() {
|
||||
// Check if the gRPC context is cancelled
|
||||
return context->IsCancelled();
|
||||
});
|
||||
|
||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
||||
}
|
||||
|
||||
// Check if context was cancelled during processing
|
||||
if (context->IsCancelled()) {
|
||||
@@ -963,7 +1088,7 @@ public:
|
||||
}
|
||||
std::cout << "[PREDICT] Received result: " << data.dump(2) << std::endl;
|
||||
auto completion_id = gen_chatcmplid();
|
||||
std::unordered_set<int> task_ids;
|
||||
const auto rd = std::make_shared<server_response_reader>(ctx_server);
|
||||
try {
|
||||
std::vector<server_task> tasks;
|
||||
|
||||
@@ -1004,6 +1129,11 @@ public:
|
||||
content_val = msg.content();
|
||||
}
|
||||
|
||||
// If content is an object (e.g., from tool call failures), convert to string
|
||||
if (content_val.is_object()) {
|
||||
content_val = content_val.dump();
|
||||
}
|
||||
|
||||
// If content is a string and this is the last user message with images/audio, combine them
|
||||
if (content_val.is_string() && is_last_user_msg && has_images_or_audio) {
|
||||
json content_array = json::array();
|
||||
@@ -1261,9 +1391,7 @@ public:
|
||||
tasks.push_back(std::move(task));
|
||||
}
|
||||
|
||||
task_ids = server_task::get_list_id(tasks);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(std::move(tasks));
|
||||
rd->post_tasks(std::move(tasks));
|
||||
} catch (const std::exception & e) {
|
||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, e.what());
|
||||
}
|
||||
@@ -1271,51 +1399,71 @@ public:
|
||||
|
||||
std::cout << "[DEBUG] Waiting for results..." << std::endl;
|
||||
|
||||
// Check cancellation before waiting for results
|
||||
if (context->IsCancelled()) {
|
||||
ctx_server.cancel_tasks(task_ids);
|
||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
||||
// Wait for all results
|
||||
auto all_results = rd->wait_for_all([&context]() { return context->IsCancelled(); });
|
||||
|
||||
if (all_results.is_terminated) {
|
||||
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
|
||||
}
|
||||
|
||||
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
|
||||
std::cout << "[DEBUG] Received " << results.size() << " results" << std::endl;
|
||||
if (results.size() == 1) {
|
||||
} else if (all_results.error) {
|
||||
std::cout << "[DEBUG] Error in results: " << all_results.error->to_json().value("message", "") << std::endl;
|
||||
reply->set_message(all_results.error->to_json().value("message", ""));
|
||||
return grpc::Status(grpc::StatusCode::INTERNAL, all_results.error->to_json().value("message", "Error occurred"));
|
||||
} else {
|
||||
std::cout << "[DEBUG] Received " << all_results.results.size() << " results" << std::endl;
|
||||
if (all_results.results.size() == 1) {
|
||||
// single result
|
||||
reply->set_message(results[0]->to_json().value("content", ""));
|
||||
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(all_results.results[0].get()) != nullptr);
|
||||
json result_json = all_results.results[0]->to_json();
|
||||
reply->set_message(result_json.value("content", ""));
|
||||
|
||||
int32_t tokens_predicted = results[0]->to_json().value("tokens_predicted", 0);
|
||||
int32_t tokens_predicted = result_json.value("tokens_predicted", 0);
|
||||
reply->set_tokens(tokens_predicted);
|
||||
int32_t tokens_evaluated = results[0]->to_json().value("tokens_evaluated", 0);
|
||||
int32_t tokens_evaluated = result_json.value("tokens_evaluated", 0);
|
||||
reply->set_prompt_tokens(tokens_evaluated);
|
||||
|
||||
if (results[0]->to_json().contains("timings")) {
|
||||
double timing_prompt_processing = results[0]->to_json().at("timings").value("prompt_ms", 0.0);
|
||||
if (result_json.contains("timings")) {
|
||||
double timing_prompt_processing = result_json.at("timings").value("prompt_ms", 0.0);
|
||||
reply->set_timing_prompt_processing(timing_prompt_processing);
|
||||
double timing_token_generation = results[0]->to_json().at("timings").value("predicted_ms", 0.0);
|
||||
double timing_token_generation = result_json.at("timings").value("predicted_ms", 0.0);
|
||||
reply->set_timing_token_generation(timing_token_generation);
|
||||
}
|
||||
|
||||
// Extract and set logprobs if present
|
||||
json logprobs_json = extract_logprobs_from_json(result_json);
|
||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||
std::string logprobs_str = logprobs_json.dump();
|
||||
reply->set_logprobs(logprobs_str);
|
||||
}
|
||||
|
||||
} else {
|
||||
// multiple results (multitask)
|
||||
json arr = json::array();
|
||||
for (auto & res : results) {
|
||||
arr.push_back(res->to_json().value("content", ""));
|
||||
json logprobs_arr = json::array();
|
||||
bool has_logprobs = false;
|
||||
for (auto & res : all_results.results) {
|
||||
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
|
||||
json res_json = res->to_json();
|
||||
arr.push_back(res_json.value("content", ""));
|
||||
|
||||
// Extract logprobs for each result
|
||||
json logprobs_json = extract_logprobs_from_json(res_json);
|
||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||
has_logprobs = true;
|
||||
logprobs_arr.push_back(logprobs_json);
|
||||
} else {
|
||||
logprobs_arr.push_back(json::object());
|
||||
}
|
||||
}
|
||||
reply->set_message(arr);
|
||||
|
||||
// Set logprobs if any result has them
|
||||
if (has_logprobs) {
|
||||
std::string logprobs_str = logprobs_arr.dump();
|
||||
reply->set_logprobs(logprobs_str);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}, [&](const json & error_data) {
|
||||
std::cout << "[DEBUG] Error in results: " << error_data.value("content", "") << std::endl;
|
||||
reply->set_message(error_data.value("content", ""));
|
||||
}, [&context]() {
|
||||
// Check if the gRPC context is cancelled
|
||||
// This is checked every HTTP_POLLING_SECONDS (1 second) during receive_multi_results
|
||||
return context->IsCancelled();
|
||||
});
|
||||
|
||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
||||
}
|
||||
|
||||
std::cout << "[DEBUG] Predict request completed successfully" << std::endl;
|
||||
|
||||
// Check if context was cancelled during processing
|
||||
@@ -1352,9 +1500,7 @@ public:
|
||||
|
||||
int embd_normalize = 2; // default to Euclidean/L2 norm
|
||||
// create and queue the task
|
||||
json responses = json::array();
|
||||
bool error = false;
|
||||
std::unordered_set<int> task_ids;
|
||||
const auto rd = std::make_shared<server_response_reader>(ctx_server);
|
||||
{
|
||||
std::vector<server_task> tasks;
|
||||
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||
@@ -1369,40 +1515,23 @@ public:
|
||||
tasks.push_back(std::move(task));
|
||||
}
|
||||
|
||||
task_ids = server_task::get_list_id(tasks);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(std::move(tasks));
|
||||
rd->post_tasks(std::move(tasks));
|
||||
}
|
||||
|
||||
// Check cancellation before waiting for results
|
||||
if (context->IsCancelled()) {
|
||||
ctx_server.cancel_tasks(task_ids);
|
||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
||||
// Wait for all results
|
||||
auto all_results = rd->wait_for_all([&context]() { return context->IsCancelled(); });
|
||||
|
||||
if (all_results.is_terminated) {
|
||||
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
|
||||
} else if (all_results.error) {
|
||||
return grpc::Status(grpc::StatusCode::INTERNAL, all_results.error->to_json().value("message", "Error in receiving results"));
|
||||
}
|
||||
|
||||
// get the result
|
||||
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
|
||||
for (auto & res : results) {
|
||||
GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr);
|
||||
responses.push_back(res->to_json());
|
||||
}
|
||||
}, [&](const json & error_data) {
|
||||
error = true;
|
||||
}, [&context]() {
|
||||
// Check if the gRPC context is cancelled
|
||||
return context->IsCancelled();
|
||||
});
|
||||
|
||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
||||
|
||||
// Check if context was cancelled during processing
|
||||
if (context->IsCancelled()) {
|
||||
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results");
|
||||
// Collect responses
|
||||
json responses = json::array();
|
||||
for (auto & res : all_results.results) {
|
||||
GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr);
|
||||
responses.push_back(res->to_json());
|
||||
}
|
||||
|
||||
std::cout << "[DEBUG] Responses size: " << responses.size() << std::endl;
|
||||
@@ -1453,9 +1582,7 @@ public:
|
||||
}
|
||||
|
||||
// Create and queue the task
|
||||
json responses = json::array();
|
||||
bool error = false;
|
||||
std::unordered_set<int> task_ids;
|
||||
const auto rd = std::make_shared<server_response_reader>(ctx_server);
|
||||
{
|
||||
std::vector<server_task> tasks;
|
||||
std::vector<std::string> documents;
|
||||
@@ -1473,40 +1600,23 @@ public:
|
||||
tasks.push_back(std::move(task));
|
||||
}
|
||||
|
||||
task_ids = server_task::get_list_id(tasks);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(std::move(tasks));
|
||||
rd->post_tasks(std::move(tasks));
|
||||
}
|
||||
|
||||
// Check cancellation before waiting for results
|
||||
if (context->IsCancelled()) {
|
||||
ctx_server.cancel_tasks(task_ids);
|
||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
||||
// Wait for all results
|
||||
auto all_results = rd->wait_for_all([&context]() { return context->IsCancelled(); });
|
||||
|
||||
if (all_results.is_terminated) {
|
||||
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
|
||||
} else if (all_results.error) {
|
||||
return grpc::Status(grpc::StatusCode::INTERNAL, all_results.error->to_json().value("message", "Error in receiving results"));
|
||||
}
|
||||
|
||||
// Get the results
|
||||
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
|
||||
for (auto & res : results) {
|
||||
GGML_ASSERT(dynamic_cast<server_task_result_rerank*>(res.get()) != nullptr);
|
||||
responses.push_back(res->to_json());
|
||||
}
|
||||
}, [&](const json & error_data) {
|
||||
error = true;
|
||||
}, [&context]() {
|
||||
// Check if the gRPC context is cancelled
|
||||
return context->IsCancelled();
|
||||
});
|
||||
|
||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
||||
|
||||
// Check if context was cancelled during processing
|
||||
if (context->IsCancelled()) {
|
||||
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results");
|
||||
// Collect responses
|
||||
json responses = json::array();
|
||||
for (auto & res : all_results.results) {
|
||||
GGML_ASSERT(dynamic_cast<server_task_result_rerank*>(res.get()) != nullptr);
|
||||
responses.push_back(res->to_json());
|
||||
}
|
||||
// Sort responses by score in descending order
|
||||
std::sort(responses.begin(), responses.end(), [](const json& a, const json& b) {
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# whisper.cpp version
|
||||
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
||||
WHISPER_CPP_VERSION?=a1867e0dad0b21b35afa43fc815dae60c9a139d6
|
||||
WHISPER_CPP_VERSION?=d9b7613b34a343848af572cc14467fc5e82fc788
|
||||
SO_TARGET?=libgowhisper.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -62,12 +62,12 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if err := coreStartup.InstallModels(application.GalleryService(), options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
|
||||
if err := coreStartup.InstallModels(options.Context, application.GalleryService(), options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
|
||||
log.Error().Err(err).Msg("error installing models")
|
||||
}
|
||||
|
||||
for _, backend := range options.ExternalBackends {
|
||||
if err := coreStartup.InstallExternalBackends(options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil {
|
||||
if err := coreStartup.InstallExternalBackends(options.Context, options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil {
|
||||
log.Error().Err(err).Msg("error installing external backend")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
@@ -24,6 +25,7 @@ type LLMResponse struct {
|
||||
Response string // should this be []byte?
|
||||
Usage TokenUsage
|
||||
AudioOutput string
|
||||
Logprobs *schema.Logprobs // Logprobs from the backend response
|
||||
}
|
||||
|
||||
type TokenUsage struct {
|
||||
@@ -33,7 +35,7 @@ type TokenUsage struct {
|
||||
TimingTokenGeneration float64
|
||||
}
|
||||
|
||||
func ModelInference(ctx context.Context, s string, messages schema.Messages, images, videos, audios []string, loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool, tools string, toolChoice string) (func() (LLMResponse, error), error) {
|
||||
func ModelInference(ctx context.Context, s string, messages schema.Messages, images, videos, audios []string, loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool, tools string, toolChoice string, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (LLMResponse, error), error) {
|
||||
modelFile := c.Model
|
||||
|
||||
// Check if the modelFile exists, if it doesn't try to load it from the gallery
|
||||
@@ -45,7 +47,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
if !slices.Contains(modelNames, c.Name) {
|
||||
utils.ResetDownloadTimers()
|
||||
// if we failed to load the model, we try to download it
|
||||
err := gallery.InstallModelFromGallery(o.Galleries, o.BackendGalleries, o.SystemState, loader, c.Name, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
|
||||
err := gallery.InstallModelFromGallery(ctx, o.Galleries, o.BackendGalleries, o.SystemState, loader, c.Name, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("failed to install model %q from gallery", modelFile)
|
||||
//return nil, err
|
||||
@@ -78,6 +80,19 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
opts.Audios = audios
|
||||
opts.Tools = tools
|
||||
opts.ToolChoice = toolChoice
|
||||
if logprobs != nil {
|
||||
opts.Logprobs = int32(*logprobs)
|
||||
}
|
||||
if topLogprobs != nil {
|
||||
opts.TopLogprobs = int32(*topLogprobs)
|
||||
}
|
||||
if len(logitBias) > 0 {
|
||||
// Serialize logit_bias map to JSON string for proto
|
||||
logitBiasJSON, err := json.Marshal(logitBias)
|
||||
if err == nil {
|
||||
opts.LogitBias = string(logitBiasJSON)
|
||||
}
|
||||
}
|
||||
|
||||
tokenUsage := TokenUsage{}
|
||||
|
||||
@@ -109,6 +124,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
}
|
||||
|
||||
ss := ""
|
||||
var logprobs *schema.Logprobs
|
||||
|
||||
var partialRune []byte
|
||||
err := inferenceModel.PredictStream(ctx, opts, func(reply *proto.Reply) {
|
||||
@@ -120,6 +136,14 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
tokenUsage.TimingTokenGeneration = reply.TimingTokenGeneration
|
||||
tokenUsage.TimingPromptProcessing = reply.TimingPromptProcessing
|
||||
|
||||
// Parse logprobs from reply if present (collect from last chunk that has them)
|
||||
if len(reply.Logprobs) > 0 {
|
||||
var parsedLogprobs schema.Logprobs
|
||||
if err := json.Unmarshal(reply.Logprobs, &parsedLogprobs); err == nil {
|
||||
logprobs = &parsedLogprobs
|
||||
}
|
||||
}
|
||||
|
||||
// Process complete runes and accumulate them
|
||||
var completeRunes []byte
|
||||
for len(partialRune) > 0 {
|
||||
@@ -145,6 +169,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
return LLMResponse{
|
||||
Response: ss,
|
||||
Usage: tokenUsage,
|
||||
Logprobs: logprobs,
|
||||
}, err
|
||||
} else {
|
||||
// TODO: Is the chicken bit the only way to get here? is that acceptable?
|
||||
@@ -167,9 +192,19 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
response = c.TemplateConfig.ReplyPrefix + response
|
||||
}
|
||||
|
||||
// Parse logprobs from reply if present
|
||||
var logprobs *schema.Logprobs
|
||||
if len(reply.Logprobs) > 0 {
|
||||
var parsedLogprobs schema.Logprobs
|
||||
if err := json.Unmarshal(reply.Logprobs, &parsedLogprobs); err == nil {
|
||||
logprobs = &parsedLogprobs
|
||||
}
|
||||
}
|
||||
|
||||
return LLMResponse{
|
||||
Response: response,
|
||||
Usage: tokenUsage,
|
||||
Logprobs: logprobs,
|
||||
}, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -212,7 +212,7 @@ func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions
|
||||
}
|
||||
}
|
||||
|
||||
return &pb.PredictOptions{
|
||||
pbOpts := &pb.PredictOptions{
|
||||
Temperature: float32(*c.Temperature),
|
||||
TopP: float32(*c.TopP),
|
||||
NDraft: c.NDraft,
|
||||
@@ -249,4 +249,6 @@ func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions
|
||||
TailFreeSamplingZ: float32(*c.TFZ),
|
||||
TypicalP: float32(*c.TypicalP),
|
||||
}
|
||||
// Logprobs and TopLogprobs are set by the caller if provided
|
||||
return pbOpts
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
@@ -102,7 +103,7 @@ func (bi *BackendsInstall) Run(ctx *cliContext.Context) error {
|
||||
}
|
||||
|
||||
modelLoader := model.NewModelLoader(systemState, true)
|
||||
err = startup.InstallExternalBackends(galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
|
||||
err = startup.InstallExternalBackends(context.Background(), galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -48,10 +48,12 @@ func (e *ExplorerCMD) Run(ctx *cliContext.Context) error {
|
||||
appHTTP := http.Explorer(db)
|
||||
|
||||
signals.RegisterGracefulTerminationHandler(func() {
|
||||
if err := appHTTP.Shutdown(); err != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := appHTTP.Shutdown(ctx); err != nil {
|
||||
log.Error().Err(err).Msg("error during shutdown")
|
||||
}
|
||||
})
|
||||
|
||||
return appHTTP.Listen(e.Address)
|
||||
return appHTTP.Start(e.Address)
|
||||
}
|
||||
|
||||
@@ -135,7 +135,7 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
|
||||
}
|
||||
|
||||
modelLoader := model.NewModelLoader(systemState, true)
|
||||
err = startup.InstallModels(galleryService, galleries, backendGalleries, systemState, modelLoader, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, progressCallback, modelName)
|
||||
err = startup.InstallModels(context.Background(), galleryService, galleries, backendGalleries, systemState, modelLoader, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, progressCallback, modelName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -232,5 +232,5 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
}
|
||||
})
|
||||
|
||||
return appHTTP.Listen(r.Address)
|
||||
return appHTTP.Start(r.Address)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -42,7 +43,7 @@ func findLLamaCPPBackend(galleries string, systemState *system.SystemState) (str
|
||||
log.Error().Err(err).Msg("failed loading galleries")
|
||||
return "", err
|
||||
}
|
||||
err := gallery.InstallBackendFromGallery(gals, systemState, ml, llamaCPPGalleryName, nil, true)
|
||||
err := gallery.InstallBackendFromGallery(context.Background(), gals, systemState, ml, llamaCPPGalleryName, nil, true)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("llama-cpp backend not found, failed to install it")
|
||||
return "", err
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"github.com/mudler/cogito"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
@@ -668,3 +669,40 @@ func (c *ModelConfig) GuessUsecases(u ModelConfigUsecases) bool {
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// BuildCogitoOptions generates cogito options from the model configuration
|
||||
// It accepts a context, MCP sessions, and optional callback functions for status, reasoning, tool calls, and tool results
|
||||
func (c *ModelConfig) BuildCogitoOptions() []cogito.Option {
|
||||
cogitoOpts := []cogito.Option{
|
||||
cogito.WithIterations(3), // default to 3 iterations
|
||||
cogito.WithMaxAttempts(3), // default to 3 attempts
|
||||
cogito.WithForceReasoning(),
|
||||
}
|
||||
|
||||
// Apply agent configuration options
|
||||
if c.Agent.EnableReasoning {
|
||||
cogitoOpts = append(cogitoOpts, cogito.EnableToolReasoner)
|
||||
}
|
||||
|
||||
if c.Agent.EnablePlanning {
|
||||
cogitoOpts = append(cogitoOpts, cogito.EnableAutoPlan)
|
||||
}
|
||||
|
||||
if c.Agent.EnableMCPPrompts {
|
||||
cogitoOpts = append(cogitoOpts, cogito.EnableMCPPrompts)
|
||||
}
|
||||
|
||||
if c.Agent.EnablePlanReEvaluator {
|
||||
cogitoOpts = append(cogitoOpts, cogito.EnableAutoPlanReEvaluator)
|
||||
}
|
||||
|
||||
if c.Agent.MaxIterations != 0 {
|
||||
cogitoOpts = append(cogitoOpts, cogito.WithIterations(c.Agent.MaxIterations))
|
||||
}
|
||||
|
||||
if c.Agent.MaxAttempts != 0 {
|
||||
cogitoOpts = append(cogitoOpts, cogito.WithMaxAttempts(c.Agent.MaxAttempts))
|
||||
}
|
||||
|
||||
return cogitoOpts
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package gallery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -69,7 +70,7 @@ func writeBackendMetadata(backendPath string, metadata *BackendMetadata) error {
|
||||
}
|
||||
|
||||
// InstallBackendFromGallery installs a backend from the gallery.
|
||||
func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, name string, downloadStatus func(string, string, string, float64), force bool) error {
|
||||
func InstallBackendFromGallery(ctx context.Context, galleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, name string, downloadStatus func(string, string, string, float64), force bool) error {
|
||||
if !force {
|
||||
// check if we already have the backend installed
|
||||
backends, err := ListSystemBackends(systemState)
|
||||
@@ -109,7 +110,7 @@ func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.S
|
||||
log.Debug().Str("name", name).Str("bestBackend", bestBackend.Name).Msg("Installing backend from meta backend")
|
||||
|
||||
// Then, let's install the best backend
|
||||
if err := InstallBackend(systemState, modelLoader, bestBackend, downloadStatus); err != nil {
|
||||
if err := InstallBackend(ctx, systemState, modelLoader, bestBackend, downloadStatus); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -134,10 +135,10 @@ func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.S
|
||||
return nil
|
||||
}
|
||||
|
||||
return InstallBackend(systemState, modelLoader, backend, downloadStatus)
|
||||
return InstallBackend(ctx, systemState, modelLoader, backend, downloadStatus)
|
||||
}
|
||||
|
||||
func InstallBackend(systemState *system.SystemState, modelLoader *model.ModelLoader, config *GalleryBackend, downloadStatus func(string, string, string, float64)) error {
|
||||
func InstallBackend(ctx context.Context, systemState *system.SystemState, modelLoader *model.ModelLoader, config *GalleryBackend, downloadStatus func(string, string, string, float64)) error {
|
||||
// Create base path if it doesn't exist
|
||||
err := os.MkdirAll(systemState.Backend.BackendsPath, 0750)
|
||||
if err != nil {
|
||||
@@ -164,11 +165,17 @@ func InstallBackend(systemState *system.SystemState, modelLoader *model.ModelLoa
|
||||
}
|
||||
} else {
|
||||
uri := downloader.URI(config.URI)
|
||||
if err := uri.DownloadFile(backendPath, "", 1, 1, downloadStatus); err != nil {
|
||||
if err := uri.DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err != nil {
|
||||
success := false
|
||||
// Try to download from mirrors
|
||||
for _, mirror := range config.Mirrors {
|
||||
if err := downloader.URI(mirror).DownloadFile(backendPath, "", 1, 1, downloadStatus); err == nil {
|
||||
// Check for cancellation before trying next mirror
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
if err := downloader.URI(mirror).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil {
|
||||
success = true
|
||||
break
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package gallery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -55,7 +56,7 @@ var _ = Describe("Runtime capability-based backend selection", func() {
|
||||
)
|
||||
must(err)
|
||||
sysDefault.GPUVendor = "" // force default selection
|
||||
backs, err := ListSystemBackends(sysDefault)
|
||||
backs, err := ListSystemBackends(sysDefault)
|
||||
must(err)
|
||||
aliasBack, ok := backs.Get("llama-cpp")
|
||||
Expect(ok).To(BeTrue())
|
||||
@@ -77,7 +78,7 @@ var _ = Describe("Runtime capability-based backend selection", func() {
|
||||
must(err)
|
||||
sysNvidia.GPUVendor = "nvidia"
|
||||
sysNvidia.VRAM = 8 * 1024 * 1024 * 1024
|
||||
backs, err = ListSystemBackends(sysNvidia)
|
||||
backs, err = ListSystemBackends(sysNvidia)
|
||||
must(err)
|
||||
aliasBack, ok = backs.Get("llama-cpp")
|
||||
Expect(ok).To(BeTrue())
|
||||
@@ -116,13 +117,13 @@ var _ = Describe("Gallery Backends", func() {
|
||||
|
||||
Describe("InstallBackendFromGallery", func() {
|
||||
It("should return error when backend is not found", func() {
|
||||
err := InstallBackendFromGallery(galleries, systemState, ml, "non-existent", nil, true)
|
||||
err := InstallBackendFromGallery(context.TODO(), galleries, systemState, ml, "non-existent", nil, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("no backend found with name \"non-existent\""))
|
||||
})
|
||||
|
||||
It("should install backend from gallery", func() {
|
||||
err := InstallBackendFromGallery(galleries, systemState, ml, "test-backend", nil, true)
|
||||
err := InstallBackendFromGallery(context.TODO(), galleries, systemState, ml, "test-backend", nil, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(filepath.Join(tempDir, "test-backend", "run.sh")).To(BeARegularFile())
|
||||
})
|
||||
@@ -298,7 +299,7 @@ var _ = Describe("Gallery Backends", func() {
|
||||
VRAM: 1000000000000,
|
||||
Backend: system.Backend{BackendsPath: tempDir},
|
||||
}
|
||||
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
|
||||
err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
metaBackendPath := filepath.Join(tempDir, "meta-backend")
|
||||
@@ -378,7 +379,7 @@ var _ = Describe("Gallery Backends", func() {
|
||||
VRAM: 1000000000000,
|
||||
Backend: system.Backend{BackendsPath: tempDir},
|
||||
}
|
||||
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
|
||||
err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
metaBackendPath := filepath.Join(tempDir, "meta-backend")
|
||||
@@ -462,7 +463,7 @@ var _ = Describe("Gallery Backends", func() {
|
||||
VRAM: 1000000000000,
|
||||
Backend: system.Backend{BackendsPath: tempDir},
|
||||
}
|
||||
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
|
||||
err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
metaBackendPath := filepath.Join(tempDir, "meta-backend")
|
||||
@@ -561,7 +562,7 @@ var _ = Describe("Gallery Backends", func() {
|
||||
system.WithBackendPath(newPath),
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
err = InstallBackend(systemState, ml, &backend, nil)
|
||||
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
|
||||
Expect(err).To(HaveOccurred()) // Will fail due to invalid URI, but path should be created
|
||||
Expect(newPath).To(BeADirectory())
|
||||
})
|
||||
@@ -593,7 +594,7 @@ var _ = Describe("Gallery Backends", func() {
|
||||
system.WithBackendPath(tempDir),
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
err = InstallBackend(systemState, ml, &backend, nil)
|
||||
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
|
||||
dat, err := os.ReadFile(filepath.Join(tempDir, "test-backend", "metadata.json"))
|
||||
@@ -626,7 +627,7 @@ var _ = Describe("Gallery Backends", func() {
|
||||
|
||||
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).ToNot(BeARegularFile())
|
||||
|
||||
err = InstallBackend(systemState, ml, &backend, nil)
|
||||
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
|
||||
})
|
||||
@@ -647,7 +648,7 @@ var _ = Describe("Gallery Backends", func() {
|
||||
system.WithBackendPath(tempDir),
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
err = InstallBackend(systemState, ml, &backend, nil)
|
||||
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package gallery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -28,6 +29,19 @@ func GetGalleryConfigFromURL[T any](url string, basePath string) (T, error) {
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func GetGalleryConfigFromURLWithContext[T any](ctx context.Context, url string, basePath string) (T, error) {
|
||||
var config T
|
||||
uri := downloader.URI(url)
|
||||
err := uri.DownloadWithAuthorizationAndCallback(ctx, basePath, "", func(url string, d []byte) error {
|
||||
return yaml.Unmarshal(d, &config)
|
||||
})
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("url", url).Msg("failed to get gallery config for url")
|
||||
return config, err
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func ReadConfigFile[T any](filePath string) (*T, error) {
|
||||
// Read the YAML file
|
||||
yamlFile, err := os.ReadFile(filePath)
|
||||
|
||||
@@ -10,9 +10,11 @@ import (
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
)
|
||||
|
||||
var DefaultImporters = []Importer{
|
||||
var defaultImporters = []Importer{
|
||||
&LlamaCPPImporter{},
|
||||
&MLXImporter{},
|
||||
&VLLMImporter{},
|
||||
&TransformersImporter{},
|
||||
}
|
||||
|
||||
type Details struct {
|
||||
@@ -52,7 +54,7 @@ func DiscoverModelConfig(uri string, preferences json.RawMessage) (gallery.Model
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
for _, importer := range DefaultImporters {
|
||||
for _, importer := range defaultImporters {
|
||||
if importer.Match(details) {
|
||||
modelConfig, err = importer.Import(details)
|
||||
if err != nil {
|
||||
|
||||
@@ -80,10 +80,13 @@ func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error)
|
||||
mmprojQuantsList = strings.Split(mmprojQuants, ",")
|
||||
}
|
||||
|
||||
embeddings, _ := preferencesMap["embeddings"].(string)
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
KnownUsecaseStrings: []string{"chat"},
|
||||
Options: []string{"use_jinja:true"},
|
||||
Backend: "llama-cpp",
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
UseTokenizerTemplate: true,
|
||||
@@ -95,6 +98,11 @@ func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error)
|
||||
},
|
||||
}
|
||||
|
||||
if embeddings != "" && strings.ToLower(embeddings) == "true" || strings.ToLower(embeddings) == "yes" {
|
||||
trueV := true
|
||||
modelConfig.Embeddings = &trueV
|
||||
}
|
||||
|
||||
cfg := gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
|
||||
@@ -5,20 +5,21 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("LlamaCPPImporter", func() {
|
||||
var importer *importers.LlamaCPPImporter
|
||||
var importer *LlamaCPPImporter
|
||||
|
||||
BeforeEach(func() {
|
||||
importer = &importers.LlamaCPPImporter{}
|
||||
importer = &LlamaCPPImporter{}
|
||||
})
|
||||
|
||||
Context("Match", func() {
|
||||
It("should match when URI ends with .gguf", func() {
|
||||
details := importers.Details{
|
||||
details := Details{
|
||||
URI: "https://example.com/model.gguf",
|
||||
}
|
||||
|
||||
@@ -28,7 +29,7 @@ var _ = Describe("LlamaCPPImporter", func() {
|
||||
|
||||
It("should match when backend preference is llama-cpp", func() {
|
||||
preferences := json.RawMessage(`{"backend": "llama-cpp"}`)
|
||||
details := importers.Details{
|
||||
details := Details{
|
||||
URI: "https://example.com/model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
@@ -38,7 +39,7 @@ var _ = Describe("LlamaCPPImporter", func() {
|
||||
})
|
||||
|
||||
It("should not match when URI does not end with .gguf and no backend preference", func() {
|
||||
details := importers.Details{
|
||||
details := Details{
|
||||
URI: "https://example.com/model.bin",
|
||||
}
|
||||
|
||||
@@ -48,7 +49,7 @@ var _ = Describe("LlamaCPPImporter", func() {
|
||||
|
||||
It("should not match when backend preference is different", func() {
|
||||
preferences := json.RawMessage(`{"backend": "mlx"}`)
|
||||
details := importers.Details{
|
||||
details := Details{
|
||||
URI: "https://example.com/model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
@@ -59,7 +60,7 @@ var _ = Describe("LlamaCPPImporter", func() {
|
||||
|
||||
It("should return false when JSON preferences are invalid", func() {
|
||||
preferences := json.RawMessage(`invalid json`)
|
||||
details := importers.Details{
|
||||
details := Details{
|
||||
URI: "https://example.com/model.gguf",
|
||||
Preferences: preferences,
|
||||
}
|
||||
@@ -72,7 +73,7 @@ var _ = Describe("LlamaCPPImporter", func() {
|
||||
|
||||
Context("Import", func() {
|
||||
It("should import model config with default name and description", func() {
|
||||
details := importers.Details{
|
||||
details := Details{
|
||||
URI: "https://example.com/my-model.gguf",
|
||||
}
|
||||
|
||||
@@ -89,7 +90,7 @@ var _ = Describe("LlamaCPPImporter", func() {
|
||||
|
||||
It("should import model config with custom name and description from preferences", func() {
|
||||
preferences := json.RawMessage(`{"name": "custom-model", "description": "Custom description"}`)
|
||||
details := importers.Details{
|
||||
details := Details{
|
||||
URI: "https://example.com/my-model.gguf",
|
||||
Preferences: preferences,
|
||||
}
|
||||
@@ -106,7 +107,7 @@ var _ = Describe("LlamaCPPImporter", func() {
|
||||
|
||||
It("should handle invalid JSON preferences", func() {
|
||||
preferences := json.RawMessage(`invalid json`)
|
||||
details := importers.Details{
|
||||
details := Details{
|
||||
URI: "https://example.com/my-model.gguf",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
110
core/gallery/importers/transformers.go
Normal file
110
core/gallery/importers/transformers.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &TransformersImporter{}
|
||||
|
||||
type TransformersImporter struct{}
|
||||
|
||||
func (i *TransformersImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
err = json.Unmarshal(preferences, &preferencesMap)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
b, ok := preferencesMap["backend"].(string)
|
||||
if ok && b == "transformers" {
|
||||
return true
|
||||
}
|
||||
|
||||
if details.HuggingFace != nil {
|
||||
for _, file := range details.HuggingFace.Files {
|
||||
if strings.Contains(file.Path, "tokenizer.json") ||
|
||||
strings.Contains(file.Path, "tokenizer_config.json") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *TransformersImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
err = json.Unmarshal(preferences, &preferencesMap)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
backend := "transformers"
|
||||
b, ok := preferencesMap["backend"].(string)
|
||||
if ok {
|
||||
backend = b
|
||||
}
|
||||
|
||||
modelType, ok := preferencesMap["type"].(string)
|
||||
if !ok {
|
||||
modelType = "AutoModelForCausalLM"
|
||||
}
|
||||
|
||||
quantization, ok := preferencesMap["quantization"].(string)
|
||||
if !ok {
|
||||
quantization = ""
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
KnownUsecaseStrings: []string{"chat"},
|
||||
Backend: backend,
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{
|
||||
Model: details.URI,
|
||||
},
|
||||
},
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
UseTokenizerTemplate: true,
|
||||
},
|
||||
}
|
||||
modelConfig.ModelType = modelType
|
||||
modelConfig.Quantization = quantization
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
return gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
ConfigFile: string(data),
|
||||
}, nil
|
||||
}
|
||||
219
core/gallery/importers/transformers_test.go
Normal file
219
core/gallery/importers/transformers_test.go
Normal file
@@ -0,0 +1,219 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/mudler/LocalAI/core/gallery/importers"
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("TransformersImporter", func() {
|
||||
var importer *TransformersImporter
|
||||
|
||||
BeforeEach(func() {
|
||||
importer = &TransformersImporter{}
|
||||
})
|
||||
|
||||
Context("Match", func() {
|
||||
It("should match when backend preference is transformers", func() {
|
||||
preferences := json.RawMessage(`{"backend": "transformers"}`)
|
||||
details := Details{
|
||||
URI: "https://example.com/model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should match when HuggingFace details contain tokenizer.json", func() {
|
||||
hfDetails := &hfapi.ModelDetails{
|
||||
Files: []hfapi.ModelFile{
|
||||
{Path: "tokenizer.json"},
|
||||
},
|
||||
}
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/model",
|
||||
HuggingFace: hfDetails,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should match when HuggingFace details contain tokenizer_config.json", func() {
|
||||
hfDetails := &hfapi.ModelDetails{
|
||||
Files: []hfapi.ModelFile{
|
||||
{Path: "tokenizer_config.json"},
|
||||
},
|
||||
}
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/model",
|
||||
HuggingFace: hfDetails,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should not match when URI has no tokenizer files and no backend preference", func() {
|
||||
details := Details{
|
||||
URI: "https://example.com/model.bin",
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should not match when backend preference is different", func() {
|
||||
preferences := json.RawMessage(`{"backend": "llama-cpp"}`)
|
||||
details := Details{
|
||||
URI: "https://example.com/model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should return false when JSON preferences are invalid", func() {
|
||||
preferences := json.RawMessage(`invalid json`)
|
||||
details := Details{
|
||||
URI: "https://example.com/model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Context("Import", func() {
|
||||
It("should import model config with default name and description", func() {
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("my-model"))
|
||||
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/test/my-model"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: transformers"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: https://huggingface.co/test/my-model"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("type: AutoModelForCausalLM"))
|
||||
})
|
||||
|
||||
It("should import model config with custom name and description from preferences", func() {
|
||||
preferences := json.RawMessage(`{"name": "custom-model", "description": "Custom description"}`)
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("custom-model"))
|
||||
Expect(modelConfig.Description).To(Equal("Custom description"))
|
||||
})
|
||||
|
||||
It("should use custom model type from preferences", func() {
|
||||
preferences := json.RawMessage(`{"type": "SentenceTransformer"}`)
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("type: SentenceTransformer"))
|
||||
})
|
||||
|
||||
It("should use default model type when not specified", func() {
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("type: AutoModelForCausalLM"))
|
||||
})
|
||||
|
||||
It("should use custom backend from preferences", func() {
|
||||
preferences := json.RawMessage(`{"backend": "transformers"}`)
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: transformers"))
|
||||
})
|
||||
|
||||
It("should use quantization from preferences", func() {
|
||||
preferences := json.RawMessage(`{"quantization": "int8"}`)
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("quantization: int8"))
|
||||
})
|
||||
|
||||
It("should handle invalid JSON preferences", func() {
|
||||
preferences := json.RawMessage(`invalid json`)
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
_, err := importer.Import(details)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("should extract filename correctly from URI with path", func() {
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/test/path/to/model",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("model"))
|
||||
})
|
||||
|
||||
It("should include use_tokenizer_template in config", func() {
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("use_tokenizer_template: true"))
|
||||
})
|
||||
|
||||
It("should include known_usecases in config", func() {
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("known_usecases:"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("- chat"))
|
||||
})
|
||||
})
|
||||
})
|
||||
98
core/gallery/importers/vllm.go
Normal file
98
core/gallery/importers/vllm.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &VLLMImporter{}
|
||||
|
||||
type VLLMImporter struct{}
|
||||
|
||||
func (i *VLLMImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
err = json.Unmarshal(preferences, &preferencesMap)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
b, ok := preferencesMap["backend"].(string)
|
||||
if ok && b == "vllm" {
|
||||
return true
|
||||
}
|
||||
|
||||
if details.HuggingFace != nil {
|
||||
for _, file := range details.HuggingFace.Files {
|
||||
if strings.Contains(file.Path, "tokenizer.json") ||
|
||||
strings.Contains(file.Path, "tokenizer_config.json") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *VLLMImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
err = json.Unmarshal(preferences, &preferencesMap)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
backend := "vllm"
|
||||
b, ok := preferencesMap["backend"].(string)
|
||||
if ok {
|
||||
backend = b
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
KnownUsecaseStrings: []string{"chat"},
|
||||
Backend: backend,
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{
|
||||
Model: details.URI,
|
||||
},
|
||||
},
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
UseTokenizerTemplate: true,
|
||||
},
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
return gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
ConfigFile: string(data),
|
||||
}, nil
|
||||
}
|
||||
181
core/gallery/importers/vllm_test.go
Normal file
181
core/gallery/importers/vllm_test.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/mudler/LocalAI/core/gallery/importers"
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("VLLMImporter", func() {
|
||||
var importer *VLLMImporter
|
||||
|
||||
BeforeEach(func() {
|
||||
importer = &VLLMImporter{}
|
||||
})
|
||||
|
||||
Context("Match", func() {
|
||||
It("should match when backend preference is vllm", func() {
|
||||
preferences := json.RawMessage(`{"backend": "vllm"}`)
|
||||
details := Details{
|
||||
URI: "https://example.com/model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should match when HuggingFace details contain tokenizer.json", func() {
|
||||
hfDetails := &hfapi.ModelDetails{
|
||||
Files: []hfapi.ModelFile{
|
||||
{Path: "tokenizer.json"},
|
||||
},
|
||||
}
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/model",
|
||||
HuggingFace: hfDetails,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should match when HuggingFace details contain tokenizer_config.json", func() {
|
||||
hfDetails := &hfapi.ModelDetails{
|
||||
Files: []hfapi.ModelFile{
|
||||
{Path: "tokenizer_config.json"},
|
||||
},
|
||||
}
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/model",
|
||||
HuggingFace: hfDetails,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should not match when URI has no tokenizer files and no backend preference", func() {
|
||||
details := Details{
|
||||
URI: "https://example.com/model.bin",
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should not match when backend preference is different", func() {
|
||||
preferences := json.RawMessage(`{"backend": "llama-cpp"}`)
|
||||
details := Details{
|
||||
URI: "https://example.com/model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should return false when JSON preferences are invalid", func() {
|
||||
preferences := json.RawMessage(`invalid json`)
|
||||
details := Details{
|
||||
URI: "https://example.com/model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Context("Import", func() {
|
||||
It("should import model config with default name and description", func() {
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("my-model"))
|
||||
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/test/my-model"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: vllm"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: https://huggingface.co/test/my-model"))
|
||||
})
|
||||
|
||||
It("should import model config with custom name and description from preferences", func() {
|
||||
preferences := json.RawMessage(`{"name": "custom-model", "description": "Custom description"}`)
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("custom-model"))
|
||||
Expect(modelConfig.Description).To(Equal("Custom description"))
|
||||
})
|
||||
|
||||
It("should use custom backend from preferences", func() {
|
||||
preferences := json.RawMessage(`{"backend": "vllm"}`)
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: vllm"))
|
||||
})
|
||||
|
||||
It("should handle invalid JSON preferences", func() {
|
||||
preferences := json.RawMessage(`invalid json`)
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
_, err := importer.Import(details)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("should extract filename correctly from URI with path", func() {
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/test/path/to/model",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("model"))
|
||||
})
|
||||
|
||||
It("should include use_tokenizer_template in config", func() {
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("use_tokenizer_template: true"))
|
||||
})
|
||||
|
||||
It("should include known_usecases in config", func() {
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("known_usecases:"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("- chat"))
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,6 +1,7 @@
|
||||
package gallery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
@@ -72,6 +73,7 @@ type PromptTemplate struct {
|
||||
|
||||
// Installs a model from the gallery
|
||||
func InstallModelFromGallery(
|
||||
ctx context.Context,
|
||||
modelGalleries, backendGalleries []config.Gallery,
|
||||
systemState *system.SystemState,
|
||||
modelLoader *model.ModelLoader,
|
||||
@@ -84,7 +86,7 @@ func InstallModelFromGallery(
|
||||
|
||||
if len(model.URL) > 0 {
|
||||
var err error
|
||||
config, err = GetGalleryConfigFromURL[ModelConfig](model.URL, systemState.Model.ModelsPath)
|
||||
config, err = GetGalleryConfigFromURLWithContext[ModelConfig](ctx, model.URL, systemState.Model.ModelsPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -125,7 +127,7 @@ func InstallModelFromGallery(
|
||||
return err
|
||||
}
|
||||
|
||||
installedModel, err := InstallModel(systemState, installName, &config, model.Overrides, downloadStatus, enforceScan)
|
||||
installedModel, err := InstallModel(ctx, systemState, installName, &config, model.Overrides, downloadStatus, enforceScan)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -133,7 +135,7 @@ func InstallModelFromGallery(
|
||||
if automaticallyInstallBackend && installedModel.Backend != "" {
|
||||
log.Debug().Msgf("Installing backend %q", installedModel.Backend)
|
||||
|
||||
if err := InstallBackendFromGallery(backendGalleries, systemState, modelLoader, installedModel.Backend, downloadStatus, false); err != nil {
|
||||
if err := InstallBackendFromGallery(ctx, backendGalleries, systemState, modelLoader, installedModel.Backend, downloadStatus, false); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -154,7 +156,7 @@ func InstallModelFromGallery(
|
||||
return applyModel(model)
|
||||
}
|
||||
|
||||
func InstallModel(systemState *system.SystemState, nameOverride string, config *ModelConfig, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.ModelConfig, error) {
|
||||
func InstallModel(ctx context.Context, systemState *system.SystemState, nameOverride string, config *ModelConfig, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.ModelConfig, error) {
|
||||
basePath := systemState.Model.ModelsPath
|
||||
// Create base path if it doesn't exist
|
||||
err := os.MkdirAll(basePath, 0750)
|
||||
@@ -168,6 +170,13 @@ func InstallModel(systemState *system.SystemState, nameOverride string, config *
|
||||
|
||||
// Download files and verify their SHA
|
||||
for i, file := range config.Files {
|
||||
// Check for cancellation before each file
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename)
|
||||
|
||||
if err := utils.VerifyPath(file.Filename, basePath); err != nil {
|
||||
@@ -185,7 +194,7 @@ func InstallModel(systemState *system.SystemState, nameOverride string, config *
|
||||
}
|
||||
}
|
||||
uri := downloader.URI(file.URI)
|
||||
if err := uri.DownloadFile(filePath, file.SHA256, i, len(config.Files), downloadStatus); err != nil {
|
||||
if err := uri.DownloadFileWithContext(ctx, filePath, file.SHA256, i, len(config.Files), downloadStatus); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package gallery_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -34,7 +35,7 @@ var _ = Describe("Model test", func() {
|
||||
system.WithModelPath(tempdir),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = InstallModel(systemState, "", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
|
||||
_, err = InstallModel(context.TODO(), systemState, "", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} {
|
||||
@@ -88,7 +89,7 @@ var _ = Describe("Model test", func() {
|
||||
Expect(models[0].URL).To(Equal(bertEmbeddingsURL))
|
||||
Expect(models[0].Installed).To(BeFalse())
|
||||
|
||||
err = InstallModelFromGallery(galleries, []config.Gallery{}, systemState, nil, "test@bert", GalleryModel{}, func(s1, s2, s3 string, f float64) {}, true, true)
|
||||
err = InstallModelFromGallery(context.TODO(), galleries, []config.Gallery{}, systemState, nil, "test@bert", GalleryModel{}, func(s1, s2, s3 string, f float64) {}, true, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
dat, err := os.ReadFile(filepath.Join(tempdir, "bert.yaml"))
|
||||
@@ -129,7 +130,7 @@ var _ = Describe("Model test", func() {
|
||||
system.WithModelPath(tempdir),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = InstallModel(systemState, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
|
||||
_, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
|
||||
@@ -149,7 +150,7 @@ var _ = Describe("Model test", func() {
|
||||
system.WithModelPath(tempdir),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = InstallModel(systemState, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}, true)
|
||||
_, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
|
||||
@@ -179,7 +180,7 @@ var _ = Describe("Model test", func() {
|
||||
system.WithModelPath(tempdir),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = InstallModel(systemState, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
|
||||
_, err = InstallModel(context.TODO(), systemState, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
223
core/http/app.go
223
core/http/app.go
@@ -4,30 +4,23 @@ import (
|
||||
"embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/dave-gray101/v2keyauth"
|
||||
"github.com/gofiber/websocket/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/labstack/echo/v4/middleware"
|
||||
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
httpMiddleware "github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/http/routes"
|
||||
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
|
||||
"github.com/gofiber/contrib/fiberzerolog"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/cors"
|
||||
"github.com/gofiber/fiber/v2/middleware/csrf"
|
||||
"github.com/gofiber/fiber/v2/middleware/favicon"
|
||||
"github.com/gofiber/fiber/v2/middleware/filesystem"
|
||||
"github.com/gofiber/fiber/v2/middleware/recover"
|
||||
|
||||
// swagger handler
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
@@ -49,85 +42,85 @@ var embedDirStatic embed.FS
|
||||
// @in header
|
||||
// @name Authorization
|
||||
|
||||
func API(application *application.Application) (*fiber.App, error) {
|
||||
func API(application *application.Application) (*echo.Echo, error) {
|
||||
e := echo.New()
|
||||
|
||||
fiberCfg := fiber.Config{
|
||||
Views: renderEngine(),
|
||||
BodyLimit: application.ApplicationConfig().UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
|
||||
// We disable the Fiber startup message as it does not conform to structured logging.
|
||||
// We register a startup log line with connection information in the OnListen hook to keep things user friendly though
|
||||
DisableStartupMessage: true,
|
||||
// Override default error handler
|
||||
// Set body limit
|
||||
if application.ApplicationConfig().UploadLimitMB > 0 {
|
||||
e.Use(middleware.BodyLimit(fmt.Sprintf("%dM", application.ApplicationConfig().UploadLimitMB)))
|
||||
}
|
||||
|
||||
// Set error handler
|
||||
if !application.ApplicationConfig().OpaqueErrors {
|
||||
// Normally, return errors as JSON responses
|
||||
fiberCfg.ErrorHandler = func(ctx *fiber.Ctx, err error) error {
|
||||
// Status code defaults to 500
|
||||
code := fiber.StatusInternalServerError
|
||||
e.HTTPErrorHandler = func(err error, c echo.Context) {
|
||||
code := http.StatusInternalServerError
|
||||
var he *echo.HTTPError
|
||||
if errors.As(err, &he) {
|
||||
code = he.Code
|
||||
}
|
||||
|
||||
// Retrieve the custom status code if it's a *fiber.Error
|
||||
var e *fiber.Error
|
||||
if errors.As(err, &e) {
|
||||
code = e.Code
|
||||
// Handle 404 errors with HTML rendering when appropriate
|
||||
if code == http.StatusNotFound {
|
||||
notFoundHandler(c)
|
||||
return
|
||||
}
|
||||
|
||||
// Send custom error page
|
||||
return ctx.Status(code).JSON(
|
||||
schema.ErrorResponse{
|
||||
Error: &schema.APIError{Message: err.Error(), Code: code},
|
||||
},
|
||||
)
|
||||
c.JSON(code, schema.ErrorResponse{
|
||||
Error: &schema.APIError{Message: err.Error(), Code: code},
|
||||
})
|
||||
}
|
||||
} else {
|
||||
// If OpaqueErrors are required, replace everything with a blank 500.
|
||||
fiberCfg.ErrorHandler = func(ctx *fiber.Ctx, _ error) error {
|
||||
return ctx.Status(500).SendString("")
|
||||
e.HTTPErrorHandler = func(err error, c echo.Context) {
|
||||
code := http.StatusInternalServerError
|
||||
var he *echo.HTTPError
|
||||
if errors.As(err, &he) {
|
||||
code = he.Code
|
||||
}
|
||||
c.NoContent(code)
|
||||
}
|
||||
}
|
||||
|
||||
router := fiber.New(fiberCfg)
|
||||
// Set renderer
|
||||
e.Renderer = renderEngine()
|
||||
|
||||
router.Use(middleware.StripPathPrefix())
|
||||
// Hide banner
|
||||
e.HideBanner = true
|
||||
|
||||
// Middleware - StripPathPrefix must be registered early as it uses Rewrite which runs before routing
|
||||
e.Pre(httpMiddleware.StripPathPrefix())
|
||||
|
||||
if application.ApplicationConfig().MachineTag != "" {
|
||||
router.Use(func(c *fiber.Ctx) error {
|
||||
c.Response().Header.Set("Machine-Tag", application.ApplicationConfig().MachineTag)
|
||||
|
||||
return c.Next()
|
||||
e.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
c.Response().Header().Set("Machine-Tag", application.ApplicationConfig().MachineTag)
|
||||
return next(c)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
router.Use("/v1/realtime", func(c *fiber.Ctx) error {
|
||||
if websocket.IsWebSocketUpgrade(c) {
|
||||
// Returns true if the client requested upgrade to the WebSocket protocol
|
||||
return c.Next()
|
||||
// Custom logger middleware using zerolog
|
||||
e.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
req := c.Request()
|
||||
res := c.Response()
|
||||
start := log.Logger.Info()
|
||||
err := next(c)
|
||||
start.
|
||||
Str("method", req.Method).
|
||||
Str("path", req.URL.Path).
|
||||
Int("status", res.Status).
|
||||
Msg("HTTP request")
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
router.Hooks().OnListen(func(listenData fiber.ListenData) error {
|
||||
scheme := "http"
|
||||
if listenData.TLS {
|
||||
scheme = "https"
|
||||
}
|
||||
log.Info().Str("endpoint", scheme+"://"+listenData.Host+":"+listenData.Port).Msg("LocalAI API is listening! Please connect to the endpoint for API documentation.")
|
||||
return nil
|
||||
})
|
||||
|
||||
// Have Fiber use zerolog like the rest of the application rather than it's built-in logger
|
||||
logger := log.Logger
|
||||
router.Use(fiberzerolog.New(fiberzerolog.Config{
|
||||
Logger: &logger,
|
||||
}))
|
||||
|
||||
// Default middleware config
|
||||
|
||||
// Recover middleware
|
||||
if !application.ApplicationConfig().Debug {
|
||||
router.Use(recover.New())
|
||||
e.Use(middleware.Recover())
|
||||
}
|
||||
|
||||
// Metrics middleware
|
||||
if !application.ApplicationConfig().DisableMetrics {
|
||||
metricsService, err := services.NewLocalAIMetricsService()
|
||||
if err != nil {
|
||||
@@ -135,34 +128,40 @@ func API(application *application.Application) (*fiber.App, error) {
|
||||
}
|
||||
|
||||
if metricsService != nil {
|
||||
router.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
|
||||
router.Hooks().OnShutdown(func() error {
|
||||
return metricsService.Shutdown()
|
||||
e.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
|
||||
e.Server.RegisterOnShutdown(func() {
|
||||
metricsService.Shutdown()
|
||||
})
|
||||
}
|
||||
}
|
||||
// Health Checks should always be exempt from auth, so register these first
|
||||
routes.HealthRoutes(router)
|
||||
|
||||
kaConfig, err := middleware.GetKeyAuthConfig(application.ApplicationConfig())
|
||||
if err != nil || kaConfig == nil {
|
||||
// Health Checks should always be exempt from auth, so register these first
|
||||
routes.HealthRoutes(e)
|
||||
|
||||
// Get key auth middleware
|
||||
keyAuthMiddleware, err := httpMiddleware.GetKeyAuthConfig(application.ApplicationConfig())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create key auth config: %w", err)
|
||||
}
|
||||
|
||||
httpFS := http.FS(embedDirStatic)
|
||||
// Favicon handler
|
||||
e.GET("/favicon.svg", func(c echo.Context) error {
|
||||
data, err := embedDirStatic.ReadFile("static/favicon.svg")
|
||||
if err != nil {
|
||||
return c.NoContent(http.StatusNotFound)
|
||||
}
|
||||
c.Response().Header().Set("Content-Type", "image/svg+xml")
|
||||
return c.Blob(http.StatusOK, "image/svg+xml", data)
|
||||
})
|
||||
|
||||
router.Use(favicon.New(favicon.Config{
|
||||
URL: "/favicon.svg",
|
||||
FileSystem: httpFS,
|
||||
File: "static/favicon.svg",
|
||||
}))
|
||||
|
||||
router.Use("/static", filesystem.New(filesystem.Config{
|
||||
Root: httpFS,
|
||||
PathPrefix: "static",
|
||||
Browse: true,
|
||||
}))
|
||||
// Static files - use fs.Sub to create a filesystem rooted at "static"
|
||||
staticFS, err := fs.Sub(embedDirStatic, "static")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create static filesystem: %w", err)
|
||||
}
|
||||
e.StaticFS("/static", staticFS)
|
||||
|
||||
// Generated content directories
|
||||
if application.ApplicationConfig().GeneratedContentDir != "" {
|
||||
os.MkdirAll(application.ApplicationConfig().GeneratedContentDir, 0750)
|
||||
audioPath := filepath.Join(application.ApplicationConfig().GeneratedContentDir, "audio")
|
||||
@@ -173,51 +172,53 @@ func API(application *application.Application) (*fiber.App, error) {
|
||||
os.MkdirAll(imagePath, 0750)
|
||||
os.MkdirAll(videoPath, 0750)
|
||||
|
||||
router.Static("/generated-audio", audioPath)
|
||||
router.Static("/generated-images", imagePath)
|
||||
router.Static("/generated-videos", videoPath)
|
||||
e.Static("/generated-audio", audioPath)
|
||||
e.Static("/generated-images", imagePath)
|
||||
e.Static("/generated-videos", videoPath)
|
||||
}
|
||||
|
||||
// Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Filter property of the KeyAuth Configuration
|
||||
router.Use(v2keyauth.New(*kaConfig))
|
||||
// Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Skipper property of the KeyAuth Configuration
|
||||
e.Use(keyAuthMiddleware)
|
||||
|
||||
// CORS middleware
|
||||
if application.ApplicationConfig().CORS {
|
||||
var c func(ctx *fiber.Ctx) error
|
||||
if application.ApplicationConfig().CORSAllowOrigins == "" {
|
||||
c = cors.New()
|
||||
} else {
|
||||
c = cors.New(cors.Config{AllowOrigins: application.ApplicationConfig().CORSAllowOrigins})
|
||||
corsConfig := middleware.CORSConfig{}
|
||||
if application.ApplicationConfig().CORSAllowOrigins != "" {
|
||||
corsConfig.AllowOrigins = strings.Split(application.ApplicationConfig().CORSAllowOrigins, ",")
|
||||
}
|
||||
|
||||
router.Use(c)
|
||||
e.Use(middleware.CORSWithConfig(corsConfig))
|
||||
}
|
||||
|
||||
// CSRF middleware
|
||||
if application.ApplicationConfig().CSRF {
|
||||
log.Debug().Msg("Enabling CSRF middleware. Tokens are now required for state-modifying requests")
|
||||
router.Use(csrf.New())
|
||||
e.Use(middleware.CSRF())
|
||||
}
|
||||
|
||||
requestExtractor := middleware.NewRequestExtractor(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
requestExtractor := httpMiddleware.NewRequestExtractor(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
|
||||
routes.RegisterElevenLabsRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
|
||||
routes.RegisterElevenLabsRoutes(router, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
|
||||
// Create opcache for tracking UI operations (used by both UI and LocalAI routes)
|
||||
var opcache *services.OpCache
|
||||
if !application.ApplicationConfig().DisableWebUI {
|
||||
opcache = services.NewOpCache(application.GalleryService())
|
||||
}
|
||||
|
||||
routes.RegisterLocalAIRoutes(router, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache)
|
||||
routes.RegisterOpenAIRoutes(router, requestExtractor, application)
|
||||
|
||||
routes.RegisterLocalAIRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application.TemplatesEvaluator())
|
||||
routes.RegisterOpenAIRoutes(e, requestExtractor, application)
|
||||
if !application.ApplicationConfig().DisableWebUI {
|
||||
routes.RegisterUIAPIRoutes(router, application.ModelConfigLoader(), application.ApplicationConfig(), application.GalleryService(), opcache)
|
||||
routes.RegisterUIRoutes(router, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService())
|
||||
routes.RegisterUIAPIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache)
|
||||
routes.RegisterUIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService())
|
||||
}
|
||||
routes.RegisterJINARoutes(router, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
routes.RegisterJINARoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
|
||||
// Define a custom 404 handler
|
||||
// Note: keep this at the bottom!
|
||||
router.Use(notFoundHandler)
|
||||
// Note: 404 handling is done via HTTPErrorHandler above, no need for catch-all route
|
||||
|
||||
return router, nil
|
||||
// Log startup message
|
||||
e.Server.RegisterOnShutdown(func() {
|
||||
log.Info().Msg("LocalAI API server shutting down")
|
||||
})
|
||||
|
||||
return e, nil
|
||||
}
|
||||
|
||||
@@ -10,13 +10,14 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
. "github.com/mudler/LocalAI/core/http"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
@@ -25,6 +26,7 @@ import (
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
openaigo "github.com/otiai10/openaigo"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/jsonschema"
|
||||
)
|
||||
@@ -85,7 +87,7 @@ func getModels(url string) ([]gallery.GalleryModel, error) {
|
||||
response := []gallery.GalleryModel{}
|
||||
uri := downloader.URI(url)
|
||||
// TODO: No tests currently seem to exercise file:// urls. Fix?
|
||||
err := uri.DownloadWithAuthorizationAndCallback("", bearerKey, func(url string, i []byte) error {
|
||||
err := uri.DownloadWithAuthorizationAndCallback(context.TODO(), "", bearerKey, func(url string, i []byte) error {
|
||||
// Unmarshal YAML data into a struct
|
||||
return json.Unmarshal(i, &response)
|
||||
})
|
||||
@@ -266,7 +268,7 @@ const bertEmbeddingsURL = `https://gist.githubusercontent.com/mudler/0a080b166b8
|
||||
|
||||
var _ = Describe("API test", func() {
|
||||
|
||||
var app *fiber.App
|
||||
var app *echo.Echo
|
||||
var client *openai.Client
|
||||
var client2 *openaigo.Client
|
||||
var c context.Context
|
||||
@@ -339,7 +341,11 @@ var _ = Describe("API test", func() {
|
||||
app, err = API(application)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
go func() {
|
||||
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
|
||||
log.Error().Err(err).Msg("server error")
|
||||
}
|
||||
}()
|
||||
|
||||
defaultConfig := openai.DefaultConfig(apiKey)
|
||||
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
|
||||
@@ -358,7 +364,9 @@ var _ = Describe("API test", func() {
|
||||
AfterEach(func(sc SpecContext) {
|
||||
cancel()
|
||||
if app != nil {
|
||||
err := app.Shutdown()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
err := app.Shutdown(ctx)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
err := os.RemoveAll(tmpdir)
|
||||
@@ -547,7 +555,11 @@ var _ = Describe("API test", func() {
|
||||
app, err = API(application)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
go func() {
|
||||
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
|
||||
log.Error().Err(err).Msg("server error")
|
||||
}
|
||||
}()
|
||||
|
||||
defaultConfig := openai.DefaultConfig("")
|
||||
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
|
||||
@@ -566,7 +578,9 @@ var _ = Describe("API test", func() {
|
||||
AfterEach(func() {
|
||||
cancel()
|
||||
if app != nil {
|
||||
err := app.Shutdown()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
err := app.Shutdown(ctx)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
err := os.RemoveAll(tmpdir)
|
||||
@@ -755,7 +769,11 @@ var _ = Describe("API test", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
app, err = API(application)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
go func() {
|
||||
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
|
||||
log.Error().Err(err).Msg("server error")
|
||||
}
|
||||
}()
|
||||
|
||||
defaultConfig := openai.DefaultConfig("")
|
||||
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
|
||||
@@ -773,7 +791,9 @@ var _ = Describe("API test", func() {
|
||||
AfterEach(func() {
|
||||
cancel()
|
||||
if app != nil {
|
||||
err := app.Shutdown()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
err := app.Shutdown(ctx)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
})
|
||||
@@ -796,6 +816,83 @@ var _ = Describe("API test", func() {
|
||||
Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
It("returns logprobs in chat completions when requested", func() {
|
||||
topLogprobsVal := 3
|
||||
response, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{
|
||||
Model: "testmodel.ggml",
|
||||
LogProbs: true,
|
||||
TopLogProbs: topLogprobsVal,
|
||||
Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(len(response.Choices)).To(Equal(1))
|
||||
Expect(response.Choices[0].Message).ToNot(BeNil())
|
||||
Expect(response.Choices[0].Message.Content).ToNot(BeEmpty())
|
||||
|
||||
// Verify logprobs are present and have correct structure
|
||||
Expect(response.Choices[0].LogProbs).ToNot(BeNil())
|
||||
Expect(response.Choices[0].LogProbs.Content).ToNot(BeEmpty())
|
||||
|
||||
Expect(len(response.Choices[0].LogProbs.Content)).To(BeNumerically(">", 1))
|
||||
|
||||
foundatLeastToken := ""
|
||||
foundAtLeastBytes := []byte{}
|
||||
foundAtLeastTopLogprobBytes := []byte{}
|
||||
foundatLeastTopLogprob := ""
|
||||
// Verify logprobs content structure matches OpenAI format
|
||||
for _, logprobContent := range response.Choices[0].LogProbs.Content {
|
||||
// Bytes can be empty for certain tokens (special tokens, etc.), so we don't require it
|
||||
if len(logprobContent.Bytes) > 0 {
|
||||
foundAtLeastBytes = logprobContent.Bytes
|
||||
}
|
||||
if len(logprobContent.Token) > 0 {
|
||||
foundatLeastToken = logprobContent.Token
|
||||
}
|
||||
Expect(logprobContent.LogProb).To(BeNumerically("<=", 0)) // Logprobs are always <= 0
|
||||
Expect(len(logprobContent.TopLogProbs)).To(BeNumerically(">", 1))
|
||||
|
||||
// If top_logprobs is requested, verify top_logprobs array respects the limit
|
||||
if len(logprobContent.TopLogProbs) > 0 {
|
||||
// Should respect top_logprobs limit (3 in this test)
|
||||
Expect(len(logprobContent.TopLogProbs)).To(BeNumerically("<=", topLogprobsVal))
|
||||
for _, topLogprob := range logprobContent.TopLogProbs {
|
||||
if len(topLogprob.Bytes) > 0 {
|
||||
foundAtLeastTopLogprobBytes = topLogprob.Bytes
|
||||
}
|
||||
if len(topLogprob.Token) > 0 {
|
||||
foundatLeastTopLogprob = topLogprob.Token
|
||||
}
|
||||
Expect(topLogprob.LogProb).To(BeNumerically("<=", 0))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Expect(foundAtLeastBytes).ToNot(BeEmpty())
|
||||
Expect(foundAtLeastTopLogprobBytes).ToNot(BeEmpty())
|
||||
Expect(foundatLeastToken).ToNot(BeEmpty())
|
||||
Expect(foundatLeastTopLogprob).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
It("applies logit_bias to chat completions when requested", func() {
|
||||
// logit_bias is a map of token IDs (as strings) to bias values (-100 to 100)
|
||||
// According to OpenAI API: modifies the likelihood of specified tokens appearing in the completion
|
||||
logitBias := map[string]int{
|
||||
"15043": 1, // Bias token ID 15043 (example token ID) with bias value 1
|
||||
}
|
||||
response, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{
|
||||
Model: "testmodel.ggml",
|
||||
Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}},
|
||||
LogitBias: logitBias,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(response.Choices)).To(Equal(1))
|
||||
Expect(response.Choices[0].Message).ToNot(BeNil())
|
||||
Expect(response.Choices[0].Message.Content).ToNot(BeEmpty())
|
||||
// If logit_bias is applied, the response should be generated successfully
|
||||
// We can't easily verify the bias effect without knowing the actual token IDs for the model,
|
||||
// but the fact that the request succeeds confirms the API accepts and processes logit_bias
|
||||
})
|
||||
|
||||
It("returns errors", func() {
|
||||
_, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: testPrompt})
|
||||
Expect(err).To(HaveOccurred())
|
||||
@@ -1006,7 +1103,11 @@ var _ = Describe("API test", func() {
|
||||
app, err = API(application)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
go func() {
|
||||
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
|
||||
log.Error().Err(err).Msg("server error")
|
||||
}
|
||||
}()
|
||||
|
||||
defaultConfig := openai.DefaultConfig("")
|
||||
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
|
||||
@@ -1022,7 +1123,9 @@ var _ = Describe("API test", func() {
|
||||
AfterEach(func() {
|
||||
cancel()
|
||||
if app != nil {
|
||||
err := app.Shutdown()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
err := app.Shutdown(ctx)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package elevenlabs
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
@@ -15,17 +17,17 @@ import (
|
||||
// @Param request body schema.ElevenLabsSoundGenerationRequest true "query params"
|
||||
// @Success 200 {string} binary "Response"
|
||||
// @Router /v1/sound-generation [post]
|
||||
func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsSoundGenerationRequest)
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsSoundGenerationRequest)
|
||||
if !ok || input.ModelID == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
log.Debug().Str("modelFile", "modelFile").Str("backend", cfg.Backend).Msg("Sound Generation Request about to be sent to backend")
|
||||
@@ -35,7 +37,7 @@ func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Download(filePath)
|
||||
return c.Attachment(filePath, filepath.Base(filePath))
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
package elevenlabs
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
@@ -17,19 +18,19 @@ import (
|
||||
// @Param request body schema.TTSRequest true "query params"
|
||||
// @Success 200 {string} binary "Response"
|
||||
// @Router /v1/text-to-speech/{voice-id} [post]
|
||||
func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
|
||||
voiceID := c.Params("voice-id")
|
||||
voiceID := c.Param("voice-id")
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsTTSRequest)
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsTTSRequest)
|
||||
if !ok || input.ModelID == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
log.Debug().Str("modelName", input.ModelID).Msg("elevenlabs TTS request received")
|
||||
@@ -38,6 +39,6 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Download(filePath)
|
||||
return c.Attachment(filePath, filepath.Base(filePath))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,28 +2,32 @@ package explorer
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/explorer"
|
||||
"github.com/mudler/LocalAI/core/http/utils"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
)
|
||||
|
||||
func Dashboard() func(*fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
summary := fiber.Map{
|
||||
func Dashboard() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
summary := map[string]interface{}{
|
||||
"Title": "LocalAI API - " + internal.PrintableVersion(),
|
||||
"Version": internal.PrintableVersion(),
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"BaseURL": middleware.BaseURL(c),
|
||||
}
|
||||
|
||||
if string(c.Context().Request.Header.ContentType()) == "application/json" || len(c.Accepts("html")) == 0 {
|
||||
contentType := c.Request().Header.Get("Content-Type")
|
||||
accept := c.Request().Header.Get("Accept")
|
||||
if strings.Contains(contentType, "application/json") || (accept != "" && !strings.Contains(accept, "html")) {
|
||||
// The client expects a JSON response
|
||||
return c.Status(fiber.StatusOK).JSON(summary)
|
||||
return c.JSON(http.StatusOK, summary)
|
||||
} else {
|
||||
// Render index
|
||||
return c.Render("views/explorer", summary)
|
||||
return c.Render(http.StatusOK, "views/explorer", summary)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -39,8 +43,8 @@ type Network struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
func ShowNetworks(db *explorer.Database) func(*fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func ShowNetworks(db *explorer.Database) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
results := []Network{}
|
||||
for _, token := range db.TokenList() {
|
||||
networkData, exists := db.Get(token) // get the token data
|
||||
@@ -61,44 +65,44 @@ func ShowNetworks(db *explorer.Database) func(*fiber.Ctx) error {
|
||||
return len(results[i].Clusters) > len(results[j].Clusters)
|
||||
})
|
||||
|
||||
return c.JSON(results)
|
||||
return c.JSON(http.StatusOK, results)
|
||||
}
|
||||
}
|
||||
|
||||
func AddNetwork(db *explorer.Database) func(*fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func AddNetwork(db *explorer.Database) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
request := new(AddNetworkRequest)
|
||||
if err := c.BodyParser(request); err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Cannot parse JSON"})
|
||||
if err := c.Bind(request); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Cannot parse JSON"})
|
||||
}
|
||||
|
||||
if request.Token == "" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Token is required"})
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Token is required"})
|
||||
}
|
||||
|
||||
if request.Name == "" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Name is required"})
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Name is required"})
|
||||
}
|
||||
|
||||
if request.Description == "" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Description is required"})
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Description is required"})
|
||||
}
|
||||
|
||||
// TODO: check if token is valid, otherwise reject
|
||||
// try to decode the token from base64
|
||||
_, err := base64.StdEncoding.DecodeString(request.Token)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid token"})
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid token"})
|
||||
}
|
||||
|
||||
if _, exists := db.Get(request.Token); exists {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Token already exists"})
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Token already exists"})
|
||||
}
|
||||
err = db.Set(request.Token, explorer.TokenData{Name: request.Name, Description: request.Description})
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Cannot add token"})
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Cannot add token"})
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(fiber.Map{"message": "Token added"})
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Token added"})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
package jina
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
@@ -17,24 +18,36 @@ import (
|
||||
// @Param request body schema.JINARerankRequest true "query params"
|
||||
// @Success 200 {object} schema.JINARerankResponse "Response"
|
||||
// @Router /v1/rerank [post]
|
||||
func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.JINARerankRequest)
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.JINARerankRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
log.Debug().Str("model", input.Model).Msg("JINA Rerank Request received")
|
||||
|
||||
var requestTopN int32
|
||||
docs := int32(len(input.Documents))
|
||||
if input.TopN == nil { // omit top_n to get all
|
||||
requestTopN = docs
|
||||
} else {
|
||||
requestTopN = int32(*input.TopN)
|
||||
if requestTopN < 1 {
|
||||
return c.JSON(http.StatusUnprocessableEntity, "top_n - should be greater than or equal to 1")
|
||||
}
|
||||
if requestTopN > docs { // make it more obvious for backends
|
||||
requestTopN = docs
|
||||
}
|
||||
}
|
||||
request := &proto.RerankRequest{
|
||||
Query: input.Query,
|
||||
TopN: int32(input.TopN),
|
||||
TopN: requestTopN,
|
||||
Documents: input.Documents,
|
||||
}
|
||||
|
||||
@@ -58,6 +71,6 @@ func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
||||
response.Usage.TotalTokens = int(results.Usage.TotalTokens)
|
||||
response.Usage.PromptTokens = int(results.Usage.PromptTokens)
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(response)
|
||||
return c.JSON(http.StatusOK, response)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,11 +4,11 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/http/utils"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
@@ -39,13 +39,13 @@ func CreateBackendEndpointService(galleries []config.Gallery, systemState *syste
|
||||
// @Summary Returns the job status
|
||||
// @Success 200 {object} services.GalleryOpStatus "Response"
|
||||
// @Router /backends/jobs/{uuid} [get]
|
||||
func (mgs *BackendEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
status := mgs.backendApplier.GetStatus(c.Params("uuid"))
|
||||
func (mgs *BackendEndpointService) GetOpStatusEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
status := mgs.backendApplier.GetStatus(c.Param("uuid"))
|
||||
if status == nil {
|
||||
return fmt.Errorf("could not find any status for ID")
|
||||
}
|
||||
return c.JSON(status)
|
||||
return c.JSON(200, status)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,9 +53,9 @@ func (mgs *BackendEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) erro
|
||||
// @Summary Returns all the jobs status progress
|
||||
// @Success 200 {object} map[string]services.GalleryOpStatus "Response"
|
||||
// @Router /backends/jobs [get]
|
||||
func (mgs *BackendEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
return c.JSON(mgs.backendApplier.GetAllStatus())
|
||||
func (mgs *BackendEndpointService) GetAllStatusEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
return c.JSON(200, mgs.backendApplier.GetAllStatus())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,11 +64,11 @@ func (mgs *BackendEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) err
|
||||
// @Param request body GalleryBackend true "query params"
|
||||
// @Success 200 {object} schema.BackendResponse "Response"
|
||||
// @Router /backends/apply [post]
|
||||
func (mgs *BackendEndpointService) ApplyBackendEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func (mgs *BackendEndpointService) ApplyBackendEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input := new(GalleryBackend)
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
if err := c.Bind(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -82,7 +82,7 @@ func (mgs *BackendEndpointService) ApplyBackendEndpoint() func(c *fiber.Ctx) err
|
||||
Galleries: mgs.galleries,
|
||||
}
|
||||
|
||||
return c.JSON(schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", utils.BaseURL(c), uuid.String())})
|
||||
return c.JSON(200, schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", middleware.BaseURL(c), uuid.String())})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,9 +91,9 @@ func (mgs *BackendEndpointService) ApplyBackendEndpoint() func(c *fiber.Ctx) err
|
||||
// @Param name path string true "Backend name"
|
||||
// @Success 200 {object} schema.BackendResponse "Response"
|
||||
// @Router /backends/delete/{name} [post]
|
||||
func (mgs *BackendEndpointService) DeleteBackendEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
backendName := c.Params("name")
|
||||
func (mgs *BackendEndpointService) DeleteBackendEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
backendName := c.Param("name")
|
||||
|
||||
mgs.backendApplier.BackendGalleryChannel <- services.GalleryOp[gallery.GalleryBackend, any]{
|
||||
Delete: true,
|
||||
@@ -106,7 +106,7 @@ func (mgs *BackendEndpointService) DeleteBackendEndpoint() func(c *fiber.Ctx) er
|
||||
return err
|
||||
}
|
||||
|
||||
return c.JSON(schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", utils.BaseURL(c), uuid.String())})
|
||||
return c.JSON(200, schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", middleware.BaseURL(c), uuid.String())})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,13 +114,13 @@ func (mgs *BackendEndpointService) DeleteBackendEndpoint() func(c *fiber.Ctx) er
|
||||
// @Summary List all Backends
|
||||
// @Success 200 {object} []gallery.GalleryBackend "Response"
|
||||
// @Router /backends [get]
|
||||
func (mgs *BackendEndpointService) ListBackendsEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func (mgs *BackendEndpointService) ListBackendsEndpoint(systemState *system.SystemState) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
backends, err := gallery.ListSystemBackends(systemState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.JSON(backends.GetAll())
|
||||
return c.JSON(200, backends.GetAll())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -129,14 +129,14 @@ func (mgs *BackendEndpointService) ListBackendsEndpoint(systemState *system.Syst
|
||||
// @Success 200 {object} []config.Gallery "Response"
|
||||
// @Router /backends/galleries [get]
|
||||
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
|
||||
func (mgs *BackendEndpointService) ListBackendGalleriesEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func (mgs *BackendEndpointService) ListBackendGalleriesEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
log.Debug().Msgf("Listing backend galleries %+v", mgs.galleries)
|
||||
dat, err := json.Marshal(mgs.galleries)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Send(dat)
|
||||
return c.Blob(200, "application/json", dat)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,12 +144,12 @@ func (mgs *BackendEndpointService) ListBackendGalleriesEndpoint() func(c *fiber.
|
||||
// @Summary List all available Backends
|
||||
// @Success 200 {object} []gallery.GalleryBackend "Response"
|
||||
// @Router /backends/available [get]
|
||||
func (mgs *BackendEndpointService) ListAvailableBackendsEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func (mgs *BackendEndpointService) ListAvailableBackendsEndpoint(systemState *system.SystemState) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
backends, err := gallery.AvailableBackends(mgs.galleries, systemState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.JSON(backends)
|
||||
return c.JSON(200, backends)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,45 +1,45 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
)
|
||||
|
||||
// BackendMonitorEndpoint returns the status of the specified backend
|
||||
// @Summary Backend monitor endpoint
|
||||
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
|
||||
// @Success 200 {object} proto.StatusResponse "Response"
|
||||
// @Router /backend/monitor [get]
|
||||
func BackendMonitorEndpoint(bm *services.BackendMonitorService) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
input := new(schema.BackendMonitorRequest)
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := bm.CheckAndSample(input.Model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.JSON(resp)
|
||||
}
|
||||
}
|
||||
|
||||
// BackendShutdownEndpoint shuts down the specified backend
|
||||
// @Summary Backend monitor endpoint
|
||||
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
|
||||
// @Router /backend/shutdown [post]
|
||||
func BackendShutdownEndpoint(bm *services.BackendMonitorService) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input := new(schema.BackendMonitorRequest)
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return bm.ShutdownModel(input.Model)
|
||||
}
|
||||
}
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
)
|
||||
|
||||
// BackendMonitorEndpoint returns the status of the specified backend
|
||||
// @Summary Backend monitor endpoint
|
||||
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
|
||||
// @Success 200 {object} proto.StatusResponse "Response"
|
||||
// @Router /backend/monitor [get]
|
||||
func BackendMonitorEndpoint(bm *services.BackendMonitorService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
|
||||
input := new(schema.BackendMonitorRequest)
|
||||
// Get input data from the request body
|
||||
if err := c.Bind(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := bm.CheckAndSample(input.Model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
}
|
||||
|
||||
// BackendShutdownEndpoint shuts down the specified backend
|
||||
// @Summary Backend monitor endpoint
|
||||
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
|
||||
// @Router /backend/shutdown [post]
|
||||
func BackendShutdownEndpoint(bm *services.BackendMonitorService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input := new(schema.BackendMonitorRequest)
|
||||
// Get input data from the request body
|
||||
if err := c.Bind(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return bm.ShutdownModel(input.Model)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
@@ -16,17 +16,17 @@ import (
|
||||
// @Param request body schema.DetectionRequest true "query params"
|
||||
// @Success 200 {object} schema.DetectionResponse "Response"
|
||||
// @Router /v1/detection [post]
|
||||
func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.DetectionRequest)
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.DetectionRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
log.Debug().Str("image", input.Image).Str("modelFile", "modelFile").Str("backend", cfg.Backend).Msg("Detection")
|
||||
@@ -54,6 +54,6 @@ func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appC
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(response)
|
||||
return c.JSON(200, response)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,11 +2,13 @@ package localai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
httpUtils "github.com/mudler/LocalAI/core/http/utils"
|
||||
httpUtils "github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
|
||||
@@ -14,15 +16,15 @@ import (
|
||||
)
|
||||
|
||||
// GetEditModelPage renders the edit model page with current configuration
|
||||
func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
modelName := c.Params("name")
|
||||
func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
modelName := c.Param("name")
|
||||
if modelName == "" {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Model name is required",
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
modelConfig, exists := cl.GetModelConfig(modelName)
|
||||
@@ -31,7 +33,7 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio
|
||||
Success: false,
|
||||
Error: "Model configuration not found",
|
||||
}
|
||||
return c.Status(404).JSON(response)
|
||||
return c.JSON(http.StatusNotFound, response)
|
||||
}
|
||||
|
||||
modelConfigFile := modelConfig.GetModelConfigFile()
|
||||
@@ -40,7 +42,7 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio
|
||||
Success: false,
|
||||
Error: "Model configuration file not found",
|
||||
}
|
||||
return c.Status(404).JSON(response)
|
||||
return c.JSON(http.StatusNotFound, response)
|
||||
}
|
||||
configData, err := os.ReadFile(modelConfigFile)
|
||||
if err != nil {
|
||||
@@ -48,7 +50,7 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio
|
||||
Success: false,
|
||||
Error: "Failed to read configuration file: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
return c.JSON(http.StatusInternalServerError, response)
|
||||
}
|
||||
|
||||
// Render the edit page with the current configuration
|
||||
@@ -69,20 +71,20 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio
|
||||
Version: internal.PrintableVersion(),
|
||||
}
|
||||
|
||||
return c.Render("views/model-editor", templateData)
|
||||
return c.Render(http.StatusOK, "views/model-editor", templateData)
|
||||
}
|
||||
}
|
||||
|
||||
// EditModelEndpoint handles updating existing model configurations
|
||||
func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
modelName := c.Params("name")
|
||||
func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
modelName := c.Param("name")
|
||||
if modelName == "" {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Model name is required",
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
modelConfig, exists := cl.GetModelConfig(modelName)
|
||||
@@ -91,17 +93,24 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
|
||||
Success: false,
|
||||
Error: "Existing model configuration not found",
|
||||
}
|
||||
return c.Status(404).JSON(response)
|
||||
return c.JSON(http.StatusNotFound, response)
|
||||
}
|
||||
|
||||
// Get the raw body
|
||||
body := c.Body()
|
||||
body, err := io.ReadAll(c.Request().Body)
|
||||
if err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to read request body: " + err.Error(),
|
||||
}
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
if len(body) == 0 {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Request body is empty",
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
// Check content to see if it's a valid model config
|
||||
@@ -113,7 +122,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
|
||||
Success: false,
|
||||
Error: "Failed to parse YAML: " + err.Error(),
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
@@ -122,7 +131,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
|
||||
Success: false,
|
||||
Error: "Name is required",
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
// Validate the configuration
|
||||
@@ -132,7 +141,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
|
||||
Error: "Validation failed",
|
||||
Details: []string{"Configuration validation failed. Please check your YAML syntax and required fields."},
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
// Load the existing configuration
|
||||
@@ -142,7 +151,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
|
||||
Success: false,
|
||||
Error: "Model configuration not trusted: " + err.Error(),
|
||||
}
|
||||
return c.Status(404).JSON(response)
|
||||
return c.JSON(http.StatusNotFound, response)
|
||||
}
|
||||
|
||||
// Write new content to file
|
||||
@@ -151,7 +160,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
|
||||
Success: false,
|
||||
Error: "Failed to write configuration file: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
return c.JSON(http.StatusInternalServerError, response)
|
||||
}
|
||||
|
||||
// Reload configurations
|
||||
@@ -160,7 +169,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
|
||||
Success: false,
|
||||
Error: "Failed to reload configurations: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
return c.JSON(http.StatusInternalServerError, response)
|
||||
}
|
||||
|
||||
// Preload the model
|
||||
@@ -169,7 +178,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
|
||||
Success: false,
|
||||
Error: "Failed to preload model: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
return c.JSON(http.StatusInternalServerError, response)
|
||||
}
|
||||
|
||||
// Return success response
|
||||
@@ -179,20 +188,20 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
|
||||
Filename: configPath,
|
||||
Config: req,
|
||||
}
|
||||
return c.JSON(response)
|
||||
return c.JSON(200, response)
|
||||
}
|
||||
}
|
||||
|
||||
// ReloadModelsEndpoint handles reloading model configurations from disk
|
||||
func ReloadModelsEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func ReloadModelsEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
// Reload configurations
|
||||
if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to reload configurations: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
return c.JSON(http.StatusInternalServerError, response)
|
||||
}
|
||||
|
||||
// Preload the models
|
||||
@@ -201,7 +210,7 @@ func ReloadModelsEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applic
|
||||
Success: false,
|
||||
Error: "Failed to preload models: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
return c.JSON(http.StatusInternalServerError, response)
|
||||
}
|
||||
|
||||
// Return success response
|
||||
@@ -209,6 +218,6 @@ func ReloadModelsEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applic
|
||||
Success: true,
|
||||
Message: "Model configurations reloaded successfully",
|
||||
}
|
||||
return c.Status(fiber.StatusOK).JSON(response)
|
||||
return c.JSON(http.StatusOK, response)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,12 +2,14 @@ package localai_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
. "github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
@@ -15,6 +17,14 @@ import (
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// testRenderer is a simple renderer for tests that returns JSON
|
||||
type testRenderer struct{}
|
||||
|
||||
func (t *testRenderer) Render(w io.Writer, name string, data interface{}, c echo.Context) error {
|
||||
// For tests, just return the data as JSON
|
||||
return json.NewEncoder(w).Encode(data)
|
||||
}
|
||||
|
||||
var _ = Describe("Edit Model test", func() {
|
||||
|
||||
var tempDir string
|
||||
@@ -40,33 +50,35 @@ var _ = Describe("Edit Model test", func() {
|
||||
//modelLoader := model.NewModelLoader(systemState, true)
|
||||
modelConfigLoader := config.NewModelConfigLoader(systemState.Model.ModelsPath)
|
||||
|
||||
// Define Fiber app.
|
||||
app := fiber.New()
|
||||
app.Put("/import-model", ImportModelEndpoint(modelConfigLoader, applicationConfig))
|
||||
// Define Echo app and register all routes upfront
|
||||
app := echo.New()
|
||||
// Set up a simple renderer for the test
|
||||
app.Renderer = &testRenderer{}
|
||||
app.POST("/import-model", ImportModelEndpoint(modelConfigLoader, applicationConfig))
|
||||
app.GET("/edit-model/:name", GetEditModelPage(modelConfigLoader, applicationConfig))
|
||||
|
||||
requestBody := bytes.NewBufferString(`{"name": "foo", "backend": "foo", "model": "foo"}`)
|
||||
|
||||
req := httptest.NewRequest("PUT", "/import-model", requestBody)
|
||||
resp, err := app.Test(req, 5000)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req := httptest.NewRequest("POST", "/import-model", requestBody)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(rec.Body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(string(body)).To(ContainSubstring("Model configuration created successfully"))
|
||||
Expect(resp.StatusCode).To(Equal(fiber.StatusOK))
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
app.Get("/edit-model/:name", EditModelEndpoint(modelConfigLoader, applicationConfig))
|
||||
requestBody = bytes.NewBufferString(`{"name": "foo", "parameters": { "model": "foo"}}`)
|
||||
req = httptest.NewRequest("GET", "/edit-model/foo", nil)
|
||||
rec = httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
req = httptest.NewRequest("GET", "/edit-model/foo", requestBody)
|
||||
resp, _ = app.Test(req, 1)
|
||||
|
||||
body, err = io.ReadAll(resp.Body)
|
||||
defer resp.Body.Close()
|
||||
body, err = io.ReadAll(rec.Body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(string(body)).To(ContainSubstring(`"model":"foo"`))
|
||||
Expect(resp.StatusCode).To(Equal(fiber.StatusOK))
|
||||
// The response contains the model configuration with backend field
|
||||
Expect(string(body)).To(ContainSubstring(`"backend":"foo"`))
|
||||
Expect(string(body)).To(ContainSubstring(`"name":"foo"`))
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,160 +1,160 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/http/utils"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type ModelGalleryEndpointService struct {
|
||||
galleries []config.Gallery
|
||||
backendGalleries []config.Gallery
|
||||
modelPath string
|
||||
galleryApplier *services.GalleryService
|
||||
}
|
||||
|
||||
type GalleryModel struct {
|
||||
ID string `json:"id"`
|
||||
gallery.GalleryModel
|
||||
}
|
||||
|
||||
func CreateModelGalleryEndpointService(galleries []config.Gallery, backendGalleries []config.Gallery, systemState *system.SystemState, galleryApplier *services.GalleryService) ModelGalleryEndpointService {
|
||||
return ModelGalleryEndpointService{
|
||||
galleries: galleries,
|
||||
backendGalleries: backendGalleries,
|
||||
modelPath: systemState.Model.ModelsPath,
|
||||
galleryApplier: galleryApplier,
|
||||
}
|
||||
}
|
||||
|
||||
// GetOpStatusEndpoint returns the job status
|
||||
// @Summary Returns the job status
|
||||
// @Success 200 {object} services.GalleryOpStatus "Response"
|
||||
// @Router /models/jobs/{uuid} [get]
|
||||
func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
status := mgs.galleryApplier.GetStatus(c.Params("uuid"))
|
||||
if status == nil {
|
||||
return fmt.Errorf("could not find any status for ID")
|
||||
}
|
||||
return c.JSON(status)
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllStatusEndpoint returns all the jobs status progress
|
||||
// @Summary Returns all the jobs status progress
|
||||
// @Success 200 {object} map[string]services.GalleryOpStatus "Response"
|
||||
// @Router /models/jobs [get]
|
||||
func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
return c.JSON(mgs.galleryApplier.GetAllStatus())
|
||||
}
|
||||
}
|
||||
|
||||
// ApplyModelGalleryEndpoint installs a new model to a LocalAI instance from the model gallery
|
||||
// @Summary Install models to LocalAI.
|
||||
// @Param request body GalleryModel true "query params"
|
||||
// @Success 200 {object} schema.GalleryResponse "Response"
|
||||
// @Router /models/apply [post]
|
||||
func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input := new(GalleryModel)
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
uuid, err := uuid.NewUUID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mgs.galleryApplier.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
|
||||
Req: input.GalleryModel,
|
||||
ID: uuid.String(),
|
||||
GalleryElementName: input.ID,
|
||||
Galleries: mgs.galleries,
|
||||
BackendGalleries: mgs.backendGalleries,
|
||||
}
|
||||
|
||||
return c.JSON(schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", utils.BaseURL(c), uuid.String())})
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteModelGalleryEndpoint lets delete models from a LocalAI instance
|
||||
// @Summary delete models to LocalAI.
|
||||
// @Param name path string true "Model name"
|
||||
// @Success 200 {object} schema.GalleryResponse "Response"
|
||||
// @Router /models/delete/{name} [post]
|
||||
func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
modelName := c.Params("name")
|
||||
|
||||
mgs.galleryApplier.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
|
||||
Delete: true,
|
||||
GalleryElementName: modelName,
|
||||
}
|
||||
|
||||
uuid, err := uuid.NewUUID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.JSON(schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", utils.BaseURL(c), uuid.String())})
|
||||
}
|
||||
}
|
||||
|
||||
// ListModelFromGalleryEndpoint list the available models for installation from the active galleries
|
||||
// @Summary List installable models.
|
||||
// @Success 200 {object} []gallery.GalleryModel "Response"
|
||||
// @Router /models/available [get]
|
||||
func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
models, err := gallery.AvailableGalleryModels(mgs.galleries, systemState)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("could not list models from galleries")
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Available %d models from %d galleries\n", len(models), len(mgs.galleries))
|
||||
|
||||
m := []gallery.Metadata{}
|
||||
|
||||
for _, mm := range models {
|
||||
m = append(m, mm.Metadata)
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Models %#v", m)
|
||||
|
||||
dat, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not marshal models: %w", err)
|
||||
}
|
||||
return c.Send(dat)
|
||||
}
|
||||
}
|
||||
|
||||
// ListModelGalleriesEndpoint list the available galleries configured in LocalAI
|
||||
// @Summary List all Galleries
|
||||
// @Success 200 {object} []config.Gallery "Response"
|
||||
// @Router /models/galleries [get]
|
||||
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
|
||||
func (mgs *ModelGalleryEndpointService) ListModelGalleriesEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
log.Debug().Msgf("Listing model galleries %+v", mgs.galleries)
|
||||
dat, err := json.Marshal(mgs.galleries)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Send(dat)
|
||||
}
|
||||
}
|
||||
package localai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type ModelGalleryEndpointService struct {
|
||||
galleries []config.Gallery
|
||||
backendGalleries []config.Gallery
|
||||
modelPath string
|
||||
galleryApplier *services.GalleryService
|
||||
}
|
||||
|
||||
type GalleryModel struct {
|
||||
ID string `json:"id"`
|
||||
gallery.GalleryModel
|
||||
}
|
||||
|
||||
func CreateModelGalleryEndpointService(galleries []config.Gallery, backendGalleries []config.Gallery, systemState *system.SystemState, galleryApplier *services.GalleryService) ModelGalleryEndpointService {
|
||||
return ModelGalleryEndpointService{
|
||||
galleries: galleries,
|
||||
backendGalleries: backendGalleries,
|
||||
modelPath: systemState.Model.ModelsPath,
|
||||
galleryApplier: galleryApplier,
|
||||
}
|
||||
}
|
||||
|
||||
// GetOpStatusEndpoint returns the job status
|
||||
// @Summary Returns the job status
|
||||
// @Success 200 {object} services.GalleryOpStatus "Response"
|
||||
// @Router /models/jobs/{uuid} [get]
|
||||
func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
status := mgs.galleryApplier.GetStatus(c.Param("uuid"))
|
||||
if status == nil {
|
||||
return fmt.Errorf("could not find any status for ID")
|
||||
}
|
||||
return c.JSON(200, status)
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllStatusEndpoint returns all the jobs status progress
|
||||
// @Summary Returns all the jobs status progress
|
||||
// @Success 200 {object} map[string]services.GalleryOpStatus "Response"
|
||||
// @Router /models/jobs [get]
|
||||
func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
return c.JSON(200, mgs.galleryApplier.GetAllStatus())
|
||||
}
|
||||
}
|
||||
|
||||
// ApplyModelGalleryEndpoint installs a new model to a LocalAI instance from the model gallery
|
||||
// @Summary Install models to LocalAI.
|
||||
// @Param request body GalleryModel true "query params"
|
||||
// @Success 200 {object} schema.GalleryResponse "Response"
|
||||
// @Router /models/apply [post]
|
||||
func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input := new(GalleryModel)
|
||||
// Get input data from the request body
|
||||
if err := c.Bind(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
uuid, err := uuid.NewUUID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mgs.galleryApplier.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
|
||||
Req: input.GalleryModel,
|
||||
ID: uuid.String(),
|
||||
GalleryElementName: input.ID,
|
||||
Galleries: mgs.galleries,
|
||||
BackendGalleries: mgs.backendGalleries,
|
||||
}
|
||||
|
||||
return c.JSON(200, schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", middleware.BaseURL(c), uuid.String())})
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteModelGalleryEndpoint lets delete models from a LocalAI instance
|
||||
// @Summary delete models to LocalAI.
|
||||
// @Param name path string true "Model name"
|
||||
// @Success 200 {object} schema.GalleryResponse "Response"
|
||||
// @Router /models/delete/{name} [post]
|
||||
func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
modelName := c.Param("name")
|
||||
|
||||
mgs.galleryApplier.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
|
||||
Delete: true,
|
||||
GalleryElementName: modelName,
|
||||
}
|
||||
|
||||
uuid, err := uuid.NewUUID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.JSON(200, schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", middleware.BaseURL(c), uuid.String())})
|
||||
}
|
||||
}
|
||||
|
||||
// ListModelFromGalleryEndpoint list the available models for installation from the active galleries
|
||||
// @Summary List installable models.
|
||||
// @Success 200 {object} []gallery.GalleryModel "Response"
|
||||
// @Router /models/available [get]
|
||||
func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint(systemState *system.SystemState) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
|
||||
models, err := gallery.AvailableGalleryModels(mgs.galleries, systemState)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("could not list models from galleries")
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Available %d models from %d galleries\n", len(models), len(mgs.galleries))
|
||||
|
||||
m := []gallery.Metadata{}
|
||||
|
||||
for _, mm := range models {
|
||||
m = append(m, mm.Metadata)
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Models %#v", m)
|
||||
|
||||
dat, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not marshal models: %w", err)
|
||||
}
|
||||
return c.Blob(200, "application/json", dat)
|
||||
}
|
||||
}
|
||||
|
||||
// ListModelGalleriesEndpoint list the available galleries configured in LocalAI
|
||||
// @Summary List all Galleries
|
||||
// @Success 200 {object} []config.Gallery "Response"
|
||||
// @Router /models/galleries [get]
|
||||
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
|
||||
func (mgs *ModelGalleryEndpointService) ListModelGalleriesEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
log.Debug().Msgf("Listing model galleries %+v", mgs.galleries)
|
||||
dat, err := json.Marshal(mgs.galleries)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Blob(200, "application/json", dat)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
@@ -21,17 +21,17 @@ import (
|
||||
// @Success 200 {string} binary "generated audio/wav file"
|
||||
// @Router /v1/tokenMetrics [get]
|
||||
// @Router /tokenMetrics [get]
|
||||
func TokenMetricsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func TokenMetricsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
|
||||
input := new(schema.TokenMetricsRequest)
|
||||
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
if err := c.Bind(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
modelFile, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
modelFile, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
if !ok || modelFile != "" {
|
||||
modelFile = input.Model
|
||||
log.Warn().Msgf("Model not found in context: %s", input.Model)
|
||||
@@ -52,6 +52,6 @@ func TokenMetricsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, a
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.JSON(response)
|
||||
return c.JSON(200, response)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,16 +3,18 @@ package localai
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
httpUtils "github.com/mudler/LocalAI/core/http/utils"
|
||||
httpUtils "github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
@@ -21,12 +23,12 @@ import (
|
||||
)
|
||||
|
||||
// ImportModelURIEndpoint handles creating new model configurations from a URI
|
||||
func ImportModelURIEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func ImportModelURIEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
|
||||
input := new(schema.ImportModelRequest)
|
||||
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
if err := c.Bind(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -61,7 +63,7 @@ func ImportModelURIEndpoint(cl *config.ModelConfigLoader, appConfig *config.Appl
|
||||
BackendGalleries: appConfig.BackendGalleries,
|
||||
}
|
||||
|
||||
return c.JSON(schema.GalleryResponse{
|
||||
return c.JSON(200, schema.GalleryResponse{
|
||||
ID: uuid.String(),
|
||||
StatusURL: fmt.Sprintf("%smodels/jobs/%s", httpUtils.BaseURL(c), uuid.String()),
|
||||
})
|
||||
@@ -69,22 +71,28 @@ func ImportModelURIEndpoint(cl *config.ModelConfigLoader, appConfig *config.Appl
|
||||
}
|
||||
|
||||
// ImportModelEndpoint handles creating new model configurations
|
||||
func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
// Get the raw body
|
||||
body := c.Body()
|
||||
body, err := io.ReadAll(c.Request().Body)
|
||||
if err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to read request body: " + err.Error(),
|
||||
}
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
if len(body) == 0 {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Request body is empty",
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
// Check content type to determine how to parse
|
||||
contentType := string(c.Context().Request.Header.ContentType())
|
||||
contentType := c.Request().Header.Get("Content-Type")
|
||||
var modelConfig config.ModelConfig
|
||||
var err error
|
||||
|
||||
if strings.Contains(contentType, "application/json") {
|
||||
// Parse JSON
|
||||
@@ -93,7 +101,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
Success: false,
|
||||
Error: "Failed to parse JSON: " + err.Error(),
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
} else if strings.Contains(contentType, "application/x-yaml") || strings.Contains(contentType, "text/yaml") {
|
||||
// Parse YAML
|
||||
@@ -102,18 +110,18 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
Success: false,
|
||||
Error: "Failed to parse YAML: " + err.Error(),
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
} else {
|
||||
// Try to auto-detect format
|
||||
if strings.TrimSpace(string(body))[0] == '{' {
|
||||
if len(body) > 0 && strings.TrimSpace(string(body))[0] == '{' {
|
||||
// Looks like JSON
|
||||
if err := json.Unmarshal(body, &modelConfig); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to parse JSON: " + err.Error(),
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
} else {
|
||||
// Assume YAML
|
||||
@@ -122,7 +130,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
Success: false,
|
||||
Error: "Failed to parse YAML: " + err.Error(),
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -133,7 +141,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
Success: false,
|
||||
Error: "Name is required",
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
// Set defaults
|
||||
@@ -145,7 +153,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
Success: false,
|
||||
Error: "Invalid configuration",
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
// Create the configuration file
|
||||
@@ -155,7 +163,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
Success: false,
|
||||
Error: "Model path not trusted: " + err.Error(),
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
// Marshal to YAML for storage
|
||||
@@ -165,7 +173,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
Success: false,
|
||||
Error: "Failed to marshal configuration: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
return c.JSON(http.StatusInternalServerError, response)
|
||||
}
|
||||
|
||||
// Write the file
|
||||
@@ -174,7 +182,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
Success: false,
|
||||
Error: "Failed to write configuration file: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
return c.JSON(http.StatusInternalServerError, response)
|
||||
}
|
||||
// Reload configurations
|
||||
if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath); err != nil {
|
||||
@@ -182,7 +190,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
Success: false,
|
||||
Error: "Failed to reload configurations: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
return c.JSON(http.StatusInternalServerError, response)
|
||||
}
|
||||
|
||||
// Preload the model
|
||||
@@ -191,7 +199,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
Success: false,
|
||||
Error: "Failed to preload model: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
return c.JSON(http.StatusInternalServerError, response)
|
||||
}
|
||||
// Return success response
|
||||
response := ModelResponse{
|
||||
@@ -199,6 +207,6 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
Message: "Model configuration created successfully",
|
||||
Filename: filepath.Base(configPath),
|
||||
}
|
||||
return c.JSON(response)
|
||||
return c.JSON(200, response)
|
||||
}
|
||||
}
|
||||
|
||||
323
core/http/endpoints/localai/mcp.go
Normal file
323
core/http/endpoints/localai/mcp.go
Normal file
@@ -0,0 +1,323 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/cogito"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// MCP SSE Event Types
|
||||
type MCPReasoningEvent struct {
|
||||
Type string `json:"type"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type MCPToolCallEvent struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]interface{} `json:"arguments"`
|
||||
Reasoning string `json:"reasoning"`
|
||||
}
|
||||
|
||||
type MCPToolResultEvent struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
Result string `json:"result"`
|
||||
}
|
||||
|
||||
type MCPStatusEvent struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type MCPAssistantEvent struct {
|
||||
Type string `json:"type"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type MCPErrorEvent struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// MCPStreamEndpoint is the SSE streaming endpoint for MCP chat completions
|
||||
// @Summary Stream MCP chat completions with reasoning, tool calls, and results
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/mcp/chat/completions [post]
|
||||
func MCPStreamEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
created := int(time.Now().Unix())
|
||||
|
||||
// Handle Correlation
|
||||
id := c.Request().Header.Get("X-Correlation-ID")
|
||||
if id == "" {
|
||||
id = fmt.Sprintf("mcp-%d", time.Now().UnixNano())
|
||||
}
|
||||
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
if config.MCP.Servers == "" && config.MCP.Stdio == "" {
|
||||
return fmt.Errorf("no MCP servers configured")
|
||||
}
|
||||
|
||||
// Get MCP config from model config
|
||||
remote, stdio, err := config.MCP.MCPConfigFromYAML()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get MCP config: %w", err)
|
||||
}
|
||||
|
||||
// Check if we have tools in cache, or we have to have an initial connection
|
||||
sessions, err := mcpTools.SessionsFromMCPConfig(config.Name, remote, stdio)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get MCP sessions: %w", err)
|
||||
}
|
||||
|
||||
if len(sessions) == 0 {
|
||||
return fmt.Errorf("no working MCP servers found")
|
||||
}
|
||||
|
||||
// Build fragment from messages
|
||||
fragment := cogito.NewEmptyFragment()
|
||||
for _, message := range input.Messages {
|
||||
fragment = fragment.AddMessage(message.Role, message.StringContent)
|
||||
}
|
||||
|
||||
port := appConfig.APIAddress[strings.LastIndex(appConfig.APIAddress, ":")+1:]
|
||||
apiKey := ""
|
||||
if len(appConfig.ApiKeys) > 0 {
|
||||
apiKey = appConfig.ApiKeys[0]
|
||||
}
|
||||
|
||||
ctxWithCancellation, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
// TODO: instead of connecting to the API, we should just wire this internally
|
||||
// and act like completion.go.
|
||||
// We can do this as cogito expects an interface and we can create one that
|
||||
// we satisfy to just call internally ComputeChoices
|
||||
defaultLLM := cogito.NewOpenAILLM(config.Name, apiKey, "http://127.0.0.1:"+port)
|
||||
|
||||
// Build cogito options using the consolidated method
|
||||
cogitoOpts := config.BuildCogitoOptions()
|
||||
cogitoOpts = append(
|
||||
cogitoOpts,
|
||||
cogito.WithContext(ctxWithCancellation),
|
||||
cogito.WithMCPs(sessions...),
|
||||
)
|
||||
// Check if streaming is requested
|
||||
toStream := input.Stream
|
||||
|
||||
if !toStream {
|
||||
// Non-streaming mode: execute synchronously and return JSON response
|
||||
cogitoOpts = append(
|
||||
cogitoOpts,
|
||||
cogito.WithStatusCallback(func(s string) {
|
||||
log.Debug().Msgf("[model agent] [model: %s] Status: %s", config.Name, s)
|
||||
}),
|
||||
cogito.WithReasoningCallback(func(s string) {
|
||||
log.Debug().Msgf("[model agent] [model: %s] Reasoning: %s", config.Name, s)
|
||||
}),
|
||||
cogito.WithToolCallBack(func(t *cogito.ToolChoice) bool {
|
||||
log.Debug().Str("model", config.Name).Str("tool", t.Name).Str("reasoning", t.Reasoning).Interface("arguments", t.Arguments).Msg("[model agent] Tool call")
|
||||
return true
|
||||
}),
|
||||
cogito.WithToolCallResultCallback(func(t cogito.ToolStatus) {
|
||||
log.Debug().Str("model", config.Name).Str("tool", t.Name).Str("result", t.Result).Interface("tool_arguments", t.ToolArguments).Msg("[model agent] Tool call result")
|
||||
}),
|
||||
)
|
||||
|
||||
f, err := cogito.ExecuteTools(
|
||||
defaultLLM, fragment,
|
||||
cogitoOpts...,
|
||||
)
|
||||
if err != nil && !errors.Is(err, cogito.ErrNoToolSelected) {
|
||||
return err
|
||||
}
|
||||
|
||||
f, err = defaultLLM.Ask(ctxWithCancellation, f)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{Message: &schema.Message{Role: "assistant", Content: &f.LastMessage().Content}}},
|
||||
Object: "chat.completion",
|
||||
}
|
||||
|
||||
jsonResult, _ := json.Marshal(resp)
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
|
||||
// Streaming mode: use SSE
|
||||
// Set up SSE headers
|
||||
c.Response().Header().Set("Content-Type", "text/event-stream")
|
||||
c.Response().Header().Set("Cache-Control", "no-cache")
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
c.Response().Header().Set("X-Correlation-ID", id)
|
||||
|
||||
// Create channel for streaming events
|
||||
events := make(chan interface{})
|
||||
ended := make(chan error, 1)
|
||||
|
||||
// Set up callbacks for streaming
|
||||
statusCallback := func(s string) {
|
||||
events <- MCPStatusEvent{
|
||||
Type: "status",
|
||||
Message: s,
|
||||
}
|
||||
}
|
||||
|
||||
reasoningCallback := func(s string) {
|
||||
events <- MCPReasoningEvent{
|
||||
Type: "reasoning",
|
||||
Content: s,
|
||||
}
|
||||
}
|
||||
|
||||
toolCallCallback := func(t *cogito.ToolChoice) bool {
|
||||
events <- MCPToolCallEvent{
|
||||
Type: "tool_call",
|
||||
Name: t.Name,
|
||||
Arguments: t.Arguments,
|
||||
Reasoning: t.Reasoning,
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
toolCallResultCallback := func(t cogito.ToolStatus) {
|
||||
events <- MCPToolResultEvent{
|
||||
Type: "tool_result",
|
||||
Name: t.Name,
|
||||
Result: t.Result,
|
||||
}
|
||||
}
|
||||
|
||||
cogitoOpts = append(cogitoOpts,
|
||||
cogito.WithStatusCallback(statusCallback),
|
||||
cogito.WithReasoningCallback(reasoningCallback),
|
||||
cogito.WithToolCallBack(toolCallCallback),
|
||||
cogito.WithToolCallResultCallback(toolCallResultCallback),
|
||||
)
|
||||
|
||||
// Execute tools in a goroutine
|
||||
go func() {
|
||||
defer close(events)
|
||||
|
||||
f, err := cogito.ExecuteTools(
|
||||
defaultLLM, fragment,
|
||||
cogitoOpts...,
|
||||
)
|
||||
if err != nil && !errors.Is(err, cogito.ErrNoToolSelected) {
|
||||
events <- MCPErrorEvent{
|
||||
Type: "error",
|
||||
Message: fmt.Sprintf("Failed to execute tools: %v", err),
|
||||
}
|
||||
ended <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Get final response
|
||||
f, err = defaultLLM.Ask(ctxWithCancellation, f)
|
||||
if err != nil {
|
||||
events <- MCPErrorEvent{
|
||||
Type: "error",
|
||||
Message: fmt.Sprintf("Failed to get response: %v", err),
|
||||
}
|
||||
ended <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Stream final assistant response
|
||||
content := f.LastMessage().Content
|
||||
events <- MCPAssistantEvent{
|
||||
Type: "assistant",
|
||||
Content: content,
|
||||
}
|
||||
|
||||
ended <- nil
|
||||
}()
|
||||
|
||||
// Stream events to client
|
||||
LOOP:
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Context was cancelled (client disconnected or request cancelled)
|
||||
log.Debug().Msgf("Request context cancelled, stopping stream")
|
||||
cancel()
|
||||
break LOOP
|
||||
case event := <-events:
|
||||
if event == nil {
|
||||
// Channel closed
|
||||
break LOOP
|
||||
}
|
||||
eventData, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
log.Debug().Msgf("Failed to marshal event: %v", err)
|
||||
continue
|
||||
}
|
||||
log.Debug().Msgf("Sending event: %s", string(eventData))
|
||||
_, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(eventData))
|
||||
if err != nil {
|
||||
log.Debug().Msgf("Sending event failed: %v", err)
|
||||
cancel()
|
||||
return err
|
||||
}
|
||||
c.Response().Flush()
|
||||
case err := <-ended:
|
||||
if err == nil {
|
||||
// Send done signal
|
||||
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
|
||||
c.Response().Flush()
|
||||
break LOOP
|
||||
}
|
||||
log.Error().Msgf("Stream ended with error: %v", err)
|
||||
errorEvent := MCPErrorEvent{
|
||||
Type: "error",
|
||||
Message: err.Error(),
|
||||
}
|
||||
errorData, marshalErr := json.Marshal(errorEvent)
|
||||
if marshalErr != nil {
|
||||
fmt.Fprintf(c.Response().Writer, "data: {\"type\":\"error\",\"message\":\"Internal error\"}\n\n")
|
||||
} else {
|
||||
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(errorData))
|
||||
}
|
||||
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
|
||||
c.Response().Flush()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Stream ended")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -1,46 +1,47 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/adaptor"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
)
|
||||
|
||||
// LocalAIMetricsEndpoint returns the metrics endpoint for LocalAI
|
||||
// @Summary Prometheus metrics endpoint
|
||||
// @Param request body config.Gallery true "Gallery details"
|
||||
// @Router /metrics [get]
|
||||
func LocalAIMetricsEndpoint() fiber.Handler {
|
||||
return adaptor.HTTPHandler(promhttp.Handler())
|
||||
}
|
||||
|
||||
type apiMiddlewareConfig struct {
|
||||
Filter func(c *fiber.Ctx) bool
|
||||
metricsService *services.LocalAIMetricsService
|
||||
}
|
||||
|
||||
func LocalAIMetricsAPIMiddleware(metrics *services.LocalAIMetricsService) fiber.Handler {
|
||||
cfg := apiMiddlewareConfig{
|
||||
metricsService: metrics,
|
||||
Filter: func(c *fiber.Ctx) bool {
|
||||
return c.Path() == "/metrics"
|
||||
},
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
if cfg.Filter != nil && cfg.Filter(c) {
|
||||
return c.Next()
|
||||
}
|
||||
path := c.Path()
|
||||
method := c.Method()
|
||||
|
||||
start := time.Now()
|
||||
err := c.Next()
|
||||
elapsed := float64(time.Since(start)) / float64(time.Second)
|
||||
cfg.metricsService.ObserveAPICall(method, path, elapsed)
|
||||
return err
|
||||
}
|
||||
}
|
||||
package localai
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
)
|
||||
|
||||
// LocalAIMetricsEndpoint returns the metrics endpoint for LocalAI
|
||||
// @Summary Prometheus metrics endpoint
|
||||
// @Param request body config.Gallery true "Gallery details"
|
||||
// @Router /metrics [get]
|
||||
func LocalAIMetricsEndpoint() echo.HandlerFunc {
|
||||
return echo.WrapHandler(promhttp.Handler())
|
||||
}
|
||||
|
||||
type apiMiddlewareConfig struct {
|
||||
Filter func(c echo.Context) bool
|
||||
metricsService *services.LocalAIMetricsService
|
||||
}
|
||||
|
||||
func LocalAIMetricsAPIMiddleware(metrics *services.LocalAIMetricsService) echo.MiddlewareFunc {
|
||||
cfg := apiMiddlewareConfig{
|
||||
metricsService: metrics,
|
||||
Filter: func(c echo.Context) bool {
|
||||
return c.Path() == "/metrics"
|
||||
},
|
||||
}
|
||||
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if cfg.Filter != nil && cfg.Filter(c) {
|
||||
return next(c)
|
||||
}
|
||||
path := c.Path()
|
||||
method := c.Request().Method
|
||||
|
||||
start := time.Now()
|
||||
err := next(c)
|
||||
elapsed := float64(time.Since(start)) / float64(time.Second)
|
||||
cfg.metricsService.ObserveAPICall(method, path, elapsed)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/p2p"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
@@ -11,10 +11,10 @@ import (
|
||||
// @Summary Returns available P2P nodes
|
||||
// @Success 200 {object} []schema.P2PNodesResponse "Response"
|
||||
// @Router /api/p2p [get]
|
||||
func ShowP2PNodes(appConfig *config.ApplicationConfig) func(*fiber.Ctx) error {
|
||||
func ShowP2PNodes(appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
// Render index
|
||||
return func(c *fiber.Ctx) error {
|
||||
return c.JSON(schema.P2PNodesResponse{
|
||||
return func(c echo.Context) error {
|
||||
return c.JSON(200, schema.P2PNodesResponse{
|
||||
Nodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID)),
|
||||
FederatedNodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID)),
|
||||
})
|
||||
@@ -25,6 +25,6 @@ func ShowP2PNodes(appConfig *config.ApplicationConfig) func(*fiber.Ctx) error {
|
||||
// @Summary Show the P2P token
|
||||
// @Success 200 {string} string "Response"
|
||||
// @Router /api/p2p/token [get]
|
||||
func ShowP2PToken(appConfig *config.ApplicationConfig) func(*fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error { return c.Send([]byte(appConfig.P2PToken)) }
|
||||
func ShowP2PToken(appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error { return c.String(200, appConfig.P2PToken) }
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
@@ -9,11 +9,11 @@ import (
|
||||
"github.com/mudler/LocalAI/pkg/store"
|
||||
)
|
||||
|
||||
func StoresSetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func StoresSetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input := new(schema.StoresSet)
|
||||
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
if err := c.Bind(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -28,20 +28,20 @@ func StoresSetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfi
|
||||
vals[i] = []byte(v)
|
||||
}
|
||||
|
||||
err = store.SetCols(c.Context(), sb, input.Keys, vals)
|
||||
err = store.SetCols(c.Request().Context(), sb, input.Keys, vals)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.Send(nil)
|
||||
return c.NoContent(200)
|
||||
}
|
||||
}
|
||||
|
||||
func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input := new(schema.StoresDelete)
|
||||
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
if err := c.Bind(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -51,19 +51,19 @@ func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationCo
|
||||
}
|
||||
defer sl.Close()
|
||||
|
||||
if err := store.DeleteCols(c.Context(), sb, input.Keys); err != nil {
|
||||
if err := store.DeleteCols(c.Request().Context(), sb, input.Keys); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.Send(nil)
|
||||
return c.NoContent(200)
|
||||
}
|
||||
}
|
||||
|
||||
func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input := new(schema.StoresGet)
|
||||
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
if err := c.Bind(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -73,7 +73,7 @@ func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfi
|
||||
}
|
||||
defer sl.Close()
|
||||
|
||||
keys, vals, err := store.GetCols(c.Context(), sb, input.Keys)
|
||||
keys, vals, err := store.GetCols(c.Request().Context(), sb, input.Keys)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -87,15 +87,15 @@ func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfi
|
||||
res.Values[i] = string(v)
|
||||
}
|
||||
|
||||
return c.JSON(res)
|
||||
return c.JSON(200, res)
|
||||
}
|
||||
}
|
||||
|
||||
func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input := new(schema.StoresFind)
|
||||
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
if err := c.Bind(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -105,7 +105,7 @@ func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConf
|
||||
}
|
||||
defer sl.Close()
|
||||
|
||||
keys, vals, similarities, err := store.Find(c.Context(), sb, input.Key, input.Topk)
|
||||
keys, vals, similarities, err := store.Find(c.Request().Context(), sb, input.Key, input.Topk)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -120,6 +120,6 @@ func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConf
|
||||
res.Values[i] = string(v)
|
||||
}
|
||||
|
||||
return c.JSON(res)
|
||||
return c.JSON(200, res)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
@@ -11,8 +11,8 @@ import (
|
||||
// @Summary Show the LocalAI instance information
|
||||
// @Success 200 {object} schema.SystemInformationResponse "Response"
|
||||
// @Router /system [get]
|
||||
func SystemInformations(ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(*fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func SystemInformations(ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
availableBackends := []string{}
|
||||
loadedModels := ml.ListLoadedModels()
|
||||
for b := range appConfig.ExternalGRPCBackends {
|
||||
@@ -26,7 +26,7 @@ func SystemInformations(ml *model.ModelLoader, appConfig *config.ApplicationConf
|
||||
for _, m := range loadedModels {
|
||||
sysmodels = append(sysmodels, schema.SysInfoModel{ID: m.ID})
|
||||
}
|
||||
return c.JSON(
|
||||
return c.JSON(200,
|
||||
schema.SystemInformationResponse{
|
||||
Backends: availableBackends,
|
||||
Models: sysmodels,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
@@ -14,22 +14,22 @@ import (
|
||||
// @Param request body schema.TokenizeRequest true "Request"
|
||||
// @Success 200 {object} schema.TokenizeResponse "Response"
|
||||
// @Router /v1/tokenize [post]
|
||||
func TokenizeEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
input, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TokenizeRequest)
|
||||
func TokenizeEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TokenizeRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
tokenResponse, err := backend.ModelTokenize(input.Content, ml, *cfg, appConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ctx.JSON(tokenResponse)
|
||||
return c.JSON(200, tokenResponse)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
@@ -22,16 +24,16 @@ import (
|
||||
// @Success 200 {string} binary "generated audio/wav file"
|
||||
// @Router /v1/audio/speech [post]
|
||||
// @Router /tts [post]
|
||||
func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TTSRequest)
|
||||
func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TTSRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
log.Debug().Str("model", input.Model).Msg("LocalAI TTS Request received")
|
||||
@@ -59,6 +61,6 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
|
||||
return err
|
||||
}
|
||||
|
||||
return c.Download(filePath)
|
||||
return c.Attachment(filePath, filepath.Base(filePath))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
@@ -16,26 +16,26 @@ import (
|
||||
// @Param request body schema.VADRequest true "query params"
|
||||
// @Success 200 {object} proto.VADResponse "Response"
|
||||
// @Router /vad [post]
|
||||
func VADEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VADRequest)
|
||||
func VADEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VADRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
log.Debug().Str("model", input.Model).Msg("LocalAI VAD Request received")
|
||||
|
||||
resp, err := backend.VAD(input, c.Context(), ml, appConfig, *cfg)
|
||||
resp, err := backend.VAD(input, c.Request().Context(), ml, appConfig, *cfg)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.JSON(resp)
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,19 +7,20 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
@@ -64,18 +65,18 @@ func downloadFile(url string) (string, error) {
|
||||
// @Param request body schema.VideoRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /video [post]
|
||||
func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VideoRequest)
|
||||
func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VideoRequest)
|
||||
if !ok || input.Model == "" {
|
||||
log.Error().Msg("Video Endpoint - Invalid Input")
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
log.Error().Msg("Video Endpoint - Invalid Config")
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
src := ""
|
||||
@@ -164,7 +165,7 @@ func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
|
||||
return err
|
||||
}
|
||||
|
||||
baseURL := c.BaseURL()
|
||||
baseURL := middleware.BaseURL(c)
|
||||
|
||||
fn, err := backend.VideoGeneration(
|
||||
height,
|
||||
@@ -201,7 +202,10 @@ func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
|
||||
item.B64JSON = base64.StdEncoding.EncodeToString(data)
|
||||
} else {
|
||||
base := filepath.Base(output)
|
||||
item.URL = baseURL + "/generated-videos/" + base
|
||||
item.URL, err = url.JoinPath(baseURL, "generated-videos", base)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
id := uuid.New().String()
|
||||
@@ -216,6 +220,6 @@ func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,18 +1,20 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/http/utils"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func WelcomeEndpoint(appConfig *config.ApplicationConfig,
|
||||
cl *config.ModelConfigLoader, ml *model.ModelLoader, opcache *services.OpCache) func(*fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
cl *config.ModelConfigLoader, ml *model.ModelLoader, opcache *services.OpCache) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
modelConfigs := cl.GetAllModelsConfigs()
|
||||
galleryConfigs := map[string]*gallery.ModelConfig{}
|
||||
|
||||
@@ -40,10 +42,10 @@ func WelcomeEndpoint(appConfig *config.ApplicationConfig,
|
||||
// Get model statuses to display in the UI the operation in progress
|
||||
processingModels, taskTypes := opcache.GetStatus()
|
||||
|
||||
summary := fiber.Map{
|
||||
summary := map[string]interface{}{
|
||||
"Title": "LocalAI API - " + internal.PrintableVersion(),
|
||||
"Version": internal.PrintableVersion(),
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"BaseURL": middleware.BaseURL(c),
|
||||
"Models": modelsWithoutConfig,
|
||||
"ModelsConfig": modelConfigs,
|
||||
"GalleryConfig": galleryConfigs,
|
||||
@@ -54,12 +56,21 @@ func WelcomeEndpoint(appConfig *config.ApplicationConfig,
|
||||
"InstalledBackends": installedBackends,
|
||||
}
|
||||
|
||||
if string(c.Context().Request.Header.ContentType()) == "application/json" || len(c.Accepts("html")) == 0 {
|
||||
contentType := c.Request().Header.Get("Content-Type")
|
||||
accept := c.Request().Header.Get("Accept")
|
||||
// Default to HTML if Accept header is empty (browser behavior)
|
||||
// Only return JSON if explicitly requested or Content-Type is application/json
|
||||
if strings.Contains(contentType, "application/json") || (accept != "" && !strings.Contains(accept, "text/html")) {
|
||||
// The client expects a JSON response
|
||||
return c.Status(fiber.StatusOK).JSON(summary)
|
||||
return c.JSON(200, summary)
|
||||
} else {
|
||||
// Render index
|
||||
return c.Render("views/index", summary)
|
||||
// Check if this is the manage route
|
||||
templateName := "views/index"
|
||||
if strings.HasSuffix(c.Request().URL.Path, "/manage") || c.Request().URL.Path == "/manage" {
|
||||
templateName = "views/manage"
|
||||
}
|
||||
// Render appropriate template
|
||||
return c.Render(200, templateName, summary)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
@@ -20,68 +17,14 @@ import (
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// NOTE: this is a bad WORKAROUND! We should find a better way to handle this.
|
||||
// Fasthttp doesn't support context cancellation from the caller
|
||||
// for non-streaming requests, so we need to monitor the connection directly.
|
||||
// Monitor connection for client disconnection during non-streaming requests
|
||||
// We access the connection directly via c.Context().Conn() to monitor it
|
||||
// during ComputeChoices execution, not after the response is sent
|
||||
// see: https://github.com/mudler/LocalAI/pull/7187#issuecomment-3506720906
|
||||
func handleConnectionCancellation(c *fiber.Ctx, cancelFunc func(), requestCtx context.Context) {
|
||||
var conn net.Conn = c.Context().Conn()
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
// Clear read deadline when goroutine exits
|
||||
conn.SetReadDeadline(time.Time{})
|
||||
}()
|
||||
|
||||
buf := make([]byte, 1)
|
||||
// Use a short read deadline to periodically check if connection is closed
|
||||
// Without a deadline, Read() would block indefinitely waiting for data
|
||||
// that will never come (client is waiting for response, not sending more data)
|
||||
ticker := time.NewTicker(100 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-requestCtx.Done():
|
||||
// Request completed or was cancelled - exit goroutine
|
||||
return
|
||||
case <-ticker.C:
|
||||
// Set a short deadline - if connection is closed, read will fail immediately
|
||||
// If connection is open but no data, it will timeout and we check again
|
||||
conn.SetReadDeadline(time.Now().Add(50 * time.Millisecond))
|
||||
_, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
// Check if it's a timeout (connection still open, just no data)
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
// Timeout is expected - connection is still open, just no data to read
|
||||
// Continue the loop to check again
|
||||
continue
|
||||
}
|
||||
// Connection closed or other error - cancel the context to stop gRPC call
|
||||
log.Debug().Msgf("Calling cancellation function")
|
||||
cancelFunc()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// ChatEndpoint is the OpenAI Completion API endpoint https://platform.openai.com/docs/api-reference/chat/create
|
||||
// @Summary Generate a chat completions for a given prompt and model.
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/chat/completions [post]
|
||||
func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig) echo.HandlerFunc {
|
||||
var id, textContentToReturn string
|
||||
var created int
|
||||
|
||||
@@ -235,21 +178,21 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
return err
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
return func(c echo.Context) error {
|
||||
textContentToReturn = ""
|
||||
id = uuid.New().String()
|
||||
created = int(time.Now().Unix())
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
extraUsage := c.Get("Extra-Usage", "") != ""
|
||||
extraUsage := c.Request().Header.Get("Extra-Usage") != ""
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Chat endpoint configuration read: %+v", config)
|
||||
@@ -392,13 +335,10 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
case toStream:
|
||||
|
||||
log.Debug().Msgf("Stream request received")
|
||||
c.Context().SetContentType("text/event-stream")
|
||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||
// c.Set("Content-Type", "text/event-stream")
|
||||
c.Set("Cache-Control", "no-cache")
|
||||
c.Set("Connection", "keep-alive")
|
||||
c.Set("Transfer-Encoding", "chunked")
|
||||
c.Set("X-Correlation-ID", id)
|
||||
c.Response().Header().Set("Content-Type", "text/event-stream")
|
||||
c.Response().Header().Set("Cache-Control", "no-cache")
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
c.Response().Header().Set("X-Correlation-ID", id)
|
||||
|
||||
responses := make(chan schema.OpenAIResponse)
|
||||
ended := make(chan error, 1)
|
||||
@@ -411,103 +351,101 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
}
|
||||
}()
|
||||
|
||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
||||
usage := &schema.OpenAIUsage{}
|
||||
toolsCalled := false
|
||||
usage := &schema.OpenAIUsage{}
|
||||
toolsCalled := false
|
||||
|
||||
LOOP:
|
||||
for {
|
||||
select {
|
||||
case <-input.Context.Done():
|
||||
// Context was cancelled (client disconnected or request cancelled)
|
||||
log.Debug().Msgf("Request context cancelled, stopping stream")
|
||||
input.Cancel()
|
||||
break LOOP
|
||||
case ev := <-responses:
|
||||
if len(ev.Choices) == 0 {
|
||||
log.Debug().Msgf("No choices in the response, skipping")
|
||||
continue
|
||||
}
|
||||
usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
|
||||
if len(ev.Choices[0].Delta.ToolCalls) > 0 {
|
||||
toolsCalled = true
|
||||
}
|
||||
respData, err := json.Marshal(ev)
|
||||
if err != nil {
|
||||
log.Debug().Msgf("Failed to marshal response: %v", err)
|
||||
input.Cancel()
|
||||
continue
|
||||
}
|
||||
log.Debug().Msgf("Sending chunk: %s", string(respData))
|
||||
_, err = fmt.Fprintf(w, "data: %s\n\n", string(respData))
|
||||
if err != nil {
|
||||
log.Debug().Msgf("Sending chunk failed: %v", err)
|
||||
input.Cancel()
|
||||
}
|
||||
w.Flush()
|
||||
case err := <-ended:
|
||||
if err == nil {
|
||||
break LOOP
|
||||
}
|
||||
log.Error().Msgf("Stream ended with error: %v", err)
|
||||
|
||||
stopReason := FinishReasonStop
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{
|
||||
{
|
||||
FinishReason: &stopReason,
|
||||
Index: 0,
|
||||
Delta: &schema.Message{Content: "Internal error: " + err.Error()},
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
Usage: *usage,
|
||||
}
|
||||
respData, marshalErr := json.Marshal(resp)
|
||||
if marshalErr != nil {
|
||||
log.Error().Msgf("Failed to marshal error response: %v", marshalErr)
|
||||
// Send a simple error message as fallback
|
||||
w.WriteString("data: {\"error\":\"Internal error\"}\n\n")
|
||||
} else {
|
||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
|
||||
}
|
||||
w.WriteString("data: [DONE]\n\n")
|
||||
w.Flush()
|
||||
|
||||
return
|
||||
LOOP:
|
||||
for {
|
||||
select {
|
||||
case <-input.Context.Done():
|
||||
// Context was cancelled (client disconnected or request cancelled)
|
||||
log.Debug().Msgf("Request context cancelled, stopping stream")
|
||||
input.Cancel()
|
||||
break LOOP
|
||||
case ev := <-responses:
|
||||
if len(ev.Choices) == 0 {
|
||||
log.Debug().Msgf("No choices in the response, skipping")
|
||||
continue
|
||||
}
|
||||
usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
|
||||
if len(ev.Choices[0].Delta.ToolCalls) > 0 {
|
||||
toolsCalled = true
|
||||
}
|
||||
respData, err := json.Marshal(ev)
|
||||
if err != nil {
|
||||
log.Debug().Msgf("Failed to marshal response: %v", err)
|
||||
input.Cancel()
|
||||
continue
|
||||
}
|
||||
log.Debug().Msgf("Sending chunk: %s", string(respData))
|
||||
_, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(respData))
|
||||
if err != nil {
|
||||
log.Debug().Msgf("Sending chunk failed: %v", err)
|
||||
input.Cancel()
|
||||
return err
|
||||
}
|
||||
c.Response().Flush()
|
||||
case err := <-ended:
|
||||
if err == nil {
|
||||
break LOOP
|
||||
}
|
||||
log.Error().Msgf("Stream ended with error: %v", err)
|
||||
|
||||
stopReason := FinishReasonStop
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{
|
||||
{
|
||||
FinishReason: &stopReason,
|
||||
Index: 0,
|
||||
Delta: &schema.Message{Content: "Internal error: " + err.Error()},
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
Usage: *usage,
|
||||
}
|
||||
respData, marshalErr := json.Marshal(resp)
|
||||
if marshalErr != nil {
|
||||
log.Error().Msgf("Failed to marshal error response: %v", marshalErr)
|
||||
// Send a simple error message as fallback
|
||||
fmt.Fprintf(c.Response().Writer, "data: {\"error\":\"Internal error\"}\n\n")
|
||||
} else {
|
||||
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
|
||||
}
|
||||
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
|
||||
c.Response().Flush()
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
finishReason := FinishReasonStop
|
||||
if toolsCalled && len(input.Tools) > 0 {
|
||||
finishReason = FinishReasonToolCalls
|
||||
} else if toolsCalled {
|
||||
finishReason = FinishReasonFunctionCall
|
||||
}
|
||||
finishReason := FinishReasonStop
|
||||
if toolsCalled && len(input.Tools) > 0 {
|
||||
finishReason = FinishReasonToolCalls
|
||||
} else if toolsCalled {
|
||||
finishReason = FinishReasonFunctionCall
|
||||
}
|
||||
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{
|
||||
{
|
||||
FinishReason: &finishReason,
|
||||
Index: 0,
|
||||
Delta: &schema.Message{},
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
Usage: *usage,
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
|
||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
|
||||
w.WriteString("data: [DONE]\n\n")
|
||||
w.Flush()
|
||||
log.Debug().Msgf("Stream ended")
|
||||
}))
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{
|
||||
{
|
||||
FinishReason: &finishReason,
|
||||
Index: 0,
|
||||
Delta: &schema.Message{},
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
Usage: *usage,
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
|
||||
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
|
||||
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
|
||||
c.Response().Flush()
|
||||
log.Debug().Msgf("Stream ended")
|
||||
return nil
|
||||
|
||||
// no streaming mode
|
||||
@@ -589,9 +527,8 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
|
||||
}
|
||||
|
||||
// NOTE: this is a workaround as fasthttp
|
||||
// context cancellation does not fire in non-streaming requests
|
||||
handleConnectionCancellation(c, input.Cancel, input.Context)
|
||||
// Echo properly supports context cancellation via c.Request().Context()
|
||||
// No workaround needed!
|
||||
|
||||
result, tokenUsage, err := ComputeChoices(
|
||||
input,
|
||||
@@ -628,7 +565,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
log.Debug().Msgf("Response: %s", respData)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -698,7 +635,32 @@ func handleQuestion(config *config.ModelConfig, cl *config.ModelConfigLoader, in
|
||||
}
|
||||
}
|
||||
|
||||
predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, config, cl, o, nil, toolsJSON, toolChoiceJSON)
|
||||
// Extract logprobs from request
|
||||
// According to OpenAI API: logprobs is boolean, top_logprobs (0-20) controls how many top tokens per position
|
||||
var logprobs *int
|
||||
var topLogprobs *int
|
||||
if input.Logprobs.IsEnabled() {
|
||||
// If logprobs is enabled, use top_logprobs if provided, otherwise default to 1
|
||||
if input.TopLogprobs != nil {
|
||||
topLogprobs = input.TopLogprobs
|
||||
// For backend compatibility, set logprobs to the top_logprobs value
|
||||
logprobs = input.TopLogprobs
|
||||
} else {
|
||||
// Default to 1 if logprobs is true but top_logprobs not specified
|
||||
val := 1
|
||||
logprobs = &val
|
||||
topLogprobs = &val
|
||||
}
|
||||
}
|
||||
|
||||
// Extract logit_bias from request
|
||||
// According to OpenAI API: logit_bias is a map of token IDs (as strings) to bias values (-100 to 100)
|
||||
var logitBias map[string]float64
|
||||
if len(input.LogitBias) > 0 {
|
||||
logitBias = input.LogitBias
|
||||
}
|
||||
|
||||
predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, config, cl, o, nil, toolsJSON, toolChoiceJSON, logprobs, topLogprobs, logitBias)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("model inference failed")
|
||||
return "", err
|
||||
|
||||
@@ -1,24 +1,22 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// CompletionEndpoint is the OpenAI Completion API endpoint https://platform.openai.com/docs/api-reference/completions
|
||||
@@ -26,7 +24,7 @@ import (
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/completions [post]
|
||||
func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
process := func(id string, s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error {
|
||||
tokenCallback := func(s string, tokenUsage backend.TokenUsage) bool {
|
||||
created := int(time.Now().Unix())
|
||||
@@ -64,22 +62,25 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
return err
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
return func(c echo.Context) error {
|
||||
|
||||
created := int(time.Now().Unix())
|
||||
|
||||
// Handle Correlation
|
||||
id := c.Get("X-Correlation-ID", uuid.New().String())
|
||||
extraUsage := c.Get("Extra-Usage", "") != ""
|
||||
id := c.Request().Header.Get("X-Correlation-ID")
|
||||
if id == "" {
|
||||
id = uuid.New().String()
|
||||
}
|
||||
extraUsage := c.Request().Header.Get("Extra-Usage") != ""
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
if config.ResponseFormatMap != nil {
|
||||
@@ -97,15 +98,10 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
|
||||
if input.Stream {
|
||||
log.Debug().Msgf("Stream request received")
|
||||
c.Context().SetContentType("text/event-stream")
|
||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||
//c.Set("Content-Type", "text/event-stream")
|
||||
c.Set("Cache-Control", "no-cache")
|
||||
c.Set("Connection", "keep-alive")
|
||||
c.Set("Transfer-Encoding", "chunked")
|
||||
}
|
||||
c.Response().Header().Set("Content-Type", "text/event-stream")
|
||||
c.Response().Header().Set("Cache-Control", "no-cache")
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
|
||||
if input.Stream {
|
||||
if len(config.PromptStrings) > 1 {
|
||||
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming")
|
||||
}
|
||||
@@ -130,78 +126,78 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
ended <- process(id, predInput, input, config, ml, responses, extraUsage)
|
||||
}()
|
||||
|
||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
||||
LOOP:
|
||||
for {
|
||||
select {
|
||||
case ev := <-responses:
|
||||
if len(ev.Choices) == 0 {
|
||||
log.Debug().Msgf("No choices in the response, skipping")
|
||||
continue
|
||||
}
|
||||
respData, err := json.Marshal(ev)
|
||||
if err != nil {
|
||||
log.Debug().Msgf("Failed to marshal response: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
LOOP:
|
||||
for {
|
||||
select {
|
||||
case ev := <-responses:
|
||||
if len(ev.Choices) == 0 {
|
||||
log.Debug().Msgf("No choices in the response, skipping")
|
||||
continue
|
||||
}
|
||||
respData, err := json.Marshal(ev)
|
||||
if err != nil {
|
||||
log.Debug().Msgf("Failed to marshal response: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Sending chunk: %s", string(respData))
|
||||
fmt.Fprintf(w, "data: %s\n\n", string(respData))
|
||||
w.Flush()
|
||||
case err := <-ended:
|
||||
if err == nil {
|
||||
break LOOP
|
||||
}
|
||||
log.Error().Msgf("Stream ended with error: %v", err)
|
||||
|
||||
stopReason := FinishReasonStop
|
||||
errorResp := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model,
|
||||
Choices: []schema.Choice{
|
||||
{
|
||||
Index: 0,
|
||||
FinishReason: &stopReason,
|
||||
Text: "Internal error: " + err.Error(),
|
||||
},
|
||||
},
|
||||
Object: "text_completion",
|
||||
}
|
||||
errorData, marshalErr := json.Marshal(errorResp)
|
||||
if marshalErr != nil {
|
||||
log.Error().Msgf("Failed to marshal error response: %v", marshalErr)
|
||||
// Send a simple error message as fallback
|
||||
fmt.Fprintf(w, "data: {\"error\":\"Internal error\"}\n\n")
|
||||
} else {
|
||||
fmt.Fprintf(w, "data: %s\n\n", string(errorData))
|
||||
}
|
||||
w.Flush()
|
||||
log.Debug().Msgf("Sending chunk: %s", string(respData))
|
||||
_, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(respData))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.Response().Flush()
|
||||
case err := <-ended:
|
||||
if err == nil {
|
||||
break LOOP
|
||||
}
|
||||
}
|
||||
log.Error().Msgf("Stream ended with error: %v", err)
|
||||
|
||||
stopReason := FinishReasonStop
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{
|
||||
{
|
||||
Index: 0,
|
||||
FinishReason: &stopReason,
|
||||
stopReason := FinishReasonStop
|
||||
errorResp := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model,
|
||||
Choices: []schema.Choice{
|
||||
{
|
||||
Index: 0,
|
||||
FinishReason: &stopReason,
|
||||
Text: "Internal error: " + err.Error(),
|
||||
},
|
||||
},
|
||||
},
|
||||
Object: "text_completion",
|
||||
Object: "text_completion",
|
||||
}
|
||||
errorData, marshalErr := json.Marshal(errorResp)
|
||||
if marshalErr != nil {
|
||||
log.Error().Msgf("Failed to marshal error response: %v", marshalErr)
|
||||
// Send a simple error message as fallback
|
||||
fmt.Fprintf(c.Response().Writer, "data: {\"error\":\"Internal error\"}\n\n")
|
||||
} else {
|
||||
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(errorData))
|
||||
}
|
||||
c.Response().Flush()
|
||||
return nil
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
}
|
||||
|
||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
|
||||
w.WriteString("data: [DONE]\n\n")
|
||||
w.Flush()
|
||||
}))
|
||||
return <-ended
|
||||
stopReason := FinishReasonStop
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{
|
||||
{
|
||||
Index: 0,
|
||||
FinishReason: &stopReason,
|
||||
},
|
||||
},
|
||||
Object: "text_completion",
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
|
||||
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
|
||||
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
|
||||
c.Response().Flush()
|
||||
return nil
|
||||
}
|
||||
|
||||
var result []schema.Choice
|
||||
@@ -257,6 +253,6 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,11 +4,11 @@ import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
@@ -23,20 +23,20 @@ import (
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/edits [post]
|
||||
func EditEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func EditEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
return func(c echo.Context) error {
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
// Opt-in extra usage flag
|
||||
extraUsage := c.Get("Extra-Usage", "") != ""
|
||||
extraUsage := c.Request().Header.Get("Extra-Usage") != ""
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Edit Endpoint Input : %+v", input)
|
||||
@@ -98,6 +98,6 @@ func EditEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
@@ -12,7 +13,6 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
@@ -21,16 +21,16 @@ import (
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/embeddings [post]
|
||||
func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||
@@ -78,6 +78,6 @@ func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
@@ -14,13 +15,13 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
@@ -65,18 +66,18 @@ func downloadFile(url string) (string, error) {
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/images/generations [post]
|
||||
func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
log.Error().Msg("Image Endpoint - Invalid Input")
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
log.Error().Msg("Image Endpoint - Invalid Config")
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
// Process input images (for img2img/inpainting)
|
||||
@@ -188,7 +189,7 @@ func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
|
||||
return err
|
||||
}
|
||||
|
||||
baseURL := c.BaseURL()
|
||||
baseURL := middleware.BaseURL(c)
|
||||
|
||||
// Use the first input image as src if available, otherwise use the original src
|
||||
inputSrc := src
|
||||
@@ -215,7 +216,10 @@ func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
|
||||
item.B64JSON = base64.StdEncoding.EncodeToString(data)
|
||||
} else {
|
||||
base := filepath.Base(output)
|
||||
item.URL = baseURL + "/generated-images/" + base
|
||||
item.URL, err = url.JoinPath(baseURL, "generated-images", base)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
result = append(result, *item)
|
||||
@@ -234,7 +238,7 @@ func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -55,9 +55,34 @@ func ComputeChoices(
|
||||
}
|
||||
}
|
||||
|
||||
// Extract logprobs from request
|
||||
// According to OpenAI API: logprobs is boolean, top_logprobs (0-20) controls how many top tokens per position
|
||||
var logprobs *int
|
||||
var topLogprobs *int
|
||||
if req.Logprobs.IsEnabled() {
|
||||
// If logprobs is enabled, use top_logprobs if provided, otherwise default to 1
|
||||
if req.TopLogprobs != nil {
|
||||
topLogprobs = req.TopLogprobs
|
||||
// For backend compatibility, set logprobs to the top_logprobs value
|
||||
logprobs = req.TopLogprobs
|
||||
} else {
|
||||
// Default to 1 if logprobs is true but top_logprobs not specified
|
||||
val := 1
|
||||
logprobs = &val
|
||||
topLogprobs = &val
|
||||
}
|
||||
}
|
||||
|
||||
// Extract logit_bias from request
|
||||
// According to OpenAI API: logit_bias is a map of token IDs (as strings) to bias values (-100 to 100)
|
||||
var logitBias map[string]float64
|
||||
if len(req.LogitBias) > 0 {
|
||||
logitBias = req.LogitBias
|
||||
}
|
||||
|
||||
// get the model function to call for the result
|
||||
predFunc, err := backend.ModelInference(
|
||||
req.Context, predInput, req.Messages, images, videos, audios, loader, config, bcl, o, tokenCallback, toolsJSON, toolChoiceJSON)
|
||||
req.Context, predInput, req.Messages, images, videos, audios, loader, config, bcl, o, tokenCallback, toolsJSON, toolChoiceJSON, logprobs, topLogprobs, logitBias)
|
||||
if err != nil {
|
||||
return result, backend.TokenUsage{}, err
|
||||
}
|
||||
@@ -78,6 +103,11 @@ func ComputeChoices(
|
||||
finetunedResponse := backend.Finetune(*config, predInput, prediction.Response)
|
||||
cb(finetunedResponse, &result)
|
||||
|
||||
// Add logprobs to the last choice if present
|
||||
if prediction.Logprobs != nil && len(result) > 0 {
|
||||
result[len(result)-1].Logprobs = prediction.Logprobs
|
||||
}
|
||||
|
||||
//result = append(result, Choice{Text: prediction})
|
||||
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
@@ -12,14 +12,15 @@ import (
|
||||
// @Summary List and describe the various models available in the API.
|
||||
// @Success 200 {object} schema.ModelsDataResponse "Response"
|
||||
// @Router /v1/models [get]
|
||||
func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(ctx *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
// If blank, no filter is applied.
|
||||
filter := c.Query("filter")
|
||||
filter := c.QueryParam("filter")
|
||||
|
||||
// By default, exclude any loose files that are already referenced by a configuration file.
|
||||
var policy services.LooseFilePolicy
|
||||
if c.QueryBool("excludeConfigured", true) {
|
||||
excludeConfigured := c.QueryParam("excludeConfigured")
|
||||
if excludeConfigured == "" || excludeConfigured == "true" {
|
||||
policy = services.SKIP_IF_CONFIGURED
|
||||
} else {
|
||||
policy = services.ALWAYS_INCLUDE // This replicates current behavior. TODO: give more options to the user?
|
||||
@@ -41,7 +42,7 @@ func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, ap
|
||||
dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"})
|
||||
}
|
||||
|
||||
return c.JSON(schema.ModelsDataResponse{
|
||||
return c.JSON(200, schema.ModelsDataResponse{
|
||||
Object: "list",
|
||||
Data: dataModels,
|
||||
})
|
||||
|
||||
@@ -8,11 +8,11 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
@@ -26,24 +26,27 @@ import (
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /mcp/v1/completions [post]
|
||||
func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
// We do not support streaming mode (Yet?)
|
||||
return func(c *fiber.Ctx) error {
|
||||
return func(c echo.Context) error {
|
||||
created := int(time.Now().Unix())
|
||||
|
||||
ctx := c.Context()
|
||||
ctx := c.Request().Context()
|
||||
|
||||
// Handle Correlation
|
||||
id := c.Get("X-Correlation-ID", uuid.New().String())
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
id := c.Request().Header.Get("X-Correlation-ID")
|
||||
if id == "" {
|
||||
id = uuid.New().String()
|
||||
}
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
if config.MCP.Servers == "" && config.MCP.Stdio == "" {
|
||||
@@ -80,47 +83,34 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
|
||||
|
||||
ctxWithCancellation, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
handleConnectionCancellation(c, cancel, ctxWithCancellation)
|
||||
|
||||
// TODO: instead of connecting to the API, we should just wire this internally
|
||||
// and act like completion.go.
|
||||
// We can do this as cogito expects an interface and we can create one that
|
||||
// we satisfy to just call internally ComputeChoices
|
||||
defaultLLM := cogito.NewOpenAILLM(config.Name, apiKey, "http://127.0.0.1:"+port)
|
||||
|
||||
cogitoOpts := []cogito.Option{
|
||||
// Build cogito options using the consolidated method
|
||||
cogitoOpts := config.BuildCogitoOptions()
|
||||
|
||||
cogitoOpts = append(
|
||||
cogitoOpts,
|
||||
cogito.WithContext(ctxWithCancellation),
|
||||
cogito.WithMCPs(sessions...),
|
||||
cogito.WithStatusCallback(func(s string) {
|
||||
log.Debug().Msgf("[model agent] [model: %s] Status: %s", config.Name, s)
|
||||
}),
|
||||
cogito.WithContext(ctxWithCancellation),
|
||||
cogito.WithMCPs(sessions...),
|
||||
cogito.WithIterations(3), // default to 3 iterations
|
||||
cogito.WithMaxAttempts(3), // default to 3 attempts
|
||||
cogito.WithForceReasoning(),
|
||||
}
|
||||
|
||||
if config.Agent.EnableReasoning {
|
||||
cogitoOpts = append(cogitoOpts, cogito.EnableToolReasoner)
|
||||
}
|
||||
|
||||
if config.Agent.EnablePlanning {
|
||||
cogitoOpts = append(cogitoOpts, cogito.EnableAutoPlan)
|
||||
}
|
||||
|
||||
if config.Agent.EnableMCPPrompts {
|
||||
cogitoOpts = append(cogitoOpts, cogito.EnableMCPPrompts)
|
||||
}
|
||||
|
||||
if config.Agent.EnablePlanReEvaluator {
|
||||
cogitoOpts = append(cogitoOpts, cogito.EnableAutoPlanReEvaluator)
|
||||
}
|
||||
|
||||
if config.Agent.MaxIterations != 0 {
|
||||
cogitoOpts = append(cogitoOpts, cogito.WithIterations(config.Agent.MaxIterations))
|
||||
}
|
||||
|
||||
if config.Agent.MaxAttempts != 0 {
|
||||
cogitoOpts = append(cogitoOpts, cogito.WithMaxAttempts(config.Agent.MaxAttempts))
|
||||
}
|
||||
cogito.WithReasoningCallback(func(s string) {
|
||||
log.Debug().Msgf("[model agent] [model: %s] Reasoning: %s", config.Name, s)
|
||||
}),
|
||||
cogito.WithToolCallBack(func(t *cogito.ToolChoice) bool {
|
||||
log.Debug().Msgf("[model agent] [model: %s] Tool call: %s, reasoning: %s, arguments: %+v", t.Name, t.Reasoning, t.Arguments)
|
||||
return true
|
||||
}),
|
||||
cogito.WithToolCallResultCallback(func(t cogito.ToolStatus) {
|
||||
log.Debug().Msgf("[model agent] [model: %s] Tool call result: %s, tool arguments: %+v", t.Name, t.Result, t.ToolArguments)
|
||||
}),
|
||||
)
|
||||
|
||||
f, err := cogito.ExecuteTools(
|
||||
defaultLLM, fragment,
|
||||
@@ -147,6 +137,6 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,9 +10,11 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"net/http"
|
||||
|
||||
"github.com/go-audio/audio"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/websocket/v2"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/openai/types"
|
||||
@@ -167,32 +169,50 @@ type Model interface {
|
||||
PredictStream(ctx context.Context, in *proto.PredictOptions, f func(*proto.Reply), opts ...grpc.CallOption) error
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true // Allow all origins
|
||||
},
|
||||
}
|
||||
|
||||
// TODO: Implement ephemeral keys to allow these endpoints to be used
|
||||
func RealtimeSessions(application *application.Application) fiber.Handler {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
return ctx.SendStatus(501)
|
||||
func RealtimeSessions(application *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
return c.NoContent(501)
|
||||
}
|
||||
}
|
||||
|
||||
func RealtimeTranscriptionSession(application *application.Application) fiber.Handler {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
return ctx.SendStatus(501)
|
||||
func RealtimeTranscriptionSession(application *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
return c.NoContent(501)
|
||||
}
|
||||
}
|
||||
|
||||
func Realtime(application *application.Application) fiber.Handler {
|
||||
return websocket.New(registerRealtime(application))
|
||||
func Realtime(application *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
ws, err := upgrader.Upgrade(c.Response(), c.Request(), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer ws.Close()
|
||||
|
||||
// Extract query parameters from Echo context before passing to websocket handler
|
||||
model := c.QueryParam("model")
|
||||
if model == "" {
|
||||
model = "gpt-4o"
|
||||
}
|
||||
intent := c.QueryParam("intent")
|
||||
|
||||
registerRealtime(application, model, intent)(ws)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func registerRealtime(application *application.Application) func(c *websocket.Conn) {
|
||||
func registerRealtime(application *application.Application, model, intent string) func(c *websocket.Conn) {
|
||||
return func(c *websocket.Conn) {
|
||||
|
||||
evaluator := application.TemplatesEvaluator()
|
||||
log.Debug().Msgf("WebSocket connection established with '%s'", c.RemoteAddr().String())
|
||||
|
||||
model := c.Query("model", "gpt-4o")
|
||||
|
||||
intent := c.Query("intent")
|
||||
if intent != "transcription" {
|
||||
sendNotImplemented(c, "Only transcription mode is supported which requires the intent=transcription parameter")
|
||||
}
|
||||
@@ -1067,7 +1087,7 @@ func processTextResponse(config *config.ModelConfig, session *Session, prompt st
|
||||
// For example, the model might return a special token or JSON indicating a function call
|
||||
|
||||
/*
|
||||
predFunc, err := backend.ModelInference(context.Background(), prompt, input.Messages, images, videos, audios, ml, *config, o, nil)
|
||||
predFunc, err := backend.ModelInference(context.Background(), prompt, input.Messages, images, videos, audios, ml, *config, o, nil, "", "", nil, nil, nil)
|
||||
|
||||
result, tokenUsage, err := ComputeChoices(input, prompt, config, startupOptions, ml, func(s string, c *[]schema.Choice) {
|
||||
if !shouldUseFn {
|
||||
|
||||
@@ -7,13 +7,13 @@ import (
|
||||
"path"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
@@ -24,19 +24,19 @@ import (
|
||||
// @Param file formData file true "file"
|
||||
// @Success 200 {object} map[string]string "Response"
|
||||
// @Router /v1/audio/transcriptions [post]
|
||||
func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
diarize := c.FormValue("diarize", "false") != "false"
|
||||
diarize := c.FormValue("diarize") != "false"
|
||||
|
||||
// retrieve the file data from the request
|
||||
file, err := c.FormFile("file")
|
||||
@@ -76,6 +76,6 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
||||
|
||||
log.Debug().Msgf("Trascribed: %+v", tr)
|
||||
// TODO: handle different outputs here
|
||||
return c.Status(http.StatusOK).JSON(tr)
|
||||
return c.JSON(http.StatusOK, tr)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
@@ -14,20 +14,24 @@ import (
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
var raw map[string]interface{}
|
||||
if body := c.Body(); len(body) > 0 {
|
||||
body := make([]byte, 0)
|
||||
if c.Request().Body != nil {
|
||||
c.Request().Body.Read(body)
|
||||
}
|
||||
if len(body) > 0 {
|
||||
_ = json.Unmarshal(body, &raw)
|
||||
}
|
||||
// Build VideoRequest using shared mapper
|
||||
vr := MapOpenAIToVideo(input, raw)
|
||||
// Place VideoRequest into locals so localai.VideoEndpoint can consume it
|
||||
c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, vr)
|
||||
// Place VideoRequest into context so localai.VideoEndpoint can consume it
|
||||
c.Set(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, vr)
|
||||
// Delegate to existing localai handler
|
||||
return localai.VideoEndpoint(cl, ml, appConfig)(c)
|
||||
}
|
||||
|
||||
@@ -1,48 +1,50 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"net/http"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/favicon"
|
||||
"github.com/gofiber/fiber/v2/middleware/filesystem"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/explorer"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/http/routes"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func Explorer(db *explorer.Database) *fiber.App {
|
||||
func Explorer(db *explorer.Database) *echo.Echo {
|
||||
e := echo.New()
|
||||
|
||||
fiberCfg := fiber.Config{
|
||||
Views: renderEngine(),
|
||||
// We disable the Fiber startup message as it does not conform to structured logging.
|
||||
// We register a startup log line with connection information in the OnListen hook to keep things user friendly though
|
||||
DisableStartupMessage: false,
|
||||
// Override default error handler
|
||||
// Set renderer
|
||||
e.Renderer = renderEngine()
|
||||
|
||||
// Hide banner
|
||||
e.HideBanner = true
|
||||
|
||||
e.Pre(middleware.StripPathPrefix())
|
||||
routes.RegisterExplorerRoutes(e, db)
|
||||
|
||||
// Favicon handler
|
||||
e.GET("/favicon.svg", func(c echo.Context) error {
|
||||
data, err := embedDirStatic.ReadFile("static/favicon.svg")
|
||||
if err != nil {
|
||||
return c.NoContent(http.StatusNotFound)
|
||||
}
|
||||
c.Response().Header().Set("Content-Type", "image/svg+xml")
|
||||
return c.Blob(http.StatusOK, "image/svg+xml", data)
|
||||
})
|
||||
|
||||
// Static files - use fs.Sub to create a filesystem rooted at "static"
|
||||
staticFS, err := fs.Sub(embedDirStatic, "static")
|
||||
if err != nil {
|
||||
// Log error but continue - static files might not work
|
||||
log.Error().Err(err).Msg("failed to create static filesystem")
|
||||
} else {
|
||||
e.StaticFS("/static", staticFS)
|
||||
}
|
||||
|
||||
app := fiber.New(fiberCfg)
|
||||
|
||||
app.Use(middleware.StripPathPrefix())
|
||||
routes.RegisterExplorerRoutes(app, db)
|
||||
|
||||
httpFS := http.FS(embedDirStatic)
|
||||
|
||||
app.Use(favicon.New(favicon.Config{
|
||||
URL: "/favicon.svg",
|
||||
FileSystem: httpFS,
|
||||
File: "static/favicon.svg",
|
||||
}))
|
||||
|
||||
app.Use("/static", filesystem.New(filesystem.Config{
|
||||
Root: httpFS,
|
||||
PathPrefix: "static",
|
||||
Browse: true,
|
||||
}))
|
||||
|
||||
// Define a custom 404 handler
|
||||
// Note: keep this at the bottom!
|
||||
app.Use(notFoundHandler)
|
||||
e.GET("/*", notFoundHandler)
|
||||
|
||||
return app
|
||||
return e
|
||||
}
|
||||
|
||||
@@ -3,50 +3,108 @@ package middleware
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/dave-gray101/v2keyauth"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/keyauth"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/labstack/echo/v4/middleware"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/utils"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
)
|
||||
|
||||
// This file contains the configuration generators and handler functions that are used along with the fiber/keyauth middleware
|
||||
// Currently this requires an upstream patch - and feature patches are no longer accepted to v2
|
||||
// Therefore `dave-gray101/v2keyauth` contains the v2 backport of the middleware until v3 stabilizes and we migrate.
|
||||
var ErrMissingOrMalformedAPIKey = errors.New("missing or malformed API Key")
|
||||
|
||||
func GetKeyAuthConfig(applicationConfig *config.ApplicationConfig) (*v2keyauth.Config, error) {
|
||||
customLookup, err := v2keyauth.MultipleKeySourceLookup([]string{"header:Authorization", "header:x-api-key", "header:xi-api-key", "cookie:token"}, keyauth.ConfigDefault.AuthScheme)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// GetKeyAuthConfig returns Echo's KeyAuth middleware configuration
|
||||
func GetKeyAuthConfig(applicationConfig *config.ApplicationConfig) (echo.MiddlewareFunc, error) {
|
||||
// Create validator function
|
||||
validator := getApiKeyValidationFunction(applicationConfig)
|
||||
|
||||
return &v2keyauth.Config{
|
||||
CustomKeyLookup: customLookup,
|
||||
Next: getApiKeyRequiredFilterFunction(applicationConfig),
|
||||
Validator: getApiKeyValidationFunction(applicationConfig),
|
||||
ErrorHandler: getApiKeyErrorHandler(applicationConfig),
|
||||
AuthScheme: "Bearer",
|
||||
// Create error handler
|
||||
errorHandler := getApiKeyErrorHandler(applicationConfig)
|
||||
|
||||
// Create Next function (skip middleware for certain requests)
|
||||
skipper := getApiKeyRequiredFilterFunction(applicationConfig)
|
||||
|
||||
// Wrap it with our custom key lookup that checks multiple sources
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if len(applicationConfig.ApiKeys) == 0 {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
// Skip if skipper says so
|
||||
if skipper != nil && skipper(c) {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
// Try to extract key from multiple sources
|
||||
key, err := extractKeyFromMultipleSources(c)
|
||||
if err != nil {
|
||||
return errorHandler(err, c)
|
||||
}
|
||||
|
||||
// Validate the key
|
||||
valid, err := validator(key, c)
|
||||
if err != nil || !valid {
|
||||
return errorHandler(ErrMissingOrMalformedAPIKey, c)
|
||||
}
|
||||
|
||||
// Store key in context for later use
|
||||
c.Set("api_key", key)
|
||||
|
||||
return next(c)
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) fiber.ErrorHandler {
|
||||
return func(ctx *fiber.Ctx, err error) error {
|
||||
if errors.Is(err, v2keyauth.ErrMissingOrMalformedAPIKey) {
|
||||
// extractKeyFromMultipleSources checks multiple sources for the API key
|
||||
// in order: Authorization header, x-api-key header, xi-api-key header, token cookie
|
||||
func extractKeyFromMultipleSources(c echo.Context) (string, error) {
|
||||
// Check Authorization header first
|
||||
auth := c.Request().Header.Get("Authorization")
|
||||
if auth != "" {
|
||||
// Check for Bearer scheme
|
||||
if strings.HasPrefix(auth, "Bearer ") {
|
||||
return strings.TrimPrefix(auth, "Bearer "), nil
|
||||
}
|
||||
// If no Bearer prefix, return as-is (for backward compatibility)
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
// Check x-api-key header
|
||||
if key := c.Request().Header.Get("x-api-key"); key != "" {
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// Check xi-api-key header
|
||||
if key := c.Request().Header.Get("xi-api-key"); key != "" {
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// Check token cookie
|
||||
cookie, err := c.Cookie("token")
|
||||
if err == nil && cookie != nil && cookie.Value != "" {
|
||||
return cookie.Value, nil
|
||||
}
|
||||
|
||||
return "", ErrMissingOrMalformedAPIKey
|
||||
}
|
||||
|
||||
func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) func(error, echo.Context) error {
|
||||
return func(err error, c echo.Context) error {
|
||||
if errors.Is(err, ErrMissingOrMalformedAPIKey) {
|
||||
if len(applicationConfig.ApiKeys) == 0 {
|
||||
return ctx.Next() // if no keys are set up, any error we get here is not an error.
|
||||
return nil // if no keys are set up, any error we get here is not an error.
|
||||
}
|
||||
ctx.Set("WWW-Authenticate", "Bearer")
|
||||
c.Response().Header().Set("WWW-Authenticate", "Bearer")
|
||||
if applicationConfig.OpaqueErrors {
|
||||
return ctx.SendStatus(401)
|
||||
return c.NoContent(http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
// Check if the request content type is JSON
|
||||
contentType := string(ctx.Context().Request.Header.ContentType())
|
||||
contentType := c.Request().Header.Get("Content-Type")
|
||||
if strings.Contains(contentType, "application/json") {
|
||||
return ctx.Status(401).JSON(schema.ErrorResponse{
|
||||
return c.JSON(http.StatusUnauthorized, schema.ErrorResponse{
|
||||
Error: &schema.APIError{
|
||||
Message: "An authentication key is required",
|
||||
Code: 401,
|
||||
@@ -55,50 +113,69 @@ func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) fiber.Er
|
||||
})
|
||||
}
|
||||
|
||||
return ctx.Status(401).Render("views/login", fiber.Map{
|
||||
"BaseURL": utils.BaseURL(ctx),
|
||||
return c.Render(http.StatusUnauthorized, "views/login", map[string]interface{}{
|
||||
"BaseURL": BaseURL(c),
|
||||
})
|
||||
}
|
||||
if applicationConfig.OpaqueErrors {
|
||||
return ctx.SendStatus(500)
|
||||
return c.NoContent(http.StatusInternalServerError)
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func getApiKeyValidationFunction(applicationConfig *config.ApplicationConfig) func(*fiber.Ctx, string) (bool, error) {
|
||||
|
||||
func getApiKeyValidationFunction(applicationConfig *config.ApplicationConfig) func(string, echo.Context) (bool, error) {
|
||||
if applicationConfig.UseSubtleKeyComparison {
|
||||
return func(ctx *fiber.Ctx, apiKey string) (bool, error) {
|
||||
return func(key string, c echo.Context) (bool, error) {
|
||||
if len(applicationConfig.ApiKeys) == 0 {
|
||||
return true, nil // If no keys are setup, accept everything
|
||||
}
|
||||
for _, validKey := range applicationConfig.ApiKeys {
|
||||
if subtle.ConstantTimeCompare([]byte(apiKey), []byte(validKey)) == 1 {
|
||||
if subtle.ConstantTimeCompare([]byte(key), []byte(validKey)) == 1 {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, v2keyauth.ErrMissingOrMalformedAPIKey
|
||||
return false, ErrMissingOrMalformedAPIKey
|
||||
}
|
||||
}
|
||||
|
||||
return func(ctx *fiber.Ctx, apiKey string) (bool, error) {
|
||||
return func(key string, c echo.Context) (bool, error) {
|
||||
if len(applicationConfig.ApiKeys) == 0 {
|
||||
return true, nil // If no keys are setup, accept everything
|
||||
}
|
||||
for _, validKey := range applicationConfig.ApiKeys {
|
||||
if apiKey == validKey {
|
||||
if key == validKey {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, v2keyauth.ErrMissingOrMalformedAPIKey
|
||||
return false, ErrMissingOrMalformedAPIKey
|
||||
}
|
||||
}
|
||||
|
||||
func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig) func(*fiber.Ctx) bool {
|
||||
if applicationConfig.DisableApiKeyRequirementForHttpGet {
|
||||
return func(c *fiber.Ctx) bool {
|
||||
if c.Method() != "GET" {
|
||||
func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig) middleware.Skipper {
|
||||
return func(c echo.Context) bool {
|
||||
path := c.Request().URL.Path
|
||||
|
||||
// Always skip authentication for static files
|
||||
if strings.HasPrefix(path, "/static/") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Always skip authentication for generated content
|
||||
if strings.HasPrefix(path, "/generated-audio/") ||
|
||||
strings.HasPrefix(path, "/generated-images/") ||
|
||||
strings.HasPrefix(path, "/generated-videos/") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Skip authentication for favicon
|
||||
if path == "/favicon.svg" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Handle GET request exemptions if enabled
|
||||
if applicationConfig.DisableApiKeyRequirementForHttpGet {
|
||||
if c.Request().Method != http.MethodGet {
|
||||
return false
|
||||
}
|
||||
for _, rx := range applicationConfig.HttpGetExemptedEndpoints {
|
||||
@@ -106,8 +183,8 @@ func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
return func(c *fiber.Ctx) bool { return false }
|
||||
}
|
||||
|
||||
48
core/http/middleware/baseurl.go
Normal file
48
core/http/middleware/baseurl.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
// BaseURL returns the base URL for the given HTTP request context.
|
||||
// It takes into account that the app may be exposed by a reverse-proxy under a different protocol, host and path.
|
||||
// The returned URL is guaranteed to end with `/`.
|
||||
// The method should be used in conjunction with the StripPathPrefix middleware.
|
||||
func BaseURL(c echo.Context) string {
|
||||
path := c.Path()
|
||||
origPath := c.Request().URL.Path
|
||||
|
||||
// Check if StripPathPrefix middleware stored the original path
|
||||
if storedPath, ok := c.Get("_original_path").(string); ok && storedPath != "" {
|
||||
origPath = storedPath
|
||||
}
|
||||
|
||||
// Check X-Forwarded-Proto for scheme
|
||||
scheme := "http"
|
||||
if c.Request().Header.Get("X-Forwarded-Proto") == "https" {
|
||||
scheme = "https"
|
||||
} else if c.Request().TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
|
||||
// Check X-Forwarded-Host for host
|
||||
host := c.Request().Host
|
||||
if forwardedHost := c.Request().Header.Get("X-Forwarded-Host"); forwardedHost != "" {
|
||||
host = forwardedHost
|
||||
}
|
||||
|
||||
if path != origPath && strings.HasSuffix(origPath, path) && len(path) > 0 {
|
||||
prefixLen := len(origPath) - len(path)
|
||||
if prefixLen > 0 && prefixLen <= len(origPath) {
|
||||
pathPrefix := origPath[:prefixLen]
|
||||
if !strings.HasSuffix(pathPrefix, "/") {
|
||||
pathPrefix += "/"
|
||||
}
|
||||
return scheme + "://" + host + pathPrefix
|
||||
}
|
||||
}
|
||||
|
||||
return scheme + "://" + host + "/"
|
||||
}
|
||||
58
core/http/middleware/baseurl_test.go
Normal file
58
core/http/middleware/baseurl_test.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("BaseURL", func() {
|
||||
Context("without prefix", func() {
|
||||
It("should return base URL without prefix", func() {
|
||||
app := echo.New()
|
||||
actualURL := ""
|
||||
|
||||
// Register route - use the actual request path so routing works
|
||||
routePath := "/hello/world"
|
||||
app.GET(routePath, func(c echo.Context) error {
|
||||
actualURL = BaseURL(c)
|
||||
return nil
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/hello/world", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(200), "response status code")
|
||||
Expect(actualURL).To(Equal("http://example.com/"), "base URL")
|
||||
})
|
||||
})
|
||||
|
||||
Context("with prefix", func() {
|
||||
It("should return base URL with prefix", func() {
|
||||
app := echo.New()
|
||||
actualURL := ""
|
||||
|
||||
// Register route with the stripped path (after middleware removes prefix)
|
||||
routePath := "/hello/world"
|
||||
app.GET(routePath, func(c echo.Context) error {
|
||||
// Simulate what StripPathPrefix middleware does - store original path
|
||||
c.Set("_original_path", "/myprefix/hello/world")
|
||||
// Modify the request path to simulate prefix stripping
|
||||
c.Request().URL.Path = "/hello/world"
|
||||
actualURL = BaseURL(c)
|
||||
return nil
|
||||
})
|
||||
|
||||
// Make request with stripped path (middleware would have already processed it)
|
||||
req := httptest.NewRequest("GET", "/hello/world", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(200), "response status code")
|
||||
Expect(actualURL).To(Equal("http://example.com/myprefix/"), "base URL")
|
||||
})
|
||||
})
|
||||
})
|
||||
13
core/http/middleware/middleware_suite_test.go
Normal file
13
core/http/middleware/middleware_suite_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package middleware_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestMiddleware(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Middleware test suite")
|
||||
}
|
||||
@@ -1,470 +1,482 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/valyala/fasthttp"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type correlationIDKeyType string
|
||||
|
||||
// CorrelationIDKey to track request across process boundary
|
||||
const CorrelationIDKey correlationIDKeyType = "correlationID"
|
||||
|
||||
type RequestExtractor struct {
|
||||
modelConfigLoader *config.ModelConfigLoader
|
||||
modelLoader *model.ModelLoader
|
||||
applicationConfig *config.ApplicationConfig
|
||||
}
|
||||
|
||||
func NewRequestExtractor(modelConfigLoader *config.ModelConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor {
|
||||
return &RequestExtractor{
|
||||
modelConfigLoader: modelConfigLoader,
|
||||
modelLoader: modelLoader,
|
||||
applicationConfig: applicationConfig,
|
||||
}
|
||||
}
|
||||
|
||||
const CONTEXT_LOCALS_KEY_MODEL_NAME = "MODEL_NAME"
|
||||
const CONTEXT_LOCALS_KEY_LOCALAI_REQUEST = "LOCALAI_REQUEST"
|
||||
const CONTEXT_LOCALS_KEY_MODEL_CONFIG = "MODEL_CONFIG"
|
||||
|
||||
// TODO: Refactor to not return error if unchanged
|
||||
func (re *RequestExtractor) setModelNameFromRequest(ctx *fiber.Ctx) {
|
||||
model, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
if ok && model != "" {
|
||||
return
|
||||
}
|
||||
model = ctx.Params("model")
|
||||
|
||||
if (model == "") && ctx.Query("model") != "" {
|
||||
model = ctx.Query("model")
|
||||
}
|
||||
|
||||
if model == "" {
|
||||
// Set model from bearer token, if available
|
||||
bearer := strings.TrimLeft(ctx.Get("authorization"), "Bear ") // "Bearer " => "Bear" to please go-staticcheck. It looks dumb but we might as well take free performance on something called for nearly every request.
|
||||
if bearer != "" {
|
||||
exists, err := services.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, bearer, services.ALWAYS_INCLUDE)
|
||||
if err == nil && exists {
|
||||
model = bearer
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, model)
|
||||
}
|
||||
|
||||
func (re *RequestExtractor) BuildConstantDefaultModelNameMiddleware(defaultModelName string) fiber.Handler {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
re.setModelNameFromRequest(ctx)
|
||||
localModelName, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
if !ok || localModelName == "" {
|
||||
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, defaultModelName)
|
||||
log.Debug().Str("defaultModelName", defaultModelName).Msg("context local model name not found, setting to default")
|
||||
}
|
||||
return ctx.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.ModelConfigFilterFn) fiber.Handler {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
re.setModelNameFromRequest(ctx)
|
||||
localModelName := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
if localModelName != "" { // Don't overwrite existing values
|
||||
return ctx.Next()
|
||||
}
|
||||
|
||||
modelNames, err := services.ListModels(re.modelConfigLoader, re.modelLoader, filterFn, services.SKIP_IF_CONFIGURED)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("non-fatal error calling ListModels during SetDefaultModelNameToFirstAvailable()")
|
||||
return ctx.Next()
|
||||
}
|
||||
|
||||
if len(modelNames) == 0 {
|
||||
log.Warn().Msg("SetDefaultModelNameToFirstAvailable used with no matching models installed")
|
||||
// This is non-fatal - making it so was breaking the case of direct installation of raw models
|
||||
// return errors.New("this endpoint requires at least one model to be installed")
|
||||
return ctx.Next()
|
||||
}
|
||||
|
||||
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, modelNames[0])
|
||||
log.Debug().Str("first model name", modelNames[0]).Msg("context local model name not found, setting to the first model")
|
||||
return ctx.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: If context and cancel above belong on all methods, move that part of above into here!
|
||||
// Otherwise, it's in its own method below for now
|
||||
func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIRequest) fiber.Handler {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
input := initializer()
|
||||
if input == nil {
|
||||
return fmt.Errorf("unable to initialize body")
|
||||
}
|
||||
if err := ctx.BodyParser(input); err != nil {
|
||||
return fmt.Errorf("failed parsing request body: %w", err)
|
||||
}
|
||||
|
||||
// If this request doesn't have an associated model name, fetch it from earlier in the middleware chain
|
||||
if input.ModelName(nil) == "" {
|
||||
localModelName, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
if ok && localModelName != "" {
|
||||
log.Debug().Str("context localModelName", localModelName).Msg("overriding empty model name in request body with value found earlier in middleware chain")
|
||||
input.ModelName(&localModelName)
|
||||
}
|
||||
}
|
||||
|
||||
cfg, err := re.modelConfigLoader.LoadModelConfigFileByNameDefaultOptions(input.ModelName(nil), re.applicationConfig)
|
||||
|
||||
if err != nil {
|
||||
log.Err(err)
|
||||
log.Warn().Msgf("Model Configuration File not found for %q", input.ModelName(nil))
|
||||
} else if cfg.Model == "" && input.ModelName(nil) != "" {
|
||||
log.Debug().Str("input.ModelName", input.ModelName(nil)).Msg("config does not include model, using input")
|
||||
cfg.Model = input.ModelName(nil)
|
||||
}
|
||||
|
||||
ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
|
||||
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
|
||||
|
||||
return ctx.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func (re *RequestExtractor) SetOpenAIRequest(ctx *fiber.Ctx) error {
|
||||
input, ok := ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
// Extract or generate the correlation ID
|
||||
correlationID := ctx.Get("X-Correlation-ID", uuid.New().String())
|
||||
ctx.Set("X-Correlation-ID", correlationID)
|
||||
|
||||
//c1, cancel := context.WithCancel(re.applicationConfig.Context)
|
||||
// Use the application context as parent to ensure cancellation on app shutdown
|
||||
// We'll monitor the Fiber context separately and cancel our context when the request is canceled
|
||||
c1, cancel := context.WithCancel(re.applicationConfig.Context)
|
||||
// Monitor the Fiber context and cancel our context when it's canceled
|
||||
// This ensures we respect request cancellation without causing panics
|
||||
go func(fiberCtx *fasthttp.RequestCtx) {
|
||||
if fiberCtx != nil {
|
||||
<-fiberCtx.Done()
|
||||
cancel()
|
||||
}
|
||||
}(ctx.Context())
|
||||
// Add the correlation ID to the new context
|
||||
ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID)
|
||||
|
||||
input.Context = ctxWithCorrelationID
|
||||
input.Cancel = cancel
|
||||
|
||||
err := mergeOpenAIRequestAndModelConfig(cfg, input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if cfg.Model == "" {
|
||||
log.Debug().Str("input.Model", input.Model).Msg("replacing empty cfg.Model with input value")
|
||||
cfg.Model = input.Model
|
||||
}
|
||||
|
||||
ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
|
||||
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
|
||||
|
||||
return ctx.Next()
|
||||
}
|
||||
|
||||
func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.OpenAIRequest) error {
|
||||
if input.Echo {
|
||||
config.Echo = input.Echo
|
||||
}
|
||||
if input.TopK != nil {
|
||||
config.TopK = input.TopK
|
||||
}
|
||||
if input.TopP != nil {
|
||||
config.TopP = input.TopP
|
||||
}
|
||||
|
||||
if input.Backend != "" {
|
||||
config.Backend = input.Backend
|
||||
}
|
||||
|
||||
if input.ClipSkip != 0 {
|
||||
config.Diffusers.ClipSkip = input.ClipSkip
|
||||
}
|
||||
|
||||
if input.NegativePromptScale != 0 {
|
||||
config.NegativePromptScale = input.NegativePromptScale
|
||||
}
|
||||
|
||||
if input.NegativePrompt != "" {
|
||||
config.NegativePrompt = input.NegativePrompt
|
||||
}
|
||||
|
||||
if input.RopeFreqBase != 0 {
|
||||
config.RopeFreqBase = input.RopeFreqBase
|
||||
}
|
||||
|
||||
if input.RopeFreqScale != 0 {
|
||||
config.RopeFreqScale = input.RopeFreqScale
|
||||
}
|
||||
|
||||
if input.Grammar != "" {
|
||||
config.Grammar = input.Grammar
|
||||
}
|
||||
|
||||
if input.Temperature != nil {
|
||||
config.Temperature = input.Temperature
|
||||
}
|
||||
|
||||
if input.Maxtokens != nil {
|
||||
config.Maxtokens = input.Maxtokens
|
||||
}
|
||||
|
||||
if input.ResponseFormat != nil {
|
||||
switch responseFormat := input.ResponseFormat.(type) {
|
||||
case string:
|
||||
config.ResponseFormat = responseFormat
|
||||
case map[string]interface{}:
|
||||
config.ResponseFormatMap = responseFormat
|
||||
}
|
||||
}
|
||||
|
||||
switch stop := input.Stop.(type) {
|
||||
case string:
|
||||
if stop != "" {
|
||||
config.StopWords = append(config.StopWords, stop)
|
||||
}
|
||||
case []interface{}:
|
||||
for _, pp := range stop {
|
||||
if s, ok := pp.(string); ok {
|
||||
config.StopWords = append(config.StopWords, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(input.Tools) > 0 {
|
||||
for _, tool := range input.Tools {
|
||||
input.Functions = append(input.Functions, tool.Function)
|
||||
}
|
||||
}
|
||||
|
||||
if input.ToolsChoice != nil {
|
||||
var toolChoice functions.Tool
|
||||
|
||||
switch content := input.ToolsChoice.(type) {
|
||||
case string:
|
||||
_ = json.Unmarshal([]byte(content), &toolChoice)
|
||||
case map[string]interface{}:
|
||||
dat, _ := json.Marshal(content)
|
||||
_ = json.Unmarshal(dat, &toolChoice)
|
||||
}
|
||||
input.FunctionCall = map[string]interface{}{
|
||||
"name": toolChoice.Function.Name,
|
||||
}
|
||||
}
|
||||
|
||||
// Decode each request's message content
|
||||
imgIndex, vidIndex, audioIndex := 0, 0, 0
|
||||
for i, m := range input.Messages {
|
||||
nrOfImgsInMessage := 0
|
||||
nrOfVideosInMessage := 0
|
||||
nrOfAudiosInMessage := 0
|
||||
|
||||
switch content := m.Content.(type) {
|
||||
case string:
|
||||
input.Messages[i].StringContent = content
|
||||
case []interface{}:
|
||||
dat, _ := json.Marshal(content)
|
||||
c := []schema.Content{}
|
||||
json.Unmarshal(dat, &c)
|
||||
|
||||
textContent := ""
|
||||
// we will template this at the end
|
||||
|
||||
CONTENT:
|
||||
for _, pp := range c {
|
||||
switch pp.Type {
|
||||
case "text":
|
||||
textContent += pp.Text
|
||||
//input.Messages[i].StringContent = pp.Text
|
||||
case "video", "video_url":
|
||||
// Decode content as base64 either if it's an URL or base64 text
|
||||
base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Failed encoding video: %s", err)
|
||||
continue CONTENT
|
||||
}
|
||||
input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff
|
||||
vidIndex++
|
||||
nrOfVideosInMessage++
|
||||
case "audio_url", "audio":
|
||||
// Decode content as base64 either if it's an URL or base64 text
|
||||
base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Failed encoding audio: %s", err)
|
||||
continue CONTENT
|
||||
}
|
||||
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff
|
||||
audioIndex++
|
||||
nrOfAudiosInMessage++
|
||||
case "input_audio":
|
||||
// TODO: make sure that we only return base64 stuff
|
||||
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, pp.InputAudio.Data)
|
||||
audioIndex++
|
||||
nrOfAudiosInMessage++
|
||||
case "image_url", "image":
|
||||
// Decode content as base64 either if it's an URL or base64 text
|
||||
base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Failed encoding image: %s", err)
|
||||
continue CONTENT
|
||||
}
|
||||
|
||||
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
|
||||
|
||||
imgIndex++
|
||||
nrOfImgsInMessage++
|
||||
}
|
||||
}
|
||||
|
||||
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{
|
||||
TotalImages: imgIndex,
|
||||
TotalVideos: vidIndex,
|
||||
TotalAudios: audioIndex,
|
||||
ImagesInMessage: nrOfImgsInMessage,
|
||||
VideosInMessage: nrOfVideosInMessage,
|
||||
AudiosInMessage: nrOfAudiosInMessage,
|
||||
}, textContent)
|
||||
}
|
||||
}
|
||||
|
||||
if input.RepeatPenalty != 0 {
|
||||
config.RepeatPenalty = input.RepeatPenalty
|
||||
}
|
||||
|
||||
if input.FrequencyPenalty != 0 {
|
||||
config.FrequencyPenalty = input.FrequencyPenalty
|
||||
}
|
||||
|
||||
if input.PresencePenalty != 0 {
|
||||
config.PresencePenalty = input.PresencePenalty
|
||||
}
|
||||
|
||||
if input.Keep != 0 {
|
||||
config.Keep = input.Keep
|
||||
}
|
||||
|
||||
if input.Batch != 0 {
|
||||
config.Batch = input.Batch
|
||||
}
|
||||
|
||||
if input.IgnoreEOS {
|
||||
config.IgnoreEOS = input.IgnoreEOS
|
||||
}
|
||||
|
||||
if input.Seed != nil {
|
||||
config.Seed = input.Seed
|
||||
}
|
||||
|
||||
if input.TypicalP != nil {
|
||||
config.TypicalP = input.TypicalP
|
||||
}
|
||||
|
||||
log.Debug().Str("input.Input", fmt.Sprintf("%+v", input.Input))
|
||||
|
||||
switch inputs := input.Input.(type) {
|
||||
case string:
|
||||
if inputs != "" {
|
||||
config.InputStrings = append(config.InputStrings, inputs)
|
||||
}
|
||||
case []any:
|
||||
for _, pp := range inputs {
|
||||
switch i := pp.(type) {
|
||||
case string:
|
||||
config.InputStrings = append(config.InputStrings, i)
|
||||
case []any:
|
||||
tokens := []int{}
|
||||
inputStrings := []string{}
|
||||
for _, ii := range i {
|
||||
switch ii := ii.(type) {
|
||||
case int:
|
||||
tokens = append(tokens, ii)
|
||||
case float64:
|
||||
tokens = append(tokens, int(ii))
|
||||
case string:
|
||||
inputStrings = append(inputStrings, ii)
|
||||
default:
|
||||
log.Error().Msgf("Unknown input type: %T", ii)
|
||||
}
|
||||
}
|
||||
config.InputToken = append(config.InputToken, tokens)
|
||||
config.InputStrings = append(config.InputStrings, inputStrings...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Can be either a string or an object
|
||||
switch fnc := input.FunctionCall.(type) {
|
||||
case string:
|
||||
if fnc != "" {
|
||||
config.SetFunctionCallString(fnc)
|
||||
}
|
||||
case map[string]interface{}:
|
||||
var name string
|
||||
n, exists := fnc["name"]
|
||||
if exists {
|
||||
nn, e := n.(string)
|
||||
if e {
|
||||
name = nn
|
||||
}
|
||||
}
|
||||
config.SetFunctionCallNameString(name)
|
||||
}
|
||||
|
||||
switch p := input.Prompt.(type) {
|
||||
case string:
|
||||
config.PromptStrings = append(config.PromptStrings, p)
|
||||
case []interface{}:
|
||||
for _, pp := range p {
|
||||
if s, ok := pp.(string); ok {
|
||||
config.PromptStrings = append(config.PromptStrings, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If a quality was defined as number, convert it to step
|
||||
if input.Quality != "" {
|
||||
q, err := strconv.Atoi(input.Quality)
|
||||
if err == nil {
|
||||
config.Step = q
|
||||
}
|
||||
}
|
||||
|
||||
if config.Validate() {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unable to validate configuration after merging")
|
||||
}
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type correlationIDKeyType string
|
||||
|
||||
// CorrelationIDKey to track request across process boundary
|
||||
const CorrelationIDKey correlationIDKeyType = "correlationID"
|
||||
|
||||
type RequestExtractor struct {
|
||||
modelConfigLoader *config.ModelConfigLoader
|
||||
modelLoader *model.ModelLoader
|
||||
applicationConfig *config.ApplicationConfig
|
||||
}
|
||||
|
||||
func NewRequestExtractor(modelConfigLoader *config.ModelConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor {
|
||||
return &RequestExtractor{
|
||||
modelConfigLoader: modelConfigLoader,
|
||||
modelLoader: modelLoader,
|
||||
applicationConfig: applicationConfig,
|
||||
}
|
||||
}
|
||||
|
||||
const CONTEXT_LOCALS_KEY_MODEL_NAME = "MODEL_NAME"
|
||||
const CONTEXT_LOCALS_KEY_LOCALAI_REQUEST = "LOCALAI_REQUEST"
|
||||
const CONTEXT_LOCALS_KEY_MODEL_CONFIG = "MODEL_CONFIG"
|
||||
|
||||
// TODO: Refactor to not return error if unchanged
|
||||
func (re *RequestExtractor) setModelNameFromRequest(c echo.Context) {
|
||||
model, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
if ok && model != "" {
|
||||
return
|
||||
}
|
||||
model = c.Param("model")
|
||||
|
||||
if model == "" {
|
||||
model = c.QueryParam("model")
|
||||
}
|
||||
|
||||
if model == "" {
|
||||
// Set model from bearer token, if available
|
||||
auth := c.Request().Header.Get("Authorization")
|
||||
bearer := strings.TrimPrefix(auth, "Bearer ")
|
||||
if bearer != "" && bearer != auth {
|
||||
exists, err := services.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, bearer, services.ALWAYS_INCLUDE)
|
||||
if err == nil && exists {
|
||||
model = bearer
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, model)
|
||||
}
|
||||
|
||||
func (re *RequestExtractor) BuildConstantDefaultModelNameMiddleware(defaultModelName string) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
re.setModelNameFromRequest(c)
|
||||
localModelName, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
if !ok || localModelName == "" {
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, defaultModelName)
|
||||
log.Debug().Str("defaultModelName", defaultModelName).Msg("context local model name not found, setting to default")
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.ModelConfigFilterFn) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
re.setModelNameFromRequest(c)
|
||||
localModelName := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
if localModelName != "" { // Don't overwrite existing values
|
||||
return next(c)
|
||||
}
|
||||
|
||||
modelNames, err := services.ListModels(re.modelConfigLoader, re.modelLoader, filterFn, services.SKIP_IF_CONFIGURED)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("non-fatal error calling ListModels during SetDefaultModelNameToFirstAvailable()")
|
||||
return next(c)
|
||||
}
|
||||
|
||||
if len(modelNames) == 0 {
|
||||
log.Warn().Msg("SetDefaultModelNameToFirstAvailable used with no matching models installed")
|
||||
// This is non-fatal - making it so was breaking the case of direct installation of raw models
|
||||
// return errors.New("this endpoint requires at least one model to be installed")
|
||||
return next(c)
|
||||
}
|
||||
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelNames[0])
|
||||
log.Debug().Str("first model name", modelNames[0]).Msg("context local model name not found, setting to the first model")
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: If context and cancel above belong on all methods, move that part of above into here!
|
||||
// Otherwise, it's in its own method below for now
|
||||
func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIRequest) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input := initializer()
|
||||
if input == nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "unable to initialize body")
|
||||
}
|
||||
if err := c.Bind(input); err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed parsing request body: %v", err))
|
||||
}
|
||||
|
||||
// If this request doesn't have an associated model name, fetch it from earlier in the middleware chain
|
||||
if input.ModelName(nil) == "" {
|
||||
localModelName, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
if ok && localModelName != "" {
|
||||
log.Debug().Str("context localModelName", localModelName).Msg("overriding empty model name in request body with value found earlier in middleware chain")
|
||||
input.ModelName(&localModelName)
|
||||
}
|
||||
}
|
||||
|
||||
cfg, err := re.modelConfigLoader.LoadModelConfigFileByNameDefaultOptions(input.ModelName(nil), re.applicationConfig)
|
||||
|
||||
if err != nil {
|
||||
log.Err(err)
|
||||
log.Warn().Msgf("Model Configuration File not found for %q", input.ModelName(nil))
|
||||
} else if cfg.Model == "" && input.ModelName(nil) != "" {
|
||||
log.Debug().Str("input.ModelName", input.ModelName(nil)).Msg("config does not include model, using input")
|
||||
cfg.Model = input.ModelName(nil)
|
||||
}
|
||||
|
||||
c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
|
||||
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (re *RequestExtractor) SetOpenAIRequest(c echo.Context) error {
|
||||
input, ok := c.Get(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
// Extract or generate the correlation ID
|
||||
correlationID := c.Request().Header.Get("X-Correlation-ID")
|
||||
if correlationID == "" {
|
||||
correlationID = uuid.New().String()
|
||||
}
|
||||
c.Response().Header().Set("X-Correlation-ID", correlationID)
|
||||
|
||||
// Use the request context directly - Echo properly supports context cancellation!
|
||||
// No need for workarounds like handleConnectionCancellation
|
||||
reqCtx := c.Request().Context()
|
||||
c1, cancel := context.WithCancel(re.applicationConfig.Context)
|
||||
|
||||
// Cancel when request context is cancelled (client disconnects)
|
||||
go func() {
|
||||
select {
|
||||
case <-reqCtx.Done():
|
||||
cancel()
|
||||
case <-c1.Done():
|
||||
// Already cancelled
|
||||
}
|
||||
}()
|
||||
|
||||
// Add the correlation ID to the new context
|
||||
ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID)
|
||||
|
||||
input.Context = ctxWithCorrelationID
|
||||
input.Cancel = cancel
|
||||
|
||||
err := mergeOpenAIRequestAndModelConfig(cfg, input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if cfg.Model == "" {
|
||||
log.Debug().Str("input.Model", input.Model).Msg("replacing empty cfg.Model with input value")
|
||||
cfg.Model = input.Model
|
||||
}
|
||||
|
||||
c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.OpenAIRequest) error {
|
||||
if input.Echo {
|
||||
config.Echo = input.Echo
|
||||
}
|
||||
if input.TopK != nil {
|
||||
config.TopK = input.TopK
|
||||
}
|
||||
if input.TopP != nil {
|
||||
config.TopP = input.TopP
|
||||
}
|
||||
|
||||
if input.Backend != "" {
|
||||
config.Backend = input.Backend
|
||||
}
|
||||
|
||||
if input.ClipSkip != 0 {
|
||||
config.Diffusers.ClipSkip = input.ClipSkip
|
||||
}
|
||||
|
||||
if input.NegativePromptScale != 0 {
|
||||
config.NegativePromptScale = input.NegativePromptScale
|
||||
}
|
||||
|
||||
if input.NegativePrompt != "" {
|
||||
config.NegativePrompt = input.NegativePrompt
|
||||
}
|
||||
|
||||
if input.RopeFreqBase != 0 {
|
||||
config.RopeFreqBase = input.RopeFreqBase
|
||||
}
|
||||
|
||||
if input.RopeFreqScale != 0 {
|
||||
config.RopeFreqScale = input.RopeFreqScale
|
||||
}
|
||||
|
||||
if input.Grammar != "" {
|
||||
config.Grammar = input.Grammar
|
||||
}
|
||||
|
||||
if input.Temperature != nil {
|
||||
config.Temperature = input.Temperature
|
||||
}
|
||||
|
||||
if input.Maxtokens != nil {
|
||||
config.Maxtokens = input.Maxtokens
|
||||
}
|
||||
|
||||
if input.ResponseFormat != nil {
|
||||
switch responseFormat := input.ResponseFormat.(type) {
|
||||
case string:
|
||||
config.ResponseFormat = responseFormat
|
||||
case map[string]interface{}:
|
||||
config.ResponseFormatMap = responseFormat
|
||||
}
|
||||
}
|
||||
|
||||
switch stop := input.Stop.(type) {
|
||||
case string:
|
||||
if stop != "" {
|
||||
config.StopWords = append(config.StopWords, stop)
|
||||
}
|
||||
case []interface{}:
|
||||
for _, pp := range stop {
|
||||
if s, ok := pp.(string); ok {
|
||||
config.StopWords = append(config.StopWords, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(input.Tools) > 0 {
|
||||
for _, tool := range input.Tools {
|
||||
input.Functions = append(input.Functions, tool.Function)
|
||||
}
|
||||
}
|
||||
|
||||
if input.ToolsChoice != nil {
|
||||
var toolChoice functions.Tool
|
||||
|
||||
switch content := input.ToolsChoice.(type) {
|
||||
case string:
|
||||
_ = json.Unmarshal([]byte(content), &toolChoice)
|
||||
case map[string]interface{}:
|
||||
dat, _ := json.Marshal(content)
|
||||
_ = json.Unmarshal(dat, &toolChoice)
|
||||
}
|
||||
input.FunctionCall = map[string]interface{}{
|
||||
"name": toolChoice.Function.Name,
|
||||
}
|
||||
}
|
||||
|
||||
// Decode each request's message content
|
||||
imgIndex, vidIndex, audioIndex := 0, 0, 0
|
||||
for i, m := range input.Messages {
|
||||
nrOfImgsInMessage := 0
|
||||
nrOfVideosInMessage := 0
|
||||
nrOfAudiosInMessage := 0
|
||||
|
||||
switch content := m.Content.(type) {
|
||||
case string:
|
||||
input.Messages[i].StringContent = content
|
||||
case []interface{}:
|
||||
dat, _ := json.Marshal(content)
|
||||
c := []schema.Content{}
|
||||
json.Unmarshal(dat, &c)
|
||||
|
||||
textContent := ""
|
||||
// we will template this at the end
|
||||
|
||||
CONTENT:
|
||||
for _, pp := range c {
|
||||
switch pp.Type {
|
||||
case "text":
|
||||
textContent += pp.Text
|
||||
//input.Messages[i].StringContent = pp.Text
|
||||
case "video", "video_url":
|
||||
// Decode content as base64 either if it's an URL or base64 text
|
||||
base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Failed encoding video: %s", err)
|
||||
continue CONTENT
|
||||
}
|
||||
input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff
|
||||
vidIndex++
|
||||
nrOfVideosInMessage++
|
||||
case "audio_url", "audio":
|
||||
// Decode content as base64 either if it's an URL or base64 text
|
||||
base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Failed encoding audio: %s", err)
|
||||
continue CONTENT
|
||||
}
|
||||
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff
|
||||
audioIndex++
|
||||
nrOfAudiosInMessage++
|
||||
case "input_audio":
|
||||
// TODO: make sure that we only return base64 stuff
|
||||
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, pp.InputAudio.Data)
|
||||
audioIndex++
|
||||
nrOfAudiosInMessage++
|
||||
case "image_url", "image":
|
||||
// Decode content as base64 either if it's an URL or base64 text
|
||||
base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Failed encoding image: %s", err)
|
||||
continue CONTENT
|
||||
}
|
||||
|
||||
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
|
||||
|
||||
imgIndex++
|
||||
nrOfImgsInMessage++
|
||||
}
|
||||
}
|
||||
|
||||
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{
|
||||
TotalImages: imgIndex,
|
||||
TotalVideos: vidIndex,
|
||||
TotalAudios: audioIndex,
|
||||
ImagesInMessage: nrOfImgsInMessage,
|
||||
VideosInMessage: nrOfVideosInMessage,
|
||||
AudiosInMessage: nrOfAudiosInMessage,
|
||||
}, textContent)
|
||||
}
|
||||
}
|
||||
|
||||
if input.RepeatPenalty != 0 {
|
||||
config.RepeatPenalty = input.RepeatPenalty
|
||||
}
|
||||
|
||||
if input.FrequencyPenalty != 0 {
|
||||
config.FrequencyPenalty = input.FrequencyPenalty
|
||||
}
|
||||
|
||||
if input.PresencePenalty != 0 {
|
||||
config.PresencePenalty = input.PresencePenalty
|
||||
}
|
||||
|
||||
if input.Keep != 0 {
|
||||
config.Keep = input.Keep
|
||||
}
|
||||
|
||||
if input.Batch != 0 {
|
||||
config.Batch = input.Batch
|
||||
}
|
||||
|
||||
if input.IgnoreEOS {
|
||||
config.IgnoreEOS = input.IgnoreEOS
|
||||
}
|
||||
|
||||
if input.Seed != nil {
|
||||
config.Seed = input.Seed
|
||||
}
|
||||
|
||||
if input.TypicalP != nil {
|
||||
config.TypicalP = input.TypicalP
|
||||
}
|
||||
|
||||
log.Debug().Str("input.Input", fmt.Sprintf("%+v", input.Input))
|
||||
|
||||
switch inputs := input.Input.(type) {
|
||||
case string:
|
||||
if inputs != "" {
|
||||
config.InputStrings = append(config.InputStrings, inputs)
|
||||
}
|
||||
case []any:
|
||||
for _, pp := range inputs {
|
||||
switch i := pp.(type) {
|
||||
case string:
|
||||
config.InputStrings = append(config.InputStrings, i)
|
||||
case []any:
|
||||
tokens := []int{}
|
||||
inputStrings := []string{}
|
||||
for _, ii := range i {
|
||||
switch ii := ii.(type) {
|
||||
case int:
|
||||
tokens = append(tokens, ii)
|
||||
case float64:
|
||||
tokens = append(tokens, int(ii))
|
||||
case string:
|
||||
inputStrings = append(inputStrings, ii)
|
||||
default:
|
||||
log.Error().Msgf("Unknown input type: %T", ii)
|
||||
}
|
||||
}
|
||||
config.InputToken = append(config.InputToken, tokens)
|
||||
config.InputStrings = append(config.InputStrings, inputStrings...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Can be either a string or an object
|
||||
switch fnc := input.FunctionCall.(type) {
|
||||
case string:
|
||||
if fnc != "" {
|
||||
config.SetFunctionCallString(fnc)
|
||||
}
|
||||
case map[string]interface{}:
|
||||
var name string
|
||||
n, exists := fnc["name"]
|
||||
if exists {
|
||||
nn, e := n.(string)
|
||||
if e {
|
||||
name = nn
|
||||
}
|
||||
}
|
||||
config.SetFunctionCallNameString(name)
|
||||
}
|
||||
|
||||
switch p := input.Prompt.(type) {
|
||||
case string:
|
||||
config.PromptStrings = append(config.PromptStrings, p)
|
||||
case []interface{}:
|
||||
for _, pp := range p {
|
||||
if s, ok := pp.(string); ok {
|
||||
config.PromptStrings = append(config.PromptStrings, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If a quality was defined as number, convert it to step
|
||||
if input.Quality != "" {
|
||||
q, err := strconv.Atoi(input.Quality)
|
||||
if err == nil {
|
||||
config.Step = q
|
||||
}
|
||||
}
|
||||
|
||||
if config.Validate() {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unable to validate configuration after merging")
|
||||
}
|
||||
|
||||
@@ -3,34 +3,55 @@ package middleware
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
// StripPathPrefix returns a middleware that strips a path prefix from the request path.
|
||||
// StripPathPrefix returns middleware that strips a path prefix from the request path.
|
||||
// The path prefix is obtained from the X-Forwarded-Prefix HTTP request header.
|
||||
func StripPathPrefix() fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
for _, prefix := range c.GetReqHeaders()["X-Forwarded-Prefix"] {
|
||||
if prefix != "" {
|
||||
path := c.Path()
|
||||
pos := len(prefix)
|
||||
// This must be registered as Pre middleware (using e.Pre()) to modify the path before routing.
|
||||
func StripPathPrefix() echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
prefixes := c.Request().Header.Values("X-Forwarded-Prefix")
|
||||
originalPath := c.Request().URL.Path
|
||||
|
||||
if prefix[pos-1] == '/' {
|
||||
pos--
|
||||
} else {
|
||||
prefix += "/"
|
||||
}
|
||||
for _, prefix := range prefixes {
|
||||
if prefix != "" {
|
||||
normalizedPrefix := prefix
|
||||
if !strings.HasSuffix(prefix, "/") {
|
||||
normalizedPrefix = prefix + "/"
|
||||
}
|
||||
|
||||
if strings.HasPrefix(path, prefix) {
|
||||
c.Path(path[pos:])
|
||||
break
|
||||
} else if prefix[:pos] == path {
|
||||
c.Redirect(prefix)
|
||||
return nil
|
||||
if strings.HasPrefix(originalPath, normalizedPrefix) {
|
||||
// Update the request path by stripping the normalized prefix
|
||||
newPath := originalPath[len(normalizedPrefix):]
|
||||
if newPath == "" {
|
||||
newPath = "/"
|
||||
}
|
||||
// Ensure path starts with / for proper routing
|
||||
if !strings.HasPrefix(newPath, "/") {
|
||||
newPath = "/" + newPath
|
||||
}
|
||||
// Update the URL path - Echo's router uses URL.Path for routing
|
||||
c.Request().URL.Path = newPath
|
||||
c.Request().URL.RawPath = ""
|
||||
// Update RequestURI to match the new path (needed for proper routing)
|
||||
if c.Request().URL.RawQuery != "" {
|
||||
c.Request().RequestURI = newPath + "?" + c.Request().URL.RawQuery
|
||||
} else {
|
||||
c.Request().RequestURI = newPath
|
||||
}
|
||||
// Store original path for BaseURL utility
|
||||
c.Set("_original_path", originalPath)
|
||||
break
|
||||
} else if originalPath == prefix || originalPath == prefix+"/" {
|
||||
// Redirect to prefix with trailing slash (use 302 to match test expectations)
|
||||
return c.Redirect(302, normalizedPrefix)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return c.Next()
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,120 +2,133 @@ package middleware
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/labstack/echo/v4"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestStripPathPrefix(t *testing.T) {
|
||||
var _ = Describe("StripPathPrefix", func() {
|
||||
var app *echo.Echo
|
||||
var actualPath string
|
||||
var appInitialized bool
|
||||
|
||||
app := fiber.New()
|
||||
BeforeEach(func() {
|
||||
actualPath = ""
|
||||
if !appInitialized {
|
||||
app = echo.New()
|
||||
app.Pre(StripPathPrefix())
|
||||
|
||||
app.Use(StripPathPrefix())
|
||||
app.GET("/hello/world", func(c echo.Context) error {
|
||||
actualPath = c.Request().URL.Path
|
||||
return nil
|
||||
})
|
||||
|
||||
app.Get("/hello/world", func(c *fiber.Ctx) error {
|
||||
actualPath = c.Path()
|
||||
return nil
|
||||
app.GET("/", func(c echo.Context) error {
|
||||
actualPath = c.Request().URL.Path
|
||||
return nil
|
||||
})
|
||||
appInitialized = true
|
||||
}
|
||||
})
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
actualPath = c.Path()
|
||||
return nil
|
||||
})
|
||||
Context("without prefix", func() {
|
||||
It("should not modify path when no header is present", func() {
|
||||
req := httptest.NewRequest("GET", "/hello/world", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
path string
|
||||
prefixHeader []string
|
||||
expectStatus int
|
||||
expectPath string
|
||||
}{
|
||||
{
|
||||
name: "without prefix and header",
|
||||
path: "/hello/world",
|
||||
expectStatus: 200,
|
||||
expectPath: "/hello/world",
|
||||
},
|
||||
{
|
||||
name: "without prefix and headers on root path",
|
||||
path: "/",
|
||||
expectStatus: 200,
|
||||
expectPath: "/",
|
||||
},
|
||||
{
|
||||
name: "without prefix but header",
|
||||
path: "/hello/world",
|
||||
prefixHeader: []string{"/otherprefix/"},
|
||||
expectStatus: 200,
|
||||
expectPath: "/hello/world",
|
||||
},
|
||||
{
|
||||
name: "with prefix but non-matching header",
|
||||
path: "/prefix/hello/world",
|
||||
prefixHeader: []string{"/otherprefix/"},
|
||||
expectStatus: 404,
|
||||
},
|
||||
{
|
||||
name: "with prefix and matching header",
|
||||
path: "/myprefix/hello/world",
|
||||
prefixHeader: []string{"/myprefix/"},
|
||||
expectStatus: 200,
|
||||
expectPath: "/hello/world",
|
||||
},
|
||||
{
|
||||
name: "with prefix and 1st header matching",
|
||||
path: "/myprefix/hello/world",
|
||||
prefixHeader: []string{"/myprefix/", "/otherprefix/"},
|
||||
expectStatus: 200,
|
||||
expectPath: "/hello/world",
|
||||
},
|
||||
{
|
||||
name: "with prefix and 2nd header matching",
|
||||
path: "/myprefix/hello/world",
|
||||
prefixHeader: []string{"/otherprefix/", "/myprefix/"},
|
||||
expectStatus: 200,
|
||||
expectPath: "/hello/world",
|
||||
},
|
||||
{
|
||||
name: "with prefix and header not ending with slash",
|
||||
path: "/myprefix/hello/world",
|
||||
prefixHeader: []string{"/myprefix"},
|
||||
expectStatus: 200,
|
||||
expectPath: "/hello/world",
|
||||
},
|
||||
{
|
||||
name: "with prefix and non-matching header not ending with slash",
|
||||
path: "/myprefix-suffix/hello/world",
|
||||
prefixHeader: []string{"/myprefix"},
|
||||
expectStatus: 404,
|
||||
},
|
||||
{
|
||||
name: "redirect when prefix does not end with a slash",
|
||||
path: "/myprefix",
|
||||
prefixHeader: []string{"/myprefix"},
|
||||
expectStatus: 302,
|
||||
expectPath: "/myprefix/",
|
||||
},
|
||||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
actualPath = ""
|
||||
req := httptest.NewRequest("GET", tc.path, nil)
|
||||
if tc.prefixHeader != nil {
|
||||
req.Header["X-Forwarded-Prefix"] = tc.prefixHeader
|
||||
}
|
||||
|
||||
resp, err := app.Test(req, -1)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.expectStatus, resp.StatusCode, "response status code")
|
||||
|
||||
if tc.expectStatus == 200 {
|
||||
require.Equal(t, tc.expectPath, actualPath, "rewritten path")
|
||||
} else if tc.expectStatus == 302 {
|
||||
require.Equal(t, tc.expectPath, resp.Header.Get("Location"), "redirect location")
|
||||
}
|
||||
Expect(rec.Code).To(Equal(200), "response status code")
|
||||
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
It("should not modify root path when no header is present", func() {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(200), "response status code")
|
||||
Expect(actualPath).To(Equal("/"), "rewritten path")
|
||||
})
|
||||
|
||||
It("should not modify path when header does not match", func() {
|
||||
req := httptest.NewRequest("GET", "/hello/world", nil)
|
||||
req.Header["X-Forwarded-Prefix"] = []string{"/otherprefix/"}
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(200), "response status code")
|
||||
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
|
||||
})
|
||||
})
|
||||
|
||||
Context("with prefix", func() {
|
||||
It("should return 404 when prefix does not match header", func() {
|
||||
req := httptest.NewRequest("GET", "/prefix/hello/world", nil)
|
||||
req.Header["X-Forwarded-Prefix"] = []string{"/otherprefix/"}
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(404), "response status code")
|
||||
})
|
||||
|
||||
It("should strip matching prefix from path", func() {
|
||||
req := httptest.NewRequest("GET", "/myprefix/hello/world", nil)
|
||||
req.Header["X-Forwarded-Prefix"] = []string{"/myprefix/"}
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(200), "response status code")
|
||||
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
|
||||
})
|
||||
|
||||
It("should strip prefix when it matches the first header value", func() {
|
||||
req := httptest.NewRequest("GET", "/myprefix/hello/world", nil)
|
||||
req.Header["X-Forwarded-Prefix"] = []string{"/myprefix/", "/otherprefix/"}
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(200), "response status code")
|
||||
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
|
||||
})
|
||||
|
||||
It("should strip prefix when it matches the second header value", func() {
|
||||
req := httptest.NewRequest("GET", "/myprefix/hello/world", nil)
|
||||
req.Header["X-Forwarded-Prefix"] = []string{"/otherprefix/", "/myprefix/"}
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(200), "response status code")
|
||||
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
|
||||
})
|
||||
|
||||
It("should strip prefix when header does not end with slash", func() {
|
||||
req := httptest.NewRequest("GET", "/myprefix/hello/world", nil)
|
||||
req.Header["X-Forwarded-Prefix"] = []string{"/myprefix"}
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(200), "response status code")
|
||||
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
|
||||
})
|
||||
|
||||
It("should return 404 when prefix does not match header without trailing slash", func() {
|
||||
req := httptest.NewRequest("GET", "/myprefix-suffix/hello/world", nil)
|
||||
req.Header["X-Forwarded-Prefix"] = []string{"/myprefix"}
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(404), "response status code")
|
||||
})
|
||||
|
||||
It("should redirect when prefix does not end with a slash", func() {
|
||||
req := httptest.NewRequest("GET", "/myprefix", nil)
|
||||
req.Header["X-Forwarded-Prefix"] = []string{"/myprefix"}
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(302), "response status code")
|
||||
Expect(rec.Header().Get("Location")).To(Equal("/myprefix/"), "redirect location")
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"fmt"
|
||||
. "github.com/mudler/LocalAI/core/http"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
@@ -62,7 +62,7 @@ func (f *fakeAI) VAD(*pb.VADRequest) (pb.VADResponse, error) { return pb.VADResp
|
||||
var _ = Describe("OpenAI /v1/videos (embedded backend)", func() {
|
||||
var tmpdir string
|
||||
var appServer *application.Application
|
||||
var app *fiber.App
|
||||
var app *echo.Echo
|
||||
var ctx context.Context
|
||||
var cancel context.CancelFunc
|
||||
|
||||
@@ -97,7 +97,9 @@ var _ = Describe("OpenAI /v1/videos (embedded backend)", func() {
|
||||
AfterEach(func() {
|
||||
cancel()
|
||||
if app != nil {
|
||||
_ = app.Shutdown()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = app.Shutdown(ctx)
|
||||
}
|
||||
_ = os.RemoveAll(tmpdir)
|
||||
})
|
||||
@@ -106,7 +108,11 @@ var _ = Describe("OpenAI /v1/videos (embedded backend)", func() {
|
||||
var err error
|
||||
app, err = API(appServer)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go app.Listen("127.0.0.1:9091")
|
||||
go func() {
|
||||
if err := app.Start("127.0.0.1:9091"); err != nil && err != http.ErrServerClosed {
|
||||
// Log error if needed
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for server
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
|
||||
@@ -4,13 +4,15 @@ import (
|
||||
"embed"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/Masterminds/sprig/v3"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
fiberhtml "github.com/gofiber/template/html/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/microcosm-cc/bluemonday"
|
||||
"github.com/mudler/LocalAI/core/http/utils"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/russross/blackfriday"
|
||||
)
|
||||
@@ -18,26 +20,67 @@ import (
|
||||
//go:embed views/*
|
||||
var viewsfs embed.FS
|
||||
|
||||
func notFoundHandler(c *fiber.Ctx) error {
|
||||
// TemplateRenderer is a custom template renderer for Echo
|
||||
type TemplateRenderer struct {
|
||||
templates *template.Template
|
||||
}
|
||||
|
||||
// Render renders a template document
|
||||
func (t *TemplateRenderer) Render(w io.Writer, name string, data interface{}, c echo.Context) error {
|
||||
return t.templates.ExecuteTemplate(w, name, data)
|
||||
}
|
||||
|
||||
func notFoundHandler(c echo.Context) error {
|
||||
// Check if the request accepts JSON
|
||||
if string(c.Context().Request.Header.ContentType()) == "application/json" || len(c.Accepts("html")) == 0 {
|
||||
contentType := c.Request().Header.Get("Content-Type")
|
||||
accept := c.Request().Header.Get("Accept")
|
||||
if strings.Contains(contentType, "application/json") || !strings.Contains(accept, "text/html") {
|
||||
// The client expects a JSON response
|
||||
return c.Status(fiber.StatusNotFound).JSON(schema.ErrorResponse{
|
||||
Error: &schema.APIError{Message: "Resource not found", Code: fiber.StatusNotFound},
|
||||
return c.JSON(http.StatusNotFound, schema.ErrorResponse{
|
||||
Error: &schema.APIError{Message: "Resource not found", Code: http.StatusNotFound},
|
||||
})
|
||||
} else {
|
||||
// The client expects an HTML response
|
||||
return c.Status(fiber.StatusNotFound).Render("views/404", fiber.Map{
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
return c.Render(http.StatusNotFound, "views/404", map[string]interface{}{
|
||||
"BaseURL": middleware.BaseURL(c),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func renderEngine() *fiberhtml.Engine {
|
||||
engine := fiberhtml.NewFileSystem(http.FS(viewsfs), ".html")
|
||||
engine.AddFuncMap(sprig.FuncMap())
|
||||
engine.AddFunc("MDToHTML", markDowner)
|
||||
return engine
|
||||
func renderEngine() *TemplateRenderer {
|
||||
// Parse all templates from embedded filesystem
|
||||
tmpl := template.New("").Funcs(sprig.FuncMap())
|
||||
tmpl = tmpl.Funcs(template.FuncMap{
|
||||
"MDToHTML": markDowner,
|
||||
})
|
||||
|
||||
// Recursively walk through embedded filesystem and parse all HTML templates
|
||||
err := fs.WalkDir(viewsfs, "views", func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !d.IsDir() && strings.HasSuffix(path, ".html") {
|
||||
data, err := viewsfs.ReadFile(path)
|
||||
if err == nil {
|
||||
// Remove .html extension to get template name (e.g., "views/index.html" -> "views/index")
|
||||
templateName := strings.TrimSuffix(path, ".html")
|
||||
_, err := tmpl.New(templateName).Parse(string(data))
|
||||
if err != nil {
|
||||
// If parsing fails, try parsing without explicit name (for templates with {{define}})
|
||||
tmpl.Parse(string(data))
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
// Log error but continue - templates might still work
|
||||
fmt.Printf("Error walking views directory: %v\n", err)
|
||||
}
|
||||
|
||||
return &TemplateRenderer{
|
||||
templates: tmpl,
|
||||
}
|
||||
}
|
||||
|
||||
func markDowner(args ...interface{}) template.HTML {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/elevenlabs"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
@@ -9,21 +9,23 @@ import (
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func RegisterElevenLabsRoutes(app *fiber.App,
|
||||
func RegisterElevenLabsRoutes(app *echo.Echo,
|
||||
re *middleware.RequestExtractor,
|
||||
cl *config.ModelConfigLoader,
|
||||
ml *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig) {
|
||||
|
||||
// Elevenlabs
|
||||
app.Post("/v1/text-to-speech/:voice-id",
|
||||
ttsHandler := elevenlabs.TTSEndpoint(cl, ml, appConfig)
|
||||
app.POST("/v1/text-to-speech/:voice-id",
|
||||
ttsHandler,
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.ElevenLabsTTSRequest) }),
|
||||
elevenlabs.TTSEndpoint(cl, ml, appConfig))
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.ElevenLabsTTSRequest) }))
|
||||
|
||||
app.Post("/v1/sound-generation",
|
||||
soundGenHandler := elevenlabs.SoundGenerationEndpoint(cl, ml, appConfig)
|
||||
app.POST("/v1/sound-generation",
|
||||
soundGenHandler,
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_SOUND_GENERATION)),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.ElevenLabsSoundGenerationRequest) }),
|
||||
elevenlabs.SoundGenerationEndpoint(cl, ml, appConfig))
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.ElevenLabsSoundGenerationRequest) }))
|
||||
|
||||
}
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
coreExplorer "github.com/mudler/LocalAI/core/explorer"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/explorer"
|
||||
)
|
||||
|
||||
func RegisterExplorerRoutes(app *fiber.App, db *coreExplorer.Database) {
|
||||
app.Get("/", explorer.Dashboard())
|
||||
app.Post("/network/add", explorer.AddNetwork(db))
|
||||
app.Get("/networks", explorer.ShowNetworks(db))
|
||||
func RegisterExplorerRoutes(app *echo.Echo, db *coreExplorer.Database) {
|
||||
app.GET("/", explorer.Dashboard())
|
||||
app.POST("/network/add", explorer.AddNetwork(db))
|
||||
app.GET("/networks", explorer.ShowNetworks(db))
|
||||
}
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
package routes
|
||||
|
||||
import "github.com/gofiber/fiber/v2"
|
||||
import (
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
func HealthRoutes(app *fiber.App) {
|
||||
func HealthRoutes(app *echo.Echo) {
|
||||
// Service health checks
|
||||
ok := func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(200)
|
||||
ok := func(c echo.Context) error {
|
||||
return c.NoContent(200)
|
||||
}
|
||||
|
||||
app.Get("/healthz", ok)
|
||||
app.Get("/readyz", ok)
|
||||
app.GET("/healthz", ok)
|
||||
app.GET("/readyz", ok)
|
||||
}
|
||||
|
||||
@@ -1,24 +1,25 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/jina"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func RegisterJINARoutes(app *fiber.App,
|
||||
func RegisterJINARoutes(app *echo.Echo,
|
||||
re *middleware.RequestExtractor,
|
||||
cl *config.ModelConfigLoader,
|
||||
ml *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig) {
|
||||
|
||||
// POST endpoint to mimic the reranking
|
||||
app.Post("/v1/rerank",
|
||||
rerankHandler := jina.JINARerankEndpoint(cl, ml, appConfig)
|
||||
app.POST("/v1/rerank",
|
||||
rerankHandler,
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_RERANK)),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.JINARerankRequest) }),
|
||||
jina.JINARerankEndpoint(cl, ml, appConfig))
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.JINARerankRequest) }))
|
||||
}
|
||||
|
||||
@@ -1,131 +1,157 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/swagger"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
httpUtils "github.com/mudler/LocalAI/core/http/utils"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
echoswagger "github.com/swaggo/echo-swagger"
|
||||
)
|
||||
|
||||
func RegisterLocalAIRoutes(router *fiber.App,
|
||||
func RegisterLocalAIRoutes(router *echo.Echo,
|
||||
requestExtractor *middleware.RequestExtractor,
|
||||
cl *config.ModelConfigLoader,
|
||||
ml *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
galleryService *services.GalleryService,
|
||||
opcache *services.OpCache) {
|
||||
opcache *services.OpCache,
|
||||
evaluator *templates.Evaluator) {
|
||||
|
||||
router.Get("/swagger/*", swagger.HandlerDefault) // default
|
||||
router.GET("/swagger/*", echoswagger.WrapHandler) // default
|
||||
|
||||
// LocalAI API endpoints
|
||||
if !appConfig.DisableGalleryEndpoint {
|
||||
// Import model page
|
||||
router.Get("/import-model", func(c *fiber.Ctx) error {
|
||||
return c.Render("views/model-editor", fiber.Map{
|
||||
router.GET("/import-model", func(c echo.Context) error {
|
||||
return c.Render(200, "views/model-editor", map[string]interface{}{
|
||||
"Title": "LocalAI - Import Model",
|
||||
"BaseURL": httpUtils.BaseURL(c),
|
||||
"BaseURL": middleware.BaseURL(c),
|
||||
"Version": internal.PrintableVersion(),
|
||||
})
|
||||
})
|
||||
|
||||
// Edit model page
|
||||
router.Get("/models/edit/:name", localai.GetEditModelPage(cl, appConfig))
|
||||
router.GET("/models/edit/:name", localai.GetEditModelPage(cl, appConfig))
|
||||
modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.BackendGalleries, appConfig.SystemState, galleryService)
|
||||
router.Post("/models/apply", modelGalleryEndpointService.ApplyModelGalleryEndpoint())
|
||||
router.Post("/models/delete/:name", modelGalleryEndpointService.DeleteModelGalleryEndpoint())
|
||||
router.POST("/models/apply", modelGalleryEndpointService.ApplyModelGalleryEndpoint())
|
||||
router.POST("/models/delete/:name", modelGalleryEndpointService.DeleteModelGalleryEndpoint())
|
||||
|
||||
router.Get("/models/available", modelGalleryEndpointService.ListModelFromGalleryEndpoint(appConfig.SystemState))
|
||||
router.Get("/models/galleries", modelGalleryEndpointService.ListModelGalleriesEndpoint())
|
||||
router.Get("/models/jobs/:uuid", modelGalleryEndpointService.GetOpStatusEndpoint())
|
||||
router.Get("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint())
|
||||
router.GET("/models/available", modelGalleryEndpointService.ListModelFromGalleryEndpoint(appConfig.SystemState))
|
||||
router.GET("/models/galleries", modelGalleryEndpointService.ListModelGalleriesEndpoint())
|
||||
router.GET("/models/jobs/:uuid", modelGalleryEndpointService.GetOpStatusEndpoint())
|
||||
router.GET("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint())
|
||||
|
||||
backendGalleryEndpointService := localai.CreateBackendEndpointService(
|
||||
appConfig.BackendGalleries,
|
||||
appConfig.SystemState,
|
||||
galleryService)
|
||||
router.Post("/backends/apply", backendGalleryEndpointService.ApplyBackendEndpoint())
|
||||
router.Post("/backends/delete/:name", backendGalleryEndpointService.DeleteBackendEndpoint())
|
||||
router.Get("/backends", backendGalleryEndpointService.ListBackendsEndpoint(appConfig.SystemState))
|
||||
router.Get("/backends/available", backendGalleryEndpointService.ListAvailableBackendsEndpoint(appConfig.SystemState))
|
||||
router.Get("/backends/galleries", backendGalleryEndpointService.ListBackendGalleriesEndpoint())
|
||||
router.Get("/backends/jobs/:uuid", backendGalleryEndpointService.GetOpStatusEndpoint())
|
||||
router.POST("/backends/apply", backendGalleryEndpointService.ApplyBackendEndpoint())
|
||||
router.POST("/backends/delete/:name", backendGalleryEndpointService.DeleteBackendEndpoint())
|
||||
router.GET("/backends", backendGalleryEndpointService.ListBackendsEndpoint(appConfig.SystemState))
|
||||
router.GET("/backends/available", backendGalleryEndpointService.ListAvailableBackendsEndpoint(appConfig.SystemState))
|
||||
router.GET("/backends/galleries", backendGalleryEndpointService.ListBackendGalleriesEndpoint())
|
||||
router.GET("/backends/jobs/:uuid", backendGalleryEndpointService.GetOpStatusEndpoint())
|
||||
// Custom model import endpoint
|
||||
router.Post("/models/import", localai.ImportModelEndpoint(cl, appConfig))
|
||||
router.POST("/models/import", localai.ImportModelEndpoint(cl, appConfig))
|
||||
|
||||
// URI model import endpoint
|
||||
router.Post("/models/import-uri", localai.ImportModelURIEndpoint(cl, appConfig, galleryService, opcache))
|
||||
router.POST("/models/import-uri", localai.ImportModelURIEndpoint(cl, appConfig, galleryService, opcache))
|
||||
|
||||
// Custom model edit endpoint
|
||||
router.Post("/models/edit/:name", localai.EditModelEndpoint(cl, appConfig))
|
||||
router.POST("/models/edit/:name", localai.EditModelEndpoint(cl, appConfig))
|
||||
|
||||
// Reload models endpoint
|
||||
router.Post("/models/reload", localai.ReloadModelsEndpoint(cl, appConfig))
|
||||
router.POST("/models/reload", localai.ReloadModelsEndpoint(cl, appConfig))
|
||||
}
|
||||
|
||||
router.Post("/v1/detection",
|
||||
detectionHandler := localai.DetectionEndpoint(cl, ml, appConfig)
|
||||
router.POST("/v1/detection",
|
||||
detectionHandler,
|
||||
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_DETECTION)),
|
||||
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.DetectionRequest) }),
|
||||
localai.DetectionEndpoint(cl, ml, appConfig))
|
||||
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.DetectionRequest) }))
|
||||
|
||||
router.Post("/tts",
|
||||
ttsHandler := localai.TTSEndpoint(cl, ml, appConfig)
|
||||
router.POST("/tts",
|
||||
ttsHandler,
|
||||
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)),
|
||||
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TTSRequest) }),
|
||||
localai.TTSEndpoint(cl, ml, appConfig))
|
||||
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TTSRequest) }))
|
||||
|
||||
vadChain := []fiber.Handler{
|
||||
vadHandler := localai.VADEndpoint(cl, ml, appConfig)
|
||||
router.POST("/vad",
|
||||
vadHandler,
|
||||
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_VAD)),
|
||||
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.VADRequest) }),
|
||||
localai.VADEndpoint(cl, ml, appConfig),
|
||||
}
|
||||
router.Post("/vad", vadChain...)
|
||||
router.Post("/v1/vad", vadChain...)
|
||||
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.VADRequest) }))
|
||||
router.POST("/v1/vad",
|
||||
vadHandler,
|
||||
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_VAD)),
|
||||
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.VADRequest) }))
|
||||
|
||||
// Stores
|
||||
router.Post("/stores/set", localai.StoresSetEndpoint(ml, appConfig))
|
||||
router.Post("/stores/delete", localai.StoresDeleteEndpoint(ml, appConfig))
|
||||
router.Post("/stores/get", localai.StoresGetEndpoint(ml, appConfig))
|
||||
router.Post("/stores/find", localai.StoresFindEndpoint(ml, appConfig))
|
||||
router.POST("/stores/set", localai.StoresSetEndpoint(ml, appConfig))
|
||||
router.POST("/stores/delete", localai.StoresDeleteEndpoint(ml, appConfig))
|
||||
router.POST("/stores/get", localai.StoresGetEndpoint(ml, appConfig))
|
||||
router.POST("/stores/find", localai.StoresFindEndpoint(ml, appConfig))
|
||||
|
||||
if !appConfig.DisableMetrics {
|
||||
router.Get("/metrics", localai.LocalAIMetricsEndpoint())
|
||||
router.GET("/metrics", localai.LocalAIMetricsEndpoint())
|
||||
}
|
||||
|
||||
router.Post("/video",
|
||||
videoHandler := localai.VideoEndpoint(cl, ml, appConfig)
|
||||
router.POST("/video",
|
||||
videoHandler,
|
||||
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_VIDEO)),
|
||||
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.VideoRequest) }),
|
||||
localai.VideoEndpoint(cl, ml, appConfig))
|
||||
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.VideoRequest) }))
|
||||
|
||||
// Backend Statistics Module
|
||||
// TODO: Should these use standard middlewares? Refactor later, they are extremely simple.
|
||||
backendMonitorService := services.NewBackendMonitorService(ml, cl, appConfig) // Split out for now
|
||||
router.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
|
||||
router.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
|
||||
router.GET("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
|
||||
router.POST("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
|
||||
// The v1/* urls are exactly the same as above - makes local e2e testing easier if they are registered.
|
||||
router.Get("/v1/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
|
||||
router.Post("/v1/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
|
||||
router.GET("/v1/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
|
||||
router.POST("/v1/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
|
||||
|
||||
// p2p
|
||||
router.Get("/api/p2p", localai.ShowP2PNodes(appConfig))
|
||||
router.Get("/api/p2p/token", localai.ShowP2PToken(appConfig))
|
||||
router.GET("/api/p2p", localai.ShowP2PNodes(appConfig))
|
||||
router.GET("/api/p2p/token", localai.ShowP2PToken(appConfig))
|
||||
|
||||
router.Get("/version", func(c *fiber.Ctx) error {
|
||||
return c.JSON(struct {
|
||||
router.GET("/version", func(c echo.Context) error {
|
||||
return c.JSON(200, struct {
|
||||
Version string `json:"version"`
|
||||
}{Version: internal.PrintableVersion()})
|
||||
})
|
||||
|
||||
router.Get("/system", localai.SystemInformations(ml, appConfig))
|
||||
router.GET("/system", localai.SystemInformations(ml, appConfig))
|
||||
|
||||
// misc
|
||||
router.Post("/v1/tokenize",
|
||||
tokenizeHandler := localai.TokenizeEndpoint(cl, ml, appConfig)
|
||||
router.POST("/v1/tokenize",
|
||||
tokenizeHandler,
|
||||
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TOKENIZE)),
|
||||
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TokenizeRequest) }),
|
||||
localai.TokenizeEndpoint(cl, ml, appConfig))
|
||||
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TokenizeRequest) }))
|
||||
|
||||
// MCP Stream endpoint
|
||||
if evaluator != nil {
|
||||
mcpStreamHandler := localai.MCPStreamEndpoint(cl, ml, evaluator, appConfig)
|
||||
mcpStreamMiddleware := []echo.MiddlewareFunc{
|
||||
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
|
||||
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
|
||||
func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if err := requestExtractor.SetOpenAIRequest(c); err != nil {
|
||||
return err
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
},
|
||||
}
|
||||
router.POST("/v1/mcp/chat/completions", mcpStreamHandler, mcpStreamMiddleware...)
|
||||
router.POST("/mcp/v1/chat/completions", mcpStreamHandler, mcpStreamMiddleware...)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
@@ -10,118 +10,172 @@ import (
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
)
|
||||
|
||||
func RegisterOpenAIRoutes(app *fiber.App,
|
||||
func RegisterOpenAIRoutes(app *echo.Echo,
|
||||
re *middleware.RequestExtractor,
|
||||
application *application.Application) {
|
||||
// openAI compatible API endpoint
|
||||
|
||||
// realtime
|
||||
// TODO: Modify/disable the API key middleware for this endpoint to allow ephemeral keys created by sessions
|
||||
app.Get("/v1/realtime", openai.Realtime(application))
|
||||
app.Post("/v1/realtime/sessions", openai.RealtimeTranscriptionSession(application))
|
||||
app.Post("/v1/realtime/transcription_session", openai.RealtimeTranscriptionSession(application))
|
||||
app.GET("/v1/realtime", openai.Realtime(application))
|
||||
app.POST("/v1/realtime/sessions", openai.RealtimeTranscriptionSession(application))
|
||||
app.POST("/v1/realtime/transcription_session", openai.RealtimeTranscriptionSession(application))
|
||||
|
||||
// chat
|
||||
chatChain := []fiber.Handler{
|
||||
chatHandler := openai.ChatEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig())
|
||||
chatMiddleware := []echo.MiddlewareFunc{
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
|
||||
re.SetOpenAIRequest,
|
||||
openai.ChatEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()),
|
||||
func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if err := re.SetOpenAIRequest(c); err != nil {
|
||||
return err
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
},
|
||||
}
|
||||
app.Post("/v1/chat/completions", chatChain...)
|
||||
app.Post("/chat/completions", chatChain...)
|
||||
app.POST("/v1/chat/completions", chatHandler, chatMiddleware...)
|
||||
app.POST("/chat/completions", chatHandler, chatMiddleware...)
|
||||
|
||||
// edit
|
||||
editChain := []fiber.Handler{
|
||||
editHandler := openai.EditEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig())
|
||||
editMiddleware := []echo.MiddlewareFunc{
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EDIT)),
|
||||
re.BuildConstantDefaultModelNameMiddleware("gpt-4o"),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
|
||||
re.SetOpenAIRequest,
|
||||
openai.EditEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()),
|
||||
func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if err := re.SetOpenAIRequest(c); err != nil {
|
||||
return err
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
},
|
||||
}
|
||||
app.Post("/v1/edits", editChain...)
|
||||
app.Post("/edits", editChain...)
|
||||
app.POST("/v1/edits", editHandler, editMiddleware...)
|
||||
app.POST("/edits", editHandler, editMiddleware...)
|
||||
|
||||
// completion
|
||||
completionChain := []fiber.Handler{
|
||||
completionHandler := openai.CompletionEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig())
|
||||
completionMiddleware := []echo.MiddlewareFunc{
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_COMPLETION)),
|
||||
re.BuildConstantDefaultModelNameMiddleware("gpt-4o"),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
|
||||
re.SetOpenAIRequest,
|
||||
openai.CompletionEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()),
|
||||
func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if err := re.SetOpenAIRequest(c); err != nil {
|
||||
return err
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
},
|
||||
}
|
||||
app.Post("/v1/completions", completionChain...)
|
||||
app.Post("/completions", completionChain...)
|
||||
app.Post("/v1/engines/:model/completions", completionChain...)
|
||||
app.POST("/v1/completions", completionHandler, completionMiddleware...)
|
||||
app.POST("/completions", completionHandler, completionMiddleware...)
|
||||
app.POST("/v1/engines/:model/completions", completionHandler, completionMiddleware...)
|
||||
|
||||
// MCPcompletion
|
||||
mcpCompletionChain := []fiber.Handler{
|
||||
mcpCompletionHandler := openai.MCPCompletionEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig())
|
||||
mcpCompletionMiddleware := []echo.MiddlewareFunc{
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
|
||||
re.SetOpenAIRequest,
|
||||
openai.MCPCompletionEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()),
|
||||
func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if err := re.SetOpenAIRequest(c); err != nil {
|
||||
return err
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
},
|
||||
}
|
||||
app.Post("/mcp/v1/chat/completions", mcpCompletionChain...)
|
||||
app.Post("/mcp/chat/completions", mcpCompletionChain...)
|
||||
app.POST("/mcp/v1/chat/completions", mcpCompletionHandler, mcpCompletionMiddleware...)
|
||||
app.POST("/mcp/chat/completions", mcpCompletionHandler, mcpCompletionMiddleware...)
|
||||
|
||||
// embeddings
|
||||
embeddingChain := []fiber.Handler{
|
||||
embeddingHandler := openai.EmbeddingsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
embeddingMiddleware := []echo.MiddlewareFunc{
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EMBEDDINGS)),
|
||||
re.BuildConstantDefaultModelNameMiddleware("gpt-4o"),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
|
||||
re.SetOpenAIRequest,
|
||||
openai.EmbeddingsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()),
|
||||
func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if err := re.SetOpenAIRequest(c); err != nil {
|
||||
return err
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
},
|
||||
}
|
||||
app.Post("/v1/embeddings", embeddingChain...)
|
||||
app.Post("/embeddings", embeddingChain...)
|
||||
app.Post("/v1/engines/:model/embeddings", embeddingChain...)
|
||||
app.POST("/v1/embeddings", embeddingHandler, embeddingMiddleware...)
|
||||
app.POST("/embeddings", embeddingHandler, embeddingMiddleware...)
|
||||
app.POST("/v1/engines/:model/embeddings", embeddingHandler, embeddingMiddleware...)
|
||||
|
||||
audioChain := []fiber.Handler{
|
||||
audioHandler := openai.TranscriptEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
audioMiddleware := []echo.MiddlewareFunc{
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TRANSCRIPT)),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
|
||||
re.SetOpenAIRequest,
|
||||
openai.TranscriptEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()),
|
||||
func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if err := re.SetOpenAIRequest(c); err != nil {
|
||||
return err
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
},
|
||||
}
|
||||
// audio
|
||||
app.Post("/v1/audio/transcriptions", audioChain...)
|
||||
app.Post("/audio/transcriptions", audioChain...)
|
||||
app.POST("/v1/audio/transcriptions", audioHandler, audioMiddleware...)
|
||||
app.POST("/audio/transcriptions", audioHandler, audioMiddleware...)
|
||||
|
||||
audioSpeechChain := []fiber.Handler{
|
||||
audioSpeechHandler := localai.TTSEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
audioSpeechMiddleware := []echo.MiddlewareFunc{
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TTSRequest) }),
|
||||
localai.TTSEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()),
|
||||
}
|
||||
|
||||
app.Post("/v1/audio/speech",
|
||||
audioSpeechChain...)
|
||||
app.Post("/audio/speech", audioSpeechChain...)
|
||||
app.POST("/v1/audio/speech", audioSpeechHandler, audioSpeechMiddleware...)
|
||||
app.POST("/audio/speech", audioSpeechHandler, audioSpeechMiddleware...)
|
||||
|
||||
// images
|
||||
imageChain := []fiber.Handler{
|
||||
imageHandler := openai.ImageEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
imageMiddleware := []echo.MiddlewareFunc{
|
||||
re.BuildConstantDefaultModelNameMiddleware("stablediffusion"),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
|
||||
re.SetOpenAIRequest,
|
||||
openai.ImageEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()),
|
||||
func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if err := re.SetOpenAIRequest(c); err != nil {
|
||||
return err
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
app.Post("/v1/images/generations",
|
||||
imageChain...)
|
||||
app.Post("/images/generations", imageChain...)
|
||||
app.POST("/v1/images/generations", imageHandler, imageMiddleware...)
|
||||
app.POST("/images/generations", imageHandler, imageMiddleware...)
|
||||
|
||||
// videos (OpenAI-compatible endpoints mapped to LocalAI video handler)
|
||||
videoChain := []fiber.Handler{
|
||||
videoHandler := openai.VideoEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
videoMiddleware := []echo.MiddlewareFunc{
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_VIDEO)),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
|
||||
re.SetOpenAIRequest,
|
||||
openai.VideoEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()),
|
||||
func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if err := re.SetOpenAIRequest(c); err != nil {
|
||||
return err
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// OpenAI-style create video endpoint
|
||||
app.Post("/v1/videos", videoChain...)
|
||||
app.Post("/v1/videos/generations", videoChain...)
|
||||
app.Post("/videos", videoChain...)
|
||||
app.POST("/v1/videos", videoHandler, videoMiddleware...)
|
||||
app.POST("/v1/videos/generations", videoHandler, videoMiddleware...)
|
||||
app.POST("/videos", videoHandler, videoMiddleware...)
|
||||
|
||||
// List models
|
||||
app.Get("/v1/models", openai.ListModelsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Get("/models", openai.ListModelsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.GET("/v1/models", openai.ListModelsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.GET("/models", openai.ListModelsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
}
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/core/http/utils"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
func RegisterUIRoutes(app *fiber.App,
|
||||
func RegisterUIRoutes(app *echo.Echo,
|
||||
cl *config.ModelConfigLoader,
|
||||
ml *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
@@ -21,13 +20,14 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
// keeps the state of ops that are started from the UI
|
||||
var processingOps = services.NewOpCache(galleryService)
|
||||
|
||||
app.Get("/", localai.WelcomeEndpoint(appConfig, cl, ml, processingOps))
|
||||
app.GET("/", localai.WelcomeEndpoint(appConfig, cl, ml, processingOps))
|
||||
app.GET("/manage", localai.WelcomeEndpoint(appConfig, cl, ml, processingOps))
|
||||
|
||||
// P2P
|
||||
app.Get("/p2p", func(c *fiber.Ctx) error {
|
||||
summary := fiber.Map{
|
||||
app.GET("/p2p/", func(c echo.Context) error {
|
||||
summary := map[string]interface{}{
|
||||
"Title": "LocalAI - P2P dashboard",
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"BaseURL": middleware.BaseURL(c),
|
||||
"Version": internal.PrintableVersion(),
|
||||
//"Nodes": p2p.GetAvailableNodes(""),
|
||||
//"FederatedNodes": p2p.GetAvailableNodes(p2p.FederatedID),
|
||||
@@ -37,7 +37,7 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
}
|
||||
|
||||
// Render index
|
||||
return c.Render("views/p2p", summary)
|
||||
return c.Render(200, "views/p2p", summary)
|
||||
})
|
||||
|
||||
// Note: P2P UI fragment routes (/p2p/ui/*) were removed
|
||||
@@ -50,17 +50,17 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
registerBackendGalleryRoutes(app, appConfig, galleryService, processingOps)
|
||||
}
|
||||
|
||||
app.Get("/talk/", func(c *fiber.Ctx) error {
|
||||
app.GET("/talk/", func(c echo.Context) error {
|
||||
modelConfigs, _ := services.ListModels(cl, ml, config.NoFilterFn, services.SKIP_IF_CONFIGURED)
|
||||
|
||||
if len(modelConfigs) == 0 {
|
||||
// If no model is available redirect to the index which suggests how to install models
|
||||
return c.Redirect(utils.BaseURL(c))
|
||||
return c.Redirect(302, middleware.BaseURL(c))
|
||||
}
|
||||
|
||||
summary := fiber.Map{
|
||||
summary := map[string]interface{}{
|
||||
"Title": "LocalAI - Talk",
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"BaseURL": middleware.BaseURL(c),
|
||||
"ModelsConfig": modelConfigs,
|
||||
"Model": modelConfigs[0],
|
||||
|
||||
@@ -68,16 +68,16 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
}
|
||||
|
||||
// Render index
|
||||
return c.Render("views/talk", summary)
|
||||
return c.Render(200, "views/talk", summary)
|
||||
})
|
||||
|
||||
app.Get("/chat/", func(c *fiber.Ctx) error {
|
||||
app.GET("/chat/", func(c echo.Context) error {
|
||||
modelConfigs := cl.GetAllModelsConfigs()
|
||||
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
|
||||
|
||||
if len(modelConfigs)+len(modelsWithoutConfig) == 0 {
|
||||
// If no model is available redirect to the index which suggests how to install models
|
||||
return c.Redirect(utils.BaseURL(c))
|
||||
return c.Redirect(302, middleware.BaseURL(c))
|
||||
}
|
||||
modelThatCanBeUsed := ""
|
||||
galleryConfigs := map[string]*gallery.ModelConfig{}
|
||||
@@ -104,9 +104,9 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
}
|
||||
}
|
||||
|
||||
summary := fiber.Map{
|
||||
summary := map[string]interface{}{
|
||||
"Title": title,
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"BaseURL": middleware.BaseURL(c),
|
||||
"ModelsWithoutConfig": modelsWithoutConfig,
|
||||
"GalleryConfig": galleryConfigs,
|
||||
"ModelsConfig": modelConfigs,
|
||||
@@ -116,16 +116,16 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
}
|
||||
|
||||
// Render index
|
||||
return c.Render("views/chat", summary)
|
||||
return c.Render(200, "views/chat", summary)
|
||||
})
|
||||
|
||||
// Show the Chat page
|
||||
app.Get("/chat/:model", func(c *fiber.Ctx) error {
|
||||
app.GET("/chat/:model", func(c echo.Context) error {
|
||||
modelConfigs := cl.GetAllModelsConfigs()
|
||||
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
|
||||
|
||||
galleryConfigs := map[string]*gallery.ModelConfig{}
|
||||
modelName := c.Params("model")
|
||||
modelName := c.Param("model")
|
||||
var modelContextSize *int
|
||||
|
||||
for _, m := range modelConfigs {
|
||||
@@ -139,9 +139,9 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
}
|
||||
}
|
||||
|
||||
summary := fiber.Map{
|
||||
summary := map[string]interface{}{
|
||||
"Title": "LocalAI - Chat with " + modelName,
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"BaseURL": middleware.BaseURL(c),
|
||||
"ModelsConfig": modelConfigs,
|
||||
"GalleryConfig": galleryConfigs,
|
||||
"ModelsWithoutConfig": modelsWithoutConfig,
|
||||
@@ -151,33 +151,33 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
}
|
||||
|
||||
// Render index
|
||||
return c.Render("views/chat", summary)
|
||||
return c.Render(200, "views/chat", summary)
|
||||
})
|
||||
|
||||
app.Get("/text2image/:model", func(c *fiber.Ctx) error {
|
||||
app.GET("/text2image/:model", func(c echo.Context) error {
|
||||
modelConfigs := cl.GetAllModelsConfigs()
|
||||
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
|
||||
|
||||
summary := fiber.Map{
|
||||
"Title": "LocalAI - Generate images with " + c.Params("model"),
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
summary := map[string]interface{}{
|
||||
"Title": "LocalAI - Generate images with " + c.Param("model"),
|
||||
"BaseURL": middleware.BaseURL(c),
|
||||
"ModelsConfig": modelConfigs,
|
||||
"ModelsWithoutConfig": modelsWithoutConfig,
|
||||
"Model": c.Params("model"),
|
||||
"Model": c.Param("model"),
|
||||
"Version": internal.PrintableVersion(),
|
||||
}
|
||||
|
||||
// Render index
|
||||
return c.Render("views/text2image", summary)
|
||||
return c.Render(200, "views/text2image", summary)
|
||||
})
|
||||
|
||||
app.Get("/text2image/", func(c *fiber.Ctx) error {
|
||||
app.GET("/text2image/", func(c echo.Context) error {
|
||||
modelConfigs := cl.GetAllModelsConfigs()
|
||||
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
|
||||
|
||||
if len(modelConfigs)+len(modelsWithoutConfig) == 0 {
|
||||
// If no model is available redirect to the index which suggests how to install models
|
||||
return c.Redirect(utils.BaseURL(c))
|
||||
return c.Redirect(302, middleware.BaseURL(c))
|
||||
}
|
||||
|
||||
modelThatCanBeUsed := ""
|
||||
@@ -191,9 +191,9 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
}
|
||||
}
|
||||
|
||||
summary := fiber.Map{
|
||||
summary := map[string]interface{}{
|
||||
"Title": title,
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"BaseURL": middleware.BaseURL(c),
|
||||
"ModelsConfig": modelConfigs,
|
||||
"ModelsWithoutConfig": modelsWithoutConfig,
|
||||
"Model": modelThatCanBeUsed,
|
||||
@@ -201,33 +201,33 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
}
|
||||
|
||||
// Render index
|
||||
return c.Render("views/text2image", summary)
|
||||
return c.Render(200, "views/text2image", summary)
|
||||
})
|
||||
|
||||
app.Get("/tts/:model", func(c *fiber.Ctx) error {
|
||||
app.GET("/tts/:model", func(c echo.Context) error {
|
||||
modelConfigs := cl.GetAllModelsConfigs()
|
||||
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
|
||||
|
||||
summary := fiber.Map{
|
||||
"Title": "LocalAI - Generate images with " + c.Params("model"),
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
summary := map[string]interface{}{
|
||||
"Title": "LocalAI - Generate images with " + c.Param("model"),
|
||||
"BaseURL": middleware.BaseURL(c),
|
||||
"ModelsConfig": modelConfigs,
|
||||
"ModelsWithoutConfig": modelsWithoutConfig,
|
||||
"Model": c.Params("model"),
|
||||
"Model": c.Param("model"),
|
||||
"Version": internal.PrintableVersion(),
|
||||
}
|
||||
|
||||
// Render index
|
||||
return c.Render("views/tts", summary)
|
||||
return c.Render(200, "views/tts", summary)
|
||||
})
|
||||
|
||||
app.Get("/tts/", func(c *fiber.Ctx) error {
|
||||
app.GET("/tts/", func(c echo.Context) error {
|
||||
modelConfigs := cl.GetAllModelsConfigs()
|
||||
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
|
||||
|
||||
if len(modelConfigs)+len(modelsWithoutConfig) == 0 {
|
||||
// If no model is available redirect to the index which suggests how to install models
|
||||
return c.Redirect(utils.BaseURL(c))
|
||||
return c.Redirect(302, middleware.BaseURL(c))
|
||||
}
|
||||
|
||||
modelThatCanBeUsed := ""
|
||||
@@ -240,9 +240,9 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
break
|
||||
}
|
||||
}
|
||||
summary := fiber.Map{
|
||||
summary := map[string]interface{}{
|
||||
"Title": title,
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"BaseURL": middleware.BaseURL(c),
|
||||
"ModelsConfig": modelConfigs,
|
||||
"ModelsWithoutConfig": modelsWithoutConfig,
|
||||
"Model": modelThatCanBeUsed,
|
||||
@@ -250,6 +250,6 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
}
|
||||
|
||||
// Render index
|
||||
return c.Render("views/tts", summary)
|
||||
return c.Render(200, "views/tts", summary)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,30 +1,33 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/p2p"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// RegisterUIAPIRoutes registers JSON API routes for the web UI
|
||||
func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) {
|
||||
func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) {
|
||||
|
||||
// Operations API - Get all current operations (models + backends)
|
||||
app.Get("/api/operations", func(c *fiber.Ctx) error {
|
||||
app.GET("/api/operations", func(c echo.Context) error {
|
||||
processingData, taskTypes := opcache.GetStatus()
|
||||
|
||||
operations := []fiber.Map{}
|
||||
operations := []map[string]interface{}{}
|
||||
for galleryID, jobID := range processingData {
|
||||
taskType := "installation"
|
||||
if tt, ok := taskTypes[galleryID]; ok {
|
||||
@@ -35,23 +38,35 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
progress := 0
|
||||
isDeletion := false
|
||||
isQueued := false
|
||||
isCancelled := false
|
||||
isCancellable := false
|
||||
message := ""
|
||||
|
||||
if status != nil {
|
||||
// Skip completed operations
|
||||
if status.Processed {
|
||||
// Skip completed operations (unless cancelled and not yet cleaned up)
|
||||
if status.Processed && !status.Cancelled {
|
||||
continue
|
||||
}
|
||||
// Skip cancelled operations that are processed (they're done, no need to show)
|
||||
if status.Processed && status.Cancelled {
|
||||
continue
|
||||
}
|
||||
|
||||
progress = int(status.Progress)
|
||||
isDeletion = status.Deletion
|
||||
isCancelled = status.Cancelled
|
||||
isCancellable = status.Cancellable
|
||||
message = status.Message
|
||||
if isDeletion {
|
||||
taskType = "deletion"
|
||||
}
|
||||
if isCancelled {
|
||||
taskType = "cancelled"
|
||||
}
|
||||
} else {
|
||||
// Job is queued but hasn't started
|
||||
isQueued = true
|
||||
isCancellable = true
|
||||
message = "Operation queued"
|
||||
}
|
||||
|
||||
@@ -75,17 +90,19 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
}
|
||||
}
|
||||
|
||||
operations = append(operations, fiber.Map{
|
||||
"id": galleryID,
|
||||
"name": displayName,
|
||||
"fullName": galleryID,
|
||||
"jobID": jobID,
|
||||
"progress": progress,
|
||||
"taskType": taskType,
|
||||
"isDeletion": isDeletion,
|
||||
"isBackend": isBackend,
|
||||
"isQueued": isQueued,
|
||||
"message": message,
|
||||
operations = append(operations, map[string]interface{}{
|
||||
"id": galleryID,
|
||||
"name": displayName,
|
||||
"fullName": galleryID,
|
||||
"jobID": jobID,
|
||||
"progress": progress,
|
||||
"taskType": taskType,
|
||||
"isDeletion": isDeletion,
|
||||
"isBackend": isBackend,
|
||||
"isQueued": isQueued,
|
||||
"isCancelled": isCancelled,
|
||||
"cancellable": isCancellable,
|
||||
"message": message,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -103,21 +120,49 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
return operations[i]["id"].(string) < operations[j]["id"].(string)
|
||||
})
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
return c.JSON(200, map[string]interface{}{
|
||||
"operations": operations,
|
||||
})
|
||||
})
|
||||
|
||||
// Cancel operation endpoint
|
||||
app.POST("/api/operations/:jobID/cancel", func(c echo.Context) error {
|
||||
jobID := c.Param("jobID")
|
||||
log.Debug().Msgf("API request to cancel operation: %s", jobID)
|
||||
|
||||
err := galleryService.CancelOperation(jobID)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Failed to cancel operation: %s", jobID)
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Clean up opcache for cancelled operation
|
||||
opcache.DeleteUUID(jobID)
|
||||
|
||||
return c.JSON(200, map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "Operation cancelled",
|
||||
})
|
||||
})
|
||||
|
||||
// Model Gallery APIs
|
||||
app.Get("/api/models", func(c *fiber.Ctx) error {
|
||||
term := c.Query("term")
|
||||
page := c.Query("page", "1")
|
||||
items := c.Query("items", "21")
|
||||
app.GET("/api/models", func(c echo.Context) error {
|
||||
term := c.QueryParam("term")
|
||||
page := c.QueryParam("page")
|
||||
if page == "" {
|
||||
page = "1"
|
||||
}
|
||||
items := c.QueryParam("items")
|
||||
if items == "" {
|
||||
items = "21"
|
||||
}
|
||||
|
||||
models, err := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.SystemState)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("could not list models from galleries")
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
@@ -160,7 +205,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
}
|
||||
|
||||
// Convert models to JSON-friendly format and deduplicate by ID
|
||||
modelsJSON := make([]fiber.Map, 0, len(models))
|
||||
modelsJSON := make([]map[string]interface{}, 0, len(models))
|
||||
seenIDs := make(map[string]bool)
|
||||
|
||||
for _, m := range models {
|
||||
@@ -186,7 +231,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
|
||||
_, trustRemoteCodeExists := m.Overrides["trust_remote_code"]
|
||||
|
||||
modelsJSON = append(modelsJSON, fiber.Map{
|
||||
modelsJSON = append(modelsJSON, map[string]interface{}{
|
||||
"id": modelID,
|
||||
"name": m.Name,
|
||||
"description": m.Description,
|
||||
@@ -213,26 +258,32 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
nextPage = totalPages
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
"models": modelsJSON,
|
||||
"repositories": appConfig.Galleries,
|
||||
"allTags": tags,
|
||||
"processingModels": processingModelsData,
|
||||
"taskTypes": taskTypes,
|
||||
"availableModels": totalModels,
|
||||
"currentPage": pageNum,
|
||||
"totalPages": totalPages,
|
||||
"prevPage": prevPage,
|
||||
"nextPage": nextPage,
|
||||
// Calculate installed models count (models with configs + models without configs)
|
||||
modelConfigs := cl.GetAllModelsConfigs()
|
||||
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
|
||||
installedModelsCount := len(modelConfigs) + len(modelsWithoutConfig)
|
||||
|
||||
return c.JSON(200, map[string]interface{}{
|
||||
"models": modelsJSON,
|
||||
"repositories": appConfig.Galleries,
|
||||
"allTags": tags,
|
||||
"processingModels": processingModelsData,
|
||||
"taskTypes": taskTypes,
|
||||
"availableModels": totalModels,
|
||||
"installedModels": installedModelsCount,
|
||||
"currentPage": pageNum,
|
||||
"totalPages": totalPages,
|
||||
"prevPage": prevPage,
|
||||
"nextPage": nextPage,
|
||||
})
|
||||
})
|
||||
|
||||
app.Post("/api/models/install/:id", func(c *fiber.Ctx) error {
|
||||
galleryID := strings.Clone(c.Params("id"))
|
||||
app.POST("/api/models/install/:id", func(c echo.Context) error {
|
||||
galleryID := c.Param("id")
|
||||
// URL decode the gallery ID (e.g., "localai%40model" -> "localai@model")
|
||||
galleryID, err := url.QueryUnescape(galleryID)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{
|
||||
"error": "invalid model ID",
|
||||
})
|
||||
}
|
||||
@@ -240,7 +291,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
|
||||
id, err := uuid.NewUUID()
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
@@ -248,28 +299,33 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
uid := id.String()
|
||||
opcache.Set(galleryID, uid)
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
op := services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
|
||||
ID: uid,
|
||||
GalleryElementName: galleryID,
|
||||
Galleries: appConfig.Galleries,
|
||||
BackendGalleries: appConfig.BackendGalleries,
|
||||
Context: ctx,
|
||||
CancelFunc: cancelFunc,
|
||||
}
|
||||
// Store cancellation function immediately so queued operations can be cancelled
|
||||
galleryService.StoreCancellation(uid, cancelFunc)
|
||||
go func() {
|
||||
galleryService.ModelGalleryChannel <- op
|
||||
}()
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
return c.JSON(200, map[string]interface{}{
|
||||
"jobID": uid,
|
||||
"message": "Installation started",
|
||||
})
|
||||
})
|
||||
|
||||
app.Post("/api/models/delete/:id", func(c *fiber.Ctx) error {
|
||||
galleryID := strings.Clone(c.Params("id"))
|
||||
app.POST("/api/models/delete/:id", func(c echo.Context) error {
|
||||
galleryID := c.Param("id")
|
||||
// URL decode the gallery ID
|
||||
galleryID, err := url.QueryUnescape(galleryID)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{
|
||||
"error": "invalid model ID",
|
||||
})
|
||||
}
|
||||
@@ -282,7 +338,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
|
||||
id, err := uuid.NewUUID()
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
@@ -291,30 +347,35 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
|
||||
opcache.Set(galleryID, uid)
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
op := services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
|
||||
ID: uid,
|
||||
Delete: true,
|
||||
GalleryElementName: galleryName,
|
||||
Galleries: appConfig.Galleries,
|
||||
BackendGalleries: appConfig.BackendGalleries,
|
||||
Context: ctx,
|
||||
CancelFunc: cancelFunc,
|
||||
}
|
||||
// Store cancellation function immediately so queued operations can be cancelled
|
||||
galleryService.StoreCancellation(uid, cancelFunc)
|
||||
go func() {
|
||||
galleryService.ModelGalleryChannel <- op
|
||||
cl.RemoveModelConfig(galleryName)
|
||||
}()
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
return c.JSON(200, map[string]interface{}{
|
||||
"jobID": uid,
|
||||
"message": "Deletion started",
|
||||
})
|
||||
})
|
||||
|
||||
app.Post("/api/models/config/:id", func(c *fiber.Ctx) error {
|
||||
galleryID := strings.Clone(c.Params("id"))
|
||||
app.POST("/api/models/config/:id", func(c echo.Context) error {
|
||||
galleryID := c.Param("id")
|
||||
// URL decode the gallery ID
|
||||
galleryID, err := url.QueryUnescape(galleryID)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{
|
||||
"error": "invalid model ID",
|
||||
})
|
||||
}
|
||||
@@ -322,44 +383,44 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
|
||||
models, err := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.SystemState)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
model := gallery.FindGalleryElement(models, galleryID)
|
||||
if model == nil {
|
||||
return c.Status(fiber.StatusNotFound).JSON(fiber.Map{
|
||||
return c.JSON(http.StatusNotFound, map[string]interface{}{
|
||||
"error": "model not found",
|
||||
})
|
||||
}
|
||||
|
||||
config, err := gallery.GetGalleryConfigFromURL[gallery.ModelConfig](model.URL, appConfig.SystemState.Model.ModelsPath)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
_, err = gallery.InstallModel(appConfig.SystemState, model.Name, &config, model.Overrides, nil, false)
|
||||
_, err = gallery.InstallModel(context.Background(), appConfig.SystemState, model.Name, &config, model.Overrides, nil, false)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
return c.JSON(200, map[string]interface{}{
|
||||
"message": "Configuration file saved",
|
||||
})
|
||||
})
|
||||
|
||||
app.Get("/api/models/job/:uid", func(c *fiber.Ctx) error {
|
||||
jobUID := strings.Clone(c.Params("uid"))
|
||||
app.GET("/api/models/job/:uid", func(c echo.Context) error {
|
||||
jobUID := c.Param("uid")
|
||||
|
||||
status := galleryService.GetStatus(jobUID)
|
||||
if status == nil {
|
||||
// Job is queued but hasn't started processing yet
|
||||
return c.JSON(fiber.Map{
|
||||
return c.JSON(200, map[string]interface{}{
|
||||
"progress": 0,
|
||||
"message": "Operation queued",
|
||||
"galleryElementName": "",
|
||||
@@ -369,7 +430,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
})
|
||||
}
|
||||
|
||||
response := fiber.Map{
|
||||
response := map[string]interface{}{
|
||||
"progress": status.Progress,
|
||||
"message": status.Message,
|
||||
"galleryElementName": status.GalleryElementName,
|
||||
@@ -387,19 +448,25 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
response["completed"] = true
|
||||
}
|
||||
|
||||
return c.JSON(response)
|
||||
return c.JSON(200, response)
|
||||
})
|
||||
|
||||
// Backend Gallery APIs
|
||||
app.Get("/api/backends", func(c *fiber.Ctx) error {
|
||||
term := c.Query("term")
|
||||
page := c.Query("page", "1")
|
||||
items := c.Query("items", "21")
|
||||
app.GET("/api/backends", func(c echo.Context) error {
|
||||
term := c.QueryParam("term")
|
||||
page := c.QueryParam("page")
|
||||
if page == "" {
|
||||
page = "1"
|
||||
}
|
||||
items := c.QueryParam("items")
|
||||
if items == "" {
|
||||
items = "21"
|
||||
}
|
||||
|
||||
backends, err := gallery.AvailableBackends(appConfig.BackendGalleries, appConfig.SystemState)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("could not list backends from galleries")
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
@@ -442,7 +509,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
}
|
||||
|
||||
// Convert backends to JSON-friendly format and deduplicate by ID
|
||||
backendsJSON := make([]fiber.Map, 0, len(backends))
|
||||
backendsJSON := make([]map[string]interface{}, 0, len(backends))
|
||||
seenBackendIDs := make(map[string]bool)
|
||||
|
||||
for _, b := range backends {
|
||||
@@ -466,7 +533,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
}
|
||||
}
|
||||
|
||||
backendsJSON = append(backendsJSON, fiber.Map{
|
||||
backendsJSON = append(backendsJSON, map[string]interface{}{
|
||||
"id": backendID,
|
||||
"name": b.Name,
|
||||
"description": b.Description,
|
||||
@@ -491,13 +558,21 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
nextPage = totalPages
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
// Calculate installed backends count
|
||||
installedBackends, err := gallery.ListSystemBackends(appConfig.SystemState)
|
||||
installedBackendsCount := 0
|
||||
if err == nil {
|
||||
installedBackendsCount = len(installedBackends)
|
||||
}
|
||||
|
||||
return c.JSON(200, map[string]interface{}{
|
||||
"backends": backendsJSON,
|
||||
"repositories": appConfig.BackendGalleries,
|
||||
"allTags": tags,
|
||||
"processingBackends": processingBackendsData,
|
||||
"taskTypes": taskTypes,
|
||||
"availableBackends": totalBackends,
|
||||
"installedBackends": installedBackendsCount,
|
||||
"currentPage": pageNum,
|
||||
"totalPages": totalPages,
|
||||
"prevPage": prevPage,
|
||||
@@ -505,12 +580,12 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
})
|
||||
})
|
||||
|
||||
app.Post("/api/backends/install/:id", func(c *fiber.Ctx) error {
|
||||
backendID := strings.Clone(c.Params("id"))
|
||||
app.POST("/api/backends/install/:id", func(c echo.Context) error {
|
||||
backendID := c.Param("id")
|
||||
// URL decode the backend ID
|
||||
backendID, err := url.QueryUnescape(backendID)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{
|
||||
"error": "invalid backend ID",
|
||||
})
|
||||
}
|
||||
@@ -518,7 +593,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
|
||||
id, err := uuid.NewUUID()
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
@@ -526,27 +601,32 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
uid := id.String()
|
||||
opcache.Set(backendID, uid)
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
op := services.GalleryOp[gallery.GalleryBackend, any]{
|
||||
ID: uid,
|
||||
GalleryElementName: backendID,
|
||||
Galleries: appConfig.BackendGalleries,
|
||||
Context: ctx,
|
||||
CancelFunc: cancelFunc,
|
||||
}
|
||||
// Store cancellation function immediately so queued operations can be cancelled
|
||||
galleryService.StoreCancellation(uid, cancelFunc)
|
||||
go func() {
|
||||
galleryService.BackendGalleryChannel <- op
|
||||
}()
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
return c.JSON(200, map[string]interface{}{
|
||||
"jobID": uid,
|
||||
"message": "Backend installation started",
|
||||
})
|
||||
})
|
||||
|
||||
app.Post("/api/backends/delete/:id", func(c *fiber.Ctx) error {
|
||||
backendID := strings.Clone(c.Params("id"))
|
||||
app.POST("/api/backends/delete/:id", func(c echo.Context) error {
|
||||
backendID := c.Param("id")
|
||||
// URL decode the backend ID
|
||||
backendID, err := url.QueryUnescape(backendID)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{
|
||||
"error": "invalid backend ID",
|
||||
})
|
||||
}
|
||||
@@ -559,7 +639,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
|
||||
id, err := uuid.NewUUID()
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
@@ -568,29 +648,34 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
|
||||
opcache.Set(backendID, uid)
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
op := services.GalleryOp[gallery.GalleryBackend, any]{
|
||||
ID: uid,
|
||||
Delete: true,
|
||||
GalleryElementName: backendName,
|
||||
Galleries: appConfig.BackendGalleries,
|
||||
Context: ctx,
|
||||
CancelFunc: cancelFunc,
|
||||
}
|
||||
// Store cancellation function immediately so queued operations can be cancelled
|
||||
galleryService.StoreCancellation(uid, cancelFunc)
|
||||
go func() {
|
||||
galleryService.BackendGalleryChannel <- op
|
||||
}()
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
return c.JSON(200, map[string]interface{}{
|
||||
"jobID": uid,
|
||||
"message": "Backend deletion started",
|
||||
})
|
||||
})
|
||||
|
||||
app.Get("/api/backends/job/:uid", func(c *fiber.Ctx) error {
|
||||
jobUID := strings.Clone(c.Params("uid"))
|
||||
app.GET("/api/backends/job/:uid", func(c echo.Context) error {
|
||||
jobUID := c.Param("uid")
|
||||
|
||||
status := galleryService.GetStatus(jobUID)
|
||||
if status == nil {
|
||||
// Job is queued but hasn't started processing yet
|
||||
return c.JSON(fiber.Map{
|
||||
return c.JSON(200, map[string]interface{}{
|
||||
"progress": 0,
|
||||
"message": "Operation queued",
|
||||
"galleryElementName": "",
|
||||
@@ -600,7 +685,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
})
|
||||
}
|
||||
|
||||
response := fiber.Map{
|
||||
response := map[string]interface{}{
|
||||
"progress": status.Progress,
|
||||
"message": status.Message,
|
||||
"galleryElementName": status.GalleryElementName,
|
||||
@@ -618,16 +703,16 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
response["completed"] = true
|
||||
}
|
||||
|
||||
return c.JSON(response)
|
||||
return c.JSON(200, response)
|
||||
})
|
||||
|
||||
// System Backend Deletion API (for installed backends on index page)
|
||||
app.Post("/api/backends/system/delete/:name", func(c *fiber.Ctx) error {
|
||||
backendName := strings.Clone(c.Params("name"))
|
||||
app.POST("/api/backends/system/delete/:name", func(c echo.Context) error {
|
||||
backendName := c.Param("name")
|
||||
// URL decode the backend name
|
||||
backendName, err := url.QueryUnescape(backendName)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{
|
||||
"error": "invalid backend name",
|
||||
})
|
||||
}
|
||||
@@ -636,24 +721,24 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
// Use the gallery package to delete the backend
|
||||
if err := gallery.DeleteBackendFromSystem(appConfig.SystemState, backendName); err != nil {
|
||||
log.Error().Err(err).Msgf("Failed to delete backend: %s", backendName)
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
return c.JSON(200, map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "Backend deleted successfully",
|
||||
})
|
||||
})
|
||||
|
||||
// P2P APIs
|
||||
app.Get("/api/p2p/workers", func(c *fiber.Ctx) error {
|
||||
app.GET("/api/p2p/workers", func(c echo.Context) error {
|
||||
nodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID))
|
||||
|
||||
nodesJSON := make([]fiber.Map, 0, len(nodes))
|
||||
nodesJSON := make([]map[string]interface{}, 0, len(nodes))
|
||||
for _, n := range nodes {
|
||||
nodesJSON = append(nodesJSON, fiber.Map{
|
||||
nodesJSON = append(nodesJSON, map[string]interface{}{
|
||||
"name": n.Name,
|
||||
"id": n.ID,
|
||||
"tunnelAddress": n.TunnelAddress,
|
||||
@@ -663,17 +748,17 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
return c.JSON(200, map[string]interface{}{
|
||||
"nodes": nodesJSON,
|
||||
})
|
||||
})
|
||||
|
||||
app.Get("/api/p2p/federation", func(c *fiber.Ctx) error {
|
||||
app.GET("/api/p2p/federation", func(c echo.Context) error {
|
||||
nodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID))
|
||||
|
||||
nodesJSON := make([]fiber.Map, 0, len(nodes))
|
||||
nodesJSON := make([]map[string]interface{}, 0, len(nodes))
|
||||
for _, n := range nodes {
|
||||
nodesJSON = append(nodesJSON, fiber.Map{
|
||||
nodesJSON = append(nodesJSON, map[string]interface{}{
|
||||
"name": n.Name,
|
||||
"id": n.ID,
|
||||
"tunnelAddress": n.TunnelAddress,
|
||||
@@ -683,12 +768,12 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
return c.JSON(200, map[string]interface{}{
|
||||
"nodes": nodesJSON,
|
||||
})
|
||||
})
|
||||
|
||||
app.Get("/api/p2p/stats", func(c *fiber.Ctx) error {
|
||||
app.GET("/api/p2p/stats", func(c echo.Context) error {
|
||||
workerNodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID))
|
||||
federatedNodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID))
|
||||
|
||||
@@ -706,12 +791,12 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
"workers": fiber.Map{
|
||||
return c.JSON(200, map[string]interface{}{
|
||||
"workers": map[string]interface{}{
|
||||
"online": workersOnline,
|
||||
"total": len(workerNodes),
|
||||
},
|
||||
"federated": fiber.Map{
|
||||
"federated": map[string]interface{}{
|
||||
"online": federatedOnline,
|
||||
"total": len(federatedNodes),
|
||||
},
|
||||
|
||||
@@ -1,24 +1,24 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/utils"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
)
|
||||
|
||||
func registerBackendGalleryRoutes(app *fiber.App, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) {
|
||||
func registerBackendGalleryRoutes(app *echo.Echo, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) {
|
||||
// Show the Backends page (all backends are loaded client-side via Alpine.js)
|
||||
app.Get("/browse/backends", func(c *fiber.Ctx) error {
|
||||
summary := fiber.Map{
|
||||
app.GET("/browse/backends", func(c echo.Context) error {
|
||||
summary := map[string]interface{}{
|
||||
"Title": "LocalAI - Backends",
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"BaseURL": middleware.BaseURL(c),
|
||||
"Version": internal.PrintableVersion(),
|
||||
"Repositories": appConfig.BackendGalleries,
|
||||
}
|
||||
|
||||
// Render index - backends are now loaded via Alpine.js from /api/backends
|
||||
return c.Render("views/backends", summary)
|
||||
return c.Render(200, "views/backends", summary)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,24 +1,24 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/utils"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
)
|
||||
|
||||
func registerGalleryRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) {
|
||||
func registerGalleryRoutes(app *echo.Echo, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) {
|
||||
|
||||
app.Get("/browse", func(c *fiber.Ctx) error {
|
||||
summary := fiber.Map{
|
||||
app.GET("/browse/", func(c echo.Context) error {
|
||||
summary := map[string]interface{}{
|
||||
"Title": "LocalAI - Models",
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"BaseURL": middleware.BaseURL(c),
|
||||
"Version": internal.PrintableVersion(),
|
||||
"Repositories": appConfig.Galleries,
|
||||
}
|
||||
|
||||
// Render index - models are now loaded via Alpine.js from /api/models
|
||||
return c.Render("views/models", summary)
|
||||
return c.Render(200, "views/models", summary)
|
||||
})
|
||||
}
|
||||
|
||||
1
core/http/static/assets/autorefresh.min.js
vendored
Normal file
1
core/http/static/assets/autorefresh.min.js
vendored
Normal file
@@ -0,0 +1 @@
|
||||
!function(e){"object"==typeof exports&&"object"==typeof module?e(require("../../lib/codemirror")):"function"==typeof define&&define.amd?define(["../../lib/codemirror"],e):e(CodeMirror)}(function(u){"use strict";function f(e,t){clearTimeout(t.timeout),u.off(window,"mouseup",t.hurry),u.off(window,"keyup",t.hurry)}u.defineOption("autoRefresh",!1,function(e,t){function o(){i.display.wrapper.offsetHeight?(f(0,r),i.display.lastWrapHeight!=i.display.wrapper.clientHeight&&i.refresh()):r.timeout=setTimeout(o,r.delay)}var i,r;e.state.autoRefresh&&(f(0,e.state.autoRefresh),e.state.autoRefresh=null),t&&0==e.display.wrapper.offsetHeight&&((r=(i=e).state.autoRefresh={delay:t.delay||250}).timeout=setTimeout(o,r.delay),r.hurry=function(){clearTimeout(r.timeout),r.timeout=setTimeout(o,50)},u.on(window,"mouseup",r.hurry),u.on(window,"keyup",r.hurry))})});
|
||||
1
core/http/static/assets/codemirror.min.css
vendored
Normal file
1
core/http/static/assets/codemirror.min.css
vendored
Normal file
File diff suppressed because one or more lines are too long
1
core/http/static/assets/codemirror.min.js
vendored
Normal file
1
core/http/static/assets/codemirror.min.js
vendored
Normal file
File diff suppressed because one or more lines are too long
2
core/http/static/assets/js-yaml.min.js
vendored
Normal file
2
core/http/static/assets/js-yaml.min.js
vendored
Normal file
File diff suppressed because one or more lines are too long
22
core/http/static/assets/pdf.min.js
vendored
Normal file
22
core/http/static/assets/pdf.min.js
vendored
Normal file
File diff suppressed because one or more lines are too long
22
core/http/static/assets/pdf.worker.min.js
vendored
Normal file
22
core/http/static/assets/pdf.worker.min.js
vendored
Normal file
File diff suppressed because one or more lines are too long
1
core/http/static/assets/yaml.min.js
vendored
Normal file
1
core/http/static/assets/yaml.min.js
vendored
Normal file
@@ -0,0 +1 @@
|
||||
!function(e){"object"==typeof exports&&"object"==typeof module?e(require("../../lib/codemirror")):"function"==typeof define&&define.amd?define(["../../lib/codemirror"],e):e(CodeMirror)}(function(e){"use strict";e.defineMode("yaml",function(){var n=new RegExp("\\b(("+["true","false","on","off","yes","no"].join(")|(")+"))$","i");return{token:function(e,i){var t=e.peek(),r=i.escaped;if(i.escaped=!1,"#"==t&&(0==e.pos||/\s/.test(e.string.charAt(e.pos-1))))return e.skipToEnd(),"comment";if(e.match(/^('([^']|\\.)*'?|"([^"]|\\.)*"?)/))return"string";if(i.literal&&e.indentation()>i.keyCol)return e.skipToEnd(),"string";if(i.literal&&(i.literal=!1),e.sol()){if(i.keyCol=0,i.pair=!1,i.pairStart=!1,e.match("---"))return"def";if(e.match("..."))return"def";if(e.match(/\s*-\s+/))return"meta"}if(e.match(/^(\{|\}|\[|\])/))return"{"==t?i.inlinePairs++:"}"==t?i.inlinePairs--:"["==t?i.inlineList++:i.inlineList--,"meta";if(0<i.inlineList&&!r&&","==t)return e.next(),"meta";if(0<i.inlinePairs&&!r&&","==t)return i.keyCol=0,i.pair=!1,i.pairStart=!1,e.next(),"meta";if(i.pairStart){if(e.match(/^\s*(\||\>)\s*/))return i.literal=!0,"meta";if(e.match(/^\s*(\&|\*)[a-z0-9\._-]+\b/i))return"variable-2";if(0==i.inlinePairs&&e.match(/^\s*-?[0-9\.\,]+\s?$/))return"number";if(0<i.inlinePairs&&e.match(/^\s*-?[0-9\.\,]+\s?(?=(,|}))/))return"number";if(e.match(n))return"keyword"}return!i.pair&&e.match(/^\s*(?:[,\[\]{}&*!|>'"%@`][^\s'":]|[^,\[\]{}#&*!|>'"%@`])[^#]*?(?=\s*:($|\s))/)?(i.pair=!0,i.keyCol=e.indentation(),"atom"):i.pair&&e.match(/^:\s*/)?(i.pairStart=!0,"meta"):(i.pairStart=!1,i.escaped="\\"==t,e.next(),null)},startState:function(){return{pair:!1,pairStart:!1,keyCol:0,inlinePairs:0,inlineList:0,literal:!1,escaped:!1}},lineComment:"#",fold:"indent"}}),e.defineMIME("text/x-yaml","yaml"),e.defineMIME("text/yaml","yaml")});
|
||||
@@ -90,6 +90,23 @@ function updateTokensPerSecond() {
|
||||
}
|
||||
}
|
||||
|
||||
function scrollThinkingBoxToBottom() {
|
||||
// Find all thinking/reasoning message containers that are expanded
|
||||
const thinkingBoxes = document.querySelectorAll('[data-thinking-box]');
|
||||
thinkingBoxes.forEach(box => {
|
||||
// Only scroll if the box is visible (expanded) and has overflow
|
||||
if (box.offsetParent !== null && box.scrollHeight > box.clientHeight) {
|
||||
box.scrollTo({
|
||||
top: box.scrollHeight,
|
||||
behavior: 'smooth'
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Make function available globally
|
||||
window.scrollThinkingBoxToBottom = scrollThinkingBoxToBottom;
|
||||
|
||||
function stopRequest() {
|
||||
if (currentAbortController) {
|
||||
currentAbortController.abort();
|
||||
@@ -160,6 +177,9 @@ var images = [];
|
||||
var audios = [];
|
||||
var fileContents = [];
|
||||
var currentFileNames = [];
|
||||
// Track file names to data URLs for proper removal
|
||||
var imageFileMap = new Map(); // fileName -> dataURL
|
||||
var audioFileMap = new Map(); // fileName -> dataURL
|
||||
|
||||
async function extractTextFromPDF(pdfData) {
|
||||
try {
|
||||
@@ -180,35 +200,119 @@ async function extractTextFromPDF(pdfData) {
|
||||
}
|
||||
}
|
||||
|
||||
// Global function to handle file selection and update Alpine.js state
|
||||
window.handleFileSelection = function(event, fileType) {
|
||||
if (!event.target.files || !event.target.files.length) return;
|
||||
|
||||
// Get the Alpine.js component - find the parent div with x-data containing attachedFiles
|
||||
let inputContainer = event.target.closest('[x-data*="attachedFiles"]');
|
||||
if (!inputContainer && window.Alpine) {
|
||||
// Fallback: find any element with attachedFiles in x-data
|
||||
inputContainer = document.querySelector('[x-data*="attachedFiles"]');
|
||||
}
|
||||
if (!inputContainer || !window.Alpine) return;
|
||||
|
||||
const alpineData = Alpine.$data(inputContainer);
|
||||
if (!alpineData || !alpineData.attachedFiles) return;
|
||||
|
||||
Array.from(event.target.files).forEach(file => {
|
||||
// Check if file already exists
|
||||
const exists = alpineData.attachedFiles.some(f => f.name === file.name && f.type === fileType);
|
||||
if (!exists) {
|
||||
alpineData.attachedFiles.push({ name: file.name, type: fileType });
|
||||
|
||||
// Process the file based on type
|
||||
if (fileType === 'image') {
|
||||
readInputImageFile(file);
|
||||
} else if (fileType === 'audio') {
|
||||
readInputAudioFile(file);
|
||||
} else if (fileType === 'file') {
|
||||
readInputFileFile(file);
|
||||
}
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
// Global function to remove file from input
|
||||
window.removeFileFromInput = function(fileType, fileName) {
|
||||
// Remove from arrays
|
||||
if (fileType === 'image') {
|
||||
// Remove from images array using the mapping
|
||||
const dataURL = imageFileMap.get(fileName);
|
||||
if (dataURL) {
|
||||
const imageIndex = images.indexOf(dataURL);
|
||||
if (imageIndex !== -1) {
|
||||
images.splice(imageIndex, 1);
|
||||
}
|
||||
imageFileMap.delete(fileName);
|
||||
}
|
||||
} else if (fileType === 'audio') {
|
||||
// Remove from audios array using the mapping
|
||||
const dataURL = audioFileMap.get(fileName);
|
||||
if (dataURL) {
|
||||
const audioIndex = audios.indexOf(dataURL);
|
||||
if (audioIndex !== -1) {
|
||||
audios.splice(audioIndex, 1);
|
||||
}
|
||||
audioFileMap.delete(fileName);
|
||||
}
|
||||
} else if (fileType === 'file') {
|
||||
// Remove from fileContents and currentFileNames
|
||||
const fileIndex = currentFileNames.indexOf(fileName);
|
||||
if (fileIndex !== -1) {
|
||||
currentFileNames.splice(fileIndex, 1);
|
||||
fileContents.splice(fileIndex, 1);
|
||||
}
|
||||
}
|
||||
|
||||
// Also remove from the actual input element
|
||||
const inputId = fileType === 'image' ? 'input_image' :
|
||||
fileType === 'audio' ? 'input_audio' : 'input_file';
|
||||
const input = document.getElementById(inputId);
|
||||
if (input && input.files) {
|
||||
const dt = new DataTransfer();
|
||||
Array.from(input.files).forEach(file => {
|
||||
if (file.name !== fileName) {
|
||||
dt.items.add(file);
|
||||
}
|
||||
});
|
||||
input.files = dt.files;
|
||||
}
|
||||
};
|
||||
|
||||
function readInputFile() {
|
||||
if (!this.files || !this.files.length) return;
|
||||
|
||||
Array.from(this.files).forEach(file => {
|
||||
const FR = new FileReader();
|
||||
currentFileNames.push(file.name);
|
||||
const fileExtension = file.name.split('.').pop().toLowerCase();
|
||||
|
||||
FR.addEventListener("load", async function(evt) {
|
||||
if (fileExtension === 'pdf') {
|
||||
try {
|
||||
const content = await extractTextFromPDF(evt.target.result);
|
||||
fileContents.push({ name: file.name, content: content });
|
||||
} catch (error) {
|
||||
console.error('Error processing PDF:', error);
|
||||
fileContents.push({ name: file.name, content: "Error processing PDF file" });
|
||||
}
|
||||
} else {
|
||||
// For text and markdown files
|
||||
fileContents.push({ name: file.name, content: evt.target.result });
|
||||
}
|
||||
});
|
||||
readInputFileFile(file);
|
||||
});
|
||||
}
|
||||
|
||||
function readInputFileFile(file) {
|
||||
const FR = new FileReader();
|
||||
currentFileNames.push(file.name);
|
||||
const fileExtension = file.name.split('.').pop().toLowerCase();
|
||||
|
||||
FR.addEventListener("load", async function(evt) {
|
||||
if (fileExtension === 'pdf') {
|
||||
FR.readAsArrayBuffer(file);
|
||||
try {
|
||||
const content = await extractTextFromPDF(evt.target.result);
|
||||
fileContents.push({ name: file.name, content: content });
|
||||
} catch (error) {
|
||||
console.error('Error processing PDF:', error);
|
||||
fileContents.push({ name: file.name, content: "Error processing PDF file" });
|
||||
}
|
||||
} else {
|
||||
FR.readAsText(file);
|
||||
// For text and markdown files
|
||||
fileContents.push({ name: file.name, content: evt.target.result });
|
||||
}
|
||||
});
|
||||
|
||||
if (fileExtension === 'pdf') {
|
||||
FR.readAsArrayBuffer(file);
|
||||
} else {
|
||||
FR.readAsText(file);
|
||||
}
|
||||
}
|
||||
|
||||
function submitPrompt(event) {
|
||||
@@ -267,7 +371,15 @@ function processAndSendMessage(inputValue) {
|
||||
const input = document.getElementById("input");
|
||||
if (input) input.value = "";
|
||||
const systemPrompt = localStorage.getItem("system_prompt");
|
||||
Alpine.nextTick(() => { document.getElementById('messages').scrollIntoView(false); });
|
||||
Alpine.nextTick(() => {
|
||||
const chatContainer = document.getElementById('chat');
|
||||
if (chatContainer) {
|
||||
chatContainer.scrollTo({
|
||||
top: chatContainer.scrollHeight,
|
||||
behavior: 'smooth'
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
// Reset token tracking before starting new request
|
||||
requestStartTime = Date.now();
|
||||
@@ -278,36 +390,66 @@ function processAndSendMessage(inputValue) {
|
||||
// Reset file contents and names after sending
|
||||
fileContents = [];
|
||||
currentFileNames = [];
|
||||
images = [];
|
||||
audios = [];
|
||||
imageFileMap.clear();
|
||||
audioFileMap.clear();
|
||||
|
||||
// Clear Alpine.js attachedFiles array
|
||||
const inputContainer = document.querySelector('[x-data*="attachedFiles"]');
|
||||
if (inputContainer && window.Alpine) {
|
||||
const alpineData = Alpine.$data(inputContainer);
|
||||
if (alpineData && alpineData.attachedFiles) {
|
||||
alpineData.attachedFiles = [];
|
||||
}
|
||||
}
|
||||
|
||||
// Clear file inputs
|
||||
document.getElementById("input_image").value = null;
|
||||
document.getElementById("input_audio").value = null;
|
||||
document.getElementById("input_file").value = null;
|
||||
}
|
||||
|
||||
function readInputImage() {
|
||||
if (!this.files || !this.files.length) return;
|
||||
|
||||
Array.from(this.files).forEach(file => {
|
||||
const FR = new FileReader();
|
||||
|
||||
FR.addEventListener("load", function(evt) {
|
||||
images.push(evt.target.result);
|
||||
});
|
||||
|
||||
FR.readAsDataURL(file);
|
||||
readInputImageFile(file);
|
||||
});
|
||||
}
|
||||
|
||||
function readInputImageFile(file) {
|
||||
const FR = new FileReader();
|
||||
|
||||
FR.addEventListener("load", function(evt) {
|
||||
const dataURL = evt.target.result;
|
||||
images.push(dataURL);
|
||||
imageFileMap.set(file.name, dataURL);
|
||||
});
|
||||
|
||||
FR.readAsDataURL(file);
|
||||
}
|
||||
|
||||
function readInputAudio() {
|
||||
if (!this.files || !this.files.length) return;
|
||||
|
||||
Array.from(this.files).forEach(file => {
|
||||
const FR = new FileReader();
|
||||
|
||||
FR.addEventListener("load", function(evt) {
|
||||
audios.push(evt.target.result);
|
||||
});
|
||||
|
||||
FR.readAsDataURL(file);
|
||||
readInputAudioFile(file);
|
||||
});
|
||||
}
|
||||
|
||||
function readInputAudioFile(file) {
|
||||
const FR = new FileReader();
|
||||
|
||||
FR.addEventListener("load", function(evt) {
|
||||
const dataURL = evt.target.result;
|
||||
audios.push(dataURL);
|
||||
audioFileMap.set(file.name, dataURL);
|
||||
});
|
||||
|
||||
FR.readAsDataURL(file);
|
||||
}
|
||||
|
||||
async function promptGPT(systemPrompt, input) {
|
||||
const model = document.getElementById("chat-model").value;
|
||||
const mcpMode = Alpine.store("chat").mcpMode;
|
||||
@@ -370,25 +512,18 @@ async function promptGPT(systemPrompt, input) {
|
||||
}
|
||||
});
|
||||
|
||||
// reset the form and the files
|
||||
images = [];
|
||||
audios = [];
|
||||
document.getElementById("input_image").value = null;
|
||||
document.getElementById("input_audio").value = null;
|
||||
document.getElementById("input_file").value = null;
|
||||
document.getElementById("fileName").innerHTML = "";
|
||||
// reset the form and the files (already done in processAndSendMessage)
|
||||
// images, audios, and file inputs are cleared after sending
|
||||
|
||||
// Choose endpoint based on MCP mode
|
||||
const endpoint = mcpMode ? "mcp/v1/chat/completions" : "v1/chat/completions";
|
||||
const endpoint = mcpMode ? "v1/mcp/chat/completions" : "v1/chat/completions";
|
||||
const requestBody = {
|
||||
model: model,
|
||||
messages: messages,
|
||||
};
|
||||
|
||||
// Only add stream parameter for regular chat (MCP doesn't support streaming)
|
||||
if (!mcpMode) {
|
||||
requestBody.stream = true;
|
||||
}
|
||||
// Add stream parameter for both regular chat and MCP (MCP now supports SSE streaming)
|
||||
requestBody.stream = true;
|
||||
|
||||
let response;
|
||||
try {
|
||||
@@ -444,64 +579,441 @@ async function promptGPT(systemPrompt, input) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle streaming response (both regular and MCP mode now use SSE)
|
||||
if (mcpMode) {
|
||||
// Handle MCP non-streaming response
|
||||
// Handle MCP SSE streaming with new event types
|
||||
const reader = response.body
|
||||
?.pipeThrough(new TextDecoderStream())
|
||||
.getReader();
|
||||
|
||||
if (!reader) {
|
||||
Alpine.store("chat").add(
|
||||
"assistant",
|
||||
`<span class='error'>Error: Failed to decode MCP API response</span>`,
|
||||
);
|
||||
toggleLoader(false);
|
||||
return;
|
||||
}
|
||||
|
||||
// Store reader globally so stop button can cancel it
|
||||
currentReader = reader;
|
||||
|
||||
let buffer = "";
|
||||
let assistantContent = "";
|
||||
let assistantContentBuffer = [];
|
||||
let thinkingContent = "";
|
||||
let isThinking = false;
|
||||
let lastAssistantMessageIndex = -1;
|
||||
let lastThinkingMessageIndex = -1;
|
||||
let lastThinkingScrollTime = 0;
|
||||
const THINKING_SCROLL_THROTTLE = 200; // Throttle scrolling to every 200ms
|
||||
|
||||
try {
|
||||
const data = await response.json();
|
||||
|
||||
// Update token usage if present
|
||||
if (data.usage) {
|
||||
Alpine.store("chat").updateTokenUsage(data.usage);
|
||||
}
|
||||
|
||||
// MCP endpoint returns content in choices[0].message.content (chat completion format)
|
||||
// Fallback to choices[0].text for backward compatibility (completion format)
|
||||
const content = data.choices[0]?.message?.content || data.choices[0]?.text || "";
|
||||
|
||||
if (!content && (!data.choices || data.choices.length === 0)) {
|
||||
Alpine.store("chat").add(
|
||||
"assistant",
|
||||
`<span class='error'>Error: Empty response from MCP endpoint</span>`,
|
||||
);
|
||||
toggleLoader(false);
|
||||
return;
|
||||
}
|
||||
|
||||
if (content) {
|
||||
// Count tokens for rate calculation (MCP mode - full content at once)
|
||||
// Prefer actual token count from API if available
|
||||
if (data.usage && data.usage.completion_tokens) {
|
||||
tokensReceived = data.usage.completion_tokens;
|
||||
} else {
|
||||
tokensReceived += Math.ceil(content.length / 4);
|
||||
while (true) {
|
||||
const { value, done } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
buffer += value;
|
||||
|
||||
let lines = buffer.split("\n");
|
||||
buffer = lines.pop(); // Retain any incomplete line in the buffer
|
||||
|
||||
lines.forEach((line) => {
|
||||
if (line.length === 0 || line.startsWith(":")) return;
|
||||
if (line === "data: [DONE]") {
|
||||
return;
|
||||
}
|
||||
|
||||
if (line.startsWith("data: ")) {
|
||||
try {
|
||||
const eventData = JSON.parse(line.substring(6));
|
||||
|
||||
// Handle different event types
|
||||
switch (eventData.type) {
|
||||
case "reasoning":
|
||||
if (eventData.content) {
|
||||
const chatStore = Alpine.store("chat");
|
||||
// Insert reasoning before assistant message if it exists
|
||||
if (lastAssistantMessageIndex >= 0 && chatStore.history[lastAssistantMessageIndex]?.role === "assistant") {
|
||||
chatStore.history.splice(lastAssistantMessageIndex, 0, {
|
||||
role: "reasoning",
|
||||
content: eventData.content,
|
||||
html: DOMPurify.sanitize(marked.parse(eventData.content)),
|
||||
image: [],
|
||||
audio: [],
|
||||
expanded: false // Reasoning is always collapsed
|
||||
});
|
||||
lastAssistantMessageIndex++; // Adjust index since we inserted
|
||||
// Scroll smoothly after adding reasoning
|
||||
setTimeout(() => {
|
||||
const chatContainer = document.getElementById('chat');
|
||||
if (chatContainer) {
|
||||
chatContainer.scrollTo({
|
||||
top: chatContainer.scrollHeight,
|
||||
behavior: 'smooth'
|
||||
});
|
||||
}
|
||||
}, 100);
|
||||
} else {
|
||||
// No assistant message yet, just add normally
|
||||
chatStore.add("reasoning", eventData.content);
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
case "tool_call":
|
||||
if (eventData.name) {
|
||||
// Store as JSON for better formatting
|
||||
const toolCallData = {
|
||||
name: eventData.name,
|
||||
arguments: eventData.arguments || {},
|
||||
reasoning: eventData.reasoning || ""
|
||||
};
|
||||
Alpine.store("chat").add("tool_call", JSON.stringify(toolCallData, null, 2));
|
||||
// Scroll smoothly after adding tool call
|
||||
setTimeout(() => {
|
||||
const chatContainer = document.getElementById('chat');
|
||||
if (chatContainer) {
|
||||
chatContainer.scrollTo({
|
||||
top: chatContainer.scrollHeight,
|
||||
behavior: 'smooth'
|
||||
});
|
||||
}
|
||||
}, 100);
|
||||
}
|
||||
break;
|
||||
|
||||
case "tool_result":
|
||||
if (eventData.name) {
|
||||
// Store as JSON for better formatting
|
||||
const toolResultData = {
|
||||
name: eventData.name,
|
||||
result: eventData.result || ""
|
||||
};
|
||||
Alpine.store("chat").add("tool_result", JSON.stringify(toolResultData, null, 2));
|
||||
// Scroll smoothly after adding tool result
|
||||
setTimeout(() => {
|
||||
const chatContainer = document.getElementById('chat');
|
||||
if (chatContainer) {
|
||||
chatContainer.scrollTo({
|
||||
top: chatContainer.scrollHeight,
|
||||
behavior: 'smooth'
|
||||
});
|
||||
}
|
||||
}, 100);
|
||||
}
|
||||
break;
|
||||
|
||||
case "status":
|
||||
// Status messages can be logged but not necessarily displayed
|
||||
console.log("[MCP Status]", eventData.message);
|
||||
break;
|
||||
|
||||
case "assistant":
|
||||
if (eventData.content) {
|
||||
assistantContent += eventData.content;
|
||||
const contentChunk = eventData.content;
|
||||
|
||||
// Count tokens for rate calculation
|
||||
tokensReceived += Math.ceil(contentChunk.length / 4);
|
||||
updateTokensPerSecond();
|
||||
|
||||
// Check for thinking tags in the chunk (incremental detection)
|
||||
if (contentChunk.includes("<thinking>") || contentChunk.includes("<think>")) {
|
||||
isThinking = true;
|
||||
thinkingContent = "";
|
||||
lastThinkingMessageIndex = -1;
|
||||
}
|
||||
|
||||
if (contentChunk.includes("</thinking>") || contentChunk.includes("</think>")) {
|
||||
isThinking = false;
|
||||
// When closing tag is detected, process the accumulated thinking content
|
||||
if (thinkingContent.trim()) {
|
||||
// Extract just the thinking part from the accumulated content
|
||||
const thinkingMatch = thinkingContent.match(/<(?:thinking|redacted_reasoning)>(.*?)<\/(?:thinking|redacted_reasoning)>/s);
|
||||
if (thinkingMatch && thinkingMatch[1]) {
|
||||
const extractedThinking = thinkingMatch[1];
|
||||
const chatStore = Alpine.store("chat");
|
||||
const isMCPMode = chatStore.mcpMode || false;
|
||||
const shouldExpand = !isMCPMode; // Expanded in non-MCP mode, collapsed in MCP mode
|
||||
if (lastThinkingMessageIndex === -1) {
|
||||
// Insert thinking before the last assistant message if it exists
|
||||
if (lastAssistantMessageIndex >= 0 && chatStore.history[lastAssistantMessageIndex]?.role === "assistant") {
|
||||
// Insert before assistant message
|
||||
chatStore.history.splice(lastAssistantMessageIndex, 0, {
|
||||
role: "thinking",
|
||||
content: extractedThinking,
|
||||
html: DOMPurify.sanitize(marked.parse(extractedThinking)),
|
||||
image: [],
|
||||
audio: [],
|
||||
expanded: shouldExpand
|
||||
});
|
||||
lastThinkingMessageIndex = lastAssistantMessageIndex;
|
||||
lastAssistantMessageIndex++; // Adjust index since we inserted
|
||||
} else {
|
||||
// No assistant message yet, just add normally
|
||||
chatStore.add("thinking", extractedThinking);
|
||||
lastThinkingMessageIndex = chatStore.history.length - 1;
|
||||
}
|
||||
} else {
|
||||
// Update existing thinking message
|
||||
const lastMessage = chatStore.history[lastThinkingMessageIndex];
|
||||
if (lastMessage && lastMessage.role === "thinking") {
|
||||
lastMessage.content = extractedThinking;
|
||||
lastMessage.html = DOMPurify.sanitize(marked.parse(extractedThinking));
|
||||
}
|
||||
}
|
||||
// Scroll when thinking is finalized in non-MCP mode
|
||||
if (!isMCPMode) {
|
||||
setTimeout(() => {
|
||||
const chatContainer = document.getElementById('chat');
|
||||
if (chatContainer) {
|
||||
chatContainer.scrollTo({
|
||||
top: chatContainer.scrollHeight,
|
||||
behavior: 'smooth'
|
||||
});
|
||||
}
|
||||
}, 50);
|
||||
}
|
||||
}
|
||||
thinkingContent = "";
|
||||
}
|
||||
}
|
||||
|
||||
// Handle content based on thinking state
|
||||
if (isThinking) {
|
||||
thinkingContent += contentChunk;
|
||||
const chatStore = Alpine.store("chat");
|
||||
const isMCPMode = chatStore.mcpMode || false;
|
||||
const shouldExpand = !isMCPMode; // Expanded in non-MCP mode, collapsed in MCP mode
|
||||
// Update the last thinking message or create a new one (incremental)
|
||||
if (lastThinkingMessageIndex === -1) {
|
||||
// Insert thinking before the last assistant message if it exists
|
||||
if (lastAssistantMessageIndex >= 0 && chatStore.history[lastAssistantMessageIndex]?.role === "assistant") {
|
||||
// Insert before assistant message
|
||||
chatStore.history.splice(lastAssistantMessageIndex, 0, {
|
||||
role: "thinking",
|
||||
content: thinkingContent,
|
||||
html: DOMPurify.sanitize(marked.parse(thinkingContent)),
|
||||
image: [],
|
||||
audio: [],
|
||||
expanded: shouldExpand
|
||||
});
|
||||
lastThinkingMessageIndex = lastAssistantMessageIndex;
|
||||
lastAssistantMessageIndex++; // Adjust index since we inserted
|
||||
} else {
|
||||
// No assistant message yet, just add normally
|
||||
chatStore.add("thinking", thinkingContent);
|
||||
lastThinkingMessageIndex = chatStore.history.length - 1;
|
||||
}
|
||||
} else {
|
||||
// Update existing thinking message
|
||||
const lastMessage = chatStore.history[lastThinkingMessageIndex];
|
||||
if (lastMessage && lastMessage.role === "thinking") {
|
||||
lastMessage.content = thinkingContent;
|
||||
lastMessage.html = DOMPurify.sanitize(marked.parse(thinkingContent));
|
||||
}
|
||||
}
|
||||
// Scroll when thinking is updated in non-MCP mode (throttled)
|
||||
if (!isMCPMode) {
|
||||
const now = Date.now();
|
||||
if (now - lastThinkingScrollTime > THINKING_SCROLL_THROTTLE) {
|
||||
lastThinkingScrollTime = now;
|
||||
setTimeout(() => {
|
||||
const chatContainer = document.getElementById('chat');
|
||||
if (chatContainer) {
|
||||
chatContainer.scrollTo({
|
||||
top: chatContainer.scrollHeight,
|
||||
behavior: 'smooth'
|
||||
});
|
||||
}
|
||||
}, 100);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Regular assistant content - buffer it for batch processing
|
||||
assistantContentBuffer.push(contentChunk);
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
case "error":
|
||||
Alpine.store("chat").add(
|
||||
"assistant",
|
||||
`<span class='error'>MCP Error: ${eventData.message}</span>`,
|
||||
);
|
||||
break;
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to parse MCP event:", line, error);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Efficiently update assistant message in batch
|
||||
if (assistantContentBuffer.length > 0) {
|
||||
const regularContent = assistantContentBuffer.join("");
|
||||
|
||||
// Process any thinking tags that might be in the accumulated content
|
||||
// This handles cases where tags are split across chunks
|
||||
const { regularContent: processedRegular, thinkingContent: processedThinking } = processThinkingTags(regularContent);
|
||||
|
||||
// Update or create assistant message with processed regular content
|
||||
if (lastAssistantMessageIndex === -1) {
|
||||
if (processedRegular && processedRegular.trim()) {
|
||||
Alpine.store("chat").add("assistant", processedRegular);
|
||||
lastAssistantMessageIndex = Alpine.store("chat").history.length - 1;
|
||||
}
|
||||
} else {
|
||||
const chatStore = Alpine.store("chat");
|
||||
const lastMessage = chatStore.history[lastAssistantMessageIndex];
|
||||
if (lastMessage && lastMessage.role === "assistant") {
|
||||
lastMessage.content = (lastMessage.content || "") + (processedRegular || "");
|
||||
lastMessage.html = DOMPurify.sanitize(marked.parse(lastMessage.content));
|
||||
}
|
||||
}
|
||||
|
||||
// Add any extracted thinking content from the processed buffer BEFORE assistant message
|
||||
if (processedThinking && processedThinking.trim()) {
|
||||
const chatStore = Alpine.store("chat");
|
||||
const isMCPMode = chatStore.mcpMode || false;
|
||||
const shouldExpand = !isMCPMode; // Expanded in non-MCP mode, collapsed in MCP mode
|
||||
// Insert thinking before assistant message if it exists
|
||||
if (lastAssistantMessageIndex >= 0 && chatStore.history[lastAssistantMessageIndex]?.role === "assistant") {
|
||||
chatStore.history.splice(lastAssistantMessageIndex, 0, {
|
||||
role: "thinking",
|
||||
content: processedThinking,
|
||||
html: DOMPurify.sanitize(marked.parse(processedThinking)),
|
||||
image: [],
|
||||
audio: [],
|
||||
expanded: shouldExpand
|
||||
});
|
||||
lastAssistantMessageIndex++; // Adjust index since we inserted
|
||||
} else {
|
||||
// No assistant message yet, just add normally
|
||||
chatStore.add("thinking", processedThinking);
|
||||
}
|
||||
}
|
||||
|
||||
assistantContentBuffer = [];
|
||||
}
|
||||
updateTokensPerSecond();
|
||||
}
|
||||
|
||||
// Final assistant content flush if any data remains
|
||||
if (assistantContentBuffer.length > 0) {
|
||||
const regularContent = assistantContentBuffer.join("");
|
||||
// Process any remaining thinking tags that might be in the buffer
|
||||
const { regularContent: processedRegular, thinkingContent: processedThinking } = processThinkingTags(regularContent);
|
||||
|
||||
// Process thinking tags using shared function
|
||||
const { regularContent, thinkingContent } = processThinkingTags(content);
|
||||
const chatStore = Alpine.store("chat");
|
||||
|
||||
// Add thinking content if present
|
||||
if (thinkingContent) {
|
||||
// First, add any extracted thinking content BEFORE assistant message
|
||||
if (processedThinking && processedThinking.trim()) {
|
||||
const isMCPMode = chatStore.mcpMode || false;
|
||||
const shouldExpand = !isMCPMode; // Expanded in non-MCP mode, collapsed in MCP mode
|
||||
// Insert thinking before assistant message if it exists
|
||||
if (lastAssistantMessageIndex >= 0 && chatStore.history[lastAssistantMessageIndex]?.role === "assistant") {
|
||||
chatStore.history.splice(lastAssistantMessageIndex, 0, {
|
||||
role: "thinking",
|
||||
content: processedThinking,
|
||||
html: DOMPurify.sanitize(marked.parse(processedThinking)),
|
||||
image: [],
|
||||
audio: [],
|
||||
expanded: shouldExpand
|
||||
});
|
||||
lastAssistantMessageIndex++; // Adjust index since we inserted
|
||||
} else {
|
||||
// No assistant message yet, just add normally
|
||||
chatStore.add("thinking", processedThinking);
|
||||
}
|
||||
}
|
||||
|
||||
// Then update or create assistant message
|
||||
if (lastAssistantMessageIndex !== -1) {
|
||||
const lastMessage = chatStore.history[lastAssistantMessageIndex];
|
||||
if (lastMessage && lastMessage.role === "assistant") {
|
||||
lastMessage.content = (lastMessage.content || "") + (processedRegular || "");
|
||||
lastMessage.html = DOMPurify.sanitize(marked.parse(lastMessage.content));
|
||||
}
|
||||
} else if (processedRegular && processedRegular.trim()) {
|
||||
chatStore.add("assistant", processedRegular);
|
||||
lastAssistantMessageIndex = chatStore.history.length - 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Final thinking content flush if any data remains (from incremental detection)
|
||||
if (thinkingContent.trim() && lastThinkingMessageIndex === -1) {
|
||||
// Extract thinking content if tags are present
|
||||
const thinkingMatch = thinkingContent.match(/<(?:thinking|redacted_reasoning)>(.*?)<\/(?:thinking|redacted_reasoning)>/s);
|
||||
if (thinkingMatch && thinkingMatch[1]) {
|
||||
const chatStore = Alpine.store("chat");
|
||||
const isMCPMode = chatStore.mcpMode || false;
|
||||
const shouldExpand = !isMCPMode; // Expanded in non-MCP mode, collapsed in MCP mode
|
||||
// Insert thinking before assistant message if it exists
|
||||
if (lastAssistantMessageIndex >= 0 && chatStore.history[lastAssistantMessageIndex]?.role === "assistant") {
|
||||
chatStore.history.splice(lastAssistantMessageIndex, 0, {
|
||||
role: "thinking",
|
||||
content: thinkingMatch[1],
|
||||
html: DOMPurify.sanitize(marked.parse(thinkingMatch[1])),
|
||||
image: [],
|
||||
audio: [],
|
||||
expanded: shouldExpand
|
||||
});
|
||||
} else {
|
||||
// No assistant message yet, just add normally
|
||||
chatStore.add("thinking", thinkingMatch[1]);
|
||||
}
|
||||
} else {
|
||||
Alpine.store("chat").add("thinking", thinkingContent);
|
||||
}
|
||||
|
||||
// Add regular content if present
|
||||
if (regularContent) {
|
||||
Alpine.store("chat").add("assistant", regularContent);
|
||||
}
|
||||
}
|
||||
|
||||
// Highlight all code blocks
|
||||
// Final pass: process the entire assistantContent to catch any missed thinking tags
|
||||
// This ensures we don't miss tags that were split across chunks
|
||||
if (assistantContent.trim()) {
|
||||
const { regularContent: finalRegular, thinkingContent: finalThinking } = processThinkingTags(assistantContent);
|
||||
|
||||
// Update assistant message with final processed content (without thinking tags)
|
||||
if (finalRegular && finalRegular.trim()) {
|
||||
if (lastAssistantMessageIndex !== -1) {
|
||||
const chatStore = Alpine.store("chat");
|
||||
const lastMessage = chatStore.history[lastAssistantMessageIndex];
|
||||
if (lastMessage && lastMessage.role === "assistant") {
|
||||
lastMessage.content = finalRegular;
|
||||
lastMessage.html = DOMPurify.sanitize(marked.parse(lastMessage.content));
|
||||
}
|
||||
} else {
|
||||
Alpine.store("chat").add("assistant", finalRegular);
|
||||
}
|
||||
}
|
||||
|
||||
// Add any extracted thinking content (only if not already added)
|
||||
if (finalThinking && finalThinking.trim()) {
|
||||
const hasThinking = Alpine.store("chat").history.some(msg =>
|
||||
msg.role === "thinking" && msg.content.trim() === finalThinking.trim()
|
||||
);
|
||||
if (!hasThinking) {
|
||||
Alpine.store("chat").add("thinking", finalThinking);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Highlight all code blocks once at the end
|
||||
hljs.highlightAll();
|
||||
} catch (error) {
|
||||
// Don't show error if request was aborted by user
|
||||
if (error.name !== 'AbortError' || currentAbortController) {
|
||||
if (error.name !== 'AbortError' || !currentAbortController) {
|
||||
Alpine.store("chat").add(
|
||||
"assistant",
|
||||
`<span class='error'>Error: Failed to parse MCP response</span>`,
|
||||
`<span class='error'>Error: Failed to process MCP stream</span>`,
|
||||
);
|
||||
}
|
||||
} finally {
|
||||
// Perform any cleanup if necessary
|
||||
if (reader) {
|
||||
reader.releaseLock();
|
||||
}
|
||||
currentReader = null;
|
||||
currentAbortController = null;
|
||||
}
|
||||
} else {
|
||||
@@ -539,6 +1051,8 @@ async function promptGPT(systemPrompt, input) {
|
||||
let thinkingContent = "";
|
||||
let isThinking = false;
|
||||
let lastThinkingMessageIndex = -1;
|
||||
let lastThinkingScrollTime = 0;
|
||||
const THINKING_SCROLL_THROTTLE = 200; // Throttle scrolling to every 200ms
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
@@ -606,6 +1120,23 @@ async function promptGPT(systemPrompt, input) {
|
||||
lastMessage.html = DOMPurify.sanitize(marked.parse(thinkingContent));
|
||||
}
|
||||
}
|
||||
// Scroll when thinking is updated (throttled)
|
||||
const now = Date.now();
|
||||
if (now - lastThinkingScrollTime > THINKING_SCROLL_THROTTLE) {
|
||||
lastThinkingScrollTime = now;
|
||||
setTimeout(() => {
|
||||
// Scroll main chat container
|
||||
const chatContainer = document.getElementById('chat');
|
||||
if (chatContainer) {
|
||||
chatContainer.scrollTo({
|
||||
top: chatContainer.scrollHeight,
|
||||
behavior: 'smooth'
|
||||
});
|
||||
}
|
||||
// Scroll thinking box to bottom if it's expanded and scrollable
|
||||
scrollThinkingBoxToBottom();
|
||||
}, 100);
|
||||
}
|
||||
} else {
|
||||
contentBuffer.push(token);
|
||||
}
|
||||
@@ -620,6 +1151,16 @@ async function promptGPT(systemPrompt, input) {
|
||||
if (contentBuffer.length > 0) {
|
||||
addToChat(contentBuffer.join(""));
|
||||
contentBuffer = [];
|
||||
// Scroll when assistant content is updated (this will also show thinking messages above)
|
||||
setTimeout(() => {
|
||||
const chatContainer = document.getElementById('chat');
|
||||
if (chatContainer) {
|
||||
chatContainer.scrollTo({
|
||||
top: chatContainer.scrollHeight,
|
||||
behavior: 'smooth'
|
||||
});
|
||||
}
|
||||
}, 50);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -654,8 +1195,17 @@ async function promptGPT(systemPrompt, input) {
|
||||
// Remove class "loader" from the element with "loader" id
|
||||
toggleLoader(false);
|
||||
|
||||
// scroll to the bottom of the chat
|
||||
document.getElementById('messages').scrollIntoView(false)
|
||||
// scroll to the bottom of the chat consistently
|
||||
setTimeout(() => {
|
||||
const chatContainer = document.getElementById('chat');
|
||||
if (chatContainer) {
|
||||
chatContainer.scrollTo({
|
||||
top: chatContainer.scrollHeight,
|
||||
behavior: 'smooth'
|
||||
});
|
||||
}
|
||||
}, 100);
|
||||
|
||||
// set focus to the input
|
||||
document.getElementById("input").focus();
|
||||
}
|
||||
@@ -748,8 +1298,8 @@ document.addEventListener("alpine:init", () => {
|
||||
},
|
||||
add(role, content, image, audio) {
|
||||
const N = this.history.length - 1;
|
||||
// For thinking messages, always create a new message
|
||||
if (role === "thinking") {
|
||||
// For thinking and reasoning messages, always create a new message
|
||||
if (role === "thinking" || role === "reasoning") {
|
||||
let c = "";
|
||||
const lines = content.split("\n");
|
||||
lines.forEach((line) => {
|
||||
@@ -784,7 +1334,21 @@ document.addEventListener("alpine:init", () => {
|
||||
audio: audio || []
|
||||
});
|
||||
}
|
||||
document.getElementById('messages').scrollIntoView(false);
|
||||
const chatContainer = document.getElementById('chat');
|
||||
if (chatContainer) {
|
||||
chatContainer.scrollTo({
|
||||
top: chatContainer.scrollHeight,
|
||||
behavior: 'smooth'
|
||||
});
|
||||
}
|
||||
// Also scroll thinking box if it's a thinking/reasoning message
|
||||
if (role === "thinking" || role === "reasoning") {
|
||||
setTimeout(() => {
|
||||
if (typeof window.scrollThinkingBoxToBottom === 'function') {
|
||||
window.scrollThinkingBoxToBottom();
|
||||
}
|
||||
}, 100);
|
||||
}
|
||||
const parser = new DOMParser();
|
||||
const html = parser.parseFromString(
|
||||
this.history[this.history.length - 1].html,
|
||||
@@ -812,3 +1376,56 @@ document.addEventListener("alpine:init", () => {
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
// Check for message from index page on load
|
||||
document.addEventListener('DOMContentLoaded', function() {
|
||||
// Wait for Alpine to be ready
|
||||
setTimeout(() => {
|
||||
const chatData = localStorage.getItem('localai_index_chat_data');
|
||||
if (chatData) {
|
||||
try {
|
||||
const data = JSON.parse(chatData);
|
||||
const input = document.getElementById('input');
|
||||
|
||||
if (input && data.message) {
|
||||
// Set the message in the input
|
||||
input.value = data.message;
|
||||
|
||||
// Process files if any
|
||||
if (data.imageFiles && data.imageFiles.length > 0) {
|
||||
data.imageFiles.forEach(file => {
|
||||
images.push(file.data);
|
||||
});
|
||||
}
|
||||
|
||||
if (data.audioFiles && data.audioFiles.length > 0) {
|
||||
data.audioFiles.forEach(file => {
|
||||
audios.push(file.data);
|
||||
});
|
||||
}
|
||||
|
||||
if (data.textFiles && data.textFiles.length > 0) {
|
||||
data.textFiles.forEach(file => {
|
||||
fileContents.push({ name: file.name, content: file.data });
|
||||
currentFileNames.push(file.name);
|
||||
});
|
||||
}
|
||||
|
||||
// Clear localStorage
|
||||
localStorage.removeItem('localai_index_chat_data');
|
||||
|
||||
// Auto-submit after a short delay to ensure everything is ready
|
||||
setTimeout(() => {
|
||||
if (input.value.trim()) {
|
||||
processAndSendMessage(input.value);
|
||||
}
|
||||
}, 500);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error processing chat data from index:', error);
|
||||
localStorage.removeItem('localai_index_chat_data');
|
||||
}
|
||||
}
|
||||
}, 300);
|
||||
});
|
||||
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// BaseURL returns the base URL for the given HTTP request context.
|
||||
// It takes into account that the app may be exposed by a reverse-proxy under a different protocol, host and path.
|
||||
// The returned URL is guaranteed to end with `/`.
|
||||
// The method should be used in conjunction with the StripPathPrefix middleware.
|
||||
func BaseURL(c *fiber.Ctx) string {
|
||||
path := c.Path()
|
||||
origPath := c.OriginalURL()
|
||||
|
||||
if path != origPath && strings.HasSuffix(origPath, path) {
|
||||
pathPrefix := origPath[:len(origPath)-len(path)+1]
|
||||
|
||||
return c.BaseURL() + pathPrefix
|
||||
}
|
||||
|
||||
return c.BaseURL() + "/"
|
||||
}
|
||||
@@ -1,48 +0,0 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBaseURL(t *testing.T) {
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
prefix string
|
||||
expectURL string
|
||||
}{
|
||||
{
|
||||
name: "without prefix",
|
||||
prefix: "/",
|
||||
expectURL: "http://example.com/",
|
||||
},
|
||||
{
|
||||
name: "with prefix",
|
||||
prefix: "/myprefix/",
|
||||
expectURL: "http://example.com/myprefix/",
|
||||
},
|
||||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
app := fiber.New()
|
||||
actualURL := ""
|
||||
|
||||
app.Get(tc.prefix+"hello/world", func(c *fiber.Ctx) error {
|
||||
if tc.prefix != "/" {
|
||||
c.Path("/hello/world")
|
||||
}
|
||||
actualURL = BaseURL(c)
|
||||
return nil
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", tc.prefix+"hello/world", nil)
|
||||
resp, err := app.Test(req, -1)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 200, resp.StatusCode, "response status code")
|
||||
require.Equal(t, tc.expectURL, actualURL, "base URL")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -9,9 +9,9 @@
|
||||
|
||||
<div class="container mx-auto px-4 py-8 flex-grow">
|
||||
<!-- Error Section -->
|
||||
<div class="bg-[#1E293B] border border-[#38BDF8]/20 rounded-2xl shadow-2xl shadow-[#38BDF8]/10 p-8 mb-10">
|
||||
<div class="bg-[#1E293B] border border-[#38BDF8]/20 rounded-xl p-8 mb-10">
|
||||
<div class="max-w-4xl mx-auto text-center">
|
||||
<div class="mb-6 text-6xl text-[#38BDF8] animate-pulse">
|
||||
<div class="mb-6 text-6xl text-[#38BDF8]">
|
||||
<i class="fas fa-exclamation-circle"></i>
|
||||
</div>
|
||||
<h1 class="text-4xl md:text-5xl font-bold text-[#E5E7EB] mb-4">
|
||||
@@ -22,23 +22,21 @@
|
||||
<p class="text-xl text-[#94A3B8] mb-6">The page you're looking for doesn't exist or has been moved</p>
|
||||
<div class="flex flex-wrap justify-center gap-4">
|
||||
<a href="./"
|
||||
class="group flex items-center bg-[#38BDF8] hover:bg-[#38BDF8]/90 text-[#101827] font-semibold py-3 px-6 rounded-lg transition duration-300 ease-in-out transform hover:scale-105 hover:shadow-[0_0_20px_rgba(56,189,248,0.4)]">
|
||||
class="inline-flex items-center bg-[#38BDF8] hover:bg-[#38BDF8]/90 text-[#101827] font-semibold py-3 px-6 rounded-lg transition-colors">
|
||||
<i class="fas fa-home mr-2"></i>
|
||||
<span>Return Home</span>
|
||||
<i class="fas fa-arrow-right opacity-0 group-hover:opacity-100 group-hover:translate-x-2 ml-2 transition-all duration-300"></i>
|
||||
</a>
|
||||
<a href="browse"
|
||||
class="group flex items-center bg-[#8B5CF6] hover:bg-[#8B5CF6]/90 text-white font-semibold py-3 px-6 rounded-lg transition duration-300 ease-in-out transform hover:scale-105 hover:shadow-[0_0_20px_rgba(139,92,246,0.4)]">
|
||||
class="inline-flex items-center bg-[#8B5CF6] hover:bg-[#8B5CF6]/90 text-white font-semibold py-3 px-6 rounded-lg transition-colors">
|
||||
<i class="fas fa-images mr-2"></i>
|
||||
<span>Browse Gallery</span>
|
||||
<i class="fas fa-arrow-right opacity-0 group-hover:opacity-100 group-hover:translate-x-2 ml-2 transition-all duration-300"></i>
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Additional Information -->
|
||||
<div class="bg-[#1E293B]/80 border border-[#1E293B] rounded-xl p-8 shadow-lg backdrop-blur-sm">
|
||||
<div class="bg-[#1E293B] border border-[#1E293B] rounded-xl p-8">
|
||||
<div class="text-center max-w-3xl mx-auto">
|
||||
<div class="inline-flex items-center justify-center w-16 h-16 rounded-full bg-yellow-500/10 border border-yellow-500/20 mb-4">
|
||||
<i class="text-yellow-400 text-2xl fa-solid fa-triangle-exclamation"></i>
|
||||
|
||||
@@ -11,21 +11,21 @@
|
||||
<div class="fixed top-20 right-4 z-50 space-y-2" style="max-width: 400px;">
|
||||
<template x-for="notification in notifications" :key="notification.id">
|
||||
<div x-show="true"
|
||||
x-transition:enter="transform ease-out duration-300 transition"
|
||||
x-transition:enter-start="translate-x-full opacity-0"
|
||||
x-transition:enter-end="translate-x-0 opacity-100"
|
||||
x-transition:leave="transform ease-in duration-200 transition"
|
||||
x-transition:leave-start="translate-x-0 opacity-100"
|
||||
x-transition:leave-end="translate-x-full opacity-0"
|
||||
x-transition:enter="transition ease-out duration-200"
|
||||
x-transition:enter-start="opacity-0"
|
||||
x-transition:enter-end="opacity-100"
|
||||
x-transition:leave="transition ease-in duration-150"
|
||||
x-transition:leave-start="opacity-100"
|
||||
x-transition:leave-end="opacity-0"
|
||||
:class="notification.type === 'error' ? 'bg-red-500' : 'bg-green-500'"
|
||||
class="rounded-lg shadow-xl p-4 text-white flex items-start space-x-3">
|
||||
class="rounded-lg p-4 text-white flex items-start space-x-3">
|
||||
<div class="flex-shrink-0">
|
||||
<i :class="notification.type === 'error' ? 'fas fa-exclamation-circle' : 'fas fa-check-circle'" class="text-xl"></i>
|
||||
</div>
|
||||
<div class="flex-1 min-w-0">
|
||||
<p class="text-sm font-medium break-words" x-text="notification.message"></p>
|
||||
</div>
|
||||
<button @click="dismissNotification(notification.id)" class="flex-shrink-0 text-white hover:text-gray-200">
|
||||
<button @click="dismissNotification(notification.id)" class="flex-shrink-0 text-white hover:opacity-80 transition-opacity">
|
||||
<i class="fas fa-times"></i>
|
||||
</button>
|
||||
</div>
|
||||
@@ -35,14 +35,8 @@
|
||||
<div class="container mx-auto px-4 py-8 flex-grow">
|
||||
|
||||
<!-- Hero Header -->
|
||||
<div class="relative bg-[#1E293B] border border-[#8B5CF6]/20 rounded-3xl shadow-2xl shadow-[#8B5CF6]/10 p-8 mb-12 overflow-hidden">
|
||||
<!-- Background Pattern -->
|
||||
<div class="absolute inset-0 opacity-10">
|
||||
<div class="absolute inset-0 bg-gradient-to-r from-[#8B5CF6]/20 to-[#38BDF8]/20"></div>
|
||||
<div class="absolute top-0 left-0 w-full h-full" style="background-image: radial-gradient(circle at 1px 1px, rgba(139,92,246,0.15) 1px, transparent 0); background-size: 20px 20px;"></div>
|
||||
</div>
|
||||
|
||||
<div class="relative max-w-5xl mx-auto text-center">
|
||||
<div class="bg-[#1E293B] border border-[#8B5CF6]/20 rounded-xl p-8 mb-12">
|
||||
<div class="max-w-5xl mx-auto text-center">
|
||||
<h1 class="text-4xl md:text-5xl font-bold text-[#E5E7EB] mb-4">
|
||||
<span class="bg-clip-text text-transparent bg-gradient-to-r from-[#8B5CF6] via-[#38BDF8] to-[#8B5CF6]">
|
||||
Backend Management
|
||||
@@ -52,13 +46,18 @@
|
||||
Discover and install AI backends to power your models
|
||||
</p>
|
||||
<div class="flex flex-wrap justify-center items-center gap-6 text-sm md:text-base">
|
||||
<div class="flex items-center bg-white/10 rounded-full px-4 py-2">
|
||||
<div class="w-2 h-2 bg-emerald-400 rounded-full mr-2 animate-pulse"></div>
|
||||
<div class="flex items-center bg-[#101827] rounded-lg px-4 py-2">
|
||||
<div class="w-2 h-2 bg-emerald-400 rounded-full mr-2"></div>
|
||||
<span class="font-semibold text-emerald-300" x-text="availableBackends"></span>
|
||||
<span class="text-gray-300 ml-1">backends available</span>
|
||||
<span class="text-[#94A3B8] ml-1">backends available</span>
|
||||
</div>
|
||||
<a href="/manage" class="flex items-center bg-[#101827] hover:bg-[#1E293B] rounded-lg px-4 py-2 transition-colors border border-[#8B5CF6]/30 hover:border-[#8B5CF6]/50">
|
||||
<div class="w-2 h-2 bg-cyan-400 rounded-full mr-2"></div>
|
||||
<span class="font-semibold text-cyan-300" x-text="installedBackends"></span>
|
||||
<span class="text-[#94A3B8] ml-1">installed</span>
|
||||
</a>
|
||||
<a href="https://localai.io/backends/" target="_blank"
|
||||
class="flex items-center bg-cyan-600/80 hover:bg-cyan-600 text-white px-4 py-2 rounded-full transition-all duration-300 hover:scale-105">
|
||||
class="inline-flex items-center bg-cyan-600 hover:bg-cyan-700 text-white px-4 py-2 rounded-lg transition-colors">
|
||||
<i class="fas fa-info-circle mr-2"></i>
|
||||
<span>Documentation</span>
|
||||
<i class="fas fa-external-link-alt ml-2 text-xs"></i>
|
||||
@@ -70,28 +69,26 @@
|
||||
{{template "views/partials/inprogress" .}}
|
||||
|
||||
<!-- Search and Filter Section -->
|
||||
<div class="relative bg-gradient-to-br from-gray-800/80 to-gray-900/80 rounded-2xl p-8 mb-8 shadow-xl border border-gray-700/50 backdrop-blur-sm">
|
||||
<div class="absolute inset-0 rounded-2xl bg-gradient-to-br from-emerald-500/5 to-cyan-500/5"></div>
|
||||
|
||||
<div class="relative">
|
||||
<div class="bg-[#1E293B] border border-[#8B5CF6]/20 rounded-xl p-8 mb-8">
|
||||
<div>
|
||||
<!-- Search Input -->
|
||||
<div class="mb-8">
|
||||
<h3 class="text-xl font-semibold text-white mb-4 flex items-center">
|
||||
<i class="fas fa-search mr-3 text-emerald-400"></i>
|
||||
<h3 class="text-xl font-semibold text-[#E5E7EB] mb-4 flex items-center">
|
||||
<i class="fas fa-search mr-3 text-[#8B5CF6]"></i>
|
||||
Find Backend Components
|
||||
</h3>
|
||||
<div class="relative">
|
||||
<div class="absolute inset-y-0 start-0 flex items-center ps-4 pointer-events-none">
|
||||
<i class="fas fa-search text-gray-400"></i>
|
||||
<i class="fas fa-search text-[#94A3B8]"></i>
|
||||
</div>
|
||||
<input
|
||||
x-model="searchTerm"
|
||||
@input.debounce.500ms="fetchBackends()"
|
||||
class="w-full pl-12 pr-16 py-4 text-base font-normal text-gray-300 bg-gray-900/90 border border-gray-700/70 rounded-xl transition-all duration-300 focus:text-gray-200 focus:bg-gray-900 focus:border-emerald-500 focus:ring-2 focus:ring-emerald-500/50 focus:outline-none"
|
||||
class="w-full pl-12 pr-16 py-4 text-base font-normal text-[#E5E7EB] bg-[#101827] border border-[#1E293B] rounded-lg transition-colors focus:text-[#E5E7EB] focus:bg-[#101827] focus:border-[#8B5CF6] focus:ring-2 focus:ring-[#8B5CF6]/50 focus:outline-none"
|
||||
type="search"
|
||||
placeholder="Search backends by name, description or type...">
|
||||
<span class="absolute right-4 top-4" x-show="loading">
|
||||
<svg class="animate-spin h-6 w-6 text-emerald-500" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24">
|
||||
<svg class="animate-spin h-6 w-6 text-[#8B5CF6]" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24">
|
||||
<circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle>
|
||||
<path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path>
|
||||
</svg>
|
||||
@@ -107,28 +104,28 @@
|
||||
</h3>
|
||||
<div class="grid grid-cols-2 sm:grid-cols-3 lg:grid-cols-5 gap-3">
|
||||
<button @click="filterByTerm('llm')"
|
||||
class="group flex items-center justify-center rounded-xl px-4 py-3 text-sm font-semibold bg-gradient-to-r from-indigo-600/80 to-indigo-700/80 hover:from-indigo-600 hover:to-indigo-700 text-indigo-100 border border-indigo-500/30 hover:border-indigo-400/50 transition-all duration-300 transform hover:scale-105 hover:shadow-lg hover:shadow-indigo-500/25">
|
||||
<i class="fas fa-brain mr-2 group-hover:animate-pulse"></i>
|
||||
class="flex items-center justify-center rounded-lg px-4 py-3 text-sm font-semibold bg-indigo-600/20 hover:bg-indigo-600/30 text-indigo-300 border border-indigo-500/30 transition-colors">
|
||||
<i class="fas fa-brain mr-2"></i>
|
||||
<span>LLM</span>
|
||||
</button>
|
||||
<button @click="filterByTerm('diffusion')"
|
||||
class="group flex items-center justify-center rounded-xl px-4 py-3 text-sm font-semibold bg-gradient-to-r from-purple-600/80 to-purple-700/80 hover:from-purple-600 hover:to-purple-700 text-purple-100 border border-purple-500/30 hover:border-purple-400/50 transition-all duration-300 transform hover:scale-105 hover:shadow-lg hover:shadow-purple-500/25">
|
||||
<i class="fas fa-image mr-2 group-hover:animate-pulse"></i>
|
||||
class="flex items-center justify-center rounded-lg px-4 py-3 text-sm font-semibold bg-purple-600/20 hover:bg-purple-600/30 text-purple-300 border border-purple-500/30 transition-colors">
|
||||
<i class="fas fa-image mr-2"></i>
|
||||
<span>Diffusion</span>
|
||||
</button>
|
||||
<button @click="filterByTerm('tts')"
|
||||
class="group flex items-center justify-center rounded-xl px-4 py-3 text-sm font-semibold bg-gradient-to-r from-blue-600/80 to-blue-700/80 hover:from-blue-600 hover:to-blue-700 text-blue-100 border border-blue-500/30 hover:border-blue-400/50 transition-all duration-300 transform hover:scale-105 hover:shadow-lg hover:shadow-blue-500/25">
|
||||
<i class="fas fa-microphone mr-2 group-hover:animate-pulse"></i>
|
||||
class="flex items-center justify-center rounded-lg px-4 py-3 text-sm font-semibold bg-blue-600/20 hover:bg-blue-600/30 text-blue-300 border border-blue-500/30 transition-colors">
|
||||
<i class="fas fa-microphone mr-2"></i>
|
||||
<span>TTS</span>
|
||||
</button>
|
||||
<button @click="filterByTerm('whisper')"
|
||||
class="group flex items-center justify-center rounded-xl px-4 py-3 text-sm font-semibold bg-gradient-to-r from-green-600/80 to-green-700/80 hover:from-green-600 hover:to-green-700 text-green-100 border border-green-500/30 hover:border-green-400/50 transition-all duration-300 transform hover:scale-105 hover:shadow-lg hover:shadow-green-500/25">
|
||||
<i class="fas fa-headphones mr-2 group-hover:animate-pulse"></i>
|
||||
class="flex items-center justify-center rounded-lg px-4 py-3 text-sm font-semibold bg-green-600/20 hover:bg-green-600/30 text-green-300 border border-green-500/30 transition-colors">
|
||||
<i class="fas fa-headphones mr-2"></i>
|
||||
<span>Whisper</span>
|
||||
</button>
|
||||
<button @click="filterByTerm('object-detection')"
|
||||
class="group flex items-center justify-center rounded-xl px-4 py-3 text-sm font-semibold bg-gradient-to-r from-red-600/80 to-red-700/80 hover:from-red-600 hover:to-red-700 text-red-100 border border-red-500/30 hover:border-red-400/50 transition-all duration-300 transform hover:scale-105 hover:shadow-lg hover:shadow-red-500/25">
|
||||
<i class="fas fa-eye mr-2 group-hover:animate-pulse"></i>
|
||||
class="flex items-center justify-center rounded-lg px-4 py-3 text-sm font-semibold bg-red-600/20 hover:bg-red-600/30 text-red-300 border border-red-500/30 transition-colors">
|
||||
<i class="fas fa-eye mr-2"></i>
|
||||
<span>Vision</span>
|
||||
</button>
|
||||
</div>
|
||||
@@ -171,10 +168,14 @@
|
||||
<tr class="hover:bg-[#38BDF8]/10 transition-colors duration-200">
|
||||
<!-- Icon -->
|
||||
<td class="px-6 py-4">
|
||||
<img :src="backend.icon || 'https://upload.wikimedia.org/wikipedia/commons/6/65/No-Image-Placeholder.svg'"
|
||||
class="w-12 h-12 object-cover rounded-lg border border-[#38BDF8]/30"
|
||||
loading="lazy"
|
||||
:alt="backend.name">
|
||||
<div class="w-12 h-12 rounded-lg border border-[#38BDF8]/30 flex items-center justify-center bg-[#101827]">
|
||||
<img x-show="backend.icon"
|
||||
:src="backend.icon"
|
||||
class="w-full h-full object-cover rounded-lg"
|
||||
loading="lazy"
|
||||
:alt="backend.name">
|
||||
<i x-show="!backend.icon" class="fas fa-cog text-xl text-[#8B5CF6]"></i>
|
||||
</div>
|
||||
</td>
|
||||
|
||||
<!-- Backend Name -->
|
||||
@@ -301,9 +302,13 @@
|
||||
<!-- Modal Body -->
|
||||
<div class="p-4 md:p-5 space-y-4 overflow-y-auto flex-1 min-h-0">
|
||||
<div class="flex justify-center items-center">
|
||||
<img :src="selectedBackend?.icon || 'https://upload.wikimedia.org/wikipedia/commons/6/65/No-Image-Placeholder.svg'"
|
||||
class="rounded-t-lg max-h-48 max-w-96 object-cover mt-3"
|
||||
loading="lazy">
|
||||
<div class="w-48 h-48 rounded-lg border border-gray-300 dark:border-gray-600 flex items-center justify-center bg-gray-100 dark:bg-gray-800 mt-3">
|
||||
<img x-show="selectedBackend?.icon"
|
||||
:src="selectedBackend?.icon"
|
||||
class="rounded-lg max-h-48 max-w-96 object-cover"
|
||||
loading="lazy">
|
||||
<i x-show="!selectedBackend?.icon" class="fas fa-cog text-6xl text-gray-400 dark:text-gray-500"></i>
|
||||
</div>
|
||||
</div>
|
||||
<div class="text-base leading-relaxed text-gray-500 dark:text-gray-400 break-words max-w-full markdown-content" x-html="renderMarkdown(selectedBackend?.description)"></div>
|
||||
<template x-if="selectedBackend?.tags && selectedBackend.tags.length > 0">
|
||||
@@ -353,8 +358,8 @@
|
||||
<button @click="goToPage(currentPage - 1)"
|
||||
:disabled="currentPage <= 1"
|
||||
:class="currentPage <= 1 ? 'opacity-50 cursor-not-allowed' : ''"
|
||||
class="group flex items-center justify-center h-12 w-12 bg-gray-700/80 hover:bg-emerald-600 text-gray-300 hover:text-white rounded-xl shadow-lg transition-all duration-300 ease-in-out transform hover:scale-110">
|
||||
<i class="fas fa-chevron-left group-hover:animate-pulse"></i>
|
||||
class="flex items-center justify-center h-12 w-12 bg-[#1E293B] hover:bg-emerald-600 text-[#94A3B8] hover:text-white rounded-lg transition-colors">
|
||||
<i class="fas fa-chevron-left"></i>
|
||||
</button>
|
||||
<div class="text-gray-300 text-sm font-medium px-4">
|
||||
<span class="text-gray-400">Page</span>
|
||||
@@ -488,6 +493,7 @@ function backendsGallery() {
|
||||
currentPage: 1,
|
||||
totalPages: 1,
|
||||
availableBackends: 0,
|
||||
installedBackends: 0,
|
||||
selectedBackend: null,
|
||||
jobProgress: {},
|
||||
notifications: [],
|
||||
@@ -526,6 +532,7 @@ function backendsGallery() {
|
||||
this.currentPage = data.currentPage || 1;
|
||||
this.totalPages = data.totalPages || 1;
|
||||
this.availableBackends = data.availableBackends || 0;
|
||||
this.installedBackends = data.installedBackends || 0;
|
||||
} catch (error) {
|
||||
console.error('Error fetching backends:', error);
|
||||
} finally {
|
||||
|
||||
@@ -28,10 +28,10 @@ SOFTWARE.
|
||||
<!doctype html>
|
||||
<html lang="en">
|
||||
{{template "views/partials/head" .}}
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/pdf.js/3.11.174/pdf.min.js"></script>
|
||||
<script src="static/assets/pdf.min.js"></script>
|
||||
<script>
|
||||
// Initialize PDF.js worker
|
||||
pdfjsLib.GlobalWorkerOptions.workerSrc = 'https://cdnjs.cloudflare.com/ajax/libs/pdf.js/3.11.174/pdf.worker.min.js';
|
||||
pdfjsLib.GlobalWorkerOptions.workerSrc = 'static/assets/pdf.worker.min.js';
|
||||
</script>
|
||||
<script>
|
||||
// Initialize Alpine store - must run before Alpine processes DOM
|
||||
@@ -111,14 +111,36 @@ SOFTWARE.
|
||||
},
|
||||
add(role, content, image, audio) {
|
||||
const N = this.history.length - 1;
|
||||
// For thinking messages, always create a new message
|
||||
if (role === "thinking") {
|
||||
// For thinking, reasoning, tool_call, and tool_result messages, always create a new message
|
||||
if (role === "thinking" || role === "reasoning" || role === "tool_call" || role === "tool_result") {
|
||||
let c = "";
|
||||
const lines = content.split("\n");
|
||||
lines.forEach((line) => {
|
||||
c += DOMPurify.sanitize(marked.parse(line));
|
||||
});
|
||||
this.history.push({ role, content, html: c, image, audio });
|
||||
if (role === "tool_call" || role === "tool_result") {
|
||||
// For tool calls and results, try to parse as JSON and format nicely
|
||||
try {
|
||||
const parsed = typeof content === 'string' ? JSON.parse(content) : content;
|
||||
// Format JSON with proper indentation
|
||||
const formatted = JSON.stringify(parsed, null, 2);
|
||||
c = DOMPurify.sanitize('<pre><code class="language-json">' + formatted + '</code></pre>');
|
||||
} catch (e) {
|
||||
// If not JSON, treat as markdown
|
||||
const lines = content.split("\n");
|
||||
lines.forEach((line) => {
|
||||
c += DOMPurify.sanitize(marked.parse(line));
|
||||
});
|
||||
}
|
||||
} else {
|
||||
// For thinking and reasoning, format as markdown
|
||||
const lines = content.split("\n");
|
||||
lines.forEach((line) => {
|
||||
c += DOMPurify.sanitize(marked.parse(line));
|
||||
});
|
||||
}
|
||||
// Set expanded state: thinking is expanded by default in non-MCP mode, collapsed in MCP mode
|
||||
// Reasoning, tool_call, and tool_result are always collapsed by default
|
||||
const isMCPMode = this.mcpMode || false;
|
||||
const shouldExpand = (role === "thinking" && !isMCPMode) || false;
|
||||
this.history.push({ role, content, html: c, image, audio, expanded: shouldExpand });
|
||||
|
||||
}
|
||||
// For other messages, merge if same role
|
||||
else if (this.history.length && this.history[N].role === role) {
|
||||
@@ -147,7 +169,22 @@ SOFTWARE.
|
||||
audio: audio || []
|
||||
});
|
||||
}
|
||||
document.getElementById('messages').scrollIntoView(false);
|
||||
// Scroll to bottom consistently for all messages (use #chat as it's the scrollable container)
|
||||
setTimeout(() => {
|
||||
const chatContainer = document.getElementById('chat');
|
||||
if (chatContainer) {
|
||||
chatContainer.scrollTo({
|
||||
top: chatContainer.scrollHeight,
|
||||
behavior: 'smooth'
|
||||
});
|
||||
}
|
||||
// Also scroll thinking box if it's a thinking/reasoning message
|
||||
if (role === "thinking" || role === "reasoning") {
|
||||
if (typeof window.scrollThinkingBoxToBottom === 'function') {
|
||||
window.scrollThinkingBoxToBottom();
|
||||
}
|
||||
}
|
||||
}, 100);
|
||||
const parser = new DOMParser();
|
||||
const html = parser.parseFromString(
|
||||
this.history[this.history.length - 1].html,
|
||||
@@ -160,9 +197,33 @@ SOFTWARE.
|
||||
if (this.languages.includes(language)) return;
|
||||
const script = document.createElement("script");
|
||||
script.src = `https://cdn.jsdelivr.net/gh/highlightjs/cdn-release@11.8.0/build/languages/${language}.min.js`;
|
||||
script.onload = () => {
|
||||
// Re-highlight after language script loads
|
||||
if (window.hljs) {
|
||||
const container = document.getElementById('messages');
|
||||
if (container) {
|
||||
container.querySelectorAll('pre code.language-json').forEach(block => {
|
||||
window.hljs.highlightElement(block);
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
document.head.appendChild(script);
|
||||
this.languages.push(language);
|
||||
});
|
||||
// Highlight code blocks immediately if hljs is available
|
||||
if (window.hljs) {
|
||||
setTimeout(() => {
|
||||
const container = document.getElementById('messages');
|
||||
if (container) {
|
||||
container.querySelectorAll('pre code.language-json').forEach(block => {
|
||||
if (!block.classList.contains('hljs')) {
|
||||
window.hljs.highlightElement(block);
|
||||
}
|
||||
});
|
||||
}
|
||||
}, 100);
|
||||
}
|
||||
},
|
||||
messages() {
|
||||
return this.history.map((message) => ({
|
||||
@@ -484,9 +545,113 @@ SOFTWARE.
|
||||
<li>To send a text, markdown or PDF file, click the <i class="fa-solid fa-file text-[#38BDF8]"></i> icon.</li>
|
||||
</ul>
|
||||
</p>
|
||||
<div id="messages" class="max-w-3xl mx-auto">
|
||||
<template x-for="message in history">
|
||||
<div :class="message.role === 'user' ? 'flex items-start space-x-2 my-2 justify-end' : 'flex items-start space-x-2 my-2'">
|
||||
<div id="messages" class="max-w-3xl mx-auto space-y-2">
|
||||
<template x-for="(message, index) in history" :key="index">
|
||||
<div>
|
||||
<!-- Reasoning/Thinking messages appear first (before assistant) - collapsible in MCP mode -->
|
||||
<template x-if="message.role === 'reasoning' || message.role === 'thinking'">
|
||||
<div class="flex items-start space-x-2 mb-1">
|
||||
<div class="flex flex-col flex-1">
|
||||
<div class="p-2 flex-1 rounded-lg bg-[#38BDF8]/10 text-[#94A3B8] border border-[#38BDF8]/30">
|
||||
<button
|
||||
@click="message.expanded = !message.expanded"
|
||||
class="w-full flex items-center justify-between text-left hover:bg-[#38BDF8]/20 rounded p-2 transition-colors"
|
||||
>
|
||||
<div class="flex items-center space-x-2">
|
||||
<i :class="message.role === 'thinking' ? 'fa-solid fa-brain' : 'fa-solid fa-lightbulb'" class="text-[#38BDF8]"></i>
|
||||
<span class="text-xs font-semibold text-[#38BDF8]" x-text="message.role === 'thinking' ? 'Thinking' : 'Reasoning'"></span>
|
||||
<span class="text-xs text-[#94A3B8]" x-show="message.content && message.content.length > 0" x-text="'(' + Math.ceil(message.content.length / 100) + ' lines)'"></span>
|
||||
</div>
|
||||
<i
|
||||
class="fa-solid text-[#38BDF8] transition-transform text-xs"
|
||||
:class="message.expanded ? 'fa-chevron-up' : 'fa-chevron-down'"
|
||||
></i>
|
||||
</button>
|
||||
<div
|
||||
x-show="message.expanded"
|
||||
x-transition
|
||||
class="mt-2 pt-2 border-t border-[#38BDF8]/20"
|
||||
>
|
||||
<div
|
||||
class="text-[#E5E7EB] text-sm max-h-96 overflow-auto"
|
||||
x-html="message.html"
|
||||
data-thinking-box
|
||||
x-effect="if (message.expanded && message.html) { setTimeout(() => { if ($el.scrollHeight > $el.clientHeight) { $el.scrollTo({ top: $el.scrollHeight, behavior: 'smooth' }); } }, 50); }"
|
||||
></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<!-- Tool calls (collapsible) -->
|
||||
<template x-if="message.role === 'tool_call'">
|
||||
<div class="flex items-start space-x-2 mb-1">
|
||||
<div class="flex flex-col flex-1">
|
||||
<div class="p-2 flex-1 rounded-lg bg-[#8B5CF6]/10 text-[#94A3B8] border border-[#8B5CF6]/30">
|
||||
<button
|
||||
@click="message.expanded = !message.expanded"
|
||||
class="w-full flex items-center justify-between text-left hover:bg-[#8B5CF6]/20 rounded p-2 transition-colors"
|
||||
>
|
||||
<div class="flex items-center space-x-2">
|
||||
<i class="fa-solid fa-wrench text-[#8B5CF6]"></i>
|
||||
<span class="text-xs font-semibold text-[#8B5CF6]">Tool Call</span>
|
||||
<span class="text-xs text-[#94A3B8]" x-text="getToolName(message.content)"></span>
|
||||
</div>
|
||||
<i
|
||||
class="fa-solid text-[#8B5CF6] transition-transform text-xs"
|
||||
:class="message.expanded ? 'fa-chevron-up' : 'fa-chevron-down'"
|
||||
></i>
|
||||
</button>
|
||||
<div
|
||||
x-show="message.expanded"
|
||||
x-transition
|
||||
class="mt-2 pt-2 border-t border-[#8B5CF6]/20"
|
||||
>
|
||||
<div class="text-[#E5E7EB] text-xs max-h-96 overflow-auto overflow-x-auto tool-call-content"
|
||||
x-html="message.html"
|
||||
x-effect="if (message.expanded && window.hljs) { setTimeout(() => { $el.querySelectorAll('pre code.language-json').forEach(block => { if (!block.classList.contains('hljs')) window.hljs.highlightElement(block); }); }, 50); }"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<!-- Tool results (collapsible) -->
|
||||
<template x-if="message.role === 'tool_result'">
|
||||
<div class="flex items-start space-x-2 mb-1">
|
||||
<div class="flex flex-col flex-1">
|
||||
<div class="p-2 flex-1 rounded-lg bg-[#10B981]/10 text-[#94A3B8] border border-[#10B981]/30">
|
||||
<button
|
||||
@click="message.expanded = !message.expanded"
|
||||
class="w-full flex items-center justify-between text-left hover:bg-[#10B981]/20 rounded p-2 transition-colors"
|
||||
>
|
||||
<div class="flex items-center space-x-2">
|
||||
<i class="fa-solid fa-check-circle text-[#10B981]"></i>
|
||||
<span class="text-xs font-semibold text-[#10B981]">Tool Result</span>
|
||||
<span class="text-xs text-[#94A3B8]" x-text="getToolName(message.content) || 'Success'"></span>
|
||||
</div>
|
||||
<i
|
||||
class="fa-solid text-[#10B981] transition-transform text-xs"
|
||||
:class="message.expanded ? 'fa-chevron-up' : 'fa-chevron-down'"
|
||||
></i>
|
||||
</button>
|
||||
<div
|
||||
x-show="message.expanded"
|
||||
x-transition
|
||||
class="mt-2 pt-2 border-t border-[#10B981]/20"
|
||||
>
|
||||
<div class="text-[#E5E7EB] text-xs max-h-96 overflow-auto overflow-x-auto tool-result-content"
|
||||
x-html="formatToolResult(message.content)"
|
||||
x-effect="if (message.expanded && window.hljs) { setTimeout(() => { $el.querySelectorAll('pre code.language-json').forEach(block => { if (!block.classList.contains('hljs')) window.hljs.highlightElement(block); }); }, 50); }"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<!-- User and Assistant messages -->
|
||||
<div :class="message.role === 'user' ? 'flex items-start space-x-2 justify-end' : 'flex items-start space-x-2'">
|
||||
{{ if .Model }}
|
||||
{{ $galleryConfig:= index $allGalleryConfigs .Model}}
|
||||
<template x-if="message.role === 'user'">
|
||||
@@ -514,20 +679,7 @@ SOFTWARE.
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
<template x-if="message.role === 'thinking'">
|
||||
<div class="flex items-center space-x-2 w-full">
|
||||
<div class="flex flex-col flex-1">
|
||||
<div class="p-3 flex-1 rounded-lg bg-[#38BDF8]/10 text-[#94A3B8] border border-[#38BDF8]/30">
|
||||
<div class="flex items-center space-x-2 mb-2">
|
||||
<i class="fa-solid fa-brain text-[#38BDF8]"></i>
|
||||
<span class="text-xs font-semibold text-[#38BDF8]">Thinking</span>
|
||||
</div>
|
||||
<div class="mt-1 text-[#E5E7EB]" x-html="message.html"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
<template x-if="message.role != 'user' && message.role != 'thinking'">
|
||||
<template x-if="message.role != 'user' && message.role != 'thinking' && message.role != 'reasoning' && message.role != 'tool_call' && message.role != 'tool_result'">
|
||||
<div class="flex items-center space-x-2">
|
||||
{{ if $galleryConfig }}
|
||||
{{ if $galleryConfig.Icon }}<img src="{{$galleryConfig.Icon}}" class="rounded-lg mt-2 max-w-8 max-h-8 border border-[#38BDF8]/20">{{end}}
|
||||
@@ -566,6 +718,7 @@ SOFTWARE.
|
||||
:class="message.role === 'user' ? 'fa-user text-[#38BDF8]' : 'fa-robot text-[#8B5CF6]'"
|
||||
></i>
|
||||
{{ end }}
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
</div>
|
||||
@@ -573,8 +726,26 @@ SOFTWARE.
|
||||
|
||||
|
||||
<!-- Chat Input -->
|
||||
<div class="p-4 border-t border-[#1E293B]" x-data="{ inputValue: '', shiftPressed: false, fileName: '' }">
|
||||
<div class="p-4 border-t border-[#1E293B]" x-data="{ inputValue: '', shiftPressed: false, attachedFiles: [] }">
|
||||
<form id="prompt" action="chat/{{.Model}}" method="get" @submit.prevent="submitPrompt" class="max-w-3xl mx-auto">
|
||||
<!-- Attachment Tags - Show above input when files are attached -->
|
||||
<div x-show="attachedFiles.length > 0" class="mb-3 flex flex-wrap gap-2 items-center">
|
||||
<template x-for="(file, index) in attachedFiles" :key="index">
|
||||
<div class="inline-flex items-center gap-2 px-3 py-1.5 rounded-lg text-sm bg-[#38BDF8]/20 border border-[#38BDF8]/40 text-[#E5E7EB]">
|
||||
<i :class="file.type === 'image' ? 'fa-solid fa-image' : file.type === 'audio' ? 'fa-solid fa-microphone' : 'fa-solid fa-file'" class="text-[#38BDF8]"></i>
|
||||
<span x-text="file.name" class="max-w-[200px] truncate"></span>
|
||||
<button
|
||||
type="button"
|
||||
@click="attachedFiles.splice(index, 1); removeFileFromInput(file.type, file.name)"
|
||||
class="ml-1 text-[#94A3B8] hover:text-[#E5E7EB] transition-colors"
|
||||
title="Remove attachment"
|
||||
>
|
||||
<i class="fa-solid fa-times text-xs"></i>
|
||||
</button>
|
||||
</div>
|
||||
</template>
|
||||
</div>
|
||||
|
||||
<!-- Token Usage and Context Window - Compact above input -->
|
||||
<div class="mb-3 flex items-center justify-between gap-4 text-xs">
|
||||
<!-- Token Usage -->
|
||||
@@ -626,20 +797,19 @@ SOFTWARE.
|
||||
</template>
|
||||
</div>
|
||||
|
||||
<div class="relative w-full bg-[#1E293B] border border-[#38BDF8]/20 rounded-xl shadow-lg">
|
||||
<div class="relative w-full bg-[#1E293B] border border-[#38BDF8]/20 rounded-xl shadow-lg focus-within:ring-2 focus-within:ring-[#38BDF8]/50 focus-within:border-[#38BDF8] transition-all duration-200">
|
||||
<textarea
|
||||
id="input"
|
||||
name="input"
|
||||
x-model="inputValue"
|
||||
placeholder="Send a message..."
|
||||
class="p-3 pr-16 w-full bg-[#1E293B] text-[#E5E7EB] placeholder-[#94A3B8] focus:outline-none resize-none border-0 rounded-xl transition-colors duration-200 focus:ring-2 focus:ring-[#38BDF8]/50"
|
||||
class="p-3 pr-16 w-full bg-[#1E293B] text-[#E5E7EB] placeholder-[#94A3B8] focus:outline-none resize-none border-0 rounded-xl transition-colors duration-200"
|
||||
required
|
||||
@keydown.shift="shiftPressed = true"
|
||||
@keyup.shift="shiftPressed = false"
|
||||
@keydown.enter.prevent="if (!shiftPressed) { submitPrompt($event); }"
|
||||
rows="2"
|
||||
></textarea>
|
||||
<span x-text="fileName" id="fileName" class="absolute right-16 top-3 text-[#94A3B8] text-xs mr-2"></span>
|
||||
<button
|
||||
type="button"
|
||||
onclick="document.getElementById('input_image').click()"
|
||||
@@ -692,7 +862,7 @@ SOFTWARE.
|
||||
multiple
|
||||
accept="image/*"
|
||||
style="display: none;"
|
||||
@change="fileName = $event.target.files.length + ' image(s) selected'"
|
||||
@change="handleFileSelection($event, 'image')"
|
||||
/>
|
||||
<input
|
||||
id="input_audio"
|
||||
@@ -700,7 +870,7 @@ SOFTWARE.
|
||||
multiple
|
||||
accept="audio/*"
|
||||
style="display: none;"
|
||||
@change="fileName = $event.target.files.length + ' audio file(s) selected'"
|
||||
@change="handleFileSelection($event, 'audio')"
|
||||
/>
|
||||
<input
|
||||
id="input_file"
|
||||
@@ -708,7 +878,7 @@ SOFTWARE.
|
||||
multiple
|
||||
accept=".txt,.md,.pdf"
|
||||
style="display: none;"
|
||||
@change="fileName = $event.target.files.length + ' file(s) selected'"
|
||||
@change="handleFileSelection($event, 'file')"
|
||||
/>
|
||||
</div>
|
||||
</form>
|
||||
@@ -775,6 +945,83 @@ SOFTWARE.
|
||||
console.error('Failed to copy: ', err);
|
||||
});
|
||||
};
|
||||
|
||||
// Format tool result for better display
|
||||
window.formatToolResult = (content) => {
|
||||
if (!content) return '';
|
||||
try {
|
||||
// Try to parse as JSON
|
||||
const parsed = JSON.parse(content);
|
||||
|
||||
// If it has a 'result' field, try to parse that too
|
||||
if (parsed.result && typeof parsed.result === 'string') {
|
||||
try {
|
||||
const resultParsed = JSON.parse(parsed.result);
|
||||
parsed.result = resultParsed;
|
||||
} catch (e) {
|
||||
// Keep as string if not JSON
|
||||
}
|
||||
}
|
||||
|
||||
// Format the JSON nicely
|
||||
const formatted = JSON.stringify(parsed, null, 2);
|
||||
return DOMPurify.sanitize('<pre class="bg-[#101827] p-3 rounded border border-[#10B981]/20 overflow-x-auto"><code class="language-json">' + formatted + '</code></pre>');
|
||||
} catch (e) {
|
||||
// If not JSON, try to format as markdown or plain text
|
||||
try {
|
||||
// Check if it's a markdown code block
|
||||
if (content.includes('```')) {
|
||||
return DOMPurify.sanitize(marked.parse(content));
|
||||
}
|
||||
// Otherwise, try to parse as JSON one more time with error handling
|
||||
const lines = content.split('\n');
|
||||
let jsonStart = -1;
|
||||
let jsonEnd = -1;
|
||||
for (let i = 0; i < lines.length; i++) {
|
||||
if (lines[i].trim().startsWith('{') || lines[i].trim().startsWith('[')) {
|
||||
jsonStart = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (jsonStart >= 0) {
|
||||
for (let i = lines.length - 1; i >= jsonStart; i--) {
|
||||
if (lines[i].trim().endsWith('}') || lines[i].trim().endsWith(']')) {
|
||||
jsonEnd = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (jsonEnd >= jsonStart) {
|
||||
const jsonStr = lines.slice(jsonStart, jsonEnd + 1).join('\n');
|
||||
try {
|
||||
const parsed = JSON.parse(jsonStr);
|
||||
const formatted = JSON.stringify(parsed, null, 2);
|
||||
return DOMPurify.sanitize('<pre class="bg-[#101827] p-3 rounded border border-[#10B981]/20 overflow-x-auto"><code class="language-json">' + formatted + '</code></pre>');
|
||||
} catch (e2) {
|
||||
// Fall through to markdown
|
||||
}
|
||||
}
|
||||
}
|
||||
// Fall back to markdown
|
||||
return DOMPurify.sanitize(marked.parse(content));
|
||||
} catch (e2) {
|
||||
// Last resort: plain text
|
||||
return DOMPurify.sanitize('<pre class="bg-[#101827] p-3 rounded border border-[#10B981]/20 overflow-x-auto text-xs">' + content.replace(/</g, '<').replace(/>/g, '>') + '</pre>');
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Get tool name from content
|
||||
window.getToolName = (content) => {
|
||||
if (!content || typeof content !== 'string') return '';
|
||||
try {
|
||||
const parsed = JSON.parse(content);
|
||||
return parsed.name || '';
|
||||
} catch (e) {
|
||||
// Try to extract name from string
|
||||
const nameMatch = content.match(/"name"\s*:\s*"([^"]+)"/);
|
||||
return nameMatch ? nameMatch[1] : '';
|
||||
}
|
||||
};
|
||||
});
|
||||
|
||||
// Context size is now initialized in the Alpine store initialization above
|
||||
@@ -904,6 +1151,76 @@ SOFTWARE.
|
||||
max-width: 100%;
|
||||
height: auto;
|
||||
}
|
||||
|
||||
/* Prevent JSON overflow in tool calls and results */
|
||||
.tool-call-content pre,
|
||||
.tool-result-content pre {
|
||||
overflow-x: auto;
|
||||
overflow-y: auto;
|
||||
max-width: 100%;
|
||||
word-wrap: break-word;
|
||||
white-space: pre-wrap;
|
||||
background: #101827 !important;
|
||||
border: 1px solid #1E293B;
|
||||
border-radius: 6px;
|
||||
padding: 12px;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.tool-call-content code,
|
||||
.tool-result-content code {
|
||||
word-wrap: break-word;
|
||||
white-space: pre-wrap;
|
||||
overflow-wrap: break-word;
|
||||
background: transparent !important;
|
||||
color: #E5E7EB;
|
||||
font-family: 'ui-monospace', 'Monaco', 'Consolas', monospace;
|
||||
font-size: 0.875rem;
|
||||
line-height: 1.5;
|
||||
}
|
||||
|
||||
/* Dark theme syntax highlighting for JSON */
|
||||
.tool-call-content .hljs,
|
||||
.tool-result-content .hljs {
|
||||
background: #101827 !important;
|
||||
color: #E5E7EB !important;
|
||||
}
|
||||
|
||||
.tool-call-content .hljs-keyword,
|
||||
.tool-result-content .hljs-keyword {
|
||||
color: #8B5CF6 !important;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.tool-call-content .hljs-string,
|
||||
.tool-result-content .hljs-string {
|
||||
color: #10B981 !important;
|
||||
}
|
||||
|
||||
.tool-call-content .hljs-number,
|
||||
.tool-result-content .hljs-number {
|
||||
color: #38BDF8 !important;
|
||||
}
|
||||
|
||||
.tool-call-content .hljs-literal,
|
||||
.tool-result-content .hljs-literal {
|
||||
color: #F59E0B !important;
|
||||
}
|
||||
|
||||
.tool-call-content .hljs-punctuation,
|
||||
.tool-result-content .hljs-punctuation {
|
||||
color: #94A3B8 !important;
|
||||
}
|
||||
|
||||
.tool-call-content .hljs-property,
|
||||
.tool-result-content .hljs-property {
|
||||
color: #38BDF8 !important;
|
||||
}
|
||||
|
||||
.tool-call-content .hljs-attr,
|
||||
.tool-result-content .hljs-attr {
|
||||
color: #8B5CF6 !important;
|
||||
}
|
||||
</style>
|
||||
</body>
|
||||
</html>
|
||||
|
||||
@@ -9,9 +9,9 @@
|
||||
|
||||
<div class="container mx-auto px-4 py-8 flex-grow">
|
||||
<!-- Error Section -->
|
||||
<div class="bg-[#1E293B] border border-red-500/20 rounded-2xl shadow-2xl shadow-red-500/10 p-8 mb-10">
|
||||
<div class="bg-[#1E293B] border border-red-500/20 rounded-xl p-8 mb-10">
|
||||
<div class="max-w-4xl mx-auto text-center">
|
||||
<div class="mb-6 text-6xl text-red-400 animate-pulse">
|
||||
<div class="mb-6 text-6xl text-red-400">
|
||||
<i class="fas fa-exclamation-circle"></i>
|
||||
</div>
|
||||
<h1 class="text-4xl md:text-5xl font-bold text-[#E5E7EB] mb-4">
|
||||
@@ -22,23 +22,21 @@
|
||||
<p class="text-xl text-[#94A3B8] mb-6">{{if .ErrorMessage}}{{.ErrorMessage}}{{else}}An unexpected error occurred{{end}}</p>
|
||||
<div class="flex flex-wrap justify-center gap-4">
|
||||
<a href="./"
|
||||
class="group flex items-center bg-[#38BDF8] hover:bg-[#38BDF8]/90 text-[#101827] font-semibold py-3 px-6 rounded-lg transition duration-300 ease-in-out transform hover:scale-105 hover:shadow-[0_0_20px_rgba(56,189,248,0.4)]">
|
||||
class="inline-flex items-center bg-[#38BDF8] hover:bg-[#38BDF8]/90 text-[#101827] font-semibold py-3 px-6 rounded-lg transition-colors">
|
||||
<i class="fas fa-home mr-2"></i>
|
||||
<span>Return Home</span>
|
||||
<i class="fas fa-arrow-right opacity-0 group-hover:opacity-100 group-hover:translate-x-2 ml-2 transition-all duration-300"></i>
|
||||
</a>
|
||||
<a href="browse"
|
||||
class="group flex items-center bg-[#8B5CF6] hover:bg-[#8B5CF6]/90 text-white font-semibold py-3 px-6 rounded-lg transition duration-300 ease-in-out transform hover:scale-105 hover:shadow-[0_0_20px_rgba(139,92,246,0.4)]">
|
||||
class="inline-flex items-center bg-[#8B5CF6] hover:bg-[#8B5CF6]/90 text-white font-semibold py-3 px-6 rounded-lg transition-colors">
|
||||
<i class="fas fa-images mr-2"></i>
|
||||
<span>Browse Gallery</span>
|
||||
<i class="fas fa-arrow-right opacity-0 group-hover:opacity-100 group-hover:translate-x-2 ml-2 transition-all duration-300"></i>
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Additional Information -->
|
||||
<div class="bg-[#1E293B]/80 border border-[#1E293B] rounded-xl p-8 shadow-lg backdrop-blur-sm">
|
||||
<div class="bg-[#1E293B] border border-[#1E293B] rounded-xl p-8">
|
||||
<div class="text-center max-w-3xl mx-auto">
|
||||
<div class="inline-flex items-center justify-center w-16 h-16 rounded-full bg-yellow-500/10 border border-yellow-500/20 mb-4">
|
||||
<i class="text-yellow-400 text-2xl fa-solid fa-triangle-exclamation"></i>
|
||||
|
||||
@@ -23,12 +23,10 @@
|
||||
padding: 20px;
|
||||
border-radius: 8px;
|
||||
margin-bottom: 20px;
|
||||
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
|
||||
transition: transform 0.3s ease, box-shadow 0.3s ease;
|
||||
transition: background-color 0.2s ease;
|
||||
}
|
||||
.network-card:hover {
|
||||
transform: translateY(-5px);
|
||||
box-shadow: 0 6px 10px rgba(0, 0, 0, 0.15);
|
||||
background-color: #374151;
|
||||
}
|
||||
.network-title {
|
||||
font-size: 24px;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -9,7 +9,7 @@
|
||||
|
||||
<div class="container mx-auto px-4 py-8 flex-grow flex items-center justify-center">
|
||||
<!-- Auth Card -->
|
||||
<div class="max-w-md w-full bg-[#1E293B] border border-[#38BDF8]/20 rounded-xl overflow-hidden shadow-2xl shadow-[#38BDF8]/10">
|
||||
<div class="max-w-md w-full bg-[#1E293B] border border-[#38BDF8]/20 rounded-xl overflow-hidden">
|
||||
<div class="animation-container">
|
||||
<div class="text-overlay">
|
||||
<img src="static/logo.png" alt="LocalAI Logo" class="h-32 drop-shadow-[0_0_15px_rgba(56,189,248,0.3)]">
|
||||
@@ -47,11 +47,10 @@
|
||||
<div>
|
||||
<button
|
||||
type="submit"
|
||||
class="group w-full flex items-center justify-center bg-[#38BDF8] hover:bg-[#38BDF8]/90 text-[#101827] font-semibold py-3 px-6 rounded-lg transition duration-300 ease-in-out transform hover:scale-[1.02] hover:shadow-[0_0_20px_rgba(56,189,248,0.4)]"
|
||||
class="w-full flex items-center justify-center bg-[#38BDF8] hover:bg-[#38BDF8]/90 text-[#101827] font-semibold py-3 px-6 rounded-lg transition-colors"
|
||||
>
|
||||
<i class="fas fa-sign-in-alt mr-2"></i>
|
||||
<span>Login</span>
|
||||
<i class="fas fa-arrow-right opacity-0 group-hover:opacity-100 group-hover:translate-x-2 ml-2 transition-all duration-300"></i>
|
||||
</button>
|
||||
</div>
|
||||
</form>
|
||||
|
||||
569
core/http/views/manage.html
Normal file
569
core/http/views/manage.html
Normal file
@@ -0,0 +1,569 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
{{template "views/partials/head" .}}
|
||||
|
||||
<body class="bg-[#101827] text-[#E5E7EB]">
|
||||
<div class="flex flex-col min-h-screen" x-data="indexDashboard()">
|
||||
|
||||
{{template "views/partials/navbar" .}}
|
||||
|
||||
<!-- Notifications -->
|
||||
<div class="fixed top-20 right-4 z-50 space-y-2" style="max-width: 400px;">
|
||||
<template x-for="notification in notifications" :key="notification.id">
|
||||
<div x-show="true"
|
||||
x-transition:enter="transition ease-out duration-200"
|
||||
x-transition:enter-start="opacity-0"
|
||||
x-transition:enter-end="opacity-100"
|
||||
x-transition:leave="transition ease-in duration-150"
|
||||
x-transition:leave-start="opacity-100"
|
||||
x-transition:leave-end="opacity-0"
|
||||
:class="notification.type === 'error' ? 'bg-red-500' : 'bg-green-500'"
|
||||
class="rounded-lg p-4 text-white flex items-start space-x-3">
|
||||
<div class="flex-shrink-0">
|
||||
<i :class="notification.type === 'error' ? 'fas fa-exclamation-circle' : 'fas fa-check-circle'" class="text-xl"></i>
|
||||
</div>
|
||||
<div class="flex-1 min-w-0">
|
||||
<p class="text-sm font-medium break-words" x-text="notification.message"></p>
|
||||
</div>
|
||||
<button @click="dismissNotification(notification.id)" class="flex-shrink-0 text-white hover:opacity-80 transition-opacity">
|
||||
<i class="fas fa-times"></i>
|
||||
</button>
|
||||
</div>
|
||||
</template>
|
||||
</div>
|
||||
|
||||
<div class="container mx-auto px-4 py-6 flex-grow">
|
||||
<!-- Header -->
|
||||
<div class="mb-6">
|
||||
<h1 class="text-2xl font-semibold text-[#E5E7EB] mb-1">
|
||||
Model & Backend Management
|
||||
</h1>
|
||||
<p class="text-sm text-[#94A3B8]">Manage your installed models and backends</p>
|
||||
</div>
|
||||
|
||||
<!-- Quick Actions -->
|
||||
<div class="flex flex-wrap gap-2 mb-6">
|
||||
<a href="browse"
|
||||
class="inline-flex items-center bg-[#8B5CF6] hover:bg-[#8B5CF6]/90 text-white py-1.5 px-3 rounded text-xs font-medium transition-colors">
|
||||
<i class="fas fa-images mr-1.5 text-[10px]"></i>
|
||||
<span>Model Gallery</span>
|
||||
</a>
|
||||
|
||||
<a href="/import-model"
|
||||
class="inline-flex items-center bg-green-600 hover:bg-green-700 text-white py-1.5 px-3 rounded text-xs font-medium transition-colors">
|
||||
<i class="fas fa-plus mr-1.5 text-[10px]"></i>
|
||||
<span>Import Model</span>
|
||||
</a>
|
||||
|
||||
<button id="reload-models-btn"
|
||||
class="inline-flex items-center bg-orange-600 hover:bg-orange-700 text-white py-1.5 px-3 rounded text-xs font-medium transition-colors">
|
||||
<i class="fas fa-sync-alt mr-1.5 text-[10px]"></i>
|
||||
<span>Update Models</span>
|
||||
</button>
|
||||
|
||||
<a href="/browse/backends"
|
||||
class="inline-flex items-center bg-[#1E293B] hover:bg-[#1E293B]/80 border border-[#8B5CF6]/20 text-[#E5E7EB] py-1.5 px-3 rounded text-xs font-medium transition-colors">
|
||||
<i class="fas fa-cogs mr-1.5 text-[10px]"></i>
|
||||
<span>Backend Gallery</span>
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<!-- Models Section -->
|
||||
<div class="models mt-8">
|
||||
{{template "views/partials/inprogress" .}}
|
||||
|
||||
{{ if eq (len .ModelsConfig) 0 }}
|
||||
<!-- No Models State -->
|
||||
<div class="bg-[#1E293B] border border-[#38BDF8]/20 rounded-lg p-8">
|
||||
<div class="text-center max-w-4xl mx-auto">
|
||||
<div class="inline-flex items-center justify-center w-12 h-12 rounded-full bg-yellow-500/10 border border-yellow-500/20 mb-4">
|
||||
<i class="text-yellow-400 text-xl fas fa-robot"></i>
|
||||
</div>
|
||||
<h2 class="text-2xl font-bold text-[#E5E7EB] mb-2">No models installed yet</h2>
|
||||
<p class="text-sm text-[#94A3B8] mb-6">Get started by installing a model from the gallery or importing it</p>
|
||||
|
||||
<div class="flex flex-wrap justify-center gap-2 mb-6">
|
||||
<a href="browse" class="inline-flex items-center bg-[#38BDF8] hover:bg-[#38BDF8]/90 text-[#101827] py-1.5 px-3 rounded text-xs font-medium transition-colors">
|
||||
<i class="fas fa-images mr-1.5 text-[10px]"></i>
|
||||
Browse Model Gallery
|
||||
</a>
|
||||
<a href="/import-model" class="inline-flex items-center bg-green-600 hover:bg-green-700 text-white py-1.5 px-3 rounded text-xs font-medium transition-colors">
|
||||
<i class="fas fa-upload mr-1.5 text-[10px]"></i>
|
||||
Import Model
|
||||
</a>
|
||||
<a href="https://localai.io/basics/getting_started/" target="_blank" class="inline-flex items-center bg-[#1E293B] hover:bg-[#1E293B]/80 border border-[#38BDF8]/20 text-[#E5E7EB] py-1.5 px-3 rounded text-xs font-medium transition-colors">
|
||||
<i class="fas fa-book mr-1.5 text-[10px]"></i>
|
||||
Documentation
|
||||
</a>
|
||||
</div>
|
||||
|
||||
{{ if ne (len .Models) 0 }}
|
||||
<div class="mt-8 pt-6 border-t border-[#38BDF8]/20">
|
||||
<h3 class="text-lg font-semibold text-[#E5E7EB] mb-2 flex items-center">
|
||||
<i class="fas fa-file-alt mr-2 text-[#38BDF8] text-sm"></i>
|
||||
Detected Model Files
|
||||
</h3>
|
||||
<p class="text-xs text-[#94A3B8] mb-4">These models were found but don't have configuration files yet</p>
|
||||
<div class="flex flex-wrap gap-2 justify-center">
|
||||
{{ range .Models }}
|
||||
<div class="bg-[#101827] border border-[#38BDF8]/20 rounded px-2 py-1 flex items-center gap-2">
|
||||
<i class="fas fa-brain text-xs text-[#38BDF8]"></i>
|
||||
<span class="text-xs text-[#E5E7EB] font-medium">{{.}}</span>
|
||||
</div>
|
||||
{{end}}
|
||||
</div>
|
||||
</div>
|
||||
{{end}}
|
||||
</div>
|
||||
</div>
|
||||
{{ else }}
|
||||
<!-- Models Table -->
|
||||
{{ $modelsN := len .ModelsConfig}}
|
||||
{{ $modelsN = add $modelsN (len .Models)}}
|
||||
<div class="mb-6">
|
||||
<h2 class="text-2xl font-semibold text-[#E5E7EB] mb-1 flex items-center">
|
||||
<i class="fas fa-brain mr-2 text-[#38BDF8] text-sm"></i>
|
||||
Installed Models
|
||||
</h2>
|
||||
<p class="text-sm text-[#94A3B8] mb-4">
|
||||
<span class="text-[#38BDF8] font-medium">{{$modelsN}}</span> model{{if gt $modelsN 1}}s{{end}} ready to use
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div class="overflow-x-auto mb-8">
|
||||
<table class="w-full border-collapse">
|
||||
<thead>
|
||||
<tr class="border-b border-[#1E293B]">
|
||||
<th class="text-left p-2 text-xs font-semibold text-[#94A3B8]">Name</th>
|
||||
<th class="text-left p-2 text-xs font-semibold text-[#94A3B8]">Status</th>
|
||||
<th class="text-left p-2 text-xs font-semibold text-[#94A3B8]">Backend</th>
|
||||
<th class="text-left p-2 text-xs font-semibold text-[#94A3B8]">Use Cases</th>
|
||||
<th class="text-right p-2 text-xs font-semibold text-[#94A3B8]">Actions</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{{$galleryConfig:=.GalleryConfig}}
|
||||
{{ $loadedModels := .LoadedModels }}
|
||||
|
||||
{{ range .ModelsConfig }}
|
||||
{{ $backendCfg := . }}
|
||||
{{ $cfg:= index $galleryConfig .Name}}
|
||||
<tr class="hover:bg-[#1E293B]/50 border-b border-[#1E293B] transition-colors">
|
||||
<!-- Name Column -->
|
||||
<td class="p-2">
|
||||
<div class="flex items-center gap-2">
|
||||
<div class="relative flex-shrink-0">
|
||||
{{ if and $cfg $cfg.Icon }}
|
||||
<img src="{{$cfg.Icon}}" class="w-4 h-4 object-contain" alt="{{.Name}} icon">
|
||||
{{ else }}
|
||||
<i class="fas fa-brain text-xs text-[#38BDF8]"></i>
|
||||
{{ end }}
|
||||
{{ if index $loadedModels .Name }}
|
||||
<div class="absolute -top-0.5 -right-0.5 w-2 h-2 bg-green-500 rounded-full border border-[#1E293B]"></div>
|
||||
{{ end }}
|
||||
</div>
|
||||
<span class="text-xs text-[#E5E7EB] font-medium truncate">{{.Name}}</span>
|
||||
<a href="/models/edit/{{.Name}}"
|
||||
class="text-[#38BDF8]/60 hover:text-[#38BDF8] hover:bg-[#38BDF8]/10 rounded p-0.5 transition-colors ml-1 flex-shrink-0"
|
||||
title="Edit {{.Name}}">
|
||||
<i class="fas fa-edit text-[10px]"></i>
|
||||
</a>
|
||||
</div>
|
||||
</td>
|
||||
|
||||
<!-- Status Column -->
|
||||
<td class="p-2">
|
||||
<div class="flex flex-wrap gap-1">
|
||||
{{ if index $loadedModels .Name }}
|
||||
<span class="inline-flex items-center px-1.5 py-0.5 rounded text-[10px] font-medium bg-green-500/10 text-green-300">
|
||||
<i class="fas fa-circle text-[8px] mr-1"></i>Running
|
||||
</span>
|
||||
{{ end }}
|
||||
{{ if and $backendCfg (or (ne $backendCfg.MCP.Servers "") (ne $backendCfg.MCP.Stdio "")) }}
|
||||
<span class="inline-flex items-center px-1.5 py-0.5 rounded text-[10px] font-medium bg-[#8B5CF6]/10 text-[#8B5CF6]">
|
||||
<i class="fas fa-plug text-[8px] mr-1"></i>MCP
|
||||
</span>
|
||||
{{ end }}
|
||||
</div>
|
||||
</td>
|
||||
|
||||
<!-- Backend Column -->
|
||||
<td class="p-2">
|
||||
{{ if .Backend }}
|
||||
<span class="inline-flex items-center px-1.5 py-0.5 rounded text-[10px] font-medium bg-[#38BDF8]/10 text-[#38BDF8]">
|
||||
<i class="fas fa-cog text-[8px] mr-1"></i>{{.Backend}}
|
||||
</span>
|
||||
{{ else }}
|
||||
<span class="inline-flex items-center px-1.5 py-0.5 rounded text-[10px] font-medium bg-yellow-500/10 text-yellow-300">
|
||||
<i class="fas fa-magic text-[8px] mr-1"></i>Auto
|
||||
</span>
|
||||
{{ end }}
|
||||
</td>
|
||||
|
||||
<!-- Use Cases Column -->
|
||||
<td class="p-2">
|
||||
<div class="flex flex-wrap gap-1">
|
||||
{{ range .KnownUsecaseStrings }}
|
||||
{{ if eq . "FLAG_CHAT" }}
|
||||
<a href="chat/{{$backendCfg.Name}}" class="inline-flex items-center px-1.5 py-0.5 rounded text-[10px] font-medium bg-[#38BDF8]/10 text-[#38BDF8] hover:bg-[#38BDF8]/20 transition-colors" title="Chat">
|
||||
<i class="fas fa-comment-alt text-[8px] mr-1"></i>Chat
|
||||
</a>
|
||||
{{ end }}
|
||||
{{ if eq . "FLAG_IMAGE" }}
|
||||
<a href="text2image/{{$backendCfg.Name}}" class="inline-flex items-center px-1.5 py-0.5 rounded text-[10px] font-medium bg-green-500/10 text-green-300 hover:bg-green-500/20 transition-colors" title="Image">
|
||||
<i class="fas fa-image text-[8px] mr-1"></i>Image
|
||||
</a>
|
||||
{{ end }}
|
||||
{{ if eq . "FLAG_TTS" }}
|
||||
<a href="tts/{{$backendCfg.Name}}" class="inline-flex items-center px-1.5 py-0.5 rounded text-[10px] font-medium bg-[#8B5CF6]/10 text-[#8B5CF6] hover:bg-[#8B5CF6]/20 transition-colors" title="TTS">
|
||||
<i class="fas fa-microphone text-[8px] mr-1"></i>TTS
|
||||
</a>
|
||||
{{ end }}
|
||||
{{ end }}
|
||||
</div>
|
||||
</td>
|
||||
|
||||
<!-- Actions Column -->
|
||||
<td class="p-2">
|
||||
<div class="flex items-center justify-end gap-1">
|
||||
{{ if index $loadedModels .Name }}
|
||||
<button class="text-red-400/60 hover:text-red-400 hover:bg-red-500/10 rounded p-1 transition-colors"
|
||||
onclick="handleStopModel('{{.Name}}')"
|
||||
title="Stop {{.Name}}">
|
||||
<i class="fas fa-stop text-xs"></i>
|
||||
</button>
|
||||
{{ end }}
|
||||
<button class="text-red-400/60 hover:text-red-400 hover:bg-red-500/10 rounded p-1 transition-colors"
|
||||
onclick="handleDeleteModel('{{.Name}}')"
|
||||
title="Delete {{.Name}}">
|
||||
<i class="fas fa-trash-alt text-xs"></i>
|
||||
</button>
|
||||
</div>
|
||||
</td>
|
||||
</tr>
|
||||
{{ end }}
|
||||
|
||||
<!-- Models without config -->
|
||||
{{ range .Models }}
|
||||
<tr class="hover:bg-[#1E293B]/50 border-b border-[#1E293B] transition-colors">
|
||||
<td class="p-2">
|
||||
<div class="flex items-center gap-2">
|
||||
<i class="fas fa-brain text-xs text-[#94A3B8]"></i>
|
||||
<span class="text-xs text-[#E5E7EB] font-medium truncate">{{.}}</span>
|
||||
</div>
|
||||
</td>
|
||||
<td class="p-2">
|
||||
<span class="inline-flex items-center px-1.5 py-0.5 rounded text-[10px] font-medium bg-orange-500/10 text-orange-300">
|
||||
<i class="fas fa-exclamation-triangle text-[8px] mr-1"></i>No Config
|
||||
</span>
|
||||
</td>
|
||||
<td class="p-2">
|
||||
<span class="inline-flex items-center px-1.5 py-0.5 rounded text-[10px] font-medium bg-yellow-500/10 text-yellow-300">
|
||||
<i class="fas fa-magic text-[8px] mr-1"></i>Auto
|
||||
</span>
|
||||
</td>
|
||||
<td class="p-2">
|
||||
<span class="text-xs text-[#94A3B8]">—</span>
|
||||
</td>
|
||||
<td class="p-2">
|
||||
<span class="text-xs text-[#94A3B8]">—</span>
|
||||
</td>
|
||||
</tr>
|
||||
{{end}}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
{{ end }}
|
||||
</div>
|
||||
|
||||
<!-- Backends Section -->
|
||||
<div class="mt-8">
|
||||
<div class="mb-6">
|
||||
<h2 class="text-2xl font-semibold text-[#E5E7EB] mb-1 flex items-center">
|
||||
<i class="fas fa-cogs mr-2 text-[#8B5CF6] text-sm"></i>
|
||||
Installed Backends
|
||||
</h2>
|
||||
<p class="text-sm text-[#94A3B8] mb-4">
|
||||
<span class="text-[#8B5CF6] font-medium">{{len .InstalledBackends}}</span> backend{{if gt (len .InstalledBackends) 1}}s{{end}} ready to use
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{{ if eq (len .InstalledBackends) 0 }}
|
||||
<!-- No backends state -->
|
||||
<div class="bg-[#1E293B] border border-[#8B5CF6]/20 rounded-lg p-8">
|
||||
<div class="text-center max-w-4xl mx-auto">
|
||||
<div class="inline-flex items-center justify-center w-12 h-12 rounded-full bg-[#8B5CF6]/10 border border-[#8B5CF6]/20 mb-4">
|
||||
<i class="text-[#8B5CF6] text-xl fas fa-cogs"></i>
|
||||
</div>
|
||||
<h2 class="text-2xl font-bold text-[#E5E7EB] mb-2">No backends installed yet</h2>
|
||||
<p class="text-sm text-[#94A3B8] mb-6">Backends power your AI models. Install them from the backend gallery to get started</p>
|
||||
|
||||
<div class="flex flex-wrap justify-center gap-3">
|
||||
<a href="/browse/backends" class="inline-flex items-center bg-[#8B5CF6] hover:bg-[#8B5CF6]/90 text-white py-2 px-4 rounded-lg text-sm font-medium transition-colors">
|
||||
<i class="fas fa-cogs mr-2 text-xs"></i>
|
||||
Browse Backend Gallery
|
||||
</a>
|
||||
<a href="https://localai.io/backends/" target="_blank" class="inline-flex items-center bg-[#1E293B] hover:bg-[#1E293B]/80 border border-[#8B5CF6]/20 text-[#E5E7EB] py-2 px-4 rounded-lg text-sm font-medium transition-colors">
|
||||
<i class="fas fa-book mr-2 text-xs"></i>
|
||||
Documentation
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{{ else }}
|
||||
<!-- Backends Table -->
|
||||
<div class="overflow-x-auto mb-8">
|
||||
<table class="w-full border-collapse">
|
||||
<thead>
|
||||
<tr class="border-b border-[#1E293B]">
|
||||
<th class="text-left p-2 text-xs font-semibold text-[#94A3B8]">Name</th>
|
||||
<th class="text-left p-2 text-xs font-semibold text-[#94A3B8]">Type</th>
|
||||
<th class="text-left p-2 text-xs font-semibold text-[#94A3B8]">Metadata</th>
|
||||
<th class="text-right p-2 text-xs font-semibold text-[#94A3B8]">Actions</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{{ range .InstalledBackends }}
|
||||
<tr class="hover:bg-[#1E293B]/50 border-b border-[#1E293B] transition-colors">
|
||||
<!-- Name Column -->
|
||||
<td class="p-2">
|
||||
<div class="flex items-center gap-2">
|
||||
<i class="fas fa-cog text-xs text-[#8B5CF6]"></i>
|
||||
<span class="text-xs text-[#E5E7EB] font-medium truncate">{{.Name}}</span>
|
||||
</div>
|
||||
</td>
|
||||
|
||||
<!-- Type Column -->
|
||||
<td class="p-2">
|
||||
<div class="flex flex-wrap gap-1">
|
||||
{{ if .IsSystem }}
|
||||
<span class="inline-flex items-center px-1.5 py-0.5 rounded text-[10px] font-medium bg-blue-500/10 text-blue-300">
|
||||
<i class="fas fa-shield-alt text-[8px] mr-1"></i>System
|
||||
</span>
|
||||
{{ else }}
|
||||
<span class="inline-flex items-center px-1.5 py-0.5 rounded text-[10px] font-medium bg-green-500/10 text-green-300">
|
||||
<i class="fas fa-download text-[8px] mr-1"></i>User
|
||||
</span>
|
||||
{{ end }}
|
||||
{{ if .IsMeta }}
|
||||
<span class="inline-flex items-center px-1.5 py-0.5 rounded text-[10px] font-medium bg-[#8B5CF6]/10 text-[#8B5CF6]">
|
||||
<i class="fas fa-layer-group text-[8px] mr-1"></i>Meta
|
||||
</span>
|
||||
{{ end }}
|
||||
</div>
|
||||
</td>
|
||||
|
||||
<!-- Metadata Column -->
|
||||
<td class="p-2">
|
||||
<div class="flex flex-col gap-1">
|
||||
{{ if and .Metadata .Metadata.Alias }}
|
||||
<span class="text-xs text-[#94A3B8]">
|
||||
<i class="fas fa-tag text-[8px] mr-1"></i>Alias: <span class="text-[#E5E7EB]">{{.Metadata.Alias}}</span>
|
||||
</span>
|
||||
{{ end }}
|
||||
{{ if and .Metadata .Metadata.MetaBackendFor }}
|
||||
<span class="text-xs text-[#94A3B8]">
|
||||
<i class="fas fa-link text-[8px] mr-1"></i>For: <span class="text-[#8B5CF6]">{{.Metadata.MetaBackendFor}}</span>
|
||||
</span>
|
||||
{{ end }}
|
||||
{{ if and .Metadata .Metadata.InstalledAt }}
|
||||
<span class="text-xs text-[#94A3B8]">
|
||||
<i class="fas fa-calendar text-[8px] mr-1"></i>{{.Metadata.InstalledAt}}
|
||||
</span>
|
||||
{{ end }}
|
||||
</div>
|
||||
</td>
|
||||
|
||||
<!-- Actions Column -->
|
||||
<td class="p-2">
|
||||
<div class="flex items-center justify-end gap-1">
|
||||
{{ if not .IsSystem }}
|
||||
<button
|
||||
@click="deleteBackend('{{.Name}}')"
|
||||
class="text-red-400/60 hover:text-red-400 hover:bg-red-500/10 rounded p-1 transition-colors"
|
||||
title="Delete {{.Name}}">
|
||||
<i class="fas fa-trash-alt text-xs"></i>
|
||||
</button>
|
||||
{{ else }}
|
||||
<span class="text-xs text-[#94A3B8]">—</span>
|
||||
{{ end }}
|
||||
</div>
|
||||
</td>
|
||||
</tr>
|
||||
{{end}}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
{{ end }}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{{template "views/partials/footer" .}}
|
||||
</div>
|
||||
|
||||
<script>
|
||||
// Alpine.js component for index dashboard
|
||||
function indexDashboard() {
|
||||
return {
|
||||
notifications: [],
|
||||
|
||||
init() {
|
||||
// Initialize component
|
||||
},
|
||||
|
||||
addNotification(message, type = 'success') {
|
||||
const id = Date.now();
|
||||
this.notifications.push({ id, message, type });
|
||||
// Auto-dismiss after 5 seconds
|
||||
setTimeout(() => this.dismissNotification(id), 5000);
|
||||
},
|
||||
|
||||
dismissNotification(id) {
|
||||
this.notifications = this.notifications.filter(n => n.id !== id);
|
||||
},
|
||||
|
||||
async deleteBackend(backendName) {
|
||||
if (!confirm(`Are you sure you want to delete the backend "${backendName}"?`)) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch(`/api/backends/system/delete/${encodeURIComponent(backendName)}`, {
|
||||
method: 'POST'
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
if (response.ok && data.success) {
|
||||
this.addNotification(`Backend "${backendName}" deleted successfully!`, 'success');
|
||||
// Reload page after short delay
|
||||
setTimeout(() => {
|
||||
window.location.reload();
|
||||
}, 1500);
|
||||
} else {
|
||||
this.addNotification(`Failed to delete backend: ${data.error || 'Unknown error'}`, 'error');
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error deleting backend:', error);
|
||||
this.addNotification(`Failed to delete backend: ${error.message}`, 'error');
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async function handleStopModel(modelName) {
|
||||
if (!confirm('Are you sure you wish to stop this model?')) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch('/backend/shutdown', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ model: modelName })
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
window.location.reload();
|
||||
} else {
|
||||
alert('Failed to stop model');
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error stopping model:', error);
|
||||
alert('Failed to stop model');
|
||||
}
|
||||
}
|
||||
|
||||
async function handleDeleteModel(modelName) {
|
||||
if (!confirm('Are you sure you wish to delete this model?')) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch(`/api/models/delete/${encodeURIComponent(modelName)}`, {
|
||||
method: 'POST'
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
window.location.reload();
|
||||
} else {
|
||||
alert('Failed to delete model');
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error deleting model:', error);
|
||||
alert('Failed to delete model');
|
||||
}
|
||||
}
|
||||
|
||||
// Handle reload models button
|
||||
document.addEventListener('DOMContentLoaded', function() {
|
||||
const reloadBtn = document.getElementById('reload-models-btn');
|
||||
if (reloadBtn) {
|
||||
reloadBtn.addEventListener('click', function() {
|
||||
const button = this;
|
||||
const originalText = button.querySelector('span').textContent;
|
||||
const icon = button.querySelector('i');
|
||||
|
||||
// Show loading state
|
||||
button.disabled = true;
|
||||
button.querySelector('span').textContent = 'Updating...';
|
||||
icon.classList.add('fa-spin');
|
||||
|
||||
// Make the API call
|
||||
fetch('/models/reload', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
if (data.success) {
|
||||
// Show success state briefly
|
||||
button.querySelector('span').textContent = 'Updated!';
|
||||
icon.classList.remove('fa-spin', 'fa-sync-alt');
|
||||
icon.classList.add('fa-check');
|
||||
|
||||
// Reload the page after a short delay
|
||||
setTimeout(() => {
|
||||
window.location.reload();
|
||||
}, 1000);
|
||||
} else {
|
||||
// Show error state
|
||||
button.querySelector('span').textContent = 'Error!';
|
||||
icon.classList.remove('fa-spin');
|
||||
console.error('Failed to reload models:', data.error);
|
||||
|
||||
// Reset button after delay
|
||||
setTimeout(() => {
|
||||
button.disabled = false;
|
||||
button.querySelector('span').textContent = originalText;
|
||||
icon.classList.remove('fa-check');
|
||||
icon.classList.add('fa-sync-alt');
|
||||
}, 3000);
|
||||
}
|
||||
})
|
||||
.catch(error => {
|
||||
// Show error state
|
||||
button.querySelector('span').textContent = 'Error!';
|
||||
icon.classList.remove('fa-spin');
|
||||
console.error('Error reloading models:', error);
|
||||
|
||||
// Reset button after delay
|
||||
setTimeout(() => {
|
||||
button.disabled = false;
|
||||
button.querySelector('span').textContent = originalText;
|
||||
icon.classList.remove('fa-check');
|
||||
icon.classList.add('fa-sync-alt');
|
||||
}, 3000);
|
||||
});
|
||||
});
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
|
||||
@@ -10,45 +10,36 @@
|
||||
|
||||
<div class="container mx-auto px-4 py-8 flex-grow">
|
||||
<!-- Hero Header -->
|
||||
<div class="relative bg-[#1E293B] border border-[#8B5CF6]/20 rounded-3xl shadow-2xl shadow-[#8B5CF6]/10 p-8 mb-8 overflow-hidden">
|
||||
<!-- Background Pattern -->
|
||||
<div class="absolute inset-0 opacity-10">
|
||||
<div class="absolute inset-0 bg-gradient-to-r from-[#8B5CF6]/20 to-[#38BDF8]/20"></div>
|
||||
<div class="absolute top-0 left-0 w-full h-full" style="background-image: radial-gradient(circle at 1px 1px, rgba(139,92,246,0.15) 1px, transparent 0); background-size: 20px 20px;"></div>
|
||||
</div>
|
||||
|
||||
<div class="relative max-w-5xl mx-auto">
|
||||
<div class="bg-[#1E293B] border border-[#8B5CF6]/20 rounded-xl p-8 mb-8">
|
||||
<div class="max-w-5xl mx-auto">
|
||||
<div class="flex flex-col md:flex-row md:items-center md:justify-between">
|
||||
<div class="mb-4 md:mb-0">
|
||||
<h1 class="text-3xl md:text-4xl font-bold text-white mb-2">
|
||||
<h1 class="text-3xl md:text-4xl font-bold text-[#E5E7EB] mb-2">
|
||||
<span class="bg-clip-text text-transparent bg-gradient-to-r from-violet-400 via-purple-400 to-fuchsia-400">
|
||||
{{if .ModelName}}Edit Model: {{.ModelName}}{{else}}Import New Model{{end}}
|
||||
</span>
|
||||
</h1>
|
||||
<p class="text-lg text-gray-300 font-light" x-text="isAdvancedMode ? 'Configure your model settings using YAML' : 'Import a model from URI with preferences'"></p>
|
||||
<p class="text-lg text-[#94A3B8] font-light" x-text="isAdvancedMode ? 'Configure your model settings using YAML' : 'Import a model from URI with preferences'"></p>
|
||||
</div>
|
||||
<div class="flex gap-3">
|
||||
<!-- Mode Toggle (only show when not in edit mode) -->
|
||||
<template x-if="!isEditMode">
|
||||
<button @click="toggleMode()"
|
||||
class="group relative inline-flex items-center bg-gradient-to-r from-gray-600 to-gray-700 hover:from-gray-700 hover:to-gray-800 text-white py-3 px-6 rounded-xl font-semibold transition-all duration-300 ease-in-out transform hover:scale-105 hover:shadow-xl">
|
||||
<i class="fas group-hover:animate-pulse" :class="isAdvancedMode ? 'fa-magic mr-2' : 'fa-code mr-2'"></i>
|
||||
class="inline-flex items-center bg-[#1E293B] hover:bg-[#1E293B]/80 border border-[#8B5CF6]/20 text-[#E5E7EB] py-3 px-6 rounded-lg font-semibold transition-colors">
|
||||
<i class="fas" :class="isAdvancedMode ? 'fa-magic mr-2' : 'fa-code mr-2'"></i>
|
||||
<span x-text="isAdvancedMode ? 'Simple Mode' : 'Advanced Mode'"></span>
|
||||
<div class="absolute inset-0 rounded-xl bg-white/10 opacity-0 group-hover:opacity-100 transition-opacity"></div>
|
||||
</button>
|
||||
</template>
|
||||
<!-- Advanced Mode Buttons -->
|
||||
<template x-if="isAdvancedMode">
|
||||
<div class="flex gap-3">
|
||||
<button id="validateBtn" class="group relative inline-flex items-center bg-gradient-to-r from-blue-600 to-blue-700 hover:from-blue-700 hover:to-blue-800 text-white py-3 px-6 rounded-xl font-semibold transition-all duration-300 ease-in-out transform hover:scale-105 hover:shadow-xl hover:shadow-blue-500/25">
|
||||
<i class="fas fa-check mr-2 group-hover:animate-pulse"></i>
|
||||
<button id="validateBtn" class="inline-flex items-center bg-blue-600 hover:bg-blue-700 text-white py-3 px-6 rounded-lg font-semibold transition-colors">
|
||||
<i class="fas fa-check mr-2"></i>
|
||||
<span>Validate</span>
|
||||
<div class="absolute inset-0 rounded-xl bg-white/10 opacity-0 group-hover:opacity-100 transition-opacity"></div>
|
||||
</button>
|
||||
<button id="saveBtn" class="group relative inline-flex items-center bg-gradient-to-r from-green-600 to-emerald-600 hover:from-green-700 hover:to-emerald-700 text-white py-3 px-6 rounded-xl font-semibold transition-all duration-300 ease-in-out transform hover:scale-105 hover:shadow-xl hover:shadow-green-500/25">
|
||||
<i class="fas fa-save mr-2 group-hover:animate-pulse"></i>
|
||||
<button id="saveBtn" class="inline-flex items-center bg-green-600 hover:bg-green-700 text-white py-3 px-6 rounded-lg font-semibold transition-colors">
|
||||
<i class="fas fa-save mr-2"></i>
|
||||
<span>{{if .ModelName}}Update{{else}}Create{{end}}</span>
|
||||
<div class="absolute inset-0 rounded-xl bg-white/10 opacity-0 group-hover:opacity-100 transition-opacity"></div>
|
||||
</button>
|
||||
</div>
|
||||
</template>
|
||||
@@ -57,10 +48,9 @@
|
||||
<button @click="submitImport()"
|
||||
:disabled="isSubmitting || !importUri.trim()"
|
||||
:class="(isSubmitting || !importUri.trim()) ? 'opacity-50 cursor-not-allowed' : ''"
|
||||
class="group relative inline-flex items-center bg-gradient-to-r from-green-600 to-emerald-600 hover:from-green-700 hover:to-emerald-700 text-white py-3 px-6 rounded-xl font-semibold transition-all duration-300 ease-in-out transform hover:scale-105 hover:shadow-xl hover:shadow-green-500/25">
|
||||
<i class="fas group-hover:animate-pulse" :class="isSubmitting ? 'fa-spinner fa-spin mr-2' : 'fa-upload mr-2'"></i>
|
||||
class="inline-flex items-center bg-green-600 hover:bg-green-700 text-white py-3 px-6 rounded-lg font-semibold transition-colors">
|
||||
<i class="fas" :class="isSubmitting ? 'fa-spinner fa-spin mr-2' : 'fa-upload mr-2'"></i>
|
||||
<span x-text="isSubmitting ? 'Importing...' : 'Import Model'"></span>
|
||||
<div class="absolute inset-0 rounded-xl bg-white/10 opacity-0 group-hover:opacity-100 transition-opacity"></div>
|
||||
</button>
|
||||
</template>
|
||||
</div>
|
||||
@@ -73,15 +63,13 @@
|
||||
|
||||
<!-- Simple Import Mode -->
|
||||
<div x-show="!isAdvancedMode && !isEditMode"
|
||||
x-transition:enter="transition ease-out duration-300"
|
||||
x-transition:enter-start="opacity-0 transform translate-y-4"
|
||||
x-transition:enter-end="opacity-100 transform translate-y-0"
|
||||
class="relative bg-gradient-to-br from-gray-800/90 to-gray-900/90 border border-gray-700/50 rounded-2xl overflow-hidden shadow-xl backdrop-blur-sm p-8">
|
||||
<div class="absolute inset-0 rounded-2xl bg-gradient-to-br from-green-500/5 to-emerald-500/5"></div>
|
||||
|
||||
<div class="relative space-y-6">
|
||||
<h2 class="text-2xl font-semibold text-white flex items-center gap-3 mb-6">
|
||||
<div class="w-10 h-10 rounded-lg bg-green-500/20 flex items-center justify-center">
|
||||
x-transition:enter="transition ease-out duration-200"
|
||||
x-transition:enter-start="opacity-0"
|
||||
x-transition:enter-end="opacity-100"
|
||||
class="bg-[#1E293B] border border-[#8B5CF6]/20 rounded-xl p-8">
|
||||
<div class="space-y-6">
|
||||
<h2 class="text-2xl font-semibold text-[#E5E7EB] flex items-center gap-3 mb-6">
|
||||
<div class="w-10 h-10 rounded-lg bg-green-500/10 flex items-center justify-center">
|
||||
<i class="fas fa-link text-green-400"></i>
|
||||
</div>
|
||||
Import from URI
|
||||
@@ -89,16 +77,16 @@
|
||||
|
||||
<!-- URI Input -->
|
||||
<div>
|
||||
<label class="block text-sm font-medium text-gray-300 mb-2">
|
||||
<label class="block text-sm font-medium text-[#94A3B8] mb-2">
|
||||
<i class="fas fa-link mr-2"></i>Model URI
|
||||
</label>
|
||||
<input
|
||||
x-model="importUri"
|
||||
type="text"
|
||||
placeholder="https://example.com/model.gguf or file:///path/to/model.gguf"
|
||||
class="w-full px-4 py-3 bg-gray-900/90 border border-gray-700/70 rounded-xl text-gray-200 focus:border-green-500 focus:ring-2 focus:ring-green-500/50 focus:outline-none transition-all"
|
||||
class="w-full px-4 py-3 bg-[#101827] border border-[#1E293B] rounded-lg text-[#E5E7EB] focus:border-green-500 focus:ring-2 focus:ring-green-500/50 focus:outline-none transition-colors"
|
||||
:disabled="isSubmitting">
|
||||
<p class="mt-2 text-xs text-gray-400">
|
||||
<p class="mt-2 text-xs text-[#94A3B8]">
|
||||
Enter the URI or path to the model file you want to import
|
||||
</p>
|
||||
</div>
|
||||
@@ -130,6 +118,8 @@
|
||||
<option value="llama-cpp">llama-cpp</option>
|
||||
<option value="mlx">mlx</option>
|
||||
<option value="mlx-vlm">mlx-vlm</option>
|
||||
<option value="transformers">transformers</option>
|
||||
<option value="vllm">vllm</option>
|
||||
</select>
|
||||
<p class="mt-1 text-xs text-gray-400">
|
||||
Force a specific backend. Leave empty to auto-detect from URI.
|
||||
@@ -199,6 +189,39 @@
|
||||
Preferred MMProj quantizations (comma-separated). Examples: fp16, fp32. Leave empty to use default (fp16).
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- Embeddings -->
|
||||
<div>
|
||||
<label class="flex items-center cursor-pointer">
|
||||
<input
|
||||
x-model="commonPreferences.embeddings"
|
||||
type="checkbox"
|
||||
class="w-5 h-5 rounded bg-gray-900/90 border-gray-700/70 text-green-500 focus:ring-2 focus:ring-green-500/50 focus:outline-none transition-all cursor-pointer"
|
||||
:disabled="isSubmitting">
|
||||
<span class="ml-3 text-sm font-medium text-gray-300">
|
||||
<i class="fas fa-vector-square mr-2"></i>Embeddings
|
||||
</span>
|
||||
</label>
|
||||
<p class="mt-1 ml-8 text-xs text-gray-400">
|
||||
Enable embeddings support for this model.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- Model Type -->
|
||||
<div>
|
||||
<label class="block text-sm font-medium text-gray-300 mb-2">
|
||||
<i class="fas fa-tag mr-2"></i>Model Type
|
||||
</label>
|
||||
<input
|
||||
x-model="commonPreferences.type"
|
||||
type="text"
|
||||
placeholder="AutoModelForCausalLM (for transformers backend)"
|
||||
class="w-full px-4 py-2 bg-gray-900/90 border border-gray-700/70 rounded-lg text-gray-200 focus:border-green-500 focus:ring-2 focus:ring-green-500/50 focus:outline-none transition-all"
|
||||
:disabled="isSubmitting">
|
||||
<p class="mt-1 text-xs text-gray-400">
|
||||
Model type for transformers backend. Examples: AutoModelForCausalLM, SentenceTransformer, Mamba, MusicgenForConditionalGeneration. Leave empty to use default (AutoModelForCausalLM).
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Custom Preferences -->
|
||||
@@ -248,25 +271,23 @@
|
||||
|
||||
<!-- Advanced YAML Editor Panel -->
|
||||
<div x-show="isAdvancedMode || isEditMode"
|
||||
x-transition:enter="transition ease-out duration-300"
|
||||
x-transition:enter-start="opacity-0 transform translate-y-4"
|
||||
x-transition:enter-end="opacity-100 transform translate-y-0"
|
||||
class="relative bg-gradient-to-br from-gray-800/90 to-gray-900/90 border border-gray-700/50 rounded-2xl overflow-hidden shadow-xl backdrop-blur-sm h-[calc(100vh-250px)]">
|
||||
<div class="absolute inset-0 rounded-2xl bg-gradient-to-br from-fuchsia-500/5 to-purple-500/5"></div>
|
||||
|
||||
<div class="relative sticky top-0 bg-gray-800/95 border-b border-gray-700/50 p-6 flex items-center justify-between z-10 backdrop-blur-sm">
|
||||
<h2 class="text-xl font-semibold text-white flex items-center gap-3">
|
||||
<div class="w-8 h-8 rounded-lg bg-fuchsia-500/20 flex items-center justify-center">
|
||||
x-transition:enter="transition ease-out duration-200"
|
||||
x-transition:enter-start="opacity-0"
|
||||
x-transition:enter-end="opacity-100"
|
||||
class="bg-[#1E293B] border border-[#8B5CF6]/20 rounded-xl overflow-hidden h-[calc(100vh-250px)]">
|
||||
<div class="sticky top-0 bg-[#1E293B] border-b border-[#101827] p-6 flex items-center justify-between z-10">
|
||||
<h2 class="text-xl font-semibold text-[#E5E7EB] flex items-center gap-3">
|
||||
<div class="w-8 h-8 rounded-lg bg-fuchsia-500/10 flex items-center justify-center">
|
||||
<i class="fas fa-code text-fuchsia-400"></i>
|
||||
</div>
|
||||
YAML Configuration Editor
|
||||
</h2>
|
||||
<div class="flex items-center gap-3">
|
||||
<button id="formatYamlBtn" class="group text-gray-400 hover:text-gray-200 text-sm px-3 py-1.5 rounded-lg hover:bg-gray-700/50 transition-all duration-200">
|
||||
<i class="fas fa-indent mr-1.5 group-hover:animate-pulse"></i> Format
|
||||
<button id="formatYamlBtn" class="text-[#94A3B8] hover:text-[#E5E7EB] text-sm px-3 py-1.5 rounded-lg hover:bg-[#101827] transition-colors">
|
||||
<i class="fas fa-indent mr-1.5"></i> Format
|
||||
</button>
|
||||
<button id="copyYamlBtn" class="group text-gray-400 hover:text-gray-200 text-sm px-3 py-1.5 rounded-lg hover:bg-gray-700/50 transition-all duration-200">
|
||||
<i class="fas fa-copy mr-1.5 group-hover:animate-bounce"></i> Copy
|
||||
<button id="copyYamlBtn" class="text-[#94A3B8] hover:text-[#E5E7EB] text-sm px-3 py-1.5 rounded-lg hover:bg-[#101827] transition-colors">
|
||||
<i class="fas fa-copy mr-1.5"></i> Copy
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
@@ -280,13 +301,13 @@
|
||||
</div>
|
||||
|
||||
<!-- Include JS-YAML library -->
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/js-yaml/4.1.0/js-yaml.min.js"></script>
|
||||
<script src="static/assets/js-yaml.min.js"></script>
|
||||
|
||||
<!-- Include CodeMirror for syntax highlighting -->
|
||||
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/codemirror/6.65.7/codemirror.min.css">
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/codemirror/6.65.7/codemirror.min.js"></script>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/codemirror/6.65.7/mode/yaml/yaml.min.js"></script>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/codemirror/6.65.7/addon/display/autorefresh.min.js"></script>
|
||||
<link rel="stylesheet" href="static/assets/codemirror.min.css">
|
||||
<script src="static/assets/codemirror.min.js"></script>
|
||||
<script src="static/assets/yaml.min.js"></script>
|
||||
<script src="static/assets/autorefresh.min.js"></script>
|
||||
|
||||
<style>
|
||||
/* Enhanced CodeMirror styling */
|
||||
@@ -412,11 +433,9 @@
|
||||
|
||||
@keyframes slideInFromTop {
|
||||
from {
|
||||
transform: translateY(-20px);
|
||||
opacity: 0;
|
||||
}
|
||||
to {
|
||||
transform: translateY(0);
|
||||
opacity: 1;
|
||||
}
|
||||
}
|
||||
@@ -458,7 +477,9 @@ function importModel() {
|
||||
name: '',
|
||||
description: '',
|
||||
quantizations: '',
|
||||
mmproj_quantizations: ''
|
||||
mmproj_quantizations: '',
|
||||
embeddings: false,
|
||||
type: ''
|
||||
},
|
||||
isSubmitting: false,
|
||||
currentJobId: null,
|
||||
@@ -527,6 +548,12 @@ function importModel() {
|
||||
if (this.commonPreferences.mmproj_quantizations && this.commonPreferences.mmproj_quantizations.trim()) {
|
||||
prefsObj.mmproj_quantizations = this.commonPreferences.mmproj_quantizations.trim();
|
||||
}
|
||||
if (this.commonPreferences.embeddings) {
|
||||
prefsObj.embeddings = 'true';
|
||||
}
|
||||
if (this.commonPreferences.type && this.commonPreferences.type.trim()) {
|
||||
prefsObj.type = this.commonPreferences.type.trim();
|
||||
}
|
||||
|
||||
// Add custom preferences (can override common ones)
|
||||
this.preferences.forEach(pref => {
|
||||
|
||||
@@ -10,22 +10,22 @@
|
||||
<!-- Notifications -->
|
||||
<div class="fixed top-20 right-4 z-50 space-y-2" style="max-width: 400px;">
|
||||
<template x-for="notification in notifications" :key="notification.id">
|
||||
<div x-show="true"
|
||||
x-transition:enter="transform ease-out duration-300 transition"
|
||||
x-transition:enter-start="translate-x-full opacity-0"
|
||||
x-transition:enter-end="translate-x-0 opacity-100"
|
||||
x-transition:leave="transform ease-in duration-200 transition"
|
||||
x-transition:leave-start="translate-x-0 opacity-100"
|
||||
x-transition:leave-end="translate-x-full opacity-0"
|
||||
<div x-show="true"
|
||||
x-transition:enter="transition ease-out duration-200"
|
||||
x-transition:enter-start="opacity-0"
|
||||
x-transition:enter-end="opacity-100"
|
||||
x-transition:leave="transition ease-in duration-150"
|
||||
x-transition:leave-start="opacity-100"
|
||||
x-transition:leave-end="opacity-0"
|
||||
:class="notification.type === 'error' ? 'bg-red-500' : 'bg-green-500'"
|
||||
class="rounded-lg shadow-xl p-4 text-white flex items-start space-x-3">
|
||||
class="rounded-lg p-4 text-white flex items-start space-x-3">
|
||||
<div class="flex-shrink-0">
|
||||
<i :class="notification.type === 'error' ? 'fas fa-exclamation-circle' : 'fas fa-check-circle'" class="text-xl"></i>
|
||||
</div>
|
||||
<div class="flex-1 min-w-0">
|
||||
<p class="text-sm font-medium break-words" x-text="notification.message"></p>
|
||||
</div>
|
||||
<button @click="dismissNotification(notification.id)" class="flex-shrink-0 text-white hover:text-gray-200">
|
||||
<button @click="dismissNotification(notification.id)" class="flex-shrink-0 text-white hover:opacity-80 transition-opacity">
|
||||
<i class="fas fa-times"></i>
|
||||
</button>
|
||||
</div>
|
||||
@@ -35,35 +35,34 @@
|
||||
<div class="container mx-auto px-4 py-8 flex-grow">
|
||||
|
||||
<!-- Hero Header -->
|
||||
<div class="relative bg-[#1E293B] border border-[#38BDF8]/20 rounded-3xl shadow-2xl shadow-[#38BDF8]/10 p-8 mb-12 overflow-hidden">
|
||||
<!-- Background Pattern -->
|
||||
<div class="absolute inset-0 opacity-10">
|
||||
<div class="absolute inset-0 bg-gradient-to-r from-[#38BDF8]/20 to-[#8B5CF6]/20"></div>
|
||||
<div class="absolute top-0 left-0 w-full h-full" style="background-image: radial-gradient(circle at 1px 1px, rgba(56,189,248,0.15) 1px, transparent 0); background-size: 20px 20px;"></div>
|
||||
</div>
|
||||
|
||||
<div class="relative max-w-5xl mx-auto text-center">
|
||||
<div class="bg-[#1E293B] border border-[#38BDF8]/20 rounded-xl p-8 mb-12">
|
||||
<div class="max-w-5xl mx-auto text-center">
|
||||
<h1 class="text-4xl md:text-5xl font-bold text-[#E5E7EB] mb-4">
|
||||
<span class="bg-clip-text text-transparent bg-gradient-to-r from-[#38BDF8] via-[#8B5CF6] to-[#38BDF8]">
|
||||
Model Gallery
|
||||
</span>
|
||||
</h1>
|
||||
<p class="text-lg md:text-xl text-gray-300 mb-6 font-light">
|
||||
<p class="text-lg md:text-xl text-[#94A3B8] mb-6 font-light">
|
||||
Discover and install AI models from our curated collection
|
||||
</p>
|
||||
<div class="flex flex-wrap justify-center items-center gap-6 text-sm md:text-base">
|
||||
<div class="flex items-center bg-white/10 rounded-full px-4 py-2">
|
||||
<div class="w-2 h-2 bg-indigo-400 rounded-full mr-2 animate-pulse"></div>
|
||||
<div class="flex items-center bg-[#101827] rounded-lg px-4 py-2">
|
||||
<div class="w-2 h-2 bg-indigo-400 rounded-full mr-2"></div>
|
||||
<span class="font-semibold text-indigo-300" x-text="availableModels"></span>
|
||||
<span class="text-gray-300 ml-1">models available</span>
|
||||
<span class="text-[#94A3B8] ml-1">models available</span>
|
||||
</div>
|
||||
<div class="flex items-center bg-white/10 rounded-full px-4 py-2">
|
||||
<div class="w-2 h-2 bg-purple-400 rounded-full mr-2 animate-pulse"></div>
|
||||
<a href="/manage" class="flex items-center bg-[#101827] hover:bg-[#1E293B] rounded-lg px-4 py-2 transition-colors border border-[#38BDF8]/30 hover:border-[#38BDF8]/50">
|
||||
<div class="w-2 h-2 bg-emerald-400 rounded-full mr-2"></div>
|
||||
<span class="font-semibold text-emerald-300" x-text="installedModels"></span>
|
||||
<span class="text-[#94A3B8] ml-1">installed</span>
|
||||
</a>
|
||||
<div class="flex items-center bg-[#101827] rounded-lg px-4 py-2">
|
||||
<div class="w-2 h-2 bg-purple-400 rounded-full mr-2"></div>
|
||||
<span class="font-semibold text-purple-300" x-text="repositories.length"></span>
|
||||
<span class="text-gray-300 ml-1">repositories</span>
|
||||
<span class="text-[#94A3B8] ml-1">repositories</span>
|
||||
</div>
|
||||
<a href="https://localai.io/models/" target="_blank"
|
||||
class="flex items-center bg-blue-600/80 hover:bg-blue-600 text-white px-4 py-2 rounded-full transition-all duration-300 hover:scale-105">
|
||||
class="inline-flex items-center bg-blue-600 hover:bg-blue-700 text-white px-4 py-2 rounded-lg transition-colors">
|
||||
<i class="fas fa-info-circle mr-2"></i>
|
||||
<span>Documentation</span>
|
||||
<i class="fas fa-external-link-alt ml-2 text-xs"></i>
|
||||
@@ -75,28 +74,26 @@
|
||||
{{template "views/partials/inprogress" .}}
|
||||
|
||||
<!-- Search and Filter Section -->
|
||||
<div class="relative bg-gradient-to-br from-gray-800/80 to-gray-900/80 rounded-2xl p-8 mb-8 shadow-xl border border-gray-700/50 backdrop-blur-sm">
|
||||
<div class="absolute inset-0 rounded-2xl bg-gradient-to-br from-blue-500/5 to-purple-500/5"></div>
|
||||
|
||||
<div class="relative">
|
||||
<div class="bg-[#1E293B] border border-[#38BDF8]/20 rounded-xl p-8 mb-8">
|
||||
<div>
|
||||
<!-- Search Input -->
|
||||
<div class="mb-8">
|
||||
<h3 class="text-xl font-semibold text-white mb-4 flex items-center">
|
||||
<i class="fas fa-search mr-3 text-blue-400"></i>
|
||||
<h3 class="text-xl font-semibold text-[#E5E7EB] mb-4 flex items-center">
|
||||
<i class="fas fa-search mr-3 text-[#38BDF8]"></i>
|
||||
Find Your Perfect Model
|
||||
</h3>
|
||||
<div class="relative">
|
||||
<div class="absolute inset-y-0 start-0 flex items-center ps-4 pointer-events-none">
|
||||
<i class="fas fa-search text-gray-400"></i>
|
||||
<i class="fas fa-search text-[#94A3B8]"></i>
|
||||
</div>
|
||||
<input
|
||||
x-model="searchTerm"
|
||||
@input.debounce.500ms="fetchModels()"
|
||||
class="w-full pl-12 pr-16 py-4 text-base font-normal text-gray-300 bg-gray-900/90 border border-gray-700/70 rounded-xl transition-all duration-300 focus:text-gray-200 focus:bg-gray-900 focus:border-blue-500 focus:ring-2 focus:ring-blue-500/50 focus:outline-none"
|
||||
class="w-full pl-12 pr-16 py-4 text-base font-normal text-[#E5E7EB] bg-[#101827] border border-[#1E293B] rounded-lg transition-colors focus:text-[#E5E7EB] focus:bg-[#101827] focus:border-[#38BDF8] focus:ring-2 focus:ring-[#38BDF8]/50 focus:outline-none"
|
||||
type="search"
|
||||
placeholder="Search models by name, tag, or description...">
|
||||
<span class="absolute right-4 top-4" x-show="loading">
|
||||
<svg class="animate-spin h-6 w-6 text-blue-500" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24">
|
||||
<svg class="animate-spin h-6 w-6 text-[#38BDF8]" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24">
|
||||
<circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle>
|
||||
<path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path>
|
||||
</svg>
|
||||
@@ -106,49 +103,49 @@
|
||||
|
||||
<!-- Filter by Type -->
|
||||
<div class="mb-8">
|
||||
<h3 class="text-lg font-semibold text-white mb-4 flex items-center">
|
||||
<i class="fas fa-filter mr-3 text-purple-400"></i>
|
||||
<h3 class="text-lg font-semibold text-[#E5E7EB] mb-4 flex items-center">
|
||||
<i class="fas fa-filter mr-3 text-[#8B5CF6]"></i>
|
||||
Filter by Model Type
|
||||
</h3>
|
||||
<div class="grid grid-cols-2 sm:grid-cols-3 lg:grid-cols-4 xl:grid-cols-8 gap-3">
|
||||
<button @click="filterByTerm('tts')"
|
||||
class="group flex items-center justify-center rounded-xl px-4 py-3 text-sm font-semibold bg-gradient-to-r from-indigo-600/80 to-indigo-700/80 hover:from-indigo-600 hover:to-indigo-700 text-indigo-100 border border-indigo-500/30 hover:border-indigo-400/50 transition-all duration-300 transform hover:scale-105 hover:shadow-lg hover:shadow-indigo-500/25">
|
||||
<i class="fas fa-microphone mr-2 group-hover:animate-pulse"></i>
|
||||
class="flex items-center justify-center rounded-lg px-4 py-3 text-sm font-semibold bg-indigo-600/20 hover:bg-indigo-600/30 text-indigo-300 border border-indigo-500/30 transition-colors">
|
||||
<i class="fas fa-microphone mr-2"></i>
|
||||
<span>TTS</span>
|
||||
</button>
|
||||
<button @click="filterByTerm('stablediffusion')"
|
||||
class="group flex items-center justify-center rounded-xl px-4 py-3 text-sm font-semibold bg-gradient-to-r from-purple-600/80 to-purple-700/80 hover:from-purple-600 hover:to-purple-700 text-purple-100 border border-purple-500/30 hover:border-purple-400/50 transition-all duration-300 transform hover:scale-105 hover:shadow-lg hover:shadow-purple-500/25">
|
||||
<i class="fas fa-image mr-2 group-hover:animate-pulse"></i>
|
||||
class="flex items-center justify-center rounded-lg px-4 py-3 text-sm font-semibold bg-purple-600/20 hover:bg-purple-600/30 text-purple-300 border border-purple-500/30 transition-colors">
|
||||
<i class="fas fa-image mr-2"></i>
|
||||
<span>Image</span>
|
||||
</button>
|
||||
<button @click="filterByTerm('llm')"
|
||||
class="group flex items-center justify-center rounded-xl px-4 py-3 text-sm font-semibold bg-gradient-to-r from-blue-600/80 to-blue-700/80 hover:from-blue-600 hover:to-blue-700 text-blue-100 border border-blue-500/30 hover:border-blue-400/50 transition-all duration-300 transform hover:scale-105 hover:shadow-lg hover:shadow-blue-500/25">
|
||||
<i class="fas fa-comment-alt mr-2 group-hover:animate-bounce"></i>
|
||||
class="flex items-center justify-center rounded-lg px-4 py-3 text-sm font-semibold bg-blue-600/20 hover:bg-blue-600/30 text-blue-300 border border-blue-500/30 transition-colors">
|
||||
<i class="fas fa-comment-alt mr-2"></i>
|
||||
<span>LLM</span>
|
||||
</button>
|
||||
<button @click="filterByTerm('multimodal')"
|
||||
class="group flex items-center justify-center rounded-xl px-4 py-3 text-sm font-semibold bg-gradient-to-r from-green-600/80 to-green-700/80 hover:from-green-600 hover:to-green-700 text-green-100 border border-green-500/30 hover:border-green-400/50 transition-all duration-300 transform hover:scale-105 hover:shadow-lg hover:shadow-green-500/25">
|
||||
<i class="fas fa-object-group mr-2 group-hover:animate-pulse"></i>
|
||||
class="flex items-center justify-center rounded-lg px-4 py-3 text-sm font-semibold bg-green-600/20 hover:bg-green-600/30 text-green-300 border border-green-500/30 transition-colors">
|
||||
<i class="fas fa-object-group mr-2"></i>
|
||||
<span>Multimodal</span>
|
||||
</button>
|
||||
<button @click="filterByTerm('embedding')"
|
||||
class="group flex items-center justify-center rounded-xl px-4 py-3 text-sm font-semibold bg-gradient-to-r from-cyan-600/80 to-cyan-700/80 hover:from-cyan-600 hover:to-cyan-700 text-cyan-100 border border-cyan-500/30 hover:border-cyan-400/50 transition-all duration-300 transform hover:scale-105 hover:shadow-lg hover:shadow-cyan-500/25">
|
||||
<i class="fas fa-vector-square mr-2 group-hover:animate-pulse"></i>
|
||||
class="flex items-center justify-center rounded-lg px-4 py-3 text-sm font-semibold bg-cyan-600/20 hover:bg-cyan-600/30 text-cyan-300 border border-cyan-500/30 transition-colors">
|
||||
<i class="fas fa-vector-square mr-2"></i>
|
||||
<span>Embedding</span>
|
||||
</button>
|
||||
<button @click="filterByTerm('rerank')"
|
||||
class="group flex items-center justify-center rounded-xl px-4 py-3 text-sm font-semibold bg-gradient-to-r from-amber-600/80 to-amber-700/80 hover:from-amber-600 hover:to-amber-700 text-amber-100 border border-amber-500/30 hover:border-amber-400/50 transition-all duration-300 transform hover:scale-105 hover:shadow-lg hover:shadow-amber-500/25">
|
||||
<i class="fas fa-sort-amount-up mr-2 group-hover:animate-pulse"></i>
|
||||
class="flex items-center justify-center rounded-lg px-4 py-3 text-sm font-semibold bg-amber-600/20 hover:bg-amber-600/30 text-amber-300 border border-amber-500/30 transition-colors">
|
||||
<i class="fas fa-sort-amount-up mr-2"></i>
|
||||
<span>Rerank</span>
|
||||
</button>
|
||||
<button @click="filterByTerm('whisper')"
|
||||
class="group flex items-center justify-center rounded-xl px-4 py-3 text-sm font-semibold bg-gradient-to-r from-teal-600/80 to-teal-700/80 hover:from-teal-600 hover:to-teal-700 text-teal-100 border border-teal-500/30 hover:border-teal-400/50 transition-all duration-300 transform hover:scale-105 hover:shadow-lg hover:shadow-teal-500/25">
|
||||
<i class="fas fa-headphones mr-2 group-hover:animate-pulse"></i>
|
||||
class="flex items-center justify-center rounded-lg px-4 py-3 text-sm font-semibold bg-teal-600/20 hover:bg-teal-600/30 text-teal-300 border border-teal-500/30 transition-colors">
|
||||
<i class="fas fa-headphones mr-2"></i>
|
||||
<span>Whisper</span>
|
||||
</button>
|
||||
<button @click="filterByTerm('object-detection')"
|
||||
class="group flex items-center justify-center rounded-xl px-4 py-3 text-sm font-semibold bg-gradient-to-r from-red-600/80 to-red-700/80 hover:from-red-600 hover:to-red-700 text-red-100 border border-red-500/30 hover:border-red-400/50 transition-all duration-300 transform hover:scale-105 hover:shadow-lg hover:shadow-red-500/25">
|
||||
<i class="fas fa-eye mr-2 group-hover:animate-pulse"></i>
|
||||
class="flex items-center justify-center rounded-lg px-4 py-3 text-sm font-semibold bg-red-600/20 hover:bg-red-600/30 text-red-300 border border-red-500/30 transition-colors">
|
||||
<i class="fas fa-eye mr-2"></i>
|
||||
<span>Vision</span>
|
||||
</button>
|
||||
</div>
|
||||
@@ -156,16 +153,16 @@
|
||||
|
||||
<!-- Filter by Tags -->
|
||||
<div x-show="allTags.length > 0">
|
||||
<h3 class="text-lg font-semibold text-white mb-4 flex items-center">
|
||||
<i class="fas fa-tags mr-3 text-pink-400"></i>
|
||||
<h3 class="text-lg font-semibold text-[#E5E7EB] mb-4 flex items-center">
|
||||
<i class="fas fa-tags mr-3 text-[#8B5CF6]"></i>
|
||||
Browse by Tags
|
||||
</h3>
|
||||
<div class="max-h-32 overflow-y-auto scrollbar-thin scrollbar-thumb-gray-600/50 scrollbar-track-gray-800/50 pr-2">
|
||||
<div class="max-h-32 overflow-y-auto pr-2">
|
||||
<div class="flex flex-wrap gap-2">
|
||||
<template x-for="tag in allTags" :key="tag">
|
||||
<button @click="filterByTerm(tag)"
|
||||
class="group inline-flex items-center text-xs px-3 py-2 rounded-full bg-gray-700/60 hover:bg-gray-600/80 text-gray-300 hover:text-white border border-gray-600/50 hover:border-gray-500/70 transition-all duration-200 ease-in-out transform hover:scale-105">
|
||||
<i class="fas fa-tag text-xs mr-2 group-hover:animate-pulse"></i>
|
||||
class="inline-flex items-center text-xs px-3 py-2 rounded bg-[#101827] hover:bg-[#101827]/80 text-[#94A3B8] hover:text-[#E5E7EB] border border-[#1E293B] transition-colors">
|
||||
<i class="fas fa-tag text-xs mr-2"></i>
|
||||
<span x-text="tag"></span>
|
||||
</button>
|
||||
</template>
|
||||
@@ -210,10 +207,14 @@
|
||||
<tr class="hover:bg-[#38BDF8]/10 transition-colors duration-200">
|
||||
<!-- Icon -->
|
||||
<td class="px-6 py-4">
|
||||
<img :src="model.icon || 'https://upload.wikimedia.org/wikipedia/commons/6/65/No-Image-Placeholder.svg'"
|
||||
class="w-12 h-12 object-cover rounded-lg border border-[#38BDF8]/30"
|
||||
loading="lazy"
|
||||
:alt="model.name">
|
||||
<div class="w-12 h-12 rounded-lg border border-[#38BDF8]/30 flex items-center justify-center bg-[#101827]">
|
||||
<img x-show="model.icon"
|
||||
:src="model.icon"
|
||||
class="w-full h-full object-cover rounded-lg"
|
||||
loading="lazy"
|
||||
:alt="model.name">
|
||||
<i x-show="!model.icon" class="fas fa-brain text-xl text-[#38BDF8]"></i>
|
||||
</div>
|
||||
</td>
|
||||
|
||||
<!-- Model Name -->
|
||||
@@ -355,9 +356,13 @@
|
||||
<!-- Modal Body -->
|
||||
<div class="p-4 md:p-5 space-y-4 overflow-y-auto flex-1 min-h-0">
|
||||
<div class="flex justify-center items-center">
|
||||
<img :src="selectedModel?.icon || 'https://upload.wikimedia.org/wikipedia/commons/6/65/No-Image-Placeholder.svg'"
|
||||
class="lazy rounded-t-lg max-h-48 max-w-96 object-cover mt-3"
|
||||
loading="lazy">
|
||||
<div class="w-48 h-48 rounded-lg border border-gray-300 dark:border-gray-600 flex items-center justify-center bg-gray-100 dark:bg-gray-800 mt-3">
|
||||
<img x-show="selectedModel?.icon"
|
||||
:src="selectedModel?.icon"
|
||||
class="rounded-lg max-h-48 max-w-96 object-cover"
|
||||
loading="lazy">
|
||||
<i x-show="!selectedModel?.icon" class="fas fa-brain text-6xl text-gray-400 dark:text-gray-500"></i>
|
||||
</div>
|
||||
</div>
|
||||
<div class="text-base leading-relaxed text-gray-500 dark:text-gray-400 break-words max-w-full markdown-content" x-html="renderMarkdown(selectedModel?.description)"></div>
|
||||
<hr>
|
||||
@@ -424,8 +429,8 @@
|
||||
<button @click="goToPage(currentPage - 1)"
|
||||
:disabled="currentPage <= 1"
|
||||
:class="currentPage <= 1 ? 'opacity-50 cursor-not-allowed' : ''"
|
||||
class="group flex items-center justify-center h-12 w-12 bg-gray-700/80 hover:bg-indigo-600 text-gray-300 hover:text-white rounded-xl shadow-lg transition-all duration-300 ease-in-out transform hover:scale-110">
|
||||
<i class="fas fa-chevron-left group-hover:animate-pulse"></i>
|
||||
class="flex items-center justify-center h-12 w-12 bg-[#1E293B] hover:bg-indigo-600 text-[#94A3B8] hover:text-white rounded-lg transition-colors">
|
||||
<i class="fas fa-chevron-left"></i>
|
||||
</button>
|
||||
<div class="text-gray-300 text-sm font-medium px-4">
|
||||
<span class="text-gray-400">Page</span>
|
||||
@@ -559,6 +564,7 @@ function modelsGallery() {
|
||||
currentPage: 1,
|
||||
totalPages: 1,
|
||||
availableModels: 0,
|
||||
installedModels: 0,
|
||||
selectedModel: null,
|
||||
jobProgress: {},
|
||||
notifications: [],
|
||||
@@ -597,6 +603,7 @@ function modelsGallery() {
|
||||
this.currentPage = data.currentPage || 1;
|
||||
this.totalPages = data.totalPages || 1;
|
||||
this.availableModels = data.availableModels || 0;
|
||||
this.installedModels = data.installedModels || 0;
|
||||
} catch (error) {
|
||||
console.error('Error fetching models:', error);
|
||||
} finally {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user