From 1ed6b9e5ed57a6a73316d57b59fdc1ebef88923d Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Fri, 3 Apr 2026 21:38:41 +0000 Subject: [PATCH] fix(llama.cpp): correctly parse grpc header for bearer token auth Signed-off-by: Ettore Di Giacinto --- backend/cpp/llama-cpp/grpc-server.cpp | 97 ++++++++++++++------------- 1 file changed, 50 insertions(+), 47 deletions(-) diff --git a/backend/cpp/llama-cpp/grpc-server.cpp b/backend/cpp/llama-cpp/grpc-server.cpp index d9d5a5ca4..6017cb84a 100644 --- a/backend/cpp/llama-cpp/grpc-server.cpp +++ b/backend/cpp/llama-cpp/grpc-server.cpp @@ -40,45 +40,41 @@ using grpc::ServerBuilder; using grpc::ServerContext; using grpc::Status; -// gRPC bearer token auth via AuthMetadataProcessor for distributed mode. +// gRPC bearer token auth for distributed mode. // Reads LOCALAI_GRPC_AUTH_TOKEN from the environment. When set, rejects // requests without a matching "authorization: Bearer " metadata header. -class TokenAuthMetadataProcessor : public grpc::AuthMetadataProcessor { -public: - explicit TokenAuthMetadataProcessor(const std::string& token) : token_(token) {} - bool IsBlocking() const override { return false; } +// Cached auth token — empty means auth is disabled. +static std::string g_grpc_auth_token; - grpc::Status Process(const InputMetadata& auth_metadata, - grpc::AuthContext* /*context*/, - OutputMetadata* /*consumed_auth_metadata*/, - OutputMetadata* /*response_metadata*/) override { - auto it = auth_metadata.find("authorization"); - if (it != auth_metadata.end()) { - std::string expected = "Bearer " + token_; - std::string got(it->second.data(), it->second.size()); - // Constant-time comparison - if (expected.size() == got.size() && ct_memcmp(expected.data(), got.data(), expected.size()) == 0) { - return grpc::Status::OK; - } - } - return grpc::Status(grpc::StatusCode::UNAUTHENTICATED, "invalid token"); +// Minimal constant-time comparison (avoids OpenSSL dependency) +static int ct_memcmp(const void* a, const void* b, size_t n) { + const unsigned char* pa = static_cast(a); + const unsigned char* pb = static_cast(b); + unsigned char result = 0; + for (size_t i = 0; i < n; i++) { + result |= pa[i] ^ pb[i]; } + return result; +} -private: - std::string token_; - - // Minimal constant-time comparison (avoids OpenSSL dependency) - static int ct_memcmp(const void* a, const void* b, size_t n) { - const unsigned char* pa = static_cast(a); - const unsigned char* pb = static_cast(b); - unsigned char result = 0; - for (size_t i = 0; i < n; i++) { - result |= pa[i] ^ pb[i]; - } - return result; +// Returns OK when auth is disabled or the token matches. +static grpc::Status checkAuth(grpc::ServerContext* context) { + if (g_grpc_auth_token.empty()) { + return grpc::Status::OK; } -}; + auto metadata = context->client_metadata(); + auto it = metadata.find("authorization"); + if (it != metadata.end()) { + std::string expected = "Bearer " + g_grpc_auth_token; + std::string got(it->second.data(), it->second.size()); + if (expected.size() == got.size() && + ct_memcmp(expected.data(), got.data(), expected.size()) == 0) { + return grpc::Status::OK; + } + } + return grpc::Status(grpc::StatusCode::UNAUTHENTICATED, "invalid token"); +} // END LocalAI @@ -757,13 +753,17 @@ private: public: BackendServiceImpl(server_context& ctx) : ctx_server(ctx) {} - grpc::Status Health(ServerContext* /*context*/, const backend::HealthMessage* /*request*/, backend::Reply* reply) override { + grpc::Status Health(ServerContext* context, const backend::HealthMessage* /*request*/, backend::Reply* reply) override { + auto auth = checkAuth(context); + if (!auth.ok()) return auth; // Implement Health RPC reply->set_message("OK"); return Status::OK; } - grpc::Status LoadModel(ServerContext* /*context*/, const backend::ModelOptions* request, backend::Result* result) override { + grpc::Status LoadModel(ServerContext* context, const backend::ModelOptions* request, backend::Result* result) override { + auto auth = checkAuth(context); + if (!auth.ok()) return auth; // Implement LoadModel RPC common_params params; params_parse(ctx_server, request, params); @@ -962,6 +962,8 @@ public: } grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter* writer) override { + auto auth = checkAuth(context); + if (!auth.ok()) return auth; if (params_base.model.path.empty()) { return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded"); } @@ -1665,6 +1667,8 @@ public: } grpc::Status Predict(ServerContext* context, const backend::PredictOptions* request, backend::Reply* reply) override { + auto auth = checkAuth(context); + if (!auth.ok()) return auth; if (params_base.model.path.empty()) { return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded"); } @@ -2383,6 +2387,8 @@ public: } grpc::Status Embedding(ServerContext* context, const backend::PredictOptions* request, backend::EmbeddingResult* embeddingResult) override { + auto auth = checkAuth(context); + if (!auth.ok()) return auth; if (params_base.model.path.empty()) { return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded"); } @@ -2563,7 +2569,9 @@ public: return grpc::Status::OK; } - grpc::Status TokenizeString(ServerContext* /*context*/, const backend::PredictOptions* request, backend::TokenizationResponse* response) override { + grpc::Status TokenizeString(ServerContext* context, const backend::PredictOptions* request, backend::TokenizationResponse* response) override { + auto auth = checkAuth(context); + if (!auth.ok()) return auth; if (params_base.model.path.empty()) { return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded"); } @@ -2803,19 +2811,14 @@ int main(int argc, char** argv) { BackendServiceImpl service(ctx_server); ServerBuilder builder; - // Add bearer token auth via AuthMetadataProcessor if LOCALAI_GRPC_AUTH_TOKEN is set - const char* auth_token = std::getenv("LOCALAI_GRPC_AUTH_TOKEN"); - std::shared_ptr creds; - if (auth_token != nullptr && auth_token[0] != '\0') { - creds = grpc::InsecureServerCredentials(); - creds->SetAuthMetadataProcessor( - std::make_shared(auth_token)); - std::cout << "gRPC auth enabled via LOCALAI_GRPC_AUTH_TOKEN" << std::endl; - } else { - creds = grpc::InsecureServerCredentials(); - } + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); - builder.AddListeningPort(server_address, creds); + // Initialize bearer token auth if LOCALAI_GRPC_AUTH_TOKEN is set + const char* auth_token = std::getenv("LOCALAI_GRPC_AUTH_TOKEN"); + if (auth_token != nullptr && auth_token[0] != '\0') { + g_grpc_auth_token = auth_token; + std::cout << "gRPC auth enabled via LOCALAI_GRPC_AUTH_TOKEN" << std::endl; + } builder.RegisterService(&service); builder.SetMaxMessageSize(50 * 1024 * 1024); // 50MB builder.SetMaxSendMessageSize(50 * 1024 * 1024); // 50MB