diff --git a/backend/cpp/llama-cpp/Makefile b/backend/cpp/llama-cpp/Makefile index 091c3386d..02f12013f 100644 --- a/backend/cpp/llama-cpp/Makefile +++ b/backend/cpp/llama-cpp/Makefile @@ -1,5 +1,5 @@ -LLAMA_VERSION?=7d019cff744b73084b15ca81ba9916f3efab1223 +LLAMA_VERSION?=92bb442ad999a0d52df0af2730cd861012e8ac5c LLAMA_REPO?=https://github.com/ggerganov/llama.cpp CMAKE_ARGS?= diff --git a/backend/cpp/llama-cpp/grpc-server.cpp b/backend/cpp/llama-cpp/grpc-server.cpp index a71f43aec..1796df0dc 100644 --- a/backend/cpp/llama-cpp/grpc-server.cpp +++ b/backend/cpp/llama-cpp/grpc-server.cpp @@ -579,7 +579,7 @@ public: auto completion_id = gen_chatcmplid(); - std::unordered_set task_ids; + server_response_reader rd(ctx_server); try { std::vector tasks; @@ -808,10 +808,9 @@ public: } } - const auto & prompt = prompt_str; const auto type = SERVER_TASK_TYPE_COMPLETION; // TODO: this log can become very long, put it behind a flag or think about a more compact format - //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); + //SRV_DBG("Prompt: %s\n", prompt_str.c_str()); // If not using chat templates, extract files from image_data/audio_data fields // (If using chat templates, files were already extracted by oaicompat_chat_params_parse) @@ -871,18 +870,33 @@ public: tasks.push_back(std::move(task)); } - task_ids = server_task::get_list_id(tasks); - ctx_server.queue_results.add_waiting_tasks(tasks); - ctx_server.queue_tasks.post(std::move(tasks)); + // Post tasks using server_response_reader + rd.post_tasks(std::move(tasks)); } catch (const std::exception & e) { return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, e.what()); } - ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool { + // Stream results using server_response_reader + while (rd.has_next()) { // Check if context is cancelled before processing result if (context->IsCancelled()) { - ctx_server.cancel_tasks(task_ids); - return false; + rd.stop(); + return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client"); + } + + auto result = rd.next([&context]() { return context->IsCancelled(); }); + if (result == nullptr) { + // Connection closed or cancelled + break; + } + + if (result->is_error()) { + backend::Reply reply; + json error_json = result->to_json(); + reply.set_message(error_json.value("message", "Error occurred")); + writer->Write(reply); + rd.stop(); + break; } json res_json = result->to_json(); @@ -904,8 +918,6 @@ public: reply.set_timing_token_generation(timing_token_generation); } - // Log Request Correlation Id - // Send the reply writer->Write(reply); } @@ -926,24 +938,10 @@ public: reply.set_timing_token_generation(timing_token_generation); } - - // Send the reply - writer->Write(reply); - + writer->Write(reply); } - return true; - }, [&](const json & error_data) { - backend::Reply reply; - reply.set_message(error_data.value("content", "")); - writer->Write(reply); - return true; - }, [&context]() { - // Check if the gRPC context is cancelled - return context->IsCancelled(); - }); - - ctx_server.queue_results.remove_waiting_task_ids(task_ids); + } // Check if context was cancelled during processing if (context->IsCancelled()) { @@ -964,6 +962,7 @@ public: std::cout << "[PREDICT] Received result: " << data.dump(2) << std::endl; auto completion_id = gen_chatcmplid(); std::unordered_set task_ids; + server_response_reader rd(ctx_server); try { std::vector tasks; @@ -1195,10 +1194,9 @@ public: } } - const auto & prompt = prompt_str; const auto type = SERVER_TASK_TYPE_COMPLETION; // TODO: this log can become very long, put it behind a flag or think about a more compact format - //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); + //SRV_DBG("Prompt: %s\n", prompt_str.c_str()); // If not using chat templates, extract files from image_data/audio_data fields // (If using chat templates, files were already extracted by oaicompat_chat_params_parse) @@ -1261,61 +1259,57 @@ public: tasks.push_back(std::move(task)); } - task_ids = server_task::get_list_id(tasks); - ctx_server.queue_results.add_waiting_tasks(tasks); - ctx_server.queue_tasks.post(std::move(tasks)); + // Post tasks using server_response_reader + rd.post_tasks(std::move(tasks)); } catch (const std::exception & e) { return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, e.what()); } - std::cout << "[DEBUG] Waiting for results..." << std::endl; - + // Check cancellation before waiting for results if (context->IsCancelled()) { - ctx_server.cancel_tasks(task_ids); - ctx_server.queue_results.remove_waiting_task_ids(task_ids); + rd.stop(); return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client"); } - ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { - std::cout << "[DEBUG] Received " << results.size() << " results" << std::endl; - if (results.size() == 1) { - // single result - reply->set_message(results[0]->to_json().value("content", "")); + auto all_results = rd.wait_for_all([&context]() { return context->IsCancelled(); }); + + if (all_results.is_terminated) { + return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client"); + } + + if (all_results.error) { + std::cout << "[DEBUG] Error in results: " << all_results.error->to_json().value("message", "") << std::endl; + reply->set_message(all_results.error->to_json().value("message", "")); + return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results"); + } - int32_t tokens_predicted = results[0]->to_json().value("tokens_predicted", 0); - reply->set_tokens(tokens_predicted); - int32_t tokens_evaluated = results[0]->to_json().value("tokens_evaluated", 0); - reply->set_prompt_tokens(tokens_evaluated); + std::cout << "[DEBUG] Received " << all_results.results.size() << " results" << std::endl; + if (all_results.results.size() == 1) { + // single result + reply->set_message(all_results.results[0]->to_json().value("content", "")); - if (results[0]->to_json().contains("timings")) { - double timing_prompt_processing = results[0]->to_json().at("timings").value("prompt_ms", 0.0); - reply->set_timing_prompt_processing(timing_prompt_processing); - double timing_token_generation = results[0]->to_json().at("timings").value("predicted_ms", 0.0); - reply->set_timing_token_generation(timing_token_generation); - } + int32_t tokens_predicted = all_results.results[0]->to_json().value("tokens_predicted", 0); + reply->set_tokens(tokens_predicted); + int32_t tokens_evaluated = all_results.results[0]->to_json().value("tokens_evaluated", 0); + reply->set_prompt_tokens(tokens_evaluated); - } else { - // multiple results (multitask) - json arr = json::array(); - for (auto & res : results) { - arr.push_back(res->to_json().value("content", "")); - } - reply->set_message(arr); + if (all_results.results[0]->to_json().contains("timings")) { + double timing_prompt_processing = all_results.results[0]->to_json().at("timings").value("prompt_ms", 0.0); + reply->set_timing_prompt_processing(timing_prompt_processing); + double timing_token_generation = all_results.results[0]->to_json().at("timings").value("predicted_ms", 0.0); + reply->set_timing_token_generation(timing_token_generation); } - - }, [&](const json & error_data) { - std::cout << "[DEBUG] Error in results: " << error_data.value("content", "") << std::endl; - reply->set_message(error_data.value("content", "")); - }, [&context]() { - // Check if the gRPC context is cancelled - // This is checked every HTTP_POLLING_SECONDS (1 second) during receive_multi_results - return context->IsCancelled(); - }); - - ctx_server.queue_results.remove_waiting_task_ids(task_ids); + } else { + // multiple results (multitask) + json arr = json::array(); + for (auto & res : all_results.results) { + arr.push_back(res->to_json().value("content", "")); + } + reply->set_message(arr); + } std::cout << "[DEBUG] Predict request completed successfully" << std::endl; // Check if context was cancelled during processing @@ -1354,7 +1348,6 @@ public: // create and queue the task json responses = json::array(); bool error = false; - std::unordered_set task_ids; { std::vector tasks; for (size_t i = 0; i < tokenized_prompts.size(); i++) { @@ -1369,32 +1362,32 @@ public: tasks.push_back(std::move(task)); } - task_ids = server_task::get_list_id(tasks); - ctx_server.queue_results.add_waiting_tasks(tasks); - ctx_server.queue_tasks.post(std::move(tasks)); - } + // Use server_response_reader for batch results + server_response_reader rd(ctx_server); + rd.post_tasks(std::move(tasks)); - // Check cancellation before waiting for results - if (context->IsCancelled()) { - ctx_server.cancel_tasks(task_ids); - ctx_server.queue_results.remove_waiting_task_ids(task_ids); - return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client"); - } - - // get the result - ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { - for (auto & res : results) { - GGML_ASSERT(dynamic_cast(res.get()) != nullptr); - responses.push_back(res->to_json()); + // Check cancellation before waiting for results + if (context->IsCancelled()) { + rd.stop(); + return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client"); } - }, [&](const json & error_data) { - error = true; - }, [&context]() { - // Check if the gRPC context is cancelled - return context->IsCancelled(); - }); - ctx_server.queue_results.remove_waiting_task_ids(task_ids); + // get the result + auto all_results = rd.wait_for_all([&context]() { return context->IsCancelled(); }); + + if (all_results.is_terminated) { + return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client"); + } + + if (all_results.error) { + error = true; + } else { + for (auto & res : all_results.results) { + GGML_ASSERT(dynamic_cast(res.get()) != nullptr); + responses.push_back(res->to_json()); + } + } + } // Check if context was cancelled during processing if (context->IsCancelled()) { @@ -1455,7 +1448,6 @@ public: // Create and queue the task json responses = json::array(); bool error = false; - std::unordered_set task_ids; { std::vector tasks; std::vector documents; @@ -1473,32 +1465,32 @@ public: tasks.push_back(std::move(task)); } - task_ids = server_task::get_list_id(tasks); - ctx_server.queue_results.add_waiting_tasks(tasks); - ctx_server.queue_tasks.post(std::move(tasks)); - } + // Use server_response_reader for batch results + server_response_reader rd(ctx_server); + rd.post_tasks(std::move(tasks)); - // Check cancellation before waiting for results - if (context->IsCancelled()) { - ctx_server.cancel_tasks(task_ids); - ctx_server.queue_results.remove_waiting_task_ids(task_ids); - return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client"); - } - - // Get the results - ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { - for (auto & res : results) { - GGML_ASSERT(dynamic_cast(res.get()) != nullptr); - responses.push_back(res->to_json()); + // Check cancellation before waiting for results + if (context->IsCancelled()) { + rd.stop(); + return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client"); } - }, [&](const json & error_data) { - error = true; - }, [&context]() { - // Check if the gRPC context is cancelled - return context->IsCancelled(); - }); - ctx_server.queue_results.remove_waiting_task_ids(task_ids); + // Get the results + auto all_results = rd.wait_for_all([&context]() { return context->IsCancelled(); }); + + if (all_results.is_terminated) { + return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client"); + } + + if (all_results.error) { + error = true; + } else { + for (auto & res : all_results.results) { + GGML_ASSERT(dynamic_cast(res.get()) != nullptr); + responses.push_back(res->to_json()); + } + } + } // Check if context was cancelled during processing if (context->IsCancelled()) { @@ -1543,14 +1535,13 @@ public: return grpc::Status::OK; } - grpc::Status TokenizeString(ServerContext* context, const backend::PredictOptions* request, backend::TokenizationResponse* response) { + grpc::Status TokenizeString(ServerContext* /* context */, const backend::PredictOptions* request, backend::TokenizationResponse* response) { json body = parse_options(false, request, ctx_server); body["stream"] = false; json tokens_response = json::array(); if (body.count("prompt") != 0) { const bool add_special = json_value(body, "add_special", false); - const bool with_pieces = json_value(body, "with_pieces", false); llama_tokens tokens = tokenize_mixed(ctx_server.vocab, body.at("content"), add_special, true);