mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-31 12:07:45 -04:00
* feat(ds4): add standalone ds4-worker distributed worker binary Add worker_main.c, a minimal standalone worker that owns a slice of the model's transformer layers and serves activations over ds4's own TCP transport via ds4_dist_run(). It links the same engine objects the backend already builds (including ds4_distributed.o) and has NO gRPC/protobuf dependency, so it builds even on hosts lacking protobuf/grpc dev headers. Launched by `local-ai worker ds4-distributed`. Wire the ds4-worker CMake target (mirrors grpc-server's object/GPU/native handling) and have the Makefile copy + clean the binary alongside grpc-server. Ignore the built ds4-worker artifact. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Assisted-by: Claude:claude-opus-4-8 [Claude Code] * feat(ds4): package ds4-worker alongside grpc-server Copy the standalone ds4-worker binary into the backend package (Linux package.sh) and the Darwin OCI tar (ds4-darwin.sh: both the explicit copy and the otool dylib-bundling loop) so distributed workers ship with the backend. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Assisted-by: Claude:claude-opus-4-8 [Claude Code] * fix(ds4): tighten ds4-worker integer arg validation to match upstream Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Assisted-by: Claude:claude-opus-4-8 [Claude Code] * feat(ds4): wire grpc-server as distributed coordinator Add distributed COORDINATOR support to the ds4 backend's gRPC server. Distributed inference is an engine backend: when LoadModel receives 'ds4_role:coordinator', the process populates ds4_engine_options.distributed (role, layer slice, listen host/port) before ds4_engine_open, then the normal ds4_session_* generation path runs transparently once the worker route covers all layers. - New LoadModel options: ds4_role, ds4_layers (START:END or START:output), ds4_listen (host:port), ds4_route_timeout. - parse_layers_spec() maps the layer spec onto ds4_distributed_layers. - wait_route_ready() blocks generation until ds4_session_distributed_route_ready() reports full coverage (or timeout), gating both Predict and PredictStream; returns UNAVAILABLE on timeout/error. - No ds4_role => g_distributed stays false and wait_route_ready is a no-op, so single-node behavior is unchanged. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Assisted-by: Claude:claude-opus-4-8 [Claude Code] * fix(ds4): don't block Status during route wait; validate coordinator opts Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Assisted-by: Claude:claude-opus-4-8 [Claude Code] * feat(cli): add ds4-distributed worker exec helper Add the ds4WorkerArgs helper plus findDS4Backend/DS4Distributed.Run that resolve the ds4 backend via the gallery and exec the packaged ds4-worker binary. Unlike worker_llamacpp.go, ds4 bundles its own dynamic loader (lib/ld.so) for glibc compatibility, so when present we exec ds4-worker through that loader with LD_LIBRARY_PATH=<backend>/lib, mirroring backend/cpp/ds4/run.sh; otherwise we exec it directly. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Assisted-by: Claude:claude-opus-4-8 [Claude Code] * feat(cli): register the ds4-distributed worker subcommand Wire DS4Distributed into the Worker kong command tree so `local-ai worker ds4-distributed` is available. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Assisted-by: Claude:claude-opus-4-8 [Claude Code] * docs(ds4): document layer-split distributed inference Add a ds4 section to the distributed-mode feature docs (coordinator model YAML, manual worker command, layer-range semantics, the 'GGUF on every machine' requirement, coordinator-listens dial direction vs llama.cpp) and a terse Distributed mode section to the ds4 backend agent guide. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Assisted-by: Claude:claude-opus-4-8 [Claude Code] * test(ds4): opt-in hardware-gated distributed e2e spec Add a self-contained, opt-in Ginkgo spec to the backend e2e suite that spins a ds4 coordinator (via the packaged run.sh, loaded with ds4_role/ds4_layers/ds4_listen options) plus a ds4-worker process for the upper layers, then uses Eventually to assert a short successful Predict once the layer route forms, before tearing the worker down. Gated by BACKEND_TEST_DS4_DISTRIBUTED=1 (plus the existing BACKEND_BINARY + BACKEND_TEST_MODEL_FILE and optional layer/listen/accel knobs); compiles and skips cleanly with no env, hardware, or model. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Assisted-by: Claude:claude-opus-4-8 [Claude Code] * test(ds4): pass coordinator ctx to worker; lowercase error string Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Assisted-by: Claude:claude-opus-4-8 [Claude Code] * docs(ds4): note distributed transport is plaintext/unauthenticated Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Assisted-by: Claude:claude-opus-4-8 [Claude Code] * style(ds4): replace em dashes in distributed docs/agent/test per repo convention Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Assisted-by: Claude:claude-opus-4-8 [Claude Code] * fix(ds4): link ds4-worker with the C++ driver for CUDA/Metal builds The ds4-worker target is built from worker_main.c (C), so CMake linked it with the C driver. The nvcc-built ds4_cuda.o (and Obj-C++ ds4_metal.o) reference the C++ runtime, so the CUDA/Metal builds failed with undefined libstdc++ symbols (std::__throw_length_error). The CPU build passed because ds4_cpu.o is pure C. Force LINKER_LANGUAGE CXX so libstdc++ is linked. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Assisted-by: Claude:claude-opus-4-8 [Claude Code] --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
842 lines
33 KiB
C++
842 lines
33 KiB
C++
// ds4 LocalAI gRPC backend.
|
|
//
|
|
// Wraps antirez/ds4's `ds4_engine_*` / `ds4_session_*` public API
|
|
// (see ds4/ds4.h) over LocalAI's backend.proto. Tool calls, thinking
|
|
// mode, and disk KV cache are wired in follow-up commits; this commit
|
|
// is just the bind/listen/Health/Free skeleton.
|
|
|
|
#include "backend.pb.h"
|
|
#include "backend.grpc.pb.h"
|
|
|
|
#include "dsml_parser.h" // populated in Task 12
|
|
#include "dsml_renderer.h" // populated in Task 16
|
|
#include "kv_cache.h" // populated in Task 17
|
|
|
|
extern "C" {
|
|
#include "ds4.h"
|
|
}
|
|
|
|
#include <grpcpp/grpcpp.h>
|
|
#include <grpcpp/server.h>
|
|
#include <grpcpp/server_builder.h>
|
|
#include <grpcpp/ext/proto_server_reflection_plugin.h>
|
|
|
|
#include <atomic>
|
|
#include <chrono>
|
|
#include <climits>
|
|
#include <csignal>
|
|
#include <cstdlib>
|
|
#include <cstring>
|
|
#include <ctime>
|
|
#include <iostream>
|
|
#include <memory>
|
|
#include <mutex>
|
|
#include <string>
|
|
#include <thread>
|
|
#include <vector>
|
|
|
|
using grpc::Server;
|
|
using grpc::ServerBuilder;
|
|
using grpc::ServerContext;
|
|
using grpc::ServerWriter;
|
|
// NOTE: do NOT alias `grpc::Status` as `Status` - the Status RPC method below
|
|
// would shadow the type, breaking the other RPC method declarations that use
|
|
// it as a return type. Use GStatus instead.
|
|
using GStatus = ::grpc::Status;
|
|
using grpc::StatusCode;
|
|
|
|
namespace {
|
|
|
|
// Global state - ds4 is single-engine-per-process by design.
|
|
std::mutex g_engine_mu;
|
|
ds4_engine *g_engine = nullptr;
|
|
ds4_session *g_session = nullptr;
|
|
int g_ctx_size = 32768;
|
|
std::string g_kv_cache_dir; // empty disables disk cache
|
|
|
|
// Distributed coordinator state. g_distributed is set true when LoadModel is
|
|
// given 'ds4_role:coordinator'; generation then waits for the worker route to
|
|
// form before running. Single-node behavior is unchanged when unset.
|
|
bool g_distributed = false;
|
|
int g_route_timeout_sec = 60;
|
|
|
|
std::atomic<Server *> g_server{nullptr};
|
|
|
|
// Parse a "key:value" option string. Returns empty when no colon.
|
|
static std::pair<std::string, std::string> split_option(const std::string &opt) {
|
|
auto colon = opt.find(':');
|
|
if (colon == std::string::npos) return {opt, ""};
|
|
return {opt.substr(0, colon), opt.substr(colon + 1)};
|
|
}
|
|
|
|
// Parse a positive base-10 integer. Returns false (without throwing) on empty,
|
|
// trailing garbage, non-positive, or overflow - unlike std::stoi.
|
|
static bool parse_positive_int(const std::string &s, int *out) {
|
|
if (s.empty()) return false;
|
|
char *end = nullptr;
|
|
long v = std::strtol(s.c_str(), &end, 10);
|
|
if (!end || *end != '\0' || v <= 0 || v > INT_MAX) return false;
|
|
*out = static_cast<int>(v);
|
|
return true;
|
|
}
|
|
|
|
// Parse a ds4 layer spec "START:END" or "START:output" into the engine's
|
|
// distributed layer fields. Returns false on malformed input.
|
|
static bool parse_layers_spec(const std::string &spec, ds4_distributed_layers *out) {
|
|
auto colon = spec.find(':');
|
|
if (colon == std::string::npos) return false;
|
|
std::string lhs = spec.substr(0, colon);
|
|
std::string rhs = spec.substr(colon + 1);
|
|
if (lhs.empty() || rhs.empty()) return false;
|
|
char *end = nullptr;
|
|
long start = std::strtol(lhs.c_str(), &end, 10);
|
|
if (!end || *end != '\0' || start < 0) return false;
|
|
out->start = static_cast<uint32_t>(start);
|
|
out->has_output = false;
|
|
if (rhs == "output") {
|
|
out->has_output = true;
|
|
out->end = out->start; // engine treats has_output as "through final layer"
|
|
} else {
|
|
long e = std::strtol(rhs.c_str(), &end, 10);
|
|
if (!end || *end != '\0' || e < start) return false;
|
|
out->end = static_cast<uint32_t>(e);
|
|
}
|
|
out->set = true;
|
|
return true;
|
|
}
|
|
|
|
// When acting as a distributed coordinator, block until the worker route
|
|
// covers all layers (ds4_session_distributed_route_ready == 1) or the timeout
|
|
// elapses. Returns an empty string on success, or an error message to return
|
|
// to the client. No-op when not distributed.
|
|
//
|
|
// Takes the g_engine_mu lock by reference and RELEASES it during each poll
|
|
// sleep. The wait can span up to g_route_timeout_sec seconds while workers
|
|
// connect; holding g_engine_mu the whole time would block the Status/Health
|
|
// readiness probes (they also lock g_engine_mu), making LocalAI's loader treat
|
|
// a still-starting worker as hung.
|
|
static std::string wait_route_ready(std::unique_lock<std::mutex> &lock) {
|
|
if (!g_distributed) return "";
|
|
char err[256] = {0};
|
|
const int deadline_polls = g_route_timeout_sec * 10; // 100ms per poll
|
|
for (int i = 0; i <= deadline_polls; ++i) {
|
|
int ready = ds4_session_distributed_route_ready(g_session, err, sizeof(err));
|
|
if (ready == 1) return "";
|
|
if (ready < 0) {
|
|
return std::string("ds4 distributed route error: ") +
|
|
(err[0] ? err : "unknown");
|
|
}
|
|
// Release the lock while sleeping so Status/Health and other RPCs can
|
|
// interleave during worker startup.
|
|
lock.unlock();
|
|
struct timespec ts = {0, 100L * 1000L * 1000L}; // 100ms
|
|
nanosleep(&ts, nullptr);
|
|
lock.lock();
|
|
// A concurrent Free() may have torn down the engine while we slept.
|
|
if (!g_engine || !g_session) {
|
|
return "ds4: model unloaded while waiting for distributed route";
|
|
}
|
|
}
|
|
return "ds4 distributed route incomplete: workers not connected (layers uncovered)";
|
|
}
|
|
|
|
static void append_token_text(ds4_engine *engine, int token, std::string &out) {
|
|
size_t len = 0;
|
|
const char *text = ds4_token_text(engine, token, &len);
|
|
if (text && len > 0) out.append(text, len);
|
|
}
|
|
|
|
struct CollectCtx {
|
|
ds4_engine *engine;
|
|
std::string raw_buf; // exact raw bytes for Reply.message
|
|
ds4cpp::DsmlParser parser;
|
|
backend::Reply *reply;
|
|
int tokens;
|
|
|
|
// Per-tool aggregation: accumulate ChatDelta tool_calls so we emit one
|
|
// delta with all calls, mirroring how vllm's non-streaming path returns.
|
|
struct Pending {
|
|
std::string id;
|
|
std::string name;
|
|
std::string args;
|
|
};
|
|
std::vector<Pending> pending;
|
|
|
|
std::string content_buf;
|
|
std::string reasoning_buf;
|
|
};
|
|
|
|
static void apply_events(CollectCtx *c, const std::vector<ds4cpp::ParserEvent> &events) {
|
|
for (const auto &e : events) {
|
|
switch (e.type) {
|
|
case ds4cpp::ParserEvent::CONTENT:
|
|
c->content_buf += e.text;
|
|
break;
|
|
case ds4cpp::ParserEvent::REASONING:
|
|
c->reasoning_buf += e.text;
|
|
break;
|
|
case ds4cpp::ParserEvent::TOOL_START:
|
|
if ((int)c->pending.size() <= e.index)
|
|
c->pending.resize(e.index + 1);
|
|
c->pending[e.index].id = e.tool_id;
|
|
c->pending[e.index].name = e.tool_name;
|
|
break;
|
|
case ds4cpp::ParserEvent::TOOL_ARGS:
|
|
if ((int)c->pending.size() > e.index)
|
|
c->pending[e.index].args += e.text;
|
|
break;
|
|
case ds4cpp::ParserEvent::TOOL_END:
|
|
// No-op for non-streaming: the final delta is emitted at the end.
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
static void collect_emit(void *ud, int token) {
|
|
auto *c = static_cast<CollectCtx *>(ud);
|
|
if (token == ds4_token_eos(c->engine)) return;
|
|
size_t len = 0;
|
|
const char *text = ds4_token_text(c->engine, token, &len);
|
|
if (!text || len == 0) return;
|
|
std::string chunk(text, len);
|
|
c->raw_buf += chunk;
|
|
std::vector<ds4cpp::ParserEvent> events;
|
|
c->parser.Feed(chunk, events);
|
|
apply_events(c, events);
|
|
c->tokens++;
|
|
}
|
|
static void collect_done(void *) {}
|
|
|
|
struct StreamCtx {
|
|
ds4_engine *engine;
|
|
ServerWriter<backend::Reply> *writer;
|
|
ds4cpp::DsmlParser parser;
|
|
int tokens;
|
|
bool aborted;
|
|
// Track which tool indices we've seen TOOL_START for, so subsequent
|
|
// ARGS deltas can elide the redundant id/name fields.
|
|
std::vector<bool> tool_started;
|
|
};
|
|
|
|
static void stream_emit(void *ud, int token) {
|
|
auto *s = static_cast<StreamCtx *>(ud);
|
|
if (s->aborted) return;
|
|
if (token == ds4_token_eos(s->engine)) return;
|
|
size_t len = 0;
|
|
const char *text = ds4_token_text(s->engine, token, &len);
|
|
if (!text || len == 0) return;
|
|
std::string chunk(text, len);
|
|
std::vector<ds4cpp::ParserEvent> events;
|
|
s->parser.Feed(chunk, events);
|
|
if (events.empty()) { s->tokens++; return; }
|
|
|
|
backend::Reply reply;
|
|
auto *delta = reply.add_chat_deltas();
|
|
bool any_field = false;
|
|
for (const auto &e : events) {
|
|
switch (e.type) {
|
|
case ds4cpp::ParserEvent::CONTENT:
|
|
delta->set_content(delta->content() + e.text);
|
|
any_field = true;
|
|
break;
|
|
case ds4cpp::ParserEvent::REASONING:
|
|
delta->set_reasoning_content(delta->reasoning_content() + e.text);
|
|
any_field = true;
|
|
break;
|
|
case ds4cpp::ParserEvent::TOOL_START: {
|
|
if ((int)s->tool_started.size() <= e.index)
|
|
s->tool_started.resize(e.index + 1, false);
|
|
s->tool_started[e.index] = true;
|
|
auto *tc = delta->add_tool_calls();
|
|
tc->set_index(e.index);
|
|
tc->set_id(e.tool_id);
|
|
tc->set_name(e.tool_name);
|
|
any_field = true;
|
|
break;
|
|
}
|
|
case ds4cpp::ParserEvent::TOOL_ARGS: {
|
|
auto *tc = delta->add_tool_calls();
|
|
tc->set_index(e.index);
|
|
tc->set_arguments(e.text);
|
|
any_field = true;
|
|
break;
|
|
}
|
|
case ds4cpp::ParserEvent::TOOL_END:
|
|
// No marker delta needed - the Go side closes the tool call on
|
|
// the final aggregator pass.
|
|
break;
|
|
}
|
|
}
|
|
reply.set_message(chunk);
|
|
reply.set_tokens(1);
|
|
if (any_field) {
|
|
if (!s->writer->Write(reply)) s->aborted = true;
|
|
}
|
|
s->tokens++;
|
|
}
|
|
static void stream_done(void *) {}
|
|
|
|
// Per-thread RNG seed for ds4_session_sample. Initialized lazily from
|
|
// system_clock; ds4 owns the random walk after that.
|
|
static uint64_t *get_rng() {
|
|
static thread_local uint64_t seed = 0;
|
|
if (seed == 0) {
|
|
seed = static_cast<uint64_t>(
|
|
std::chrono::system_clock::now().time_since_epoch().count());
|
|
if (seed == 0) seed = 1;
|
|
}
|
|
return &seed;
|
|
}
|
|
|
|
struct SampleParams {
|
|
float temperature;
|
|
int top_k;
|
|
float top_p;
|
|
float min_p;
|
|
};
|
|
|
|
// Compute the effective sampling parameters for the next token, mirroring
|
|
// ds4_server.c:7102-7115:
|
|
// - thinking mode enabled -> override (T=1, top_k=0, top_p=1, min_p=0)
|
|
// - inside DSML structural position (tool-call markers) -> force T=0
|
|
// - otherwise -> the request's user-supplied sampling settings
|
|
// The parser argument carries state from tokens emitted so far; its
|
|
// IsInDsmlStructural() predicts the next token's classification.
|
|
static SampleParams compute_sample_params(const backend::PredictOptions *request,
|
|
const ds4cpp::DsmlParser &parser,
|
|
bool think_enabled);
|
|
|
|
static ds4_think_mode parse_think_mode(const backend::PredictOptions *request) {
|
|
// Per the vllm backend convention, "enable_thinking" gates thinking on/off,
|
|
// and "reasoning_effort" picks the strength when on.
|
|
const auto &md = request->metadata();
|
|
auto et = md.find("enable_thinking");
|
|
bool enabled = true; // default ON per ds4-server
|
|
if (et != md.end()) enabled = (et->second == "true" || et->second == "1");
|
|
if (!enabled) return DS4_THINK_NONE;
|
|
auto re = md.find("reasoning_effort");
|
|
if (re != md.end() && (re->second == "max" || re->second == "xhigh"))
|
|
return DS4_THINK_MAX;
|
|
return DS4_THINK_HIGH;
|
|
}
|
|
|
|
static SampleParams compute_sample_params(const backend::PredictOptions *request,
|
|
const ds4cpp::DsmlParser &parser,
|
|
bool think_enabled) {
|
|
SampleParams p = {
|
|
request->temperature(),
|
|
request->topk(),
|
|
request->topp(),
|
|
request->minp(),
|
|
};
|
|
if (think_enabled) {
|
|
// Match ds4-server: thinking mode wants creativity in the reasoning
|
|
// pass and the trailing content, so the entire generation overrides
|
|
// sampling unless DSML structural bytes take over below.
|
|
p.temperature = 1.0f;
|
|
p.top_k = 0;
|
|
p.top_p = 1.0f;
|
|
p.min_p = 0.0f;
|
|
}
|
|
if (parser.IsInDsmlStructural()) {
|
|
// Tool-call structural bytes (tags, markers, headers) must parse
|
|
// cleanly. Force greedy regardless of user/thinking settings.
|
|
p.temperature = 0.0f;
|
|
}
|
|
return p;
|
|
}
|
|
|
|
// Build the rendered text for cache keying. We feed the same text the model
|
|
// will see; that lets the cache survive small client-side reformatting of
|
|
// chat history (the cache is keyed on bytes, not tokens).
|
|
static std::string render_prompt_text(const backend::PredictOptions *request) {
|
|
// Two-mode: either the raw prompt or the chat-template path. We mirror
|
|
// build_prompt's branching but accumulate text (not tokens) so we can
|
|
// SHA1 it for the cache key. ds4_session caches a tokens-indexed
|
|
// checkpoint, but the disk format keys on bytes per ds4-server's design.
|
|
if (!request->usetokenizertemplate() || request->messages_size() == 0) {
|
|
return request->prompt();
|
|
}
|
|
std::string out;
|
|
const std::string sys_role = "system";
|
|
for (const auto &m : request->messages()) {
|
|
if (m.role() == sys_role) { out += "[sys] " + m.content() + "\n"; break; }
|
|
}
|
|
for (const auto &m : request->messages()) {
|
|
if (m.role() == sys_role) continue;
|
|
out += "[" + m.role() + "] " + m.content() + "\n";
|
|
}
|
|
return out;
|
|
}
|
|
|
|
ds4cpp::KvCache g_kv_cache;
|
|
|
|
// Try to recover prefill state for `rendered`. Returns the matched prefix length.
|
|
static size_t maybe_load_cache(const std::string &rendered) {
|
|
if (!g_kv_cache.enabled() || !g_session) return 0;
|
|
return g_kv_cache.LoadLongestPrefix(g_session, rendered, g_ctx_size);
|
|
}
|
|
|
|
static void maybe_save_cache(const std::string &rendered) {
|
|
if (g_kv_cache.enabled() && g_session) {
|
|
g_kv_cache.Save(g_session, rendered, g_ctx_size);
|
|
}
|
|
}
|
|
|
|
static void build_prompt(ds4_engine *engine, const backend::PredictOptions *request,
|
|
ds4_tokens *out) {
|
|
if (!request->usetokenizertemplate() || request->messages_size() == 0) {
|
|
ds4_tokenize_text(engine, request->prompt().c_str(), out);
|
|
return;
|
|
}
|
|
// Chat-template path: render via ds4's helpers.
|
|
ds4_chat_begin(engine, out);
|
|
|
|
ds4_think_mode think = parse_think_mode(request);
|
|
|
|
// ds4_encode_chat_prompt is convenient when there is exactly one
|
|
// system+user pair, but for arbitrary turn lists we use the granular
|
|
// append helpers. Pull the first system message (if any), then append
|
|
// every other message in order.
|
|
const std::string sys_role = "system";
|
|
std::string system_text;
|
|
for (const auto &m : request->messages()) {
|
|
if (m.role() == sys_role) { system_text = m.content(); break; }
|
|
}
|
|
// Inject the tools manifest into the system prompt when tools are present.
|
|
// ds4 was trained to emit DSML tool calls ONLY when this preamble is in
|
|
// the system message - without it, the model has no idea tools exist and
|
|
// the e2e tool-call test will fail. The renderer lives in dsml_renderer
|
|
// and is a verbatim port of ds4_server.c's append_tools_prompt_text.
|
|
std::string tools_manifest;
|
|
if (!request->tools().empty()) {
|
|
tools_manifest = ds4cpp::RenderToolsManifest(request->tools());
|
|
}
|
|
if (!system_text.empty() || !tools_manifest.empty()) {
|
|
std::string combined = system_text;
|
|
if (!tools_manifest.empty()) {
|
|
if (!combined.empty()) combined += "\n\n";
|
|
combined += tools_manifest;
|
|
}
|
|
ds4_chat_append_message(engine, out, "system", combined.c_str());
|
|
}
|
|
for (const auto &m : request->messages()) {
|
|
if (m.role() == sys_role) continue;
|
|
if (m.role() == "assistant" && !m.tool_calls().empty()) {
|
|
std::string combined = m.content();
|
|
combined += ds4cpp::RenderAssistantToolCalls(m.tool_calls());
|
|
ds4_chat_append_message(engine, out, "assistant", combined.c_str());
|
|
} else if (m.role() == "tool") {
|
|
std::string body = ds4cpp::RenderToolResult(m.tool_call_id(), m.content());
|
|
ds4_chat_append_message(engine, out, "user", body.c_str());
|
|
} else {
|
|
ds4_chat_append_message(engine, out, m.role().c_str(), m.content().c_str());
|
|
}
|
|
}
|
|
ds4_chat_append_assistant_prefix(engine, out, think);
|
|
}
|
|
|
|
class DS4Backend final : public backend::Backend::Service {
|
|
public:
|
|
GStatus Health(ServerContext *, const backend::HealthMessage *,
|
|
backend::Reply *reply) override {
|
|
reply->set_message(std::string("OK"));
|
|
return GStatus::OK;
|
|
}
|
|
|
|
GStatus Free(ServerContext *, const backend::HealthMessage *,
|
|
backend::Result *result) override {
|
|
std::lock_guard<std::mutex> lock(g_engine_mu);
|
|
if (g_session) { ds4_session_free(g_session); g_session = nullptr; }
|
|
if (g_engine) { ds4_engine_close(g_engine); g_engine = nullptr; }
|
|
result->set_success(true);
|
|
return GStatus::OK;
|
|
}
|
|
|
|
GStatus LoadModel(ServerContext *, const backend::ModelOptions *request,
|
|
backend::Result *result) override {
|
|
std::lock_guard<std::mutex> lock(g_engine_mu);
|
|
|
|
// Reset distributed state so a model swap (a second LoadModel without
|
|
// ds4_role) doesn't inherit a stale coordinator configuration.
|
|
g_distributed = false;
|
|
g_route_timeout_sec = 60;
|
|
|
|
if (g_engine) {
|
|
if (g_session) { ds4_session_free(g_session); g_session = nullptr; }
|
|
ds4_engine_close(g_engine);
|
|
g_engine = nullptr;
|
|
}
|
|
|
|
std::string model_path = request->modelfile();
|
|
if (model_path.empty()) model_path = request->model();
|
|
if (model_path.empty()) {
|
|
result->set_success(false);
|
|
result->set_message("ds4: ModelOptions.Model or .ModelFile must be set");
|
|
return GStatus::OK;
|
|
}
|
|
|
|
std::string mtp_path;
|
|
int mtp_draft = 0;
|
|
float mtp_margin = 3.0f;
|
|
std::string ds4_role, ds4_layers, ds4_listen;
|
|
for (const auto &opt : request->options()) {
|
|
auto [k, v] = split_option(opt);
|
|
if (k == "mtp_path") mtp_path = v;
|
|
else if (k == "mtp_draft") mtp_draft = std::stoi(v);
|
|
else if (k == "mtp_margin") mtp_margin = std::stof(v);
|
|
else if (k == "kv_cache_dir") g_kv_cache_dir = v;
|
|
else if (k == "ds4_role") ds4_role = v;
|
|
else if (k == "ds4_layers") ds4_layers = v;
|
|
else if (k == "ds4_listen") ds4_listen = v;
|
|
else if (k == "ds4_route_timeout") {
|
|
if (!parse_positive_int(v, &g_route_timeout_sec)) {
|
|
result->set_success(false);
|
|
result->set_message("ds4: ds4_route_timeout must be a positive integer");
|
|
return GStatus::OK;
|
|
}
|
|
}
|
|
}
|
|
|
|
g_kv_cache.SetDir(g_kv_cache_dir);
|
|
|
|
ds4_engine_options opt = {};
|
|
opt.model_path = model_path.c_str();
|
|
opt.mtp_path = mtp_path.empty() ? nullptr : mtp_path.c_str();
|
|
opt.n_threads = request->threads() > 0 ? request->threads() : 0;
|
|
opt.mtp_draft_tokens = mtp_draft;
|
|
opt.mtp_margin = mtp_margin;
|
|
opt.directional_steering_file = nullptr;
|
|
opt.warm_weights = false;
|
|
opt.quality = false;
|
|
|
|
#if defined(DS4_NO_GPU)
|
|
opt.backend = DS4_BACKEND_CPU;
|
|
#elif defined(__APPLE__)
|
|
opt.backend = DS4_BACKEND_METAL;
|
|
#else
|
|
opt.backend = DS4_BACKEND_CUDA;
|
|
#endif
|
|
|
|
// Coordinator wiring. 'ds4_role:coordinator' enables layer-split
|
|
// distributed inference: this process listens on ds4_listen and owns
|
|
// the ds4_layers slice; workers dial in (see `local-ai worker
|
|
// ds4-distributed`). Absent ds4_role => unchanged single-node path.
|
|
// Must be static: opt.distributed.listen_host is a const char* the
|
|
// engine retains past this call, so it cannot point at a local that
|
|
// goes out of scope (otherwise a future "simplify to local" refactor
|
|
// reintroduces a dangling pointer).
|
|
static std::string s_listen_host;
|
|
if (ds4_role == "coordinator") {
|
|
if (ds4_layers.empty() || ds4_listen.empty()) {
|
|
result->set_success(false);
|
|
result->set_message("ds4: ds4_role:coordinator requires ds4_layers and ds4_listen");
|
|
return GStatus::OK;
|
|
}
|
|
// host:port for IPv4/hostname; IPv6 literals are unsupported (the
|
|
// first colon would split inside the address).
|
|
auto host_port = split_option(ds4_listen); // "host:port" -> {host, port}
|
|
if (host_port.second.empty()) {
|
|
result->set_success(false);
|
|
result->set_message("ds4: ds4_listen must be host:port");
|
|
return GStatus::OK;
|
|
}
|
|
int listen_port = 0;
|
|
if (!parse_positive_int(host_port.second, &listen_port)) {
|
|
result->set_success(false);
|
|
result->set_message("ds4: ds4_listen port must be a positive integer");
|
|
return GStatus::OK;
|
|
}
|
|
ds4_distributed_layers layers = {};
|
|
if (!parse_layers_spec(ds4_layers, &layers)) {
|
|
result->set_success(false);
|
|
result->set_message("ds4: invalid ds4_layers (want START:END or START:output)");
|
|
return GStatus::OK;
|
|
}
|
|
s_listen_host = host_port.first;
|
|
opt.distributed.role = DS4_DISTRIBUTED_COORDINATOR;
|
|
opt.distributed.layers = layers;
|
|
opt.distributed.listen_host = s_listen_host.c_str();
|
|
opt.distributed.listen_port = listen_port;
|
|
g_distributed = true;
|
|
}
|
|
|
|
int rc = ds4_engine_open(&g_engine, &opt);
|
|
if (rc != 0 || !g_engine) {
|
|
result->set_success(false);
|
|
result->set_message("ds4_engine_open failed (rc=" + std::to_string(rc) + ")");
|
|
return GStatus::OK;
|
|
}
|
|
|
|
g_ctx_size = request->contextsize() > 0 ? request->contextsize() : 32768;
|
|
rc = ds4_session_create(&g_session, g_engine, g_ctx_size);
|
|
if (rc != 0 || !g_session) {
|
|
ds4_engine_close(g_engine);
|
|
g_engine = nullptr;
|
|
result->set_success(false);
|
|
result->set_message("ds4_session_create failed (rc=" + std::to_string(rc) + ")");
|
|
return GStatus::OK;
|
|
}
|
|
|
|
result->set_success(true);
|
|
result->set_message("loaded " + model_path);
|
|
return GStatus::OK;
|
|
}
|
|
|
|
GStatus TokenizeString(ServerContext *, const backend::PredictOptions *request,
|
|
backend::TokenizationResponse *response) override {
|
|
std::lock_guard<std::mutex> lock(g_engine_mu);
|
|
if (!g_engine) return GStatus(StatusCode::FAILED_PRECONDITION, "ds4: model not loaded");
|
|
ds4_tokens out = {};
|
|
ds4_tokenize_text(g_engine, request->prompt().c_str(), &out);
|
|
for (int i = 0; i < out.len; ++i) response->add_tokens(out.v[i]);
|
|
response->set_length(out.len);
|
|
ds4_tokens_free(&out);
|
|
return GStatus::OK;
|
|
}
|
|
|
|
GStatus Predict(ServerContext *, const backend::PredictOptions *request,
|
|
backend::Reply *reply) override {
|
|
std::unique_lock<std::mutex> lock(g_engine_mu);
|
|
if (!g_engine || !g_session) {
|
|
return GStatus(StatusCode::FAILED_PRECONDITION, "ds4: model not loaded");
|
|
}
|
|
if (std::string route_err = wait_route_ready(lock); !route_err.empty()) {
|
|
return GStatus(StatusCode::UNAVAILABLE, route_err);
|
|
}
|
|
ds4_tokens prompt = {};
|
|
build_prompt(g_engine, request, &prompt);
|
|
int n_predict = request->tokens() > 0 ? request->tokens() : 256;
|
|
|
|
CollectCtx collect = {g_engine, "", {}, reply, 0, {}, "", ""};
|
|
std::string cache_key = render_prompt_text(request);
|
|
size_t cache_hit = maybe_load_cache(cache_key);
|
|
(void)cache_hit; // future: skip prompt prefix if hit covers full prompt
|
|
|
|
// Manual generation loop on g_session. When MTP speculative weights
|
|
// were loaded (LoadModel option 'mtp_path:'), we use the
|
|
// ds4_session_eval_speculative_argmax path which may accept N>1
|
|
// tokens per outer iteration. Otherwise per-token argmax + eval.
|
|
// Either way g_session advances so the disk KV cache picks up a
|
|
// real checkpoint after the call (see maybe_save_cache below).
|
|
char err[256] = {0};
|
|
int rc = ds4_session_sync(g_session, &prompt, err, sizeof(err));
|
|
int prompt_len = prompt.len;
|
|
ds4_tokens_free(&prompt);
|
|
if (rc == 0) {
|
|
const int eos = ds4_token_eos(g_engine);
|
|
const int draft_max = ds4_engine_mtp_draft_tokens(g_engine);
|
|
const bool think_enabled = ds4_think_mode_enabled(parse_think_mode(request));
|
|
int produced = 0;
|
|
while (produced < n_predict) {
|
|
SampleParams sp = compute_sample_params(request, collect.parser, think_enabled);
|
|
int first;
|
|
if (sp.temperature <= 0.0f) {
|
|
first = ds4_session_argmax(g_session);
|
|
} else {
|
|
first = ds4_session_sample(g_session,
|
|
sp.temperature, sp.top_k,
|
|
sp.top_p, sp.min_p, get_rng());
|
|
}
|
|
if (first == eos) break;
|
|
// MTP only when sampling is greedy (ds4-server gate).
|
|
if (draft_max > 0 && sp.temperature <= 0.0f) {
|
|
constexpr int kAcceptedMax = 8;
|
|
int accepted[kAcceptedMax];
|
|
int cap = std::min(kAcceptedMax, draft_max + 1);
|
|
int n = ds4_session_eval_speculative_argmax(
|
|
g_session, first, draft_max, eos,
|
|
accepted, cap, err, sizeof(err));
|
|
if (n < 0) { rc = -1; break; }
|
|
bool stop = false;
|
|
for (int j = 0; j < n; ++j) {
|
|
if (accepted[j] == eos) { stop = true; break; }
|
|
collect_emit(&collect, accepted[j]);
|
|
if (++produced >= n_predict) { stop = true; break; }
|
|
}
|
|
if (stop) break;
|
|
} else {
|
|
collect_emit(&collect, first);
|
|
if (++produced >= n_predict) break;
|
|
rc = ds4_session_eval(g_session, first, err, sizeof(err));
|
|
if (rc != 0) break;
|
|
}
|
|
}
|
|
collect_done(&collect);
|
|
}
|
|
maybe_save_cache(cache_key);
|
|
|
|
// Flush any buffered parser state.
|
|
std::vector<ds4cpp::ParserEvent> events;
|
|
collect.parser.Flush(events);
|
|
apply_events(&collect, events);
|
|
|
|
if (rc != 0) {
|
|
return GStatus(StatusCode::INTERNAL,
|
|
std::string("ds4 generation failed: ") + err);
|
|
}
|
|
|
|
// Emit one ChatDelta with content/reasoning/tool_calls.
|
|
auto *delta = reply->add_chat_deltas();
|
|
delta->set_content(collect.content_buf);
|
|
delta->set_reasoning_content(collect.reasoning_buf);
|
|
for (size_t i = 0; i < collect.pending.size(); ++i) {
|
|
auto *tc = delta->add_tool_calls();
|
|
tc->set_index(static_cast<int32_t>(i));
|
|
tc->set_id(collect.pending[i].id);
|
|
tc->set_name(collect.pending[i].name);
|
|
tc->set_arguments(collect.pending[i].args);
|
|
}
|
|
|
|
reply->set_message(collect.raw_buf);
|
|
reply->set_tokens(collect.tokens);
|
|
reply->set_prompt_tokens(prompt_len);
|
|
return GStatus::OK;
|
|
}
|
|
|
|
GStatus PredictStream(ServerContext *, const backend::PredictOptions *request,
|
|
ServerWriter<backend::Reply> *writer) override {
|
|
std::unique_lock<std::mutex> lock(g_engine_mu);
|
|
if (!g_engine || !g_session) {
|
|
return GStatus(StatusCode::FAILED_PRECONDITION, "ds4: model not loaded");
|
|
}
|
|
if (std::string route_err = wait_route_ready(lock); !route_err.empty()) {
|
|
return GStatus(StatusCode::UNAVAILABLE, route_err);
|
|
}
|
|
ds4_tokens prompt = {};
|
|
build_prompt(g_engine, request, &prompt);
|
|
int n_predict = request->tokens() > 0 ? request->tokens() : 256;
|
|
|
|
StreamCtx s = {g_engine, writer, {}, 0, false, {}};
|
|
std::string cache_key = render_prompt_text(request);
|
|
size_t cache_hit = maybe_load_cache(cache_key);
|
|
(void)cache_hit;
|
|
|
|
// Manual loop on g_session - see Predict() above for the rationale.
|
|
// MTP speculative path used when ds4_engine_mtp_draft_tokens > 0.
|
|
char err[256] = {0};
|
|
int rc = ds4_session_sync(g_session, &prompt, err, sizeof(err));
|
|
ds4_tokens_free(&prompt);
|
|
if (rc == 0) {
|
|
const int eos = ds4_token_eos(g_engine);
|
|
const int draft_max = ds4_engine_mtp_draft_tokens(g_engine);
|
|
const bool think_enabled = ds4_think_mode_enabled(parse_think_mode(request));
|
|
int produced = 0;
|
|
while (produced < n_predict && !s.aborted) {
|
|
SampleParams sp = compute_sample_params(request, s.parser, think_enabled);
|
|
int first;
|
|
if (sp.temperature <= 0.0f) {
|
|
first = ds4_session_argmax(g_session);
|
|
} else {
|
|
first = ds4_session_sample(g_session,
|
|
sp.temperature, sp.top_k,
|
|
sp.top_p, sp.min_p, get_rng());
|
|
}
|
|
if (first == eos) break;
|
|
if (draft_max > 0 && sp.temperature <= 0.0f) {
|
|
constexpr int kAcceptedMax = 8;
|
|
int accepted[kAcceptedMax];
|
|
int cap = std::min(kAcceptedMax, draft_max + 1);
|
|
int n = ds4_session_eval_speculative_argmax(
|
|
g_session, first, draft_max, eos,
|
|
accepted, cap, err, sizeof(err));
|
|
if (n < 0) { rc = -1; break; }
|
|
bool stop = false;
|
|
for (int j = 0; j < n; ++j) {
|
|
if (accepted[j] == eos) { stop = true; break; }
|
|
stream_emit(&s, accepted[j]);
|
|
if (s.aborted) { stop = true; break; }
|
|
if (++produced >= n_predict) { stop = true; break; }
|
|
}
|
|
if (stop) break;
|
|
} else {
|
|
stream_emit(&s, first);
|
|
if (s.aborted || ++produced >= n_predict) break;
|
|
rc = ds4_session_eval(g_session, first, err, sizeof(err));
|
|
if (rc != 0) break;
|
|
}
|
|
}
|
|
stream_done(&s);
|
|
}
|
|
maybe_save_cache(cache_key);
|
|
|
|
// Flush parser state.
|
|
std::vector<ds4cpp::ParserEvent> events;
|
|
s.parser.Flush(events);
|
|
if (!events.empty() && !s.aborted) {
|
|
backend::Reply reply;
|
|
auto *delta = reply.add_chat_deltas();
|
|
for (const auto &e : events) {
|
|
if (e.type == ds4cpp::ParserEvent::CONTENT) {
|
|
delta->set_content(delta->content() + e.text);
|
|
} else if (e.type == ds4cpp::ParserEvent::REASONING) {
|
|
delta->set_reasoning_content(delta->reasoning_content() + e.text);
|
|
}
|
|
}
|
|
s.writer->Write(reply);
|
|
}
|
|
|
|
if (rc != 0 && !s.aborted) {
|
|
return GStatus(StatusCode::INTERNAL,
|
|
std::string("ds4 generation failed: ") + err);
|
|
}
|
|
return GStatus::OK;
|
|
}
|
|
|
|
GStatus Status(ServerContext *, const backend::HealthMessage *,
|
|
backend::StatusResponse *response) override {
|
|
std::lock_guard<std::mutex> lock(g_engine_mu);
|
|
response->set_state(g_engine ? backend::StatusResponse::READY
|
|
: backend::StatusResponse::UNINITIALIZED);
|
|
return GStatus::OK;
|
|
}
|
|
};
|
|
|
|
void RunServer(const std::string &addr) {
|
|
DS4Backend service;
|
|
grpc::EnableDefaultHealthCheckService(true);
|
|
grpc::reflection::InitProtoReflectionServerBuilderPlugin();
|
|
|
|
ServerBuilder builder;
|
|
builder.AddListeningPort(addr, grpc::InsecureServerCredentials());
|
|
builder.RegisterService(&service);
|
|
builder.SetMaxReceiveMessageSize(64 * 1024 * 1024);
|
|
builder.SetMaxSendMessageSize(64 * 1024 * 1024);
|
|
|
|
std::unique_ptr<Server> server(builder.BuildAndStart());
|
|
if (!server) {
|
|
std::cerr << "ds4 grpc-server: failed to bind " << addr << "\n";
|
|
std::exit(1);
|
|
}
|
|
g_server = server.get();
|
|
std::cerr << "ds4 grpc-server listening on " << addr << "\n";
|
|
server->Wait();
|
|
}
|
|
|
|
void signal_handler(int) {
|
|
if (auto *srv = g_server.load()) {
|
|
srv->Shutdown(std::chrono::system_clock::now() +
|
|
std::chrono::seconds(3));
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
int main(int argc, char *argv[]) {
|
|
std::string addr = "127.0.0.1:50051";
|
|
for (int i = 1; i < argc; ++i) {
|
|
std::string a = argv[i];
|
|
const std::string addr_flag = "--addr=";
|
|
if (a.rfind(addr_flag, 0) == 0) addr = a.substr(addr_flag.size());
|
|
else if (a == "--addr" && i + 1 < argc) addr = argv[++i];
|
|
else if (a == "--help" || a == "-h") {
|
|
std::cout << "Usage: grpc-server --addr=HOST:PORT\n";
|
|
return 0;
|
|
}
|
|
}
|
|
std::signal(SIGINT, signal_handler);
|
|
std::signal(SIGTERM, signal_handler);
|
|
RunServer(addr);
|
|
return 0;
|
|
}
|