mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-18 21:58:58 -04:00
feat(ds4): wire SSD streaming + quality engine options, add 128GB DeepSeek gallery models The ds4 backend zero-initialized ds4_engine_options and exposed none of the engine's tunable knobs, so SSD streaming (run a model larger than RAM by streaming routed MoE experts from the GGUF on SSD) and the quality/perf knobs were unreachable from LocalAI model YAMLs. Map ModelOptions.Options onto ds4_engine_options through a declarative table (kEngineOptSpecs + apply_engine_option) instead of per-field branches: the struct is fixed C with no reflection, so the field set is enumerated once and a future knob is a one-line table row. Two fields use ds4's own typed parsers (GiB budgets, cache-experts count-or-NGB). Bare flags (e.g. "ssd_streaming") mean true; path-type options (mtp_path, expert_profile_path, directional_steering_file) resolve relative to the model directory so a gallery entry can reference a companion file by bare filename. mtp_draft/mtp_margin are now validated rather than parsed with throwing std::stoi/std::stof. Add gallery entries for the 128 GB class: - deepseek-v4-flash-q2-q4 (~91 GB, mixed q2/q4, fits RAM, higher quality) - deepseek-v4-flash-q4-ssd (~153 GB full 4-bit, runs on 128 GB via SSD streaming) - deepseek-v4-flash-q2-mtp (~81 GB + MTP speculative draft weights) - deepseek-v4-pro-q2-ssd (~433 GB Pro, experimental SSD streaming) SSD streaming is Metal (Darwin) only; the options are inert on CUDA/CPU. Assisted-by: Claude:claude-opus-4-8 [Claude Code] Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
979 lines
41 KiB
C++
979 lines
41 KiB
C++
// ds4 LocalAI gRPC backend.
|
|
//
|
|
// Wraps antirez/ds4's `ds4_engine_*` / `ds4_session_*` public API
|
|
// (see ds4/ds4.h) over LocalAI's backend.proto. Tool calls, thinking
|
|
// mode, and disk KV cache are wired in follow-up commits; this commit
|
|
// is just the bind/listen/Health/Free skeleton.
|
|
|
|
#include "backend.pb.h"
|
|
#include "backend.grpc.pb.h"
|
|
|
|
#include "dsml_parser.h" // populated in Task 12
|
|
#include "dsml_renderer.h" // populated in Task 16
|
|
#include "kv_cache.h" // populated in Task 17
|
|
|
|
extern "C" {
|
|
#include "ds4.h"
|
|
}
|
|
|
|
#include <grpcpp/grpcpp.h>
|
|
#include <grpcpp/server.h>
|
|
#include <grpcpp/server_builder.h>
|
|
#include <grpcpp/ext/proto_server_reflection_plugin.h>
|
|
|
|
#include <atomic>
|
|
#include <chrono>
|
|
#include <climits>
|
|
#include <csignal>
|
|
#include <cstddef>
|
|
#include <cstdint>
|
|
#include <cstdlib>
|
|
#include <cstring>
|
|
#include <ctime>
|
|
#include <iostream>
|
|
#include <memory>
|
|
#include <mutex>
|
|
#include <string>
|
|
#include <thread>
|
|
#include <vector>
|
|
|
|
using grpc::Server;
|
|
using grpc::ServerBuilder;
|
|
using grpc::ServerContext;
|
|
using grpc::ServerWriter;
|
|
// NOTE: do NOT alias `grpc::Status` as `Status` - the Status RPC method below
|
|
// would shadow the type, breaking the other RPC method declarations that use
|
|
// it as a return type. Use GStatus instead.
|
|
using GStatus = ::grpc::Status;
|
|
using grpc::StatusCode;
|
|
|
|
namespace {
|
|
|
|
// Global state - ds4 is single-engine-per-process by design.
|
|
std::mutex g_engine_mu;
|
|
ds4_engine *g_engine = nullptr;
|
|
ds4_session *g_session = nullptr;
|
|
int g_ctx_size = 32768;
|
|
std::string g_kv_cache_dir; // empty disables disk cache
|
|
|
|
// Distributed coordinator state. g_distributed is set true when LoadModel is
|
|
// given 'ds4_role:coordinator'; generation then waits for the worker route to
|
|
// form before running. Single-node behavior is unchanged when unset.
|
|
bool g_distributed = false;
|
|
int g_route_timeout_sec = 60;
|
|
|
|
std::atomic<Server *> g_server{nullptr};
|
|
|
|
// Parse a "key:value" option string. Returns empty when no colon.
|
|
static std::pair<std::string, std::string> split_option(const std::string &opt) {
|
|
auto colon = opt.find(':');
|
|
if (colon == std::string::npos) return {opt, ""};
|
|
return {opt.substr(0, colon), opt.substr(colon + 1)};
|
|
}
|
|
|
|
// Parse a positive base-10 integer. Returns false (without throwing) on empty,
|
|
// trailing garbage, non-positive, or overflow - unlike std::stoi.
|
|
static bool parse_positive_int(const std::string &s, int *out) {
|
|
if (s.empty()) return false;
|
|
char *end = nullptr;
|
|
long v = std::strtol(s.c_str(), &end, 10);
|
|
if (!end || *end != '\0' || v <= 0 || v > INT_MAX) return false;
|
|
*out = static_cast<int>(v);
|
|
return true;
|
|
}
|
|
|
|
// Parse a ds4 layer spec "START:END" or "START:output" into the engine's
|
|
// distributed layer fields. Returns false on malformed input.
|
|
static bool parse_layers_spec(const std::string &spec, ds4_distributed_layers *out) {
|
|
auto colon = spec.find(':');
|
|
if (colon == std::string::npos) return false;
|
|
std::string lhs = spec.substr(0, colon);
|
|
std::string rhs = spec.substr(colon + 1);
|
|
if (lhs.empty() || rhs.empty()) return false;
|
|
char *end = nullptr;
|
|
long start = std::strtol(lhs.c_str(), &end, 10);
|
|
if (!end || *end != '\0' || start < 0) return false;
|
|
out->start = static_cast<uint32_t>(start);
|
|
out->has_output = false;
|
|
if (rhs == "output") {
|
|
out->has_output = true;
|
|
out->end = out->start; // engine treats has_output as "through final layer"
|
|
} else {
|
|
long e = std::strtol(rhs.c_str(), &end, 10);
|
|
if (!end || *end != '\0' || e < start) return false;
|
|
out->end = static_cast<uint32_t>(e);
|
|
}
|
|
out->set = true;
|
|
return true;
|
|
}
|
|
|
|
// Parse a boolean LoadModel option. An empty value (a bare flag-style option
|
|
// like "ssd_streaming" with no colon) means true so model YAMLs can write
|
|
// options: ["ssd_streaming"] to enable a switch.
|
|
static bool parse_bool_option(const std::string &s, bool *out) {
|
|
if (s.empty() || s == "true" || s == "1" || s == "yes" || s == "on") { *out = true; return true; }
|
|
if (s == "false" || s == "0" || s == "no" || s == "off") { *out = false; return true; }
|
|
return false;
|
|
}
|
|
|
|
// Table-driven mapping from LoadModel option keys to ds4_engine_options fields.
|
|
// ds4_engine_options is a fixed C struct with no reflection, so the field set
|
|
// is enumerated once here; adding a future engine knob is a one-line table
|
|
// entry rather than a new branch in LoadModel. Two fields need ds4's own typed
|
|
// parsers (Gib, CacheExperts) so a plain string passthrough can't cover them.
|
|
enum class DsOptType { Bool, Int, Uint, Float, Str, Gib, CacheExperts };
|
|
|
|
struct DsOptSpec {
|
|
const char *key;
|
|
DsOptType type;
|
|
size_t off; // byte offset into ds4_engine_options
|
|
size_t off2; // second offset (CacheExperts writes experts + bytes)
|
|
bool is_path; // Str values: resolve a relative value against the model dir
|
|
};
|
|
|
|
static const DsOptSpec kEngineOptSpecs[] = {
|
|
{"mtp_path", DsOptType::Str, offsetof(ds4_engine_options, mtp_path), 0, true},
|
|
{"mtp_draft", DsOptType::Int, offsetof(ds4_engine_options, mtp_draft_tokens), 0},
|
|
{"mtp_margin", DsOptType::Float, offsetof(ds4_engine_options, mtp_margin), 0},
|
|
{"prefill_chunk", DsOptType::Uint, offsetof(ds4_engine_options, prefill_chunk), 0},
|
|
{"power_percent", DsOptType::Int, offsetof(ds4_engine_options, power_percent), 0},
|
|
{"warm_weights", DsOptType::Bool, offsetof(ds4_engine_options, warm_weights), 0},
|
|
{"quality", DsOptType::Bool, offsetof(ds4_engine_options, quality), 0},
|
|
{"ssd_streaming", DsOptType::Bool, offsetof(ds4_engine_options, ssd_streaming), 0},
|
|
{"ssd_streaming_cold", DsOptType::Bool, offsetof(ds4_engine_options, ssd_streaming_cold), 0},
|
|
{"ssd_streaming_preload_experts", DsOptType::Uint, offsetof(ds4_engine_options, ssd_streaming_preload_experts), 0},
|
|
{"ssd_streaming_cache_experts", DsOptType::CacheExperts, offsetof(ds4_engine_options, ssd_streaming_cache_experts),
|
|
offsetof(ds4_engine_options, ssd_streaming_cache_bytes)},
|
|
{"simulate_used_memory", DsOptType::Gib, offsetof(ds4_engine_options, simulate_used_memory_bytes), 0},
|
|
{"expert_profile_path", DsOptType::Str, offsetof(ds4_engine_options, expert_profile_path), 0, true},
|
|
{"directional_steering_file", DsOptType::Str, offsetof(ds4_engine_options, directional_steering_file), 0, true},
|
|
{"directional_steering_attn", DsOptType::Float, offsetof(ds4_engine_options, directional_steering_attn), 0},
|
|
{"directional_steering_ffn", DsOptType::Float, offsetof(ds4_engine_options, directional_steering_ffn), 0},
|
|
};
|
|
|
|
// Apply a single key:value LoadModel option to the engine options struct.
|
|
// Unknown keys are ignored (back-compat: callers pass mixed option sets).
|
|
// String values are copied into `storage`, whose elements the engine reads by
|
|
// pointer during ds4_engine_open; `storage` MUST have reserved capacity so
|
|
// push_back never reallocates and dangles an earlier c_str(). Returns false
|
|
// with `err` set when a recognized key has an invalid value.
|
|
static bool apply_engine_option(ds4_engine_options *opt, const std::string &key,
|
|
const std::string &val, const std::string &model_dir,
|
|
std::vector<std::string> &storage, std::string &err) {
|
|
const DsOptSpec *spec = nullptr;
|
|
for (const auto &s : kEngineOptSpecs) {
|
|
if (key == s.key) { spec = &s; break; }
|
|
}
|
|
if (!spec) return true; // unknown key: ignore
|
|
|
|
char *base = reinterpret_cast<char *>(opt);
|
|
switch (spec->type) {
|
|
case DsOptType::Bool: {
|
|
bool b = false;
|
|
if (!parse_bool_option(val, &b)) { err = key + " must be true/false"; return false; }
|
|
*reinterpret_cast<bool *>(base + spec->off) = b;
|
|
return true;
|
|
}
|
|
case DsOptType::Int: {
|
|
char *end = nullptr;
|
|
long v = std::strtol(val.c_str(), &end, 10);
|
|
if (val.empty() || !end || *end != '\0') { err = key + " must be an integer"; return false; }
|
|
*reinterpret_cast<int *>(base + spec->off) = static_cast<int>(v);
|
|
return true;
|
|
}
|
|
case DsOptType::Uint: {
|
|
char *end = nullptr;
|
|
long v = std::strtol(val.c_str(), &end, 10);
|
|
if (val.empty() || !end || *end != '\0' || v < 0 || v > static_cast<long>(UINT32_MAX)) {
|
|
err = key + " must be a non-negative integer"; return false;
|
|
}
|
|
*reinterpret_cast<uint32_t *>(base + spec->off) = static_cast<uint32_t>(v);
|
|
return true;
|
|
}
|
|
case DsOptType::Float: {
|
|
char *end = nullptr;
|
|
float f = std::strtof(val.c_str(), &end);
|
|
if (val.empty() || !end || *end != '\0') { err = key + " must be a number"; return false; }
|
|
*reinterpret_cast<float *>(base + spec->off) = f;
|
|
return true;
|
|
}
|
|
case DsOptType::Str: {
|
|
// Resolve a relative path option (e.g. mtp_path: a sibling GGUF the
|
|
// gallery downloaded next to the model) against the model directory, so
|
|
// YAMLs reference companion files by name. Absolute values pass through.
|
|
if (spec->is_path && !model_dir.empty() && !val.empty() && val.front() != '/') {
|
|
storage.push_back(model_dir + "/" + val);
|
|
} else {
|
|
storage.push_back(val);
|
|
}
|
|
*reinterpret_cast<const char **>(base + spec->off) = storage.back().c_str();
|
|
return true;
|
|
}
|
|
case DsOptType::Gib: {
|
|
uint64_t bytes = 0;
|
|
if (!ds4_parse_gib_arg(val.c_str(), &bytes)) {
|
|
err = key + " must be a GiB value, e.g. 64GB"; return false;
|
|
}
|
|
*reinterpret_cast<uint64_t *>(base + spec->off) = bytes;
|
|
return true;
|
|
}
|
|
case DsOptType::CacheExperts: {
|
|
uint32_t experts = 0;
|
|
uint64_t bytes = 0;
|
|
if (!ds4_parse_streaming_cache_experts_arg(val.c_str(), &experts, &bytes)) {
|
|
err = key + " must be a positive expert count or a <number>GB budget"; return false;
|
|
}
|
|
*reinterpret_cast<uint32_t *>(base + spec->off) = experts;
|
|
*reinterpret_cast<uint64_t *>(base + spec->off2) = bytes;
|
|
return true;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
// When acting as a distributed coordinator, block until the worker route
|
|
// covers all layers (ds4_session_distributed_route_ready == 1) or the timeout
|
|
// elapses. Returns an empty string on success, or an error message to return
|
|
// to the client. No-op when not distributed.
|
|
//
|
|
// Takes the g_engine_mu lock by reference and RELEASES it during each poll
|
|
// sleep. The wait can span up to g_route_timeout_sec seconds while workers
|
|
// connect; holding g_engine_mu the whole time would block the Status/Health
|
|
// readiness probes (they also lock g_engine_mu), making LocalAI's loader treat
|
|
// a still-starting worker as hung.
|
|
static std::string wait_route_ready(std::unique_lock<std::mutex> &lock) {
|
|
if (!g_distributed) return "";
|
|
char err[256] = {0};
|
|
const int deadline_polls = g_route_timeout_sec * 10; // 100ms per poll
|
|
for (int i = 0; i <= deadline_polls; ++i) {
|
|
int ready = ds4_session_distributed_route_ready(g_session, err, sizeof(err));
|
|
if (ready == 1) return "";
|
|
if (ready < 0) {
|
|
return std::string("ds4 distributed route error: ") +
|
|
(err[0] ? err : "unknown");
|
|
}
|
|
// Release the lock while sleeping so Status/Health and other RPCs can
|
|
// interleave during worker startup.
|
|
lock.unlock();
|
|
struct timespec ts = {0, 100L * 1000L * 1000L}; // 100ms
|
|
nanosleep(&ts, nullptr);
|
|
lock.lock();
|
|
// A concurrent Free() may have torn down the engine while we slept.
|
|
if (!g_engine || !g_session) {
|
|
return "ds4: model unloaded while waiting for distributed route";
|
|
}
|
|
}
|
|
return "ds4 distributed route incomplete: workers not connected (layers uncovered)";
|
|
}
|
|
|
|
static void append_token_text(ds4_engine *engine, int token, std::string &out) {
|
|
size_t len = 0;
|
|
const char *text = ds4_token_text(engine, token, &len);
|
|
if (text && len > 0) out.append(text, len);
|
|
}
|
|
|
|
struct CollectCtx {
|
|
ds4_engine *engine;
|
|
std::string raw_buf; // exact raw bytes for Reply.message
|
|
ds4cpp::DsmlParser parser;
|
|
backend::Reply *reply;
|
|
int tokens;
|
|
|
|
// Per-tool aggregation: accumulate ChatDelta tool_calls so we emit one
|
|
// delta with all calls, mirroring how vllm's non-streaming path returns.
|
|
struct Pending {
|
|
std::string id;
|
|
std::string name;
|
|
std::string args;
|
|
};
|
|
std::vector<Pending> pending;
|
|
|
|
std::string content_buf;
|
|
std::string reasoning_buf;
|
|
};
|
|
|
|
static void apply_events(CollectCtx *c, const std::vector<ds4cpp::ParserEvent> &events) {
|
|
for (const auto &e : events) {
|
|
switch (e.type) {
|
|
case ds4cpp::ParserEvent::CONTENT:
|
|
c->content_buf += e.text;
|
|
break;
|
|
case ds4cpp::ParserEvent::REASONING:
|
|
c->reasoning_buf += e.text;
|
|
break;
|
|
case ds4cpp::ParserEvent::TOOL_START:
|
|
if ((int)c->pending.size() <= e.index)
|
|
c->pending.resize(e.index + 1);
|
|
c->pending[e.index].id = e.tool_id;
|
|
c->pending[e.index].name = e.tool_name;
|
|
break;
|
|
case ds4cpp::ParserEvent::TOOL_ARGS:
|
|
if ((int)c->pending.size() > e.index)
|
|
c->pending[e.index].args += e.text;
|
|
break;
|
|
case ds4cpp::ParserEvent::TOOL_END:
|
|
// No-op for non-streaming: the final delta is emitted at the end.
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
static void collect_emit(void *ud, int token) {
|
|
auto *c = static_cast<CollectCtx *>(ud);
|
|
if (token == ds4_token_eos(c->engine)) return;
|
|
size_t len = 0;
|
|
const char *text = ds4_token_text(c->engine, token, &len);
|
|
if (!text || len == 0) return;
|
|
std::string chunk(text, len);
|
|
c->raw_buf += chunk;
|
|
std::vector<ds4cpp::ParserEvent> events;
|
|
c->parser.Feed(chunk, events);
|
|
apply_events(c, events);
|
|
c->tokens++;
|
|
}
|
|
static void collect_done(void *) {}
|
|
|
|
struct StreamCtx {
|
|
ds4_engine *engine;
|
|
ServerWriter<backend::Reply> *writer;
|
|
ds4cpp::DsmlParser parser;
|
|
int tokens;
|
|
bool aborted;
|
|
// Track which tool indices we've seen TOOL_START for, so subsequent
|
|
// ARGS deltas can elide the redundant id/name fields.
|
|
std::vector<bool> tool_started;
|
|
};
|
|
|
|
static void stream_emit(void *ud, int token) {
|
|
auto *s = static_cast<StreamCtx *>(ud);
|
|
if (s->aborted) return;
|
|
if (token == ds4_token_eos(s->engine)) return;
|
|
size_t len = 0;
|
|
const char *text = ds4_token_text(s->engine, token, &len);
|
|
if (!text || len == 0) return;
|
|
std::string chunk(text, len);
|
|
std::vector<ds4cpp::ParserEvent> events;
|
|
s->parser.Feed(chunk, events);
|
|
if (events.empty()) { s->tokens++; return; }
|
|
|
|
backend::Reply reply;
|
|
auto *delta = reply.add_chat_deltas();
|
|
bool any_field = false;
|
|
for (const auto &e : events) {
|
|
switch (e.type) {
|
|
case ds4cpp::ParserEvent::CONTENT:
|
|
delta->set_content(delta->content() + e.text);
|
|
any_field = true;
|
|
break;
|
|
case ds4cpp::ParserEvent::REASONING:
|
|
delta->set_reasoning_content(delta->reasoning_content() + e.text);
|
|
any_field = true;
|
|
break;
|
|
case ds4cpp::ParserEvent::TOOL_START: {
|
|
if ((int)s->tool_started.size() <= e.index)
|
|
s->tool_started.resize(e.index + 1, false);
|
|
s->tool_started[e.index] = true;
|
|
auto *tc = delta->add_tool_calls();
|
|
tc->set_index(e.index);
|
|
tc->set_id(e.tool_id);
|
|
tc->set_name(e.tool_name);
|
|
any_field = true;
|
|
break;
|
|
}
|
|
case ds4cpp::ParserEvent::TOOL_ARGS: {
|
|
auto *tc = delta->add_tool_calls();
|
|
tc->set_index(e.index);
|
|
tc->set_arguments(e.text);
|
|
any_field = true;
|
|
break;
|
|
}
|
|
case ds4cpp::ParserEvent::TOOL_END:
|
|
// No marker delta needed - the Go side closes the tool call on
|
|
// the final aggregator pass.
|
|
break;
|
|
}
|
|
}
|
|
reply.set_message(chunk);
|
|
reply.set_tokens(1);
|
|
if (any_field) {
|
|
if (!s->writer->Write(reply)) s->aborted = true;
|
|
}
|
|
s->tokens++;
|
|
}
|
|
static void stream_done(void *) {}
|
|
|
|
// Per-thread RNG seed for ds4_session_sample. Initialized lazily from
|
|
// system_clock; ds4 owns the random walk after that.
|
|
static uint64_t *get_rng() {
|
|
static thread_local uint64_t seed = 0;
|
|
if (seed == 0) {
|
|
seed = static_cast<uint64_t>(
|
|
std::chrono::system_clock::now().time_since_epoch().count());
|
|
if (seed == 0) seed = 1;
|
|
}
|
|
return &seed;
|
|
}
|
|
|
|
struct SampleParams {
|
|
float temperature;
|
|
int top_k;
|
|
float top_p;
|
|
float min_p;
|
|
};
|
|
|
|
// Compute the effective sampling parameters for the next token, mirroring
|
|
// ds4_server.c:7102-7115:
|
|
// - thinking mode enabled -> override (T=1, top_k=0, top_p=1, min_p=0)
|
|
// - inside DSML structural position (tool-call markers) -> force T=0
|
|
// - otherwise -> the request's user-supplied sampling settings
|
|
// The parser argument carries state from tokens emitted so far; its
|
|
// IsInDsmlStructural() predicts the next token's classification.
|
|
static SampleParams compute_sample_params(const backend::PredictOptions *request,
|
|
const ds4cpp::DsmlParser &parser,
|
|
bool think_enabled);
|
|
|
|
static ds4_think_mode parse_think_mode(const backend::PredictOptions *request) {
|
|
// Per the vllm backend convention, "enable_thinking" gates thinking on/off,
|
|
// and "reasoning_effort" picks the strength when on.
|
|
const auto &md = request->metadata();
|
|
auto et = md.find("enable_thinking");
|
|
bool enabled = true; // default ON per ds4-server
|
|
if (et != md.end()) enabled = (et->second == "true" || et->second == "1");
|
|
if (!enabled) return DS4_THINK_NONE;
|
|
auto re = md.find("reasoning_effort");
|
|
if (re != md.end() && (re->second == "max" || re->second == "xhigh"))
|
|
return DS4_THINK_MAX;
|
|
return DS4_THINK_HIGH;
|
|
}
|
|
|
|
static SampleParams compute_sample_params(const backend::PredictOptions *request,
|
|
const ds4cpp::DsmlParser &parser,
|
|
bool think_enabled) {
|
|
SampleParams p = {
|
|
request->temperature(),
|
|
request->topk(),
|
|
request->topp(),
|
|
request->minp(),
|
|
};
|
|
if (think_enabled) {
|
|
// Match ds4-server: thinking mode wants creativity in the reasoning
|
|
// pass and the trailing content, so the entire generation overrides
|
|
// sampling unless DSML structural bytes take over below.
|
|
p.temperature = 1.0f;
|
|
p.top_k = 0;
|
|
p.top_p = 1.0f;
|
|
p.min_p = 0.0f;
|
|
}
|
|
if (parser.IsInDsmlStructural()) {
|
|
// Tool-call structural bytes (tags, markers, headers) must parse
|
|
// cleanly. Force greedy regardless of user/thinking settings.
|
|
p.temperature = 0.0f;
|
|
}
|
|
return p;
|
|
}
|
|
|
|
// Build the rendered text for cache keying. We feed the same text the model
|
|
// will see; that lets the cache survive small client-side reformatting of
|
|
// chat history (the cache is keyed on bytes, not tokens).
|
|
static std::string render_prompt_text(const backend::PredictOptions *request) {
|
|
// Two-mode: either the raw prompt or the chat-template path. We mirror
|
|
// build_prompt's branching but accumulate text (not tokens) so we can
|
|
// SHA1 it for the cache key. ds4_session caches a tokens-indexed
|
|
// checkpoint, but the disk format keys on bytes per ds4-server's design.
|
|
if (!request->usetokenizertemplate() || request->messages_size() == 0) {
|
|
return request->prompt();
|
|
}
|
|
std::string out;
|
|
const std::string sys_role = "system";
|
|
for (const auto &m : request->messages()) {
|
|
if (m.role() == sys_role) { out += "[sys] " + m.content() + "\n"; break; }
|
|
}
|
|
for (const auto &m : request->messages()) {
|
|
if (m.role() == sys_role) continue;
|
|
out += "[" + m.role() + "] " + m.content() + "\n";
|
|
}
|
|
return out;
|
|
}
|
|
|
|
ds4cpp::KvCache g_kv_cache;
|
|
|
|
// Try to recover prefill state for `rendered`. Returns the matched prefix length.
|
|
static size_t maybe_load_cache(const std::string &rendered) {
|
|
if (!g_kv_cache.enabled() || !g_session) return 0;
|
|
return g_kv_cache.LoadLongestPrefix(g_session, rendered, g_ctx_size);
|
|
}
|
|
|
|
static void maybe_save_cache(const std::string &rendered) {
|
|
if (g_kv_cache.enabled() && g_session) {
|
|
g_kv_cache.Save(g_session, rendered, g_ctx_size);
|
|
}
|
|
}
|
|
|
|
static void build_prompt(ds4_engine *engine, const backend::PredictOptions *request,
|
|
ds4_tokens *out) {
|
|
if (!request->usetokenizertemplate() || request->messages_size() == 0) {
|
|
ds4_tokenize_text(engine, request->prompt().c_str(), out);
|
|
return;
|
|
}
|
|
// Chat-template path: render via ds4's helpers.
|
|
ds4_chat_begin(engine, out);
|
|
|
|
ds4_think_mode think = parse_think_mode(request);
|
|
|
|
// ds4_encode_chat_prompt is convenient when there is exactly one
|
|
// system+user pair, but for arbitrary turn lists we use the granular
|
|
// append helpers. Pull the first system message (if any), then append
|
|
// every other message in order.
|
|
const std::string sys_role = "system";
|
|
std::string system_text;
|
|
for (const auto &m : request->messages()) {
|
|
if (m.role() == sys_role) { system_text = m.content(); break; }
|
|
}
|
|
// Inject the tools manifest into the system prompt when tools are present.
|
|
// ds4 was trained to emit DSML tool calls ONLY when this preamble is in
|
|
// the system message - without it, the model has no idea tools exist and
|
|
// the e2e tool-call test will fail. The renderer lives in dsml_renderer
|
|
// and is a verbatim port of ds4_server.c's append_tools_prompt_text.
|
|
std::string tools_manifest;
|
|
if (!request->tools().empty()) {
|
|
tools_manifest = ds4cpp::RenderToolsManifest(request->tools());
|
|
}
|
|
if (!system_text.empty() || !tools_manifest.empty()) {
|
|
std::string combined = system_text;
|
|
if (!tools_manifest.empty()) {
|
|
if (!combined.empty()) combined += "\n\n";
|
|
combined += tools_manifest;
|
|
}
|
|
ds4_chat_append_message(engine, out, "system", combined.c_str());
|
|
}
|
|
for (const auto &m : request->messages()) {
|
|
if (m.role() == sys_role) continue;
|
|
if (m.role() == "assistant" && !m.tool_calls().empty()) {
|
|
std::string combined = m.content();
|
|
combined += ds4cpp::RenderAssistantToolCalls(m.tool_calls());
|
|
ds4_chat_append_message(engine, out, "assistant", combined.c_str());
|
|
} else if (m.role() == "tool") {
|
|
std::string body = ds4cpp::RenderToolResult(m.tool_call_id(), m.content());
|
|
ds4_chat_append_message(engine, out, "user", body.c_str());
|
|
} else {
|
|
ds4_chat_append_message(engine, out, m.role().c_str(), m.content().c_str());
|
|
}
|
|
}
|
|
ds4_chat_append_assistant_prefix(engine, out, think);
|
|
}
|
|
|
|
class DS4Backend final : public backend::Backend::Service {
|
|
public:
|
|
GStatus Health(ServerContext *, const backend::HealthMessage *,
|
|
backend::Reply *reply) override {
|
|
reply->set_message(std::string("OK"));
|
|
return GStatus::OK;
|
|
}
|
|
|
|
GStatus Free(ServerContext *, const backend::HealthMessage *,
|
|
backend::Result *result) override {
|
|
std::lock_guard<std::mutex> lock(g_engine_mu);
|
|
if (g_session) { ds4_session_free(g_session); g_session = nullptr; }
|
|
if (g_engine) { ds4_engine_close(g_engine); g_engine = nullptr; }
|
|
result->set_success(true);
|
|
return GStatus::OK;
|
|
}
|
|
|
|
GStatus LoadModel(ServerContext *, const backend::ModelOptions *request,
|
|
backend::Result *result) override {
|
|
std::lock_guard<std::mutex> lock(g_engine_mu);
|
|
|
|
// Reset distributed state so a model swap (a second LoadModel without
|
|
// ds4_role) doesn't inherit a stale coordinator configuration.
|
|
g_distributed = false;
|
|
g_route_timeout_sec = 60;
|
|
|
|
if (g_engine) {
|
|
if (g_session) { ds4_session_free(g_session); g_session = nullptr; }
|
|
ds4_engine_close(g_engine);
|
|
g_engine = nullptr;
|
|
}
|
|
|
|
std::string model_path = request->modelfile();
|
|
if (model_path.empty()) model_path = request->model();
|
|
if (model_path.empty()) {
|
|
result->set_success(false);
|
|
result->set_message("ds4: ModelOptions.Model or .ModelFile must be set");
|
|
return GStatus::OK;
|
|
}
|
|
|
|
ds4_engine_options opt = {};
|
|
opt.model_path = model_path.c_str();
|
|
opt.n_threads = request->threads() > 0 ? request->threads() : 0;
|
|
opt.mtp_margin = 3.0f; // ds4 default; overridable via the mtp_margin option
|
|
|
|
#if defined(DS4_NO_GPU)
|
|
opt.backend = DS4_BACKEND_CPU;
|
|
#elif defined(__APPLE__)
|
|
opt.backend = DS4_BACKEND_METAL;
|
|
#else
|
|
opt.backend = DS4_BACKEND_CUDA;
|
|
#endif
|
|
|
|
// Stable storage for string-valued engine options. The engine reads
|
|
// these by pointer during ds4_engine_open, so the std::string backing
|
|
// store must outlive the call and not reallocate; reserve up front so
|
|
// push_back keeps every prior c_str() valid. Static + clear() reuses
|
|
// the buffer across LoadModel calls (the old engine is closed above).
|
|
static std::vector<std::string> s_opt_strings;
|
|
s_opt_strings.clear();
|
|
s_opt_strings.reserve(sizeof(kEngineOptSpecs) / sizeof(kEngineOptSpecs[0]));
|
|
|
|
// Directory of the main model, used to resolve relative path options.
|
|
std::string model_dir;
|
|
if (auto slash = model_path.find_last_of('/'); slash != std::string::npos) {
|
|
model_dir = model_path.substr(0, slash);
|
|
}
|
|
|
|
std::string ds4_role, ds4_layers, ds4_listen;
|
|
for (const auto &o : request->options()) {
|
|
auto [k, v] = split_option(o);
|
|
if (k == "kv_cache_dir") { g_kv_cache_dir = v; continue; }
|
|
else if (k == "ds4_role") { ds4_role = v; continue; }
|
|
else if (k == "ds4_layers") { ds4_layers = v; continue; }
|
|
else if (k == "ds4_listen") { ds4_listen = v; continue; }
|
|
else if (k == "ds4_route_timeout") {
|
|
if (!parse_positive_int(v, &g_route_timeout_sec)) {
|
|
result->set_success(false);
|
|
result->set_message("ds4: ds4_route_timeout must be a positive integer");
|
|
return GStatus::OK;
|
|
}
|
|
continue;
|
|
}
|
|
std::string err;
|
|
if (!apply_engine_option(&opt, k, v, model_dir, s_opt_strings, err)) {
|
|
result->set_success(false);
|
|
result->set_message("ds4: " + err);
|
|
return GStatus::OK;
|
|
}
|
|
}
|
|
|
|
g_kv_cache.SetDir(g_kv_cache_dir);
|
|
|
|
// Coordinator wiring. 'ds4_role:coordinator' enables layer-split
|
|
// distributed inference: this process listens on ds4_listen and owns
|
|
// the ds4_layers slice; workers dial in (see `local-ai worker
|
|
// ds4-distributed`). Absent ds4_role => unchanged single-node path.
|
|
// Must be static: opt.distributed.listen_host is a const char* the
|
|
// engine retains past this call, so it cannot point at a local that
|
|
// goes out of scope (otherwise a future "simplify to local" refactor
|
|
// reintroduces a dangling pointer).
|
|
static std::string s_listen_host;
|
|
if (ds4_role == "coordinator") {
|
|
if (ds4_layers.empty() || ds4_listen.empty()) {
|
|
result->set_success(false);
|
|
result->set_message("ds4: ds4_role:coordinator requires ds4_layers and ds4_listen");
|
|
return GStatus::OK;
|
|
}
|
|
// host:port for IPv4/hostname; IPv6 literals are unsupported (the
|
|
// first colon would split inside the address).
|
|
auto host_port = split_option(ds4_listen); // "host:port" -> {host, port}
|
|
if (host_port.second.empty()) {
|
|
result->set_success(false);
|
|
result->set_message("ds4: ds4_listen must be host:port");
|
|
return GStatus::OK;
|
|
}
|
|
int listen_port = 0;
|
|
if (!parse_positive_int(host_port.second, &listen_port)) {
|
|
result->set_success(false);
|
|
result->set_message("ds4: ds4_listen port must be a positive integer");
|
|
return GStatus::OK;
|
|
}
|
|
ds4_distributed_layers layers = {};
|
|
if (!parse_layers_spec(ds4_layers, &layers)) {
|
|
result->set_success(false);
|
|
result->set_message("ds4: invalid ds4_layers (want START:END or START:output)");
|
|
return GStatus::OK;
|
|
}
|
|
s_listen_host = host_port.first;
|
|
opt.distributed.role = DS4_DISTRIBUTED_COORDINATOR;
|
|
opt.distributed.layers = layers;
|
|
opt.distributed.listen_host = s_listen_host.c_str();
|
|
opt.distributed.listen_port = listen_port;
|
|
g_distributed = true;
|
|
}
|
|
|
|
int rc = ds4_engine_open(&g_engine, &opt);
|
|
if (rc != 0 || !g_engine) {
|
|
result->set_success(false);
|
|
result->set_message("ds4_engine_open failed (rc=" + std::to_string(rc) + ")");
|
|
return GStatus::OK;
|
|
}
|
|
|
|
g_ctx_size = request->contextsize() > 0 ? request->contextsize() : 32768;
|
|
rc = ds4_session_create(&g_session, g_engine, g_ctx_size);
|
|
if (rc != 0 || !g_session) {
|
|
ds4_engine_close(g_engine);
|
|
g_engine = nullptr;
|
|
result->set_success(false);
|
|
result->set_message("ds4_session_create failed (rc=" + std::to_string(rc) + ")");
|
|
return GStatus::OK;
|
|
}
|
|
|
|
result->set_success(true);
|
|
result->set_message("loaded " + model_path);
|
|
return GStatus::OK;
|
|
}
|
|
|
|
GStatus TokenizeString(ServerContext *, const backend::PredictOptions *request,
|
|
backend::TokenizationResponse *response) override {
|
|
std::lock_guard<std::mutex> lock(g_engine_mu);
|
|
if (!g_engine) return GStatus(StatusCode::FAILED_PRECONDITION, "ds4: model not loaded");
|
|
ds4_tokens out = {};
|
|
ds4_tokenize_text(g_engine, request->prompt().c_str(), &out);
|
|
for (int i = 0; i < out.len; ++i) response->add_tokens(out.v[i]);
|
|
response->set_length(out.len);
|
|
ds4_tokens_free(&out);
|
|
return GStatus::OK;
|
|
}
|
|
|
|
GStatus Predict(ServerContext *, const backend::PredictOptions *request,
|
|
backend::Reply *reply) override {
|
|
std::unique_lock<std::mutex> lock(g_engine_mu);
|
|
if (!g_engine || !g_session) {
|
|
return GStatus(StatusCode::FAILED_PRECONDITION, "ds4: model not loaded");
|
|
}
|
|
if (std::string route_err = wait_route_ready(lock); !route_err.empty()) {
|
|
return GStatus(StatusCode::UNAVAILABLE, route_err);
|
|
}
|
|
ds4_tokens prompt = {};
|
|
build_prompt(g_engine, request, &prompt);
|
|
int n_predict = request->tokens() > 0 ? request->tokens() : 256;
|
|
|
|
CollectCtx collect = {g_engine, "", {}, reply, 0, {}, "", ""};
|
|
std::string cache_key = render_prompt_text(request);
|
|
size_t cache_hit = maybe_load_cache(cache_key);
|
|
(void)cache_hit; // future: skip prompt prefix if hit covers full prompt
|
|
|
|
// Manual generation loop on g_session. When MTP speculative weights
|
|
// were loaded (LoadModel option 'mtp_path:'), we use the
|
|
// ds4_session_eval_speculative_argmax path which may accept N>1
|
|
// tokens per outer iteration. Otherwise per-token argmax + eval.
|
|
// Either way g_session advances so the disk KV cache picks up a
|
|
// real checkpoint after the call (see maybe_save_cache below).
|
|
char err[256] = {0};
|
|
int rc = ds4_session_sync(g_session, &prompt, err, sizeof(err));
|
|
int prompt_len = prompt.len;
|
|
ds4_tokens_free(&prompt);
|
|
if (rc == 0) {
|
|
const int eos = ds4_token_eos(g_engine);
|
|
const int draft_max = ds4_engine_mtp_draft_tokens(g_engine);
|
|
const bool think_enabled = ds4_think_mode_enabled(parse_think_mode(request));
|
|
int produced = 0;
|
|
while (produced < n_predict) {
|
|
SampleParams sp = compute_sample_params(request, collect.parser, think_enabled);
|
|
int first;
|
|
if (sp.temperature <= 0.0f) {
|
|
first = ds4_session_argmax(g_session);
|
|
} else {
|
|
first = ds4_session_sample(g_session,
|
|
sp.temperature, sp.top_k,
|
|
sp.top_p, sp.min_p, get_rng());
|
|
}
|
|
if (first == eos) break;
|
|
// MTP only when sampling is greedy (ds4-server gate).
|
|
if (draft_max > 0 && sp.temperature <= 0.0f) {
|
|
constexpr int kAcceptedMax = 8;
|
|
int accepted[kAcceptedMax];
|
|
int cap = std::min(kAcceptedMax, draft_max + 1);
|
|
int n = ds4_session_eval_speculative_argmax(
|
|
g_session, first, draft_max, eos,
|
|
accepted, cap, err, sizeof(err));
|
|
if (n < 0) { rc = -1; break; }
|
|
bool stop = false;
|
|
for (int j = 0; j < n; ++j) {
|
|
if (accepted[j] == eos) { stop = true; break; }
|
|
collect_emit(&collect, accepted[j]);
|
|
if (++produced >= n_predict) { stop = true; break; }
|
|
}
|
|
if (stop) break;
|
|
} else {
|
|
collect_emit(&collect, first);
|
|
if (++produced >= n_predict) break;
|
|
rc = ds4_session_eval(g_session, first, err, sizeof(err));
|
|
if (rc != 0) break;
|
|
}
|
|
}
|
|
collect_done(&collect);
|
|
}
|
|
maybe_save_cache(cache_key);
|
|
|
|
// Flush any buffered parser state.
|
|
std::vector<ds4cpp::ParserEvent> events;
|
|
collect.parser.Flush(events);
|
|
apply_events(&collect, events);
|
|
|
|
if (rc != 0) {
|
|
return GStatus(StatusCode::INTERNAL,
|
|
std::string("ds4 generation failed: ") + err);
|
|
}
|
|
|
|
// Emit one ChatDelta with content/reasoning/tool_calls.
|
|
auto *delta = reply->add_chat_deltas();
|
|
delta->set_content(collect.content_buf);
|
|
delta->set_reasoning_content(collect.reasoning_buf);
|
|
for (size_t i = 0; i < collect.pending.size(); ++i) {
|
|
auto *tc = delta->add_tool_calls();
|
|
tc->set_index(static_cast<int32_t>(i));
|
|
tc->set_id(collect.pending[i].id);
|
|
tc->set_name(collect.pending[i].name);
|
|
tc->set_arguments(collect.pending[i].args);
|
|
}
|
|
|
|
reply->set_message(collect.raw_buf);
|
|
reply->set_tokens(collect.tokens);
|
|
reply->set_prompt_tokens(prompt_len);
|
|
return GStatus::OK;
|
|
}
|
|
|
|
GStatus PredictStream(ServerContext *, const backend::PredictOptions *request,
|
|
ServerWriter<backend::Reply> *writer) override {
|
|
std::unique_lock<std::mutex> lock(g_engine_mu);
|
|
if (!g_engine || !g_session) {
|
|
return GStatus(StatusCode::FAILED_PRECONDITION, "ds4: model not loaded");
|
|
}
|
|
if (std::string route_err = wait_route_ready(lock); !route_err.empty()) {
|
|
return GStatus(StatusCode::UNAVAILABLE, route_err);
|
|
}
|
|
ds4_tokens prompt = {};
|
|
build_prompt(g_engine, request, &prompt);
|
|
int n_predict = request->tokens() > 0 ? request->tokens() : 256;
|
|
|
|
StreamCtx s = {g_engine, writer, {}, 0, false, {}};
|
|
std::string cache_key = render_prompt_text(request);
|
|
size_t cache_hit = maybe_load_cache(cache_key);
|
|
(void)cache_hit;
|
|
|
|
// Manual loop on g_session - see Predict() above for the rationale.
|
|
// MTP speculative path used when ds4_engine_mtp_draft_tokens > 0.
|
|
char err[256] = {0};
|
|
int rc = ds4_session_sync(g_session, &prompt, err, sizeof(err));
|
|
ds4_tokens_free(&prompt);
|
|
if (rc == 0) {
|
|
const int eos = ds4_token_eos(g_engine);
|
|
const int draft_max = ds4_engine_mtp_draft_tokens(g_engine);
|
|
const bool think_enabled = ds4_think_mode_enabled(parse_think_mode(request));
|
|
int produced = 0;
|
|
while (produced < n_predict && !s.aborted) {
|
|
SampleParams sp = compute_sample_params(request, s.parser, think_enabled);
|
|
int first;
|
|
if (sp.temperature <= 0.0f) {
|
|
first = ds4_session_argmax(g_session);
|
|
} else {
|
|
first = ds4_session_sample(g_session,
|
|
sp.temperature, sp.top_k,
|
|
sp.top_p, sp.min_p, get_rng());
|
|
}
|
|
if (first == eos) break;
|
|
if (draft_max > 0 && sp.temperature <= 0.0f) {
|
|
constexpr int kAcceptedMax = 8;
|
|
int accepted[kAcceptedMax];
|
|
int cap = std::min(kAcceptedMax, draft_max + 1);
|
|
int n = ds4_session_eval_speculative_argmax(
|
|
g_session, first, draft_max, eos,
|
|
accepted, cap, err, sizeof(err));
|
|
if (n < 0) { rc = -1; break; }
|
|
bool stop = false;
|
|
for (int j = 0; j < n; ++j) {
|
|
if (accepted[j] == eos) { stop = true; break; }
|
|
stream_emit(&s, accepted[j]);
|
|
if (s.aborted) { stop = true; break; }
|
|
if (++produced >= n_predict) { stop = true; break; }
|
|
}
|
|
if (stop) break;
|
|
} else {
|
|
stream_emit(&s, first);
|
|
if (s.aborted || ++produced >= n_predict) break;
|
|
rc = ds4_session_eval(g_session, first, err, sizeof(err));
|
|
if (rc != 0) break;
|
|
}
|
|
}
|
|
stream_done(&s);
|
|
}
|
|
maybe_save_cache(cache_key);
|
|
|
|
// Flush parser state.
|
|
std::vector<ds4cpp::ParserEvent> events;
|
|
s.parser.Flush(events);
|
|
if (!events.empty() && !s.aborted) {
|
|
backend::Reply reply;
|
|
auto *delta = reply.add_chat_deltas();
|
|
for (const auto &e : events) {
|
|
if (e.type == ds4cpp::ParserEvent::CONTENT) {
|
|
delta->set_content(delta->content() + e.text);
|
|
} else if (e.type == ds4cpp::ParserEvent::REASONING) {
|
|
delta->set_reasoning_content(delta->reasoning_content() + e.text);
|
|
}
|
|
}
|
|
s.writer->Write(reply);
|
|
}
|
|
|
|
if (rc != 0 && !s.aborted) {
|
|
return GStatus(StatusCode::INTERNAL,
|
|
std::string("ds4 generation failed: ") + err);
|
|
}
|
|
return GStatus::OK;
|
|
}
|
|
|
|
GStatus Status(ServerContext *, const backend::HealthMessage *,
|
|
backend::StatusResponse *response) override {
|
|
std::lock_guard<std::mutex> lock(g_engine_mu);
|
|
response->set_state(g_engine ? backend::StatusResponse::READY
|
|
: backend::StatusResponse::UNINITIALIZED);
|
|
return GStatus::OK;
|
|
}
|
|
};
|
|
|
|
void RunServer(const std::string &addr) {
|
|
DS4Backend service;
|
|
grpc::EnableDefaultHealthCheckService(true);
|
|
grpc::reflection::InitProtoReflectionServerBuilderPlugin();
|
|
|
|
ServerBuilder builder;
|
|
builder.AddListeningPort(addr, grpc::InsecureServerCredentials());
|
|
builder.RegisterService(&service);
|
|
builder.SetMaxReceiveMessageSize(64 * 1024 * 1024);
|
|
builder.SetMaxSendMessageSize(64 * 1024 * 1024);
|
|
|
|
std::unique_ptr<Server> server(builder.BuildAndStart());
|
|
if (!server) {
|
|
std::cerr << "ds4 grpc-server: failed to bind " << addr << "\n";
|
|
std::exit(1);
|
|
}
|
|
g_server = server.get();
|
|
std::cerr << "ds4 grpc-server listening on " << addr << "\n";
|
|
server->Wait();
|
|
}
|
|
|
|
void signal_handler(int) {
|
|
if (auto *srv = g_server.load()) {
|
|
srv->Shutdown(std::chrono::system_clock::now() +
|
|
std::chrono::seconds(3));
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
int main(int argc, char *argv[]) {
|
|
std::string addr = "127.0.0.1:50051";
|
|
for (int i = 1; i < argc; ++i) {
|
|
std::string a = argv[i];
|
|
const std::string addr_flag = "--addr=";
|
|
if (a.rfind(addr_flag, 0) == 0) addr = a.substr(addr_flag.size());
|
|
else if (a == "--addr" && i + 1 < argc) addr = argv[++i];
|
|
else if (a == "--help" || a == "-h") {
|
|
std::cout << "Usage: grpc-server --addr=HOST:PORT\n";
|
|
return 0;
|
|
}
|
|
}
|
|
std::signal(SIGINT, signal_handler);
|
|
std::signal(SIGTERM, signal_handler);
|
|
RunServer(addr);
|
|
return 0;
|
|
}
|