diff --git a/backend/cpp/llama-cpp/Makefile b/backend/cpp/llama-cpp/Makefile index 1bf630fd6..edb2d7a88 100644 --- a/backend/cpp/llama-cpp/Makefile +++ b/backend/cpp/llama-cpp/Makefile @@ -1,5 +1,5 @@ -LLAMA_VERSION?=ae9f8df77882716b1702df2bed8919499e64cc28 +LLAMA_VERSION?=480160d47297df43b43746294963476fc0a6e10f LLAMA_REPO?=https://github.com/ggerganov/llama.cpp CMAKE_ARGS?= diff --git a/backend/cpp/llama-cpp/grpc-server.cpp b/backend/cpp/llama-cpp/grpc-server.cpp index 1009d36fd..2e652cf9f 100644 --- a/backend/cpp/llama-cpp/grpc-server.cpp +++ b/backend/cpp/llama-cpp/grpc-server.cpp @@ -390,8 +390,9 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt // Initialize fit_params options (can be overridden by options) // fit_params: whether to auto-adjust params to fit device memory (default: true as in llama.cpp) params.fit_params = true; - // fit_params_target: target margin per device in bytes (default: 1GB) - params.fit_params_target = 1024 * 1024 * 1024; + // fit_params_target: target margin per device in bytes (default: 1GB per device) + // Initialize as vector with default value for all devices + params.fit_params_target = std::vector(llama_max_devices(), 1024 * 1024 * 1024); // fit_params_min_ctx: minimum context size for fit (default: 4096) params.fit_params_min_ctx = 4096; @@ -468,10 +469,28 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt } else if (!strcmp(optname, "fit_params_target") || !strcmp(optname, "fit_target")) { if (optval != NULL) { try { - // Value is in MiB, convert to bytes - params.fit_params_target = static_cast(std::stoi(optval_str)) * 1024 * 1024; + // Value is in MiB, can be comma-separated list for multiple devices + // Single value is broadcast across all devices + std::string arg_next = optval_str; + const std::regex regex{ R"([,/]+)" }; + std::sregex_token_iterator it{ arg_next.begin(), arg_next.end(), regex, -1 }; + std::vector split_arg{ it, {} }; + if (split_arg.size() >= llama_max_devices()) { + // Too many values provided + continue; + } + if (split_arg.size() == 1) { + // Single value: broadcast to all devices + size_t value_mib = std::stoul(split_arg[0]); + std::fill(params.fit_params_target.begin(), params.fit_params_target.end(), value_mib * 1024 * 1024); + } else { + // Multiple values: set per device + for (size_t i = 0; i < split_arg.size() && i < params.fit_params_target.size(); i++) { + params.fit_params_target[i] = std::stoul(split_arg[i]) * 1024 * 1024; + } + } } catch (const std::exception& e) { - // If conversion fails, keep default value (1GB) + // If conversion fails, keep default value (1GB per device) } } } else if (!strcmp(optname, "fit_params_min_ctx") || !strcmp(optname, "fit_ctx")) { @@ -686,13 +705,13 @@ private: public: BackendServiceImpl(server_context& ctx) : ctx_server(ctx) {} - grpc::Status Health(ServerContext* /*context*/, const backend::HealthMessage* /*request*/, backend::Reply* reply) { + grpc::Status Health(ServerContext* /*context*/, const backend::HealthMessage* /*request*/, backend::Reply* reply) override { // Implement Health RPC reply->set_message("OK"); return Status::OK; } - grpc::Status LoadModel(ServerContext* /*context*/, const backend::ModelOptions* request, backend::Result* result) { + grpc::Status LoadModel(ServerContext* /*context*/, const backend::ModelOptions* request, backend::Result* result) override { // Implement LoadModel RPC common_params params; params_parse(ctx_server, request, params); @@ -1492,7 +1511,7 @@ public: return grpc::Status::OK; } - grpc::Status Predict(ServerContext* context, const backend::PredictOptions* request, backend::Reply* reply) { + grpc::Status Predict(ServerContext* context, const backend::PredictOptions* request, backend::Reply* reply) override { if (params_base.model.path.empty()) { return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded"); } @@ -2163,7 +2182,7 @@ public: return grpc::Status::OK; } - grpc::Status Embedding(ServerContext* context, const backend::PredictOptions* request, backend::EmbeddingResult* embeddingResult) { + grpc::Status Embedding(ServerContext* context, const backend::PredictOptions* request, backend::EmbeddingResult* embeddingResult) override { if (params_base.model.path.empty()) { return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded"); } @@ -2258,7 +2277,7 @@ public: return grpc::Status::OK; } - grpc::Status Rerank(ServerContext* context, const backend::RerankRequest* request, backend::RerankResult* rerankResult) { + grpc::Status Rerank(ServerContext* context, const backend::RerankRequest* request, backend::RerankResult* rerankResult) override { if (!params_base.embedding || params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) { return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, "This server does not support reranking. Start it with `--reranking` and without `--embedding`"); } @@ -2344,7 +2363,7 @@ public: return grpc::Status::OK; } - grpc::Status TokenizeString(ServerContext* /*context*/, const backend::PredictOptions* request, backend::TokenizationResponse* response) { + grpc::Status TokenizeString(ServerContext* /*context*/, const backend::PredictOptions* request, backend::TokenizationResponse* response) override { if (params_base.model.path.empty()) { return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded"); } @@ -2367,7 +2386,7 @@ public: return grpc::Status::OK; } - grpc::Status GetMetrics(ServerContext* /*context*/, const backend::MetricsRequest* /*request*/, backend::MetricsResponse* response) { + grpc::Status GetMetrics(ServerContext* /*context*/, const backend::MetricsRequest* /*request*/, backend::MetricsResponse* response) override { // request slots data using task queue auto rd = ctx_server.get_response_reader();