Files
LocalAI/backend/cpp/llama-cpp/grpc-server.cpp
Ettore Di Giacinto 3728552e94 feat: import models via URI (#7245)
* feat: initial hook to install elements directly

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* WIP: ui changes

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Move HF api client to pkg

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Add simple importer for gguf files

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Add opcache

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* wire importers to CLI

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Add omitempty to config fields

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Fix tests

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Add MLX importer

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Small refactors to star to use HF for discovery

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Add tests

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Common preferences

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Add support to bare HF repos

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat(importer/llama.cpp): add support for mmproj files

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* add mmproj quants to common preferences

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Fix vlm usage in tokenizer mode with llama.cpp

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2025-11-12 20:48:56 +01:00

1669 lines
77 KiB
C++

// llama.cpp gRPC C++ backend server
//
// Ettore Di Giacinto <mudler@localai.io> and llama.cpp authors
//
// This is a gRPC server for llama.cpp compatible with the LocalAI proto
// Note: this is a re-adaptation of the original llama.cpp example/server.cpp for HTTP (https://github.com/ggerganov/llama.cpp/tree/master/examples/server),
// but modified to work with gRPC
//
#include "server.cpp"
// LocalAI
#include "backend.pb.h"
#include "backend.grpc.pb.h"
#include "common.h"
#include <getopt.h>
#include <grpcpp/ext/proto_server_reflection_plugin.h>
#include <grpcpp/grpcpp.h>
#include <grpcpp/health_check_service_interface.h>
#include <regex>
using grpc::Server;
using grpc::ServerBuilder;
using grpc::ServerContext;
using grpc::Status;
// END LocalAI
/////////////////////////////////
////////////////////////////////
//////// LOCALAI code starts below here
/////////////////////////////////
////////////////////////////////
bool loaded_model; // TODO: add a mutex for this, but happens only once loading the model
// Forward declarations
static void start_llama_server(server_context& ctx_server);
static json parse_options(bool streaming, const backend::PredictOptions* predict, const server_context& ctx_server);
static ggml_type kv_cache_type_from_str(const std::string & s);
static std::string get_all_kv_cache_types();
static void add_rpc_devices(std::string servers);
static void params_parse(server_context& ctx_server, const backend::ModelOptions* request, common_params & params);
static void start_llama_server(server_context& ctx_server) {
LOG_INF("%s: starting llama server\n", __func__);
LOG_INF("%s: waiting for model to be loaded\n", __func__);
// Wait for model to be loaded first
while (!loaded_model) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
ctx_server.init();
//state.store(SERVER_STATE_READY);
LOG_INF("%s: model loaded\n", __func__);
// print sample chat example to make it clear which template is used
// LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
// common_chat_templates_source(ctx_server.chat_templates.get()),
// common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params_base.use_jinja).c_str(), ctx_server.params_base.default_template_kwargs);
// Keep the chat templates initialized in load_model() so they can be used when UseTokenizerTemplate is enabled
// Templates will only be used conditionally in Predict/PredictStream when UseTokenizerTemplate is true and Messages are provided
ctx_server.queue_tasks.on_new_task([&ctx_server](server_task && task) {
ctx_server.process_single_task(std::move(task));
});
ctx_server.queue_tasks.on_update_slots([&ctx_server]() {
ctx_server.update_slots();
});
shutdown_handler = [&](int) {
// this will unblock start_loop()
ctx_server.queue_tasks.terminate();
};
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
struct sigaction sigint_action;
sigint_action.sa_handler = signal_handler;
sigemptyset (&sigint_action.sa_mask);
sigint_action.sa_flags = 0;
sigaction(SIGINT, &sigint_action, NULL);
sigaction(SIGTERM, &sigint_action, NULL);
#elif defined (_WIN32)
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
};
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
#endif
// this call blocks the main thread until queue_tasks.terminate() is called
ctx_server.queue_tasks.start_loop();
}
json parse_options(bool streaming, const backend::PredictOptions* predict, const server_context& ctx_server)
{
// Create now a json data from the prediction options instead
//
json data;
data["stream"] = streaming;
data["cache_prompt"] = predict->promptcacheall();
data["n_predict"] = predict->tokens() == 0 ? -1 : predict->tokens();
data["top_k"] = predict->topk();
data["top_p"] = predict->topp();
data["typical_p"] = predict->typicalp();
data["temperature"] = predict->temperature();
data["repeat_last_n"] = predict->repeat();
data["repeat_penalty"] = predict->penalty();
data["frequency_penalty"] = predict->frequencypenalty();
data["presence_penalty"] = predict->presencepenalty();
data["mirostat"] = predict->mirostat();
data["mirostat_tau"] = predict->mirostattau();
data["mirostat_eta"] = predict->mirostateta();
data["n_keep"] = predict->nkeep();
data["seed"] = predict->seed();
std::string grammar_str = predict->grammar();
if (!grammar_str.empty()) {
data["grammar"] = grammar_str;
SRV_INF("Using grammar: %s\n", grammar_str.c_str());
}
// Only set prompt if UseTokenizerTemplate is false or if no Messages are provided
// When UseTokenizerTemplate is true and Messages are provided, prompt will be set via chat templates in Predict/PredictStream
if (!predict->usetokenizertemplate() || predict->messages_size() == 0) {
data["prompt"] = predict->prompt();
}
// Extract tools and tool_choice from proto and add to data JSON
if (!predict->tools().empty()) {
try {
// Parse tools JSON string and add to data
json tools_json = json::parse(predict->tools());
data["tools"] = tools_json;
SRV_INF("Extracted tools from proto: %s\n", predict->tools().c_str());
} catch (const json::parse_error& e) {
SRV_WRN("Failed to parse tools JSON from proto: %s\n", e.what());
}
}
if (!predict->toolchoice().empty()) {
try {
// Parse tool_choice JSON string
json tool_choice_json = json::parse(predict->toolchoice());
// tool_choice can be a string ("auto", "none", "required") or an object
// Store it as-is (string or object) so we can convert object to "required" later when adding to body_json
if (tool_choice_json.is_string()) {
data["tool_choice"] = tool_choice_json.get<std::string>();
} else {
// Store object as-is so we can detect it later and convert to "required"
data["tool_choice"] = tool_choice_json;
}
SRV_INF("Extracted tool_choice from proto: %s\n", predict->toolchoice().c_str());
} catch (const json::parse_error& e) {
// If parsing fails, treat as string
data["tool_choice"] = predict->toolchoice();
SRV_INF("Extracted tool_choice as string: %s\n", predict->toolchoice().c_str());
}
}
data["ignore_eos"] = predict->ignoreeos();
data["embeddings"] = predict->embeddings();
// Add the correlationid to json data
data["correlation_id"] = predict->correlationid();
// for each image in the request, add the image data
//
for (int i = 0; i < predict->images_size(); i++) {
data["image_data"].push_back(json
{
{"id", i},
{"data", predict->images(i)},
});
}
// for each audio in the request, add the audio data
for (int i = 0; i < predict->audios_size(); i++) {
data["audio_data"].push_back(json
{
{"id", i},
{"data", predict->audios(i)},
});
}
data["stop"] = predict->stopprompts();
// data["n_probs"] = predict->nprobs();
//TODO: images,
// Serialize grammar triggers from server context to JSON array
if (!ctx_server.params_base.sampling.grammar_triggers.empty()) {
json grammar_triggers = json::array();
for (const auto& trigger : ctx_server.params_base.sampling.grammar_triggers) {
json trigger_json;
trigger_json["value"] = trigger.value;
// Always serialize as WORD type since upstream converts WORD to TOKEN internally
trigger_json["type"] = static_cast<int>(COMMON_GRAMMAR_TRIGGER_TYPE_WORD);
grammar_triggers.push_back(trigger_json);
}
data["grammar_triggers"] = grammar_triggers;
}
// Serialize preserved tokens from server context to JSON array
if (!ctx_server.params_base.sampling.preserved_tokens.empty()) {
json preserved_tokens = json::array();
for (const auto& token : ctx_server.params_base.sampling.preserved_tokens) {
preserved_tokens.push_back(common_token_to_piece(ctx_server.ctx, token));
}
data["preserved_tokens"] = preserved_tokens;
}
return data;
}
const std::vector<ggml_type> kv_cache_types = {
GGML_TYPE_F32,
GGML_TYPE_F16,
GGML_TYPE_BF16,
GGML_TYPE_Q8_0,
GGML_TYPE_Q4_0,
GGML_TYPE_Q4_1,
GGML_TYPE_IQ4_NL,
GGML_TYPE_Q5_0,
GGML_TYPE_Q5_1,
};
static ggml_type kv_cache_type_from_str(const std::string & s) {
for (const auto & type : kv_cache_types) {
if (ggml_type_name(type) == s) {
return type;
}
}
throw std::runtime_error("Unsupported cache type: " + s);
}
static std::string get_all_kv_cache_types() {
std::ostringstream msg;
for (const auto & type : kv_cache_types) {
msg << ggml_type_name(type) << (&type == &kv_cache_types.back() ? "" : ", ");
}
return msg.str();
}
// Adds an RPC server
// https://github.com/ggerganov/llama.cpp/compare/4dbc8b9cb71876e005724f4e8f73a3544646bcf5..3edfa7d3753c29e44b964c0ff424d2ea8d5fdee6
static void add_rpc_devices(std::string servers) {
auto rpc_servers = string_split<std::string>(servers, ',');
if (rpc_servers.empty()) {
throw std::invalid_argument("no RPC servers specified");
}
ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC");
if (!rpc_reg) {
throw std::invalid_argument("failed to find RPC backend");
}
typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint);
ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
if (!ggml_backend_rpc_add_device_fn) {
throw std::invalid_argument("failed to find RPC device add function");
}
for (const auto & server : rpc_servers) {
ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str());
if (dev) {
ggml_backend_device_register(dev);
} else {
throw std::invalid_argument("failed to register RPC device");
}
}
}
static void params_parse(server_context& ctx_server, const backend::ModelOptions* request,
common_params & params) {
// this is comparable to: https://github.com/ggerganov/llama.cpp/blob/d9b33fe95bd257b36c84ee5769cc048230067d6f/examples/server/server.cpp#L1809
params.model.path = request->modelfile();
if (!request->mmproj().empty()) {
// get the directory of modelfile
std::string model_dir = params.model.path.substr(0, params.model.path.find_last_of("/\\"));
params.mmproj.path = model_dir + "/"+ request->mmproj();
}
// params.model_alias ??
params.model_alias = request->modelfile();
if (!request->cachetypekey().empty()) {
params.cache_type_k = kv_cache_type_from_str(request->cachetypekey());
}
if (!request->cachetypevalue().empty()) {
params.cache_type_v = kv_cache_type_from_str(request->cachetypevalue());
}
params.n_ctx = request->contextsize();
//params.memory_f16 = request->f16memory();
params.cpuparams.n_threads = request->threads();
params.n_gpu_layers = request->ngpulayers();
params.n_batch = request->nbatch();
//params.verbosity = INT_MAX;
// Enable all debug logs by setting verbosity threshold to maximum
//common_log_set_verbosity_thold(INT_MAX);
params.n_ubatch = request->nbatch(); // fixes issue with reranking models being limited to 512 tokens (the default n_ubatch size); allows for setting the maximum input amount of tokens thereby avoiding this error "input is too large to process. increase the physical batch size"
// Initialize ctx_shift to false by default (can be overridden by options)
params.ctx_shift = false;
// Initialize cache_ram_mib to -1 by default (no limit, can be overridden by options)
params.cache_ram_mib = -1;
// Initialize n_parallel to 1 by default (can be overridden by options)
params.n_parallel = 1;
// Initialize grpc_servers to empty (can be overridden by options)
std::string grpc_servers_option = "";
// decode options. Options are in form optname:optvale, or if booleans only optname.
for (int i = 0; i < request->options_size(); i++) {
std::string opt = request->options(i);
char *optname = strtok(&opt[0], ":");
char *optval = strtok(NULL, ":");
if (optval == NULL) {
optval = "true";
}
if (!strcmp(optname, "context_shift")) {
if (!strcmp(optval, "true") || !strcmp(optval, "1") || !strcmp(optval, "yes") || !strcmp(optval, "on") || !strcmp(optval, "enabled")) {
params.ctx_shift = true;
} else if (!strcmp(optval, "false") || !strcmp(optval, "0") || !strcmp(optval, "no") || !strcmp(optval, "off") || !strcmp(optval, "disabled")) {
params.ctx_shift = false;
}
} else if (!strcmp(optname, "use_jinja") || !strcmp(optname, "jinja")) {
if (!strcmp(optval, "true") || !strcmp(optval, "1") || !strcmp(optval, "yes") || !strcmp(optval, "on") || !strcmp(optval, "enabled")) {
params.use_jinja = true;
} else if (!strcmp(optval, "false") || !strcmp(optval, "0") || !strcmp(optval, "no") || !strcmp(optval, "off") || !strcmp(optval, "disabled")) {
params.use_jinja = false;
}
} else if (!strcmp(optname, "cache_ram")) {
if (optval != NULL) {
try {
params.cache_ram_mib = std::stoi(optval);
} catch (const std::exception& e) {
// If conversion fails, keep default value (-1)
}
}
} else if (!strcmp(optname, "parallel") || !strcmp(optname, "n_parallel")) {
if (optval != NULL) {
try {
params.n_parallel = std::stoi(optval);
if (params.n_parallel > 1) {
params.cont_batching = true;
}
} catch (const std::exception& e) {
// If conversion fails, keep default value (1)
}
}
} else if (!strcmp(optname, "grpc_servers") || !strcmp(optname, "rpc_servers")) {
if (optval != NULL) {
grpc_servers_option = std::string(optval);
}
}
}
// Set params.n_parallel from environment variable if not set via options (fallback)
if (params.n_parallel == 1) {
const char *env_parallel = std::getenv("LLAMACPP_PARALLEL");
if (env_parallel != NULL) {
try {
params.n_parallel = std::stoi(env_parallel);
if (params.n_parallel > 1) {
params.cont_batching = true;
}
} catch (const std::exception& e) {
// If conversion fails, keep default value (1)
}
}
}
// Add RPC devices from option or environment variable (fallback)
if (!grpc_servers_option.empty()) {
add_rpc_devices(grpc_servers_option);
} else {
const char *llama_grpc_servers = std::getenv("LLAMACPP_GRPC_SERVERS");
if (llama_grpc_servers != NULL) {
add_rpc_devices(std::string(llama_grpc_servers));
}
}
// Add kv_overrides
if (request->overrides_size() > 0) {
for (int i = 0; i < request->overrides_size(); i++) {
string_parse_kv_override(request->overrides(i).c_str(), params.kv_overrides);
}
}
if (!params.kv_overrides.empty()) {
params.kv_overrides.emplace_back();
params.kv_overrides.back().key[0] = 0;
}
// TODO: Add yarn
if (!request->tensorsplit().empty()) {
std::string arg_next = request->tensorsplit();
// split string by , and /
const std::regex regex{ R"([,/]+)" };
std::sregex_token_iterator it{ arg_next.begin(), arg_next.end(), regex, -1 };
std::vector<std::string> split_arg{ it, {} };
GGML_ASSERT(split_arg.size() <= llama_max_devices());
for (size_t i_device = 0; i_device < llama_max_devices(); ++i_device) {
if (i_device < split_arg.size()) {
params.tensor_split[i_device] = std::stof(split_arg[i_device]);
}
else {
params.tensor_split[i_device] = 0.0f;
}
}
}
if (!request->maingpu().empty()) {
params.main_gpu = std::stoi(request->maingpu());
}
if (!request->loraadapter().empty() && !request->lorabase().empty()) {
float scale_factor = 1.0f;
if (request->lorascale() != 0.0f) {
scale_factor = request->lorascale();
}
// get the directory of modelfile
std::string model_dir = params.model.path.substr(0, params.model.path.find_last_of("/\\"));
params.lora_adapters.push_back({ model_dir + "/"+request->loraadapter(), scale_factor });
}
params.use_mlock = request->mlock();
params.use_mmap = request->mmap();
if (request->flashattention() == "on" || request->flashattention() == "enabled") {
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED;
} else if (request->flashattention() == "off" || request->flashattention() == "disabled") {
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
} else if (request->flashattention() == "auto") {
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
}
params.no_kv_offload = request->nokvoffload();
params.embedding = request->embeddings() || request->reranking();
if (request->reranking()) {
params.pooling_type = LLAMA_POOLING_TYPE_RANK;
}
if (request->ropescaling() == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; }
else if (request->ropescaling() == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; }
else if (request->ropescaling() == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; }
if ( request->yarnextfactor() != 0.0f ) {
params.yarn_ext_factor = request->yarnextfactor();
}
if ( request->yarnattnfactor() != 0.0f ) {
params.yarn_attn_factor = request->yarnattnfactor();
}
if ( request->yarnbetafast() != 0.0f ) {
params.yarn_beta_fast = request->yarnbetafast();
}
if ( request->yarnbetaslow() != 0.0f ) {
params.yarn_beta_slow = request->yarnbetaslow();
}
if ( request->ropefreqbase() != 0.0f ) {
params.rope_freq_base = request->ropefreqbase();
}
if ( request->ropefreqscale() != 0.0f ) {
params.rope_freq_scale = request->ropefreqscale();
}
if (request->grammartriggers_size() > 0) {
//params.sampling.grammar_lazy = true;
// Store grammar trigger words for processing after model is loaded
for (int i = 0; i < request->grammartriggers_size(); i++) {
const auto & word = request->grammartriggers(i).word();
common_grammar_trigger trigger;
trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_WORD;
trigger.value = word;
params.sampling.grammar_triggers.push_back(std::move(trigger));
}
}
}
// GRPC Server start
class BackendServiceImpl final : public backend::Backend::Service {
private:
server_context& ctx_server;
public:
BackendServiceImpl(server_context& ctx) : ctx_server(ctx) {}
grpc::Status Health(ServerContext* context, const backend::HealthMessage* request, backend::Reply* reply) {
// Implement Health RPC
reply->set_message("OK");
return Status::OK;
}
grpc::Status LoadModel(ServerContext* context, const backend::ModelOptions* request, backend::Result* result) {
// Implement LoadModel RPC
common_params params;
params_parse(ctx_server, request, params);
common_init();
// Ensure debug logs are enabled after common_init() sets up logging
common_log_set_verbosity_thold(params.verbosity);
llama_backend_init();
llama_numa_init(params.numa);
LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency());
LOG_INF("\n");
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
LOG_INF("\n");
// load the model
if (!ctx_server.load_model(params)) {
result->set_message("Failed loading model");
result->set_success(false);
return Status::CANCELLED;
}
// Process grammar triggers now that vocab is available
if (!params.sampling.grammar_triggers.empty()) {
std::vector<common_grammar_trigger> processed_triggers;
for (const auto& trigger : params.sampling.grammar_triggers) {
if (trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
auto ids = common_tokenize(ctx_server.vocab, trigger.value, /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
auto token = ids[0];
// Add the token to preserved_tokens if not already present
if (params.sampling.preserved_tokens.find(token) == params.sampling.preserved_tokens.end()) {
params.sampling.preserved_tokens.insert(token);
LOG_INF("Added grammar trigger token to preserved tokens: %d (`%s`)\n", token, trigger.value.c_str());
}
LOG_INF("Grammar trigger token: %d (`%s`)\n", token, trigger.value.c_str());
common_grammar_trigger processed_trigger;
processed_trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
processed_trigger.value = trigger.value;
processed_trigger.token = token;
processed_triggers.push_back(std::move(processed_trigger));
} else {
LOG_INF("Grammar trigger word: `%s`\n", trigger.value.c_str());
processed_triggers.push_back(trigger);
}
} else {
processed_triggers.push_back(trigger);
}
}
// Update the grammar triggers in params_base
ctx_server.params_base.sampling.grammar_triggers = std::move(processed_triggers);
// Also update preserved_tokens in params_base
ctx_server.params_base.sampling.preserved_tokens = params.sampling.preserved_tokens;
}
//ctx_server.init();
result->set_message("Loading succeeded");
result->set_success(true);
loaded_model = true;
ctx_server.slot_prompt_similarity = params.slot_prompt_similarity;
return Status::OK;
}
grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter<backend::Reply>* writer) override {
json data = parse_options(true, request, ctx_server);
//Raise error if embeddings is set to true
if (ctx_server.params_base.embedding) {
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Embedding is not supported in streaming mode");
}
auto completion_id = gen_chatcmplid();
std::unordered_set<int> task_ids;
try {
std::vector<server_task> tasks;
std::string prompt_str;
std::vector<raw_buffer> files; // Declare files early so it's accessible in both branches
// Handle chat templates when UseTokenizerTemplate is enabled and Messages are provided
if (request->usetokenizertemplate() && request->messages_size() > 0 && ctx_server.chat_templates != nullptr) {
// Convert proto Messages to JSON format compatible with oaicompat_chat_params_parse
json body_json;
json messages_json = json::array();
// Find the last user message index to attach images/audio to
int last_user_msg_idx = -1;
for (int i = request->messages_size() - 1; i >= 0; i--) {
if (request->messages(i).role() == "user") {
last_user_msg_idx = i;
break;
}
}
for (int i = 0; i < request->messages_size(); i++) {
const auto& msg = request->messages(i);
json msg_json;
msg_json["role"] = msg.role();
bool is_last_user_msg = (i == last_user_msg_idx);
bool has_images_or_audio = (request->images_size() > 0 || request->audios_size() > 0);
// Handle content - can be string, null, or array
// For multimodal content, we'll embed images/audio from separate fields
if (!msg.content().empty()) {
// Try to parse content as JSON to see if it's already an array
json content_val;
try {
content_val = json::parse(msg.content());
} catch (const json::parse_error&) {
// Not JSON, treat as plain string
content_val = msg.content();
}
// If content is a string and this is the last user message with images/audio, combine them
if (content_val.is_string() && is_last_user_msg && has_images_or_audio) {
json content_array = json::array();
// Add text first
content_array.push_back({{"type", "text"}, {"text", content_val.get<std::string>()}});
// Add images
if (request->images_size() > 0) {
for (int j = 0; j < request->images_size(); j++) {
json image_chunk;
image_chunk["type"] = "image_url";
json image_url;
image_url["url"] = "data:image/jpeg;base64," + request->images(j);
image_chunk["image_url"] = image_url;
content_array.push_back(image_chunk);
}
}
// Add audios
if (request->audios_size() > 0) {
for (int j = 0; j < request->audios_size(); j++) {
json audio_chunk;
audio_chunk["type"] = "input_audio";
json input_audio;
input_audio["data"] = request->audios(j);
input_audio["format"] = "wav"; // default, could be made configurable
audio_chunk["input_audio"] = input_audio;
content_array.push_back(audio_chunk);
}
}
msg_json["content"] = content_array;
} else {
// Use content as-is (already array or not last user message)
msg_json["content"] = content_val;
}
} else if (is_last_user_msg && has_images_or_audio) {
// If no content but this is the last user message with images/audio, create content array
json content_array = json::array();
if (request->images_size() > 0) {
for (int j = 0; j < request->images_size(); j++) {
json image_chunk;
image_chunk["type"] = "image_url";
json image_url;
image_url["url"] = "data:image/jpeg;base64," + request->images(j);
image_chunk["image_url"] = image_url;
content_array.push_back(image_chunk);
}
}
if (request->audios_size() > 0) {
for (int j = 0; j < request->audios_size(); j++) {
json audio_chunk;
audio_chunk["type"] = "input_audio";
json input_audio;
input_audio["data"] = request->audios(j);
input_audio["format"] = "wav"; // default, could be made configurable
audio_chunk["input_audio"] = input_audio;
content_array.push_back(audio_chunk);
}
}
msg_json["content"] = content_array;
}
// Add optional fields for OpenAI-compatible message format
if (!msg.name().empty()) {
msg_json["name"] = msg.name();
}
if (!msg.tool_call_id().empty()) {
msg_json["tool_call_id"] = msg.tool_call_id();
}
if (!msg.reasoning_content().empty()) {
msg_json["reasoning_content"] = msg.reasoning_content();
}
if (!msg.tool_calls().empty()) {
// Parse tool_calls JSON string and add to message
try {
json tool_calls = json::parse(msg.tool_calls());
msg_json["tool_calls"] = tool_calls;
} catch (const json::parse_error& e) {
SRV_WRN("Failed to parse tool_calls JSON: %s\n", e.what());
}
}
messages_json.push_back(msg_json);
}
body_json["messages"] = messages_json;
body_json["stream"] = true; // PredictStream is always streaming
// Check if grammar is provided from Go layer (NoGrammar=false)
// If grammar is provided, we must use it and NOT let template generate grammar from tools
// oaicompat_chat_params_parse throws an error if both grammar and tools are provided
bool has_grammar_from_go = data.contains("grammar") &&
data["grammar"].is_string() &&
!data["grammar"].get<std::string>().empty();
// Copy other relevant fields from data that oaicompat_chat_params_parse expects
// Tools and tool_choice are only passed when NoGrammar is true (grammar not provided)
// When grammar is provided from Go layer, we use it instead of template-generated grammar
if (!has_grammar_from_go) {
// NoGrammar=true: pass tools and let template generate grammar
if (data.contains("tools")) {
body_json["tools"] = data["tools"];
std::string tools_str = data["tools"].dump();
SRV_INF("Using tools from data (NoGrammar=true): %s\n", tools_str.c_str());
} else {
SRV_WRN("%s", "No tools found in data - tool calls will not work without tools field\n");
}
if (data.contains("tool_choice")) {
// tool_choice can be a string or object, but oaicompat_chat_params_parse expects a string
// Convert object tool_choice to "required" (since a specific function is requested)
if (data["tool_choice"].is_string()) {
body_json["tool_choice"] = data["tool_choice"].get<std::string>();
} else if (data["tool_choice"].is_object()) {
// Object tool_choice means a specific function is requested, use "required"
body_json["tool_choice"] = "required";
std::string tool_choice_obj_str = data["tool_choice"].dump();
SRV_INF("Converted object tool_choice to 'required': %s\n", tool_choice_obj_str.c_str());
} else {
// Fallback: convert to string
body_json["tool_choice"] = data["tool_choice"].dump();
}
std::string tool_choice_str = body_json["tool_choice"].get<std::string>();
SRV_INF("Using tool_choice: %s\n", tool_choice_str.c_str());
} else {
// Default to "auto" if not specified
body_json["tool_choice"] = "auto";
}
} else {
// Grammar is provided from Go layer (NoGrammar=false) - use it, don't pass tools
SRV_INF("%s", "Grammar provided from Go layer - using it instead of template-generated grammar\n");
// Grammar will be copied from data after parsing (it's already in data)
}
if (data.contains("json_schema")) {
body_json["json_schema"] = data["json_schema"];
}
// If grammar is provided from Go layer, copy it to body_json so it's preserved
// (though oaicompat_chat_params_parse may not use it if tools are present)
if (has_grammar_from_go) {
body_json["grammar"] = data["grammar"];
}
if (data.contains("response_format")) {
body_json["response_format"] = data["response_format"];
}
if (data.contains("chat_template_kwargs")) {
body_json["chat_template_kwargs"] = data["chat_template_kwargs"];
}
// Use the same approach as server.cpp: call oaicompat_chat_params_parse
// This handles all template application, grammar merging, etc. automatically
// Files extracted from multimodal content in messages will be added to the files vector
// Create parser options with current chat_templates to ensure tmpls is not null
oaicompat_parser_options parser_opt = ctx_server.oai_parser_opt;
parser_opt.tmpls = ctx_server.chat_templates.get(); // Ensure tmpls is set to current chat_templates
// Update allow_image and allow_audio based on current mctx state
parser_opt.allow_image = ctx_server.mctx ? mtmd_support_vision(ctx_server.mctx) : false;
parser_opt.allow_audio = ctx_server.mctx ? mtmd_support_audio(ctx_server.mctx) : false;
json parsed_data = oaicompat_chat_params_parse(body_json, parser_opt, files);
// Extract the prompt from parsed data
prompt_str = parsed_data.at("prompt").get<std::string>();
// Preserve grammar from Go layer if it was provided (NoGrammar=false)
// Otherwise, use grammar from parsed_data (template-generated when NoGrammar=true)
json preserved_grammar;
if (has_grammar_from_go && data.contains("grammar")) {
preserved_grammar = data["grammar"];
}
// Merge all fields from parsed_data into data (grammar, grammar_triggers, preserved_tokens, etc.)
// This ensures all template-generated fields are included
for (const auto& item : parsed_data.items()) {
if (item.key() != "prompt") { // Don't overwrite prompt_str, we already extracted it
// If grammar was provided from Go layer, preserve it instead of template-generated grammar
if (item.key() == "grammar" && has_grammar_from_go && !preserved_grammar.is_null()) {
data["grammar"] = preserved_grammar;
} else {
data[item.key()] = item.value();
}
}
}
} else {
// Use prompt directly from data
if (data.contains("prompt") && data["prompt"].is_string()) {
prompt_str = data["prompt"].get<std::string>();
} else {
prompt_str = request->prompt();
}
}
const auto & prompt = prompt_str;
const auto type = SERVER_TASK_TYPE_COMPLETION;
// TODO: this log can become very long, put it behind a flag or think about a more compact format
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
// If not using chat templates, extract files from image_data/audio_data fields
// (If using chat templates, files were already extracted by oaicompat_chat_params_parse)
if (!request->usetokenizertemplate() || request->messages_size() == 0 || ctx_server.chat_templates == nullptr) {
const auto &images_data = data.find("image_data");
if (images_data != data.end() && images_data->is_array())
{
for (const auto &img : *images_data)
{
auto decoded_data = base64_decode(img["data"].get<std::string>());
files.push_back(decoded_data);
}
}
const auto &audio_data = data.find("audio_data");
if (audio_data != data.end() && audio_data->is_array())
{
for (const auto &audio : *audio_data)
{
auto decoded_data = base64_decode(audio["data"].get<std::string>());
files.push_back(decoded_data);
}
}
}
const bool has_mtmd = ctx_server.mctx != nullptr;
// process prompt
std::vector<server_tokens> inputs;
if (has_mtmd) {
// multimodal
inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt_str, files));
} else {
// Everything else, including multimodal completions.
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt_str, true, true);
}
tasks.reserve(inputs.size());
for (size_t i = 0; i < inputs.size(); i++) {
server_task task = server_task(type);
task.id = ctx_server.queue_tasks.get_new_id();
task.index = i;
task.tokens = std::move(inputs[i]);
task.params = server_task::params_from_json_cmpl(
ctx_server.ctx,
ctx_server.params_base,
data);
task.id_slot = json_value(data, "id_slot", -1);
// OAI-compat
task.params.oaicompat = OAICOMPAT_TYPE_NONE;
task.params.oaicompat_cmpl_id = completion_id;
// oaicompat_model is already populated by params_from_json_cmpl
tasks.push_back(std::move(task));
}
task_ids = server_task::get_list_id(tasks);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(std::move(tasks));
} catch (const std::exception & e) {
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, e.what());
}
ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool {
// Check if context is cancelled before processing result
if (context->IsCancelled()) {
ctx_server.cancel_tasks(task_ids);
return false;
}
json res_json = result->to_json();
if (res_json.is_array()) {
for (const auto & res : res_json) {
std::string completion_text = res.value("content", "");
backend::Reply reply;
reply.set_message(completion_text);
int32_t tokens_predicted = res.value("tokens_predicted", 0);
reply.set_tokens(tokens_predicted);
int32_t tokens_evaluated = res.value("tokens_evaluated", 0);
reply.set_prompt_tokens(tokens_evaluated);
if (res.contains("timings")) {
double timing_prompt_processing = res.at("timings").value("prompt_ms", 0.0);
reply.set_timing_prompt_processing(timing_prompt_processing);
double timing_token_generation = res.at("timings").value("predicted_ms", 0.0);
reply.set_timing_token_generation(timing_token_generation);
}
// Log Request Correlation Id
// Send the reply
writer->Write(reply);
}
} else {
std::string completion_text = res_json.value("content", "");
backend::Reply reply;
reply.set_message(completion_text);
int32_t tokens_predicted = res_json.value("tokens_predicted", 0);
reply.set_tokens(tokens_predicted);
int32_t tokens_evaluated = res_json.value("tokens_evaluated", 0);
reply.set_prompt_tokens(tokens_evaluated);
if (res_json.contains("timings")) {
double timing_prompt_processing = res_json.at("timings").value("prompt_ms", 0.0);
reply.set_timing_prompt_processing(timing_prompt_processing);
double timing_token_generation = res_json.at("timings").value("predicted_ms", 0.0);
reply.set_timing_token_generation(timing_token_generation);
}
// Send the reply
writer->Write(reply);
}
return true;
}, [&](const json & error_data) {
backend::Reply reply;
reply.set_message(error_data.value("content", ""));
writer->Write(reply);
return true;
}, [&context]() {
// Check if the gRPC context is cancelled
return context->IsCancelled();
});
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
// Check if context was cancelled during processing
if (context->IsCancelled()) {
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
}
return grpc::Status::OK;
}
grpc::Status Predict(ServerContext* context, const backend::PredictOptions* request, backend::Reply* reply) {
json data = parse_options(true, request, ctx_server);
data["stream"] = false;
//Raise error if embeddings is set to true
if (ctx_server.params_base.embedding) {
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Embedding is not supported in Predict mode");
}
std::cout << "[PREDICT] Received result: " << data.dump(2) << std::endl;
auto completion_id = gen_chatcmplid();
std::unordered_set<int> task_ids;
try {
std::vector<server_task> tasks;
std::string prompt_str;
std::vector<raw_buffer> files; // Declare files early so it's accessible in both branches
// Handle chat templates when UseTokenizerTemplate is enabled and Messages are provided
if (request->usetokenizertemplate() && request->messages_size() > 0 && ctx_server.chat_templates != nullptr) {
// Convert proto Messages to JSON format compatible with oaicompat_chat_params_parse
json body_json;
json messages_json = json::array();
// Find the last user message index to attach images/audio to
int last_user_msg_idx = -1;
for (int i = request->messages_size() - 1; i >= 0; i--) {
if (request->messages(i).role() == "user") {
last_user_msg_idx = i;
break;
}
}
for (int i = 0; i < request->messages_size(); i++) {
const auto& msg = request->messages(i);
json msg_json;
msg_json["role"] = msg.role();
bool is_last_user_msg = (i == last_user_msg_idx);
bool has_images_or_audio = (request->images_size() > 0 || request->audios_size() > 0);
// Handle content - can be string, null, or array
// For multimodal content, we'll embed images/audio from separate fields
if (!msg.content().empty()) {
// Try to parse content as JSON to see if it's already an array
json content_val;
try {
content_val = json::parse(msg.content());
} catch (const json::parse_error&) {
// Not JSON, treat as plain string
content_val = msg.content();
}
// If content is a string and this is the last user message with images/audio, combine them
if (content_val.is_string() && is_last_user_msg && has_images_or_audio) {
json content_array = json::array();
// Add text first
content_array.push_back({{"type", "text"}, {"text", content_val.get<std::string>()}});
// Add images
if (request->images_size() > 0) {
for (int j = 0; j < request->images_size(); j++) {
json image_chunk;
image_chunk["type"] = "image_url";
json image_url;
image_url["url"] = "data:image/jpeg;base64," + request->images(j);
image_chunk["image_url"] = image_url;
content_array.push_back(image_chunk);
}
}
// Add audios
if (request->audios_size() > 0) {
for (int j = 0; j < request->audios_size(); j++) {
json audio_chunk;
audio_chunk["type"] = "input_audio";
json input_audio;
input_audio["data"] = request->audios(j);
input_audio["format"] = "wav"; // default, could be made configurable
audio_chunk["input_audio"] = input_audio;
content_array.push_back(audio_chunk);
}
}
msg_json["content"] = content_array;
} else {
// Use content as-is (already array or not last user message)
msg_json["content"] = content_val;
}
} else if (is_last_user_msg && has_images_or_audio) {
// If no content but this is the last user message with images/audio, create content array
json content_array = json::array();
if (request->images_size() > 0) {
for (int j = 0; j < request->images_size(); j++) {
json image_chunk;
image_chunk["type"] = "image_url";
json image_url;
image_url["url"] = "data:image/jpeg;base64," + request->images(j);
image_chunk["image_url"] = image_url;
content_array.push_back(image_chunk);
}
}
if (request->audios_size() > 0) {
for (int j = 0; j < request->audios_size(); j++) {
json audio_chunk;
audio_chunk["type"] = "input_audio";
json input_audio;
input_audio["data"] = request->audios(j);
input_audio["format"] = "wav"; // default, could be made configurable
audio_chunk["input_audio"] = input_audio;
content_array.push_back(audio_chunk);
}
}
msg_json["content"] = content_array;
} else if (!msg.tool_calls().empty()) {
// Tool call messages may have null content
msg_json["content"] = json();
}
// Add optional fields for OpenAI-compatible message format
if (!msg.name().empty()) {
msg_json["name"] = msg.name();
}
if (!msg.tool_call_id().empty()) {
msg_json["tool_call_id"] = msg.tool_call_id();
}
if (!msg.reasoning_content().empty()) {
msg_json["reasoning_content"] = msg.reasoning_content();
}
if (!msg.tool_calls().empty()) {
// Parse tool_calls JSON string and add to message
try {
json tool_calls = json::parse(msg.tool_calls());
msg_json["tool_calls"] = tool_calls;
} catch (const json::parse_error& e) {
SRV_WRN("Failed to parse tool_calls JSON: %s\n", e.what());
}
}
messages_json.push_back(msg_json);
}
body_json["messages"] = messages_json;
body_json["stream"] = false;
// Check if grammar is provided from Go layer (NoGrammar=false)
// If grammar is provided, we must use it and NOT let template generate grammar from tools
// oaicompat_chat_params_parse throws an error if both grammar and tools are provided
bool has_grammar_from_go = data.contains("grammar") &&
data["grammar"].is_string() &&
!data["grammar"].get<std::string>().empty();
// Copy other relevant fields from data that oaicompat_chat_params_parse expects
// Tools and tool_choice are only passed when NoGrammar is true (grammar not provided)
// When grammar is provided from Go layer, we use it instead of template-generated grammar
if (!has_grammar_from_go) {
// NoGrammar=true: pass tools and let template generate grammar
if (data.contains("tools")) {
body_json["tools"] = data["tools"];
std::string tools_str = data["tools"].dump();
SRV_INF("Using tools from data (NoGrammar=true): %s\n", tools_str.c_str());
} else {
SRV_WRN("%s", "No tools found in data - tool calls will not work without tools field\n");
}
if (data.contains("tool_choice")) {
// tool_choice can be a string or object, but oaicompat_chat_params_parse expects a string
// Convert object tool_choice to "required" (since a specific function is requested)
if (data["tool_choice"].is_string()) {
body_json["tool_choice"] = data["tool_choice"].get<std::string>();
} else if (data["tool_choice"].is_object()) {
// Object tool_choice means a specific function is requested, use "required"
body_json["tool_choice"] = "required";
std::string tool_choice_obj_str = data["tool_choice"].dump();
SRV_INF("Converted object tool_choice to 'required': %s\n", tool_choice_obj_str.c_str());
} else {
// Fallback: convert to string
body_json["tool_choice"] = data["tool_choice"].dump();
}
std::string tool_choice_str = body_json["tool_choice"].get<std::string>();
SRV_INF("Using tool_choice: %s\n", tool_choice_str.c_str());
} else {
// Default to "auto" if not specified
body_json["tool_choice"] = "auto";
}
} else {
// Grammar is provided from Go layer (NoGrammar=false) - use it, don't pass tools
SRV_INF("%s", "Grammar provided from Go layer - using it instead of template-generated grammar\n");
// Grammar will be copied from data after parsing (it's already in data)
}
if (data.contains("json_schema")) {
body_json["json_schema"] = data["json_schema"];
}
// If grammar is provided from Go layer, copy it to body_json so it's preserved
// (though oaicompat_chat_params_parse may not use it if tools are present)
if (has_grammar_from_go) {
body_json["grammar"] = data["grammar"];
}
if (data.contains("response_format")) {
body_json["response_format"] = data["response_format"];
}
if (data.contains("chat_template_kwargs")) {
body_json["chat_template_kwargs"] = data["chat_template_kwargs"];
}
// Use the same approach as server.cpp: call oaicompat_chat_params_parse
// This handles all template application, grammar merging, etc. automatically
// Files extracted from multimodal content in messages will be added to the files vector
// Create parser options with current chat_templates to ensure tmpls is not null
oaicompat_parser_options parser_opt = ctx_server.oai_parser_opt;
parser_opt.tmpls = ctx_server.chat_templates.get(); // Ensure tmpls is set to current chat_templates
// Update allow_image and allow_audio based on current mctx state
parser_opt.allow_image = ctx_server.mctx ? mtmd_support_vision(ctx_server.mctx) : false;
parser_opt.allow_audio = ctx_server.mctx ? mtmd_support_audio(ctx_server.mctx) : false;
json parsed_data = oaicompat_chat_params_parse(body_json, parser_opt, files);
// Extract the prompt from parsed data
prompt_str = parsed_data.at("prompt").get<std::string>();
// Preserve grammar from Go layer if it was provided (NoGrammar=false)
// Otherwise, use grammar from parsed_data (template-generated when NoGrammar=true)
json preserved_grammar;
if (has_grammar_from_go && data.contains("grammar")) {
preserved_grammar = data["grammar"];
}
// Merge all fields from parsed_data into data (grammar, grammar_triggers, preserved_tokens, etc.)
// This ensures all template-generated fields are included
for (const auto& item : parsed_data.items()) {
if (item.key() != "prompt") { // Don't overwrite prompt_str, we already extracted it
// If grammar was provided from Go layer, preserve it instead of template-generated grammar
if (item.key() == "grammar" && has_grammar_from_go && !preserved_grammar.is_null()) {
data["grammar"] = preserved_grammar;
} else {
data[item.key()] = item.value();
}
}
}
} else {
// Use prompt directly from data
if (data.contains("prompt") && data["prompt"].is_string()) {
prompt_str = data["prompt"].get<std::string>();
} else {
prompt_str = request->prompt();
}
}
const auto & prompt = prompt_str;
const auto type = SERVER_TASK_TYPE_COMPLETION;
// TODO: this log can become very long, put it behind a flag or think about a more compact format
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
// If not using chat templates, extract files from image_data/audio_data fields
// (If using chat templates, files were already extracted by oaicompat_chat_params_parse)
if (!request->usetokenizertemplate() || request->messages_size() == 0 || ctx_server.chat_templates == nullptr) {
const auto &images_data = data.find("image_data");
if (images_data != data.end() && images_data->is_array())
{
std::cout << "[PREDICT] Processing " << images_data->size() << " images" << std::endl;
for (const auto &img : *images_data)
{
std::cout << "[PREDICT] Processing image" << std::endl;
auto decoded_data = base64_decode(img["data"].get<std::string>());
files.push_back(decoded_data);
}
}
const auto &audio_data = data.find("audio_data");
if (audio_data != data.end() && audio_data->is_array())
{
for (const auto &audio : *audio_data)
{
auto decoded_data = base64_decode(audio["data"].get<std::string>());
files.push_back(decoded_data);
}
}
}
// process files
const bool has_mtmd = ctx_server.mctx != nullptr;
// process prompt
std::vector<server_tokens> inputs;
if (has_mtmd) {
// multimodal
inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt_str, files));
} else {
// Everything else, including multimodal completions.
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt_str, true, true);
}
tasks.reserve(inputs.size());
for (size_t i = 0; i < inputs.size(); i++) {
server_task task = server_task(type);
task.id = ctx_server.queue_tasks.get_new_id();
task.index = i;
task.tokens = std::move(inputs[i]);
task.params = server_task::params_from_json_cmpl(
ctx_server.ctx,
ctx_server.params_base,
data);
task.id_slot = json_value(data, "id_slot", -1);
// OAI-compat
task.params.oaicompat = OAICOMPAT_TYPE_NONE;
task.params.oaicompat_cmpl_id = completion_id;
// oaicompat_model is already populated by params_from_json_cmpl
tasks.push_back(std::move(task));
}
task_ids = server_task::get_list_id(tasks);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(std::move(tasks));
} catch (const std::exception & e) {
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, e.what());
}
std::cout << "[DEBUG] Waiting for results..." << std::endl;
// Check cancellation before waiting for results
if (context->IsCancelled()) {
ctx_server.cancel_tasks(task_ids);
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
}
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
std::cout << "[DEBUG] Received " << results.size() << " results" << std::endl;
if (results.size() == 1) {
// single result
reply->set_message(results[0]->to_json().value("content", ""));
int32_t tokens_predicted = results[0]->to_json().value("tokens_predicted", 0);
reply->set_tokens(tokens_predicted);
int32_t tokens_evaluated = results[0]->to_json().value("tokens_evaluated", 0);
reply->set_prompt_tokens(tokens_evaluated);
if (results[0]->to_json().contains("timings")) {
double timing_prompt_processing = results[0]->to_json().at("timings").value("prompt_ms", 0.0);
reply->set_timing_prompt_processing(timing_prompt_processing);
double timing_token_generation = results[0]->to_json().at("timings").value("predicted_ms", 0.0);
reply->set_timing_token_generation(timing_token_generation);
}
} else {
// multiple results (multitask)
json arr = json::array();
for (auto & res : results) {
arr.push_back(res->to_json().value("content", ""));
}
reply->set_message(arr);
}
}, [&](const json & error_data) {
std::cout << "[DEBUG] Error in results: " << error_data.value("content", "") << std::endl;
reply->set_message(error_data.value("content", ""));
}, [&context]() {
// Check if the gRPC context is cancelled
// This is checked every HTTP_POLLING_SECONDS (1 second) during receive_multi_results
return context->IsCancelled();
});
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
std::cout << "[DEBUG] Predict request completed successfully" << std::endl;
// Check if context was cancelled during processing
if (context->IsCancelled()) {
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
}
return grpc::Status::OK;
}
grpc::Status Embedding(ServerContext* context, const backend::PredictOptions* request, backend::EmbeddingResult* embeddingResult) {
json body = parse_options(false, request, ctx_server);
body["stream"] = false;
/*
if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Pooling type 'none' is not OAI compatible. Please use a different pooling type");
}
*/
// for the shape of input/content, see tokenize_input_prompts()
json prompt = body.at("embeddings");
auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
for (const auto & tokens : tokenized_prompts) {
// this check is necessary for models that do not add BOS token to the input
if (tokens.empty()) {
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Input content cannot be empty");
}
}
int embd_normalize = 2; // default to Euclidean/L2 norm
// create and queue the task
json responses = json::array();
bool error = false;
std::unordered_set<int> task_ids;
{
std::vector<server_task> tasks;
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
task.id = ctx_server.queue_tasks.get_new_id();
task.index = i;
task.tokens = std::move(tokenized_prompts[i]);
task.params.oaicompat = OAICOMPAT_TYPE_NONE;
task.params.embd_normalize = embd_normalize;
tasks.push_back(std::move(task));
}
task_ids = server_task::get_list_id(tasks);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(std::move(tasks));
}
// Check cancellation before waiting for results
if (context->IsCancelled()) {
ctx_server.cancel_tasks(task_ids);
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
}
// get the result
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
for (auto & res : results) {
GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr);
responses.push_back(res->to_json());
}
}, [&](const json & error_data) {
error = true;
}, [&context]() {
// Check if the gRPC context is cancelled
return context->IsCancelled();
});
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
// Check if context was cancelled during processing
if (context->IsCancelled()) {
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
}
if (error) {
return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results");
}
std::cout << "[DEBUG] Responses size: " << responses.size() << std::endl;
// Process the responses and extract embeddings
for (const auto & response_elem : responses) {
// Check if the response has an "embedding" field
if (response_elem.contains("embedding")) {
json embedding_data = json_value(response_elem, "embedding", json::array());
if (embedding_data.is_array() && !embedding_data.empty()) {
for (const auto & embedding_vector : embedding_data) {
if (embedding_vector.is_array()) {
for (const auto & embedding_value : embedding_vector) {
embeddingResult->add_embeddings(embedding_value.get<float>());
}
}
}
}
} else {
// Check if the response itself contains the embedding data directly
if (response_elem.is_array()) {
for (const auto & embedding_value : response_elem) {
embeddingResult->add_embeddings(embedding_value.get<float>());
}
}
}
}
return grpc::Status::OK;
}
grpc::Status Rerank(ServerContext* context, const backend::RerankRequest* request, backend::RerankResult* rerankResult) {
if (!ctx_server.params_base.embedding || ctx_server.params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) {
return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, "This server does not support reranking. Start it with `--reranking` and without `--embedding`");
}
// Validate request
if (request->query().empty()) {
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "\"query\" must be provided");
}
if (request->documents_size() == 0) {
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "\"documents\" must be a non-empty string array");
}
// Create and queue the task
json responses = json::array();
bool error = false;
std::unordered_set<int> task_ids;
{
std::vector<server_task> tasks;
std::vector<std::string> documents;
for (int i = 0; i < request->documents_size(); i++) {
documents.push_back(request->documents(i));
}
tasks.reserve(documents.size());
for (size_t i = 0; i < documents.size(); i++) {
auto tmp = format_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, request->query(), documents[i]);
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
task.id = ctx_server.queue_tasks.get_new_id();
task.index = i;
task.tokens = std::move(tmp);
tasks.push_back(std::move(task));
}
task_ids = server_task::get_list_id(tasks);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(std::move(tasks));
}
// Check cancellation before waiting for results
if (context->IsCancelled()) {
ctx_server.cancel_tasks(task_ids);
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
}
// Get the results
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
for (auto & res : results) {
GGML_ASSERT(dynamic_cast<server_task_result_rerank*>(res.get()) != nullptr);
responses.push_back(res->to_json());
}
}, [&](const json & error_data) {
error = true;
}, [&context]() {
// Check if the gRPC context is cancelled
return context->IsCancelled();
});
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
// Check if context was cancelled during processing
if (context->IsCancelled()) {
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
}
if (error) {
return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results");
}
// Sort responses by score in descending order
std::sort(responses.begin(), responses.end(), [](const json& a, const json& b) {
return a.value("score", 0.0f) > b.value("score", 0.0f);
});
// Crop results by request.top_n if specified
int top_n = request->top_n();
if (top_n > 0 && top_n < static_cast<int>(responses.size())) {
responses = json(responses.begin(), responses.begin() + top_n);
}
// Set usage information
backend::Usage* usage = rerankResult->mutable_usage();
int total_tokens = 0;
int prompt_tokens = 0;
// Create document results
for (const auto& response : responses) {
backend::DocumentResult* doc_result = rerankResult->add_results();
doc_result->set_index(response.value("index", 0));
doc_result->set_text(request->documents(response.value("index", 0)));
doc_result->set_relevance_score(response.value("score", 0.0f));
// Add tokens evaluated for this document
int tokens_evaluated = response.value("tokens_evaluated", 0);
total_tokens += tokens_evaluated;
prompt_tokens += tokens_evaluated;
}
// Set the total tokens in usage
usage->set_total_tokens(total_tokens);
usage->set_prompt_tokens(prompt_tokens);
return grpc::Status::OK;
}
grpc::Status TokenizeString(ServerContext* context, const backend::PredictOptions* request, backend::TokenizationResponse* response) {
json body = parse_options(false, request, ctx_server);
body["stream"] = false;
json tokens_response = json::array();
if (body.count("prompt") != 0) {
const bool add_special = json_value(body, "add_special", false);
const bool with_pieces = json_value(body, "with_pieces", false);
llama_tokens tokens = tokenize_mixed(ctx_server.vocab, body.at("content"), add_special, true);
for (const auto& token : tokens) {
std::string piece = common_token_to_piece(ctx_server.ctx, token);
response->add_tokens(token);
}
}
return grpc::Status::OK;
}
grpc::Status GetMetrics(ServerContext* context, const backend::MetricsRequest* request, backend::MetricsResponse* response) {
// request slots data using task queue
int task_id = ctx_server.queue_tasks.get_new_id();
{
server_task task(SERVER_TASK_TYPE_METRICS);
task.id = task_id;
ctx_server.queue_results.add_waiting_task_id(task_id);
ctx_server.queue_tasks.post(std::move(task), true); // high-priority task
}
// get the result
server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
ctx_server.queue_results.remove_waiting_task_id(task_id);
if (result->is_error()) {
// Handle case when no active slot exists
response->set_slot_id(0);
response->set_prompt_json_for_slot("");
response->set_tokens_per_second(0);
response->set_tokens_generated(0);
response->set_prompt_tokens_processed(0);
return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results");
}
// TODO: get rid of this dynamic_cast
auto res_metrics = dynamic_cast<server_task_result_metrics*>(result.get());
GGML_ASSERT(res_metrics != nullptr);
// Populate the response with metrics
response->set_slot_id(0);
response->set_prompt_json_for_slot("");
response->set_tokens_per_second(res_metrics->n_prompt_tokens_processed ? 1.e3 / res_metrics->t_prompt_processing * res_metrics->n_prompt_tokens_processed : 0.);
response->set_tokens_generated(res_metrics->n_tokens_predicted_total);
response->set_prompt_tokens_processed(res_metrics->n_prompt_tokens_processed_total);
return grpc::Status::OK;
}
};
int main(int argc, char** argv) {
std::string server_address("localhost:50051");
// Define long and short options
struct option long_options[] = {
{"addr", required_argument, nullptr, 'a'},
{nullptr, 0, nullptr, 0}
};
// Parse command-line arguments
int option;
int option_index = 0;
while ((option = getopt_long(argc, argv, "a:", long_options, &option_index)) != -1) {
switch (option) {
case 'a':
server_address = optarg;
break;
default:
std::cerr << "Usage: " << argv[0] << " [--addr=<address>] or [-a <address>]" << std::endl;
return 1;
}
}
server_context ctx_server;
BackendServiceImpl service(ctx_server);
ServerBuilder builder;
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
builder.RegisterService(&service);
builder.SetMaxMessageSize(50 * 1024 * 1024); // 50MB
builder.SetMaxSendMessageSize(50 * 1024 * 1024); // 50MB
builder.SetMaxReceiveMessageSize(50 * 1024 * 1024); // 50MB
std::unique_ptr<Server> server(builder.BuildAndStart());
// run the HTTP server in a thread - see comment below
std::thread t([&]()
{
std::cout << "Server listening on " << server_address << std::endl;
server->Wait();
return 0;
});
// clean up function, to be called before exit
auto clean_up = [&server, &ctx_server]() {
SRV_INF("%s: cleaning up before exit...\n", __func__);
server->Shutdown();
ctx_server.queue_results.terminate();
llama_backend_free();
};
//);
start_llama_server(ctx_server);
std::cout << "stopping" << std::endl;
clean_up();
t.join();
return 0;
}