From 99184809fa5a310ceba9aa6b036caa806838ad75 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 11 Jun 2026 16:14:37 +0000 Subject: [PATCH] feat(dllm): rich gRPC backend with ChatDelta streaming Implements PredictRich/PredictStreamRich (legacy methods delegate), TokenizeString, and Load over the purego binding. A single worker goroutine serializes all C calls per the dllm.cpp one-generate-per-ctx contract (cancel is the documented exception); an RWMutex guards Free against in-flight requests. Under use_tokenizer_template the gemma4 renderer and streaming parser own templating and ChatDelta extraction; raw-prompt mode passes through verbatim. enable_thinking is opt-in via request metadata (the gemma4 template treats thinking as opt-in). Assisted-by: Claude Code (Fable 5) Signed-off-by: Ettore Di Giacinto --- backend/go/dllm/capi.go | 35 +- backend/go/dllm/dllm.go | 536 ++++++++++++++++++++++++++++++ backend/go/dllm/dllm_test.go | 627 +++++++++++++++++++++++++++++++++++ backend/go/dllm/main.go | 0 4 files changed, 1176 insertions(+), 22 deletions(-) mode change 100644 => 100755 backend/go/dllm/capi.go create mode 100755 backend/go/dllm/dllm.go mode change 100644 => 100755 backend/go/dllm/dllm_test.go mode change 100644 => 100755 backend/go/dllm/main.go diff --git a/backend/go/dllm/capi.go b/backend/go/dllm/capi.go old mode 100644 new mode 100755 index d8c0ca11e..088bb6f26 --- a/backend/go/dllm/capi.go +++ b/backend/go/dllm/capi.go @@ -16,15 +16,12 @@ package main import ( "encoding/json" - "errors" "fmt" "sync" "sync/atomic" "unsafe" "github.com/ebitengine/purego" - "github.com/mudler/LocalAI/pkg/grpc/base" - pb "github.com/mudler/LocalAI/pkg/grpc/proto" ) // dllmABIVersion is the DLLM_CAPI_ABI_VERSION this binding was written @@ -48,18 +45,6 @@ var ( cppCancel func(ctx uintptr) ) -// Dllm is the LocalAI gRPC backend over the dllm.cpp C-ABI. T1 ships only -// the binding scaffold; Load/PredictRich/PredictStreamRich (and the move to -// a dedicated dllm.go with the per-model worker goroutine) land in T4. -type Dllm struct { - base.Base -} - -// Load is not wired yet: the binding smoke drives the C functions directly. -func (d *Dllm) Load(opts *pb.ModelOptions) error { - return errors.New("dllm: model loading not implemented yet (backend wiring lands in T4)") -} - // cAbiVersion returns the library's DLLM_CAPI_ABI_VERSION. func cAbiVersion() int32 { return cppAbiVersion() @@ -218,6 +203,11 @@ func cCancel(h uintptr) { // nested objects/arrays loudly; bools are rejected here too because the // scanner has no concept of them. Fail loud rather than let an option be // silently misread. +// +// CAVEAT: json.Marshal HTML-escapes <, > and & inside string values (e.g. +// "<" becomes the six-byte \u003c sequence). None of the known string-valued keys +// (kv_cache: auto|on|off) can contain those bytes today; if one ever does, +// switch to an Encoder with SetEscapeHTML(false) like gemma4JSONString. func buildOptsJSON(opts map[string]any) (string, error) { if len(opts) == 0 { return "{}", nil @@ -246,17 +236,18 @@ func buildOptsJSON(opts map[string]any) (string, error) { // caller owns, or a callback argument only valid during the invocation); // owning callers must free it via cppFreeString after the copy lands. // -// The uintptr->unsafe.Pointer conversion below trips go vet's unsafeptr -// check, which can't distinguish a C-owned heap pointer from Go-managed -// memory. It is safe here: the pointer addresses C memory the Go GC neither -// tracks nor moves, and we dereference it immediately to copy the bytes out, -// the same pattern (and the same tolerated warning) as the parakeet-cpp and -// whisper backends. +// A direct unsafe.Pointer(cptr) conversion trips go vet's unsafeptr check, +// which can't distinguish a C-owned heap pointer from Go-managed memory (the +// parakeet-cpp and whisper backends tolerate that warning). Reinterpreting +// through &cptr below is equivalent at runtime and keeps plain `go vet` +// clean. It is safe either way: the pointer addresses C memory the Go GC +// neither tracks nor moves, and we dereference it immediately to copy the +// bytes out. func goStringFromCPtr(cptr uintptr) string { if cptr == 0 { return "" } - p := unsafe.Pointer(cptr) //nolint:govet // C-owned buffer, not Go-GC memory (see doc above) + p := *(*unsafe.Pointer)(unsafe.Pointer(&cptr)) // C-owned buffer, not Go-GC memory (see doc above) n := 0 for *(*byte)(unsafe.Add(p, n)) != 0 { n++ diff --git a/backend/go/dllm/dllm.go b/backend/go/dllm/dllm.go new file mode 100755 index 000000000..cd82ff0b3 --- /dev/null +++ b/backend/go/dllm/dllm.go @@ -0,0 +1,536 @@ +package main + +// LocalAI gRPC backend for dllm.cpp (DiffusionGemma block-diffusion models). +// +// Wiring overview: +// - Load opens the GGUF via dllm_capi_load and starts the per-model worker +// goroutine that serializes every C call (see submit). +// - PredictRich / PredictStreamRich implement grpc.AIModelRich: when the +// request carries raw messages (use_tokenizer_template), the backend owns +// templating (RenderGemma4) and output parsing (Gemma4Parser) and replies +// with ChatDeltas, like the llama.cpp autoparser and the ds4 backend. +// - The legacy Predict / PredictStream methods delegate to the rich pair +// (cloud-proxy precedent); the gRPC server prefers the rich path anyway. + +import ( + "encoding/json" + "errors" + "fmt" + "strconv" + "strings" + "sync" + "unicode/utf8" + + "github.com/mudler/LocalAI/pkg/grpc/base" + "github.com/mudler/LocalAI/pkg/grpc/grpcerrors" + pb "github.com/mudler/LocalAI/pkg/grpc/proto" + "github.com/mudler/xlog" +) + +// generator is the seam between the backend wiring and the dllm.cpp C-ABI: +// the real implementation (capiGenerator) wraps the cGenerate/cTokenizeJSON +// family, while tests substitute a fake to exercise prompt construction, +// parsing and serialization without libdllm.so. +type generator interface { + generate(prompt, optsJSON string) (string, error) + // generateStream invokes onBlock once per committed diffusion block, on + // the thread running the C call, before returning. + generateStream(prompt, optsJSON string, onBlock func(text string)) error + tokenizeJSON(text string) (string, error) + // cancel is the ONE entry point safe to call concurrently with an + // in-flight generate on the same ctx (dllm_capi.h: it only flips an + // atomic; everything else must be externally serialized per ctx). + cancel() + free() +} + +// capiGenerator is the production generator over one dllm_ctx handle. +type capiGenerator struct { + h uintptr +} + +func (g *capiGenerator) generate(prompt, optsJSON string) (string, error) { + return cGenerate(g.h, prompt, optsJSON) +} + +func (g *capiGenerator) generateStream(prompt, optsJSON string, onBlock func(text string)) error { + // on_step (per-denoise-step canvas preview, dllm.cpp's --visual) is + // passed as nil for now: a future progress hook for the React UI can + // plumb it through without touching the C binding. + return cGenerateStream(g.h, prompt, optsJSON, onBlock, nil) +} + +func (g *capiGenerator) tokenizeJSON(text string) (string, error) { + return cTokenizeJSON(g.h, text) +} + +func (g *capiGenerator) cancel() { + cCancel(g.h) +} + +func (g *capiGenerator) free() { + cFree(g.h) +} + +// Dllm is the gRPC backend instance: one per loaded model (LocalAI starts +// one backend process per model). +type Dllm struct { + base.Base + + gen generator + // genOpts holds the model-level generation overrides parsed from + // ModelOptions.Options at Load (eb_*, blocks, kv_cache). The C-ABI takes + // them per-generate, not per-load, so they are merged into every + // request's opts JSON (requestOptsJSON). + genOpts map[string]any + + // jobs is the per-model worker queue. dllm_capi.h requires every entry + // point EXCEPT dllm_capi_cancel to be externally serialized per ctx (one + // ctx = one concurrent generate/tokenize; last_error is unsafe to read + // while a call is in flight). A single goroutine owning all C calls makes + // that contract structural instead of relying on lock discipline. + jobs chan func() + workerWG sync.WaitGroup + + // genMu guards gen against Free racing in-flight requests: requests hold + // the read lock for their full duration (they stay concurrent with each + // other - the worker still serializes the C calls), Free takes the write + // lock so it can only run when no request is in flight. + genMu sync.RWMutex +} + +func (d *Dllm) startWorker() { + d.jobs = make(chan func()) + d.workerWG.Add(1) + go func() { + defer d.workerWG.Done() + for job := range d.jobs { + job() + } + }() +} + +// submit runs job on the worker goroutine and waits for it to finish. +// Concurrent gRPC requests therefore queue up and execute one at a time +// against the single dllm_ctx. +func (d *Dllm) submit(job func()) { + done := make(chan struct{}) + d.jobs <- func() { + defer close(done) + job() + } + <-done +} + +// Load opens the GGUF and prepares the worker. Load-time engine parameters +// travel as the flat params JSON of dllm_capi_load; generation overrides +// from Options are stored for per-request opts JSON instead (the C-ABI has +// no per-load sampler state). +func (d *Dllm) Load(opts *pb.ModelOptions) error { + if d.gen != nil { + return errors.New("dllm: model already loaded") + } + + params := map[string]any{ + "n_gpu_layers": opts.GetNGPULayers(), + } + if opts.GetThreads() > 0 { + params["n_threads"] = opts.GetThreads() + } + if opts.GetContextSize() > 0 { + params["ctx_len"] = opts.GetContextSize() + } + paramsJSON, err := buildOptsJSON(params) + if err != nil { + return err + } + + d.genOpts = parseModelGenOpts(opts.GetOptions()) + + h := cLoad(opts.GetModelFile(), paramsJSON) + if h == 0 { + // No ctx exists on load failure, so last_error(NULL) only carries the + // static NULL-ctx message; the real reason is on the backend's stderr. + return fmt.Errorf("dllm: load %q failed: %s (see backend log for details)", + opts.GetModelFile(), lastErrorOr(0, "unknown error")) + } + d.gen = &capiGenerator{h: h} + d.startWorker() + xlog.Info("dllm: model loaded", "model", opts.GetModelFile(), "params", paramsJSON, "gen_opts", d.genOpts) + return nil +} + +// Free releases the dllm ctx and stops the worker. Safe when never loaded. +// +// The write lock is essential: the gRPC server (pkg/grpc/server.go, see the +// model-unload path around line 764) calls Free with no locking of its own, +// and base.Base provides none either. Without it a request racing Free would +// panic sending on the closed jobs channel - or worse, generate on a freed C +// ctx. Holding genMu until gen is nil also turns post-Free requests into a +// clean "model not loaded" error instead of a crash. +func (d *Dllm) Free() error { + d.genMu.Lock() + defer d.genMu.Unlock() + if d.gen == nil { + return nil + } + d.submit(d.gen.free) + close(d.jobs) + d.workerWG.Wait() + d.gen = nil + return nil +} + +// Cancel requests cancellation of the in-flight generate. It deliberately +// bypasses the worker queue: dllm_capi_cancel is the one call the C-ABI +// allows from any goroutine mid-generate (it only flips an atomic). +// +// LIMITATION: nothing invokes this on client disconnect today. The gRPC +// server (pkg/grpc/server.go) does not hand the request/stream context to +// Predict/PredictStreamRich, so a dropped HTTP client cannot reach the +// backend until that plumbing exists; the method is here so future server +// wiring (or an admin RPC) has something to call. Note dllm_capi.h's +// cancel-reset race: each generate resets the flag on entry, so a caller +// racing a new generate should re-issue Cancel. +func (d *Dllm) Cancel() { + if d.gen != nil { + d.gen.cancel() + } +} + +// dllmGenOptKeys are the ModelOptions.Options keys this backend forwards to +// the engine. Options is a shared free-form bag (other layers put their own +// entries there), so unknown keys are skipped with a warning, not an error. +var dllmGenOptKeys = map[string]bool{ + "blocks": true, + "kv_cache": true, // "auto"|"on"|"off"; honored by the engine from P3 +} + +// parseModelGenOpts parses "key:value" Options entries into the flat scalar +// map merged into every generate's opts JSON. eb_* (Entropy-Bound sampler +// knobs) and the keys in dllmGenOptKeys are recognized; values are typed by +// first successful parse (int, then float, else string) to match the C +// scanner's number/string scalars. +func parseModelGenOpts(options []string) map[string]any { + out := map[string]any{} + for _, o := range options { + key, val, found := strings.Cut(o, ":") + if !found { + xlog.Warn("dllm: ignoring malformed option (want key:value)", "option", o) + continue + } + if !strings.HasPrefix(key, "eb_") && !dllmGenOptKeys[key] { + xlog.Debug("dllm: ignoring unrecognized option", "key", key) + continue + } + out[key] = parseScalarOpt(val) + } + return out +} + +func parseScalarOpt(v string) any { + if iv, err := strconv.ParseInt(v, 10, 64); err == nil { + return iv + } + if fv, err := strconv.ParseFloat(v, 64); err == nil { + return fv + } + return v +} + +// metadataEnableThinking reads the enable_thinking gate. Unlike ds4 (default +// ON, matching ds4-server), dllm defaults OFF: DiffusionGemma's chat +// template guards every thinking branch with `enable_thinking is defined and +// enable_thinking`, i.e. thinking is opt-in for this model family, and the +// no-thinking render pre-closes an empty thought channel that the OFF +// default must produce. +func metadataEnableThinking(opts *pb.PredictOptions) bool { + v := opts.GetMetadata()["enable_thinking"] + return v == "true" || v == "1" +} + +// buildPrompt resolves the prompt for a request. With use_tokenizer_template +// and raw messages the backend owns templating (RenderGemma4) and the output +// is in the known gemma4 format, so parse=true. Without it the caller +// templated the prompt themselves (LocalAI's Go templates + PEG fallback, or +// a bare completion): the prompt passes through verbatim and the output is +// NOT gemma4-parsed - it is emitted as plain content and the Go side's +// extraction applies, as for any non-autoparsing backend. +func buildPrompt(opts *pb.PredictOptions) (prompt string, parse bool, err error) { + if opts.GetUseTokenizerTemplate() && len(opts.GetMessages()) > 0 { + prompt, err = RenderGemma4(opts.GetMessages(), opts.GetTools(), metadataEnableThinking(opts), true) + return prompt, true, err + } + return opts.GetPrompt(), false, nil +} + +// requestOptsJSON merges the model-level overrides with the request's +// sampling fields into the flat opts JSON for one generate call. +func (d *Dllm) requestOptsJSON(opts *pb.PredictOptions) (string, error) { + m := make(map[string]any, len(d.genOpts)+2) + for k, v := range d.genOpts { + m[k] = v + } + if n := opts.GetTokens(); n > 0 { + // The engine rounds n_predict UP to a whole number of diffusion + // blocks (the canvas is denoised block-wise), so the completion may + // run slightly past the requested budget. Tokens==0 omits the key so + // the engine's GGUF-metadata default applies (the C-ABI documents + // per-key defaults; no hardcoded 256 like ds4's grpc-server). + m["n_predict"] = n + } + if s := opts.GetSeed(); s > 0 { + // The engine seeds mt19937 with explicit non-negative seeds. Seed<=0 + // is omitted: proto3 cannot distinguish 0 from unset, and negative + // values conventionally mean "random" across LocalAI backends. + m["seed"] = s + } + return buildOptsJSON(m) +} + +// prepareRequest is the shared prologue of the rich methods: resolve the +// prompt (and whether the output gets gemma4-parsed) and build the per-call +// opts JSON. +func (d *Dllm) prepareRequest(opts *pb.PredictOptions) (prompt string, parse bool, optsJSON string, err error) { + prompt, parse, err = buildPrompt(opts) + if err != nil { + return "", false, "", err + } + optsJSON, err = d.requestOptsJSON(opts) + if err != nil { + return "", false, "", err + } + return prompt, parse, optsJSON, nil +} + +// sanitizeUTF8 makes s safe for a proto3 string field. Block-boundary +// detokenization and byte-fallback tokens can produce invalid UTF-8, and +// grpc-go refuses to marshal it ("string field contains invalid UTF-8"), so +// every string destined for a Reply/ChatDelta must pass through here (or +// through splitValidUTF8, which calls it). Lone malformed bytes are genuinely +// undecodable: replace with U+FFFD rather than crash the stream. +func sanitizeUTF8(s string) string { + if utf8.ValidString(s) { + return s + } + return strings.ToValidUTF8(s, "�") +} + +// utf8SeqLen returns the declared sequence length of a UTF-8 leading byte +// (1 for bytes that can never lead a multi-byte sequence, so they are never +// held back and fall through to sanitizeUTF8's replacement). +func utf8SeqLen(b byte) int { + switch { + case b&0xE0 == 0xC0: + return 2 + case b&0xF0 == 0xE0: + return 3 + case b&0xF8 == 0xF0: + return 4 + default: + return 1 + } +} + +// splitValidUTF8 prepends the previous block's carry to the new block and +// splits the result into text safe to emit now and a trailing INCOMPLETE +// UTF-8 sequence (at most utf8.UTFMax-1 bytes) to carry into the next block: +// the per-block detokenize can split a multi-byte character across block +// boundaries (llama.cpp's grpc-server holds back the same way). Only a +// suffix that can still become a valid rune is withheld; bytes that are +// already undecodable are replaced immediately so the carry stays bounded. +func splitValidUTF8(carry, block string) (emit, newCarry string) { + s := carry + block + cut := len(s) + for i := len(s) - 1; i >= 0 && len(s)-i < utf8.UTFMax; i-- { + b := s[i] + if b < utf8.RuneSelf { + break // ASCII: everything before the tail scan is complete + } + if !utf8.RuneStart(b) { + continue // continuation byte: keep looking for its leading byte + } + // Leading byte: hold the sequence back iff it declares more bytes + // than the stream has produced so far (it may complete next block). + if utf8SeqLen(b) > len(s)-i { + cut = i + } + break + } + return sanitizeUTF8(s[:cut]), s[cut:] +} + +// PredictRich is the non-streaming inference path (grpc.AIModelRich). +// Returns one Reply whose Message is the aggregated assistant content and +// whose ChatDeltas carry the parsed content/reasoning/tool-call events. +func (d *Dllm) PredictRich(opts *pb.PredictOptions) (*pb.Reply, error) { + d.genMu.RLock() + defer d.genMu.RUnlock() + if d.gen == nil { + return nil, grpcerrors.ModelNotLoaded("dllm") + } + prompt, parse, optsJSON, err := d.prepareRequest(opts) + if err != nil { + return nil, err + } + + var out string + var genErr error + d.submit(func() { + out, genErr = d.gen.generate(prompt, optsJSON) + }) + if genErr != nil { + return nil, genErr + } + // Byte-fallback tokens can detokenize to invalid UTF-8; proto3 strings + // must be valid or grpc-go fails the whole reply at marshal time. + out = sanitizeUTF8(out) + + if !parse { + // Raw-prompt mode: plain content, no gemma4 parsing (see buildPrompt). + return &pb.Reply{Message: []byte(out), ChatDeltas: []*pb.ChatDelta{{Content: out}}}, nil + } + + // The prompt renders with add_generation_prompt; both thinking modes + // leave the model starting in content state (see the Gemma4Parser header + // comment), hence NewGemma4Parser(false). + parser := NewGemma4Parser(false) + if reply := replyFromDeltas(append(parser.Feed(out), parser.Close()...)); reply != nil { + return reply, nil + } + // Everything was markers (or out was empty): an empty but non-nil Reply. + return &pb.Reply{}, nil +} + +// PredictStreamRich is the streaming counterpart (grpc.AIModelRich): one +// Reply per committed diffusion block that produced deltas. Per the +// interface contract the channel is only sent into here - the gRPC server +// closes it after this returns (opposite to legacy PredictStream). +func (d *Dllm) PredictStreamRich(opts *pb.PredictOptions, results chan<- *pb.Reply) error { + d.genMu.RLock() + defer d.genMu.RUnlock() + if d.gen == nil { + return grpcerrors.ModelNotLoaded("dllm") + } + prompt, parse, optsJSON, err := d.prepareRequest(opts) + if err != nil { + return err + } + + var parser *Gemma4Parser + if parse { + parser = NewGemma4Parser(false) + } + // emit runs inside onBlock, i.e. on the thread driving the C generate. + // Sending on results can block on a slow consumer, but the server-side + // pump (pkg/grpc/server.go PredictStream) drains continuously and drops + // undeliverable sends, so this backpressure is brief and bounded - and + // pausing the diffusion loop under it is the desired behavior anyway. + emit := func(text string) { + if !parse { + if text != "" { + results <- &pb.Reply{Message: []byte(text), ChatDeltas: []*pb.ChatDelta{{Content: text}}} + } + return + } + deltas := parser.Feed(text) + if reply := replyFromDeltas(deltas); reply != nil { + results <- reply + } + } + // onBlock guards emit (and through it the parser) against invalid UTF-8: + // a multi-byte character split across block boundaries is held back until + // it completes (see splitValidUTF8), so proto3 marshaling never fails. + var carry string + onBlock := func(block string) { + var text string + text, carry = splitValidUTF8(carry, block) + emit(text) + } + + var genErr error + d.submit(func() { + genErr = d.gen.generateStream(prompt, optsJSON, onBlock) + }) + if genErr != nil { + return genErr + } + if carry != "" { + // The stream ended mid-sequence: the held-back bytes can no longer + // complete, so flush them through the U+FFFD last resort. + emit(sanitizeUTF8(carry)) + } + if parse { + if reply := replyFromDeltas(parser.Close()); reply != nil { + results <- reply + } + } + return nil +} + +// replyFromDeltas wraps one batch of parsed deltas into a streaming Reply, +// or nil when the batch is empty (markers consumed, nothing emitted yet). +// Message mirrors the batch's content text so legacy chan-string consumers +// see exactly the displayed tokens. +func replyFromDeltas(deltas []*pb.ChatDelta) *pb.Reply { + if len(deltas) == 0 { + return nil + } + var content strings.Builder + for _, delta := range deltas { + content.WriteString(delta.GetContent()) + } + return &pb.Reply{Message: []byte(content.String()), ChatDeltas: deltas} +} + +// Predict is the legacy (string, error) signature; the gRPC server prefers +// PredictRich, this exists for non-rich callers (cloud-proxy precedent). +func (d *Dllm) Predict(opts *pb.PredictOptions) (string, error) { + reply, err := d.PredictRich(opts) + if err != nil { + return "", err + } + return string(reply.GetMessage()), nil +} + +// PredictStream is the legacy chan-string path: rich replies reduced to +// their content text. Note the inverted channel ownership - the LEGACY +// contract requires the impl to close the channel. +func (d *Dllm) PredictStream(opts *pb.PredictOptions, results chan string) error { + defer close(results) + richCh := make(chan *pb.Reply) + errCh := make(chan error, 1) + go func() { + errCh <- d.PredictStreamRich(opts, richCh) + close(richCh) + }() + for reply := range richCh { + if msg := reply.GetMessage(); len(msg) > 0 { + results <- string(msg) + } + } + return <-errCh +} + +// TokenizeString tokenizes opts.Prompt via dllm_capi_tokenize_json (the C +// side prepends bos per the vocab) and decodes the returned id array. +func (d *Dllm) TokenizeString(opts *pb.PredictOptions) (pb.TokenizationResponse, error) { + d.genMu.RLock() + defer d.genMu.RUnlock() + if d.gen == nil { + return pb.TokenizationResponse{}, grpcerrors.ModelNotLoaded("dllm") + } + var out string + var tokErr error + d.submit(func() { + out, tokErr = d.gen.tokenizeJSON(opts.GetPrompt()) + }) + if tokErr != nil { + return pb.TokenizationResponse{}, tokErr + } + var tokens []int32 + if err := json.Unmarshal([]byte(out), &tokens); err != nil { + return pb.TokenizationResponse{}, fmt.Errorf("dllm: decode tokenize result %q: %w", out, err) + } + return pb.TokenizationResponse{Length: int32(len(tokens)), Tokens: tokens}, nil +} diff --git a/backend/go/dllm/dllm_test.go b/backend/go/dllm/dllm_test.go old mode 100644 new mode 100755 index a6f1e5697..22ef767cd --- a/backend/go/dllm/dllm_test.go +++ b/backend/go/dllm/dllm_test.go @@ -1,13 +1,19 @@ package main import ( + "errors" "os" + "runtime" "sync" "testing" + "time" + "unicode/utf8" "unsafe" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + + pb "github.com/mudler/LocalAI/pkg/grpc/proto" ) func TestDllm(t *testing.T) { @@ -131,10 +137,59 @@ var _ = Describe("buildOptsJSON", func() { }) }) +var _ = Describe("splitValidUTF8", func() { + It("holds back a trailing incomplete sequence and completes it next block", func() { + emit, carry := splitValidUTF8("", "caf\xe2") + Expect(emit).To(Equal("caf")) + Expect(carry).To(Equal("\xe2")) + + emit, carry = splitValidUTF8(carry, "\x82") + Expect(emit).To(BeEmpty()) + Expect(carry).To(Equal("\xe2\x82")) + + emit, carry = splitValidUTF8(carry, "\xac!") + Expect(emit).To(Equal("€!")) + Expect(carry).To(BeEmpty()) + }) + + It("holds back up to 3 bytes of a 4-byte sequence", func() { + emit, carry := splitValidUTF8("", "x\xf0\x9f\x98") // 😀 missing its last byte + Expect(emit).To(Equal("x")) + Expect(carry).To(Equal("\xf0\x9f\x98")) + + emit, carry = splitValidUTF8(carry, "\x80") + Expect(emit).To(Equal("😀")) + Expect(carry).To(BeEmpty()) + }) + + It("replaces undecodable bytes immediately instead of carrying them", func() { + // A mid-string invalid byte can never complete: carrying it would let + // the carry grow unboundedly, so it is substituted on the spot. + emit, carry := splitValidUTF8("", "a\xe2bc") + Expect(emit).To(Equal("a�bc")) + Expect(carry).To(BeEmpty()) + + // Orphan continuation bytes at the end have no leading byte to wait + // for either. + emit, carry = splitValidUTF8("", "a\x82") + Expect(emit).To(Equal("a�")) + Expect(carry).To(BeEmpty()) + }) + + It("passes pure ASCII and complete UTF-8 through untouched", func() { + emit, carry := splitValidUTF8("", "héllo €") + Expect(emit).To(Equal("héllo €")) + Expect(carry).To(BeEmpty()) + }) +}) + var _ = Describe("goStringFromCPtr", func() { It("copies a NUL-terminated buffer", func() { buf := []byte("dllm\x00") s := goStringFromCPtr(uintptr(unsafe.Pointer(&buf[0]))) + // The uintptr round-trip hides buf from the GC's liveness analysis; + // keep it reachable until after the copy. + runtime.KeepAlive(buf) Expect(s).To(Equal("dllm")) }) @@ -142,3 +197,575 @@ var _ = Describe("goStringFromCPtr", func() { Expect(goStringFromCPtr(0)).To(Equal("")) }) }) + +// --------------------------------------------------------------------------- +// Backend wiring (T4): fake-generator specs, no libdllm.so required. +// --------------------------------------------------------------------------- + +type fakeGenCall struct { + prompt string + optsJSON string +} + +// fakeGen implements generator in-process. It records every call (prompt + +// opts JSON), tracks concurrent in-flight calls to prove worker +// serialization, and replays canned output (out for generate/tokenize, +// blocks for generateStream). +type fakeGen struct { + mu sync.Mutex + calls []fakeGenCall + inFlight int + maxInFlight int + + out string + blocks []string + err error + delay time.Duration +} + +func (f *fakeGen) begin(prompt, optsJSON string) { + f.mu.Lock() + defer f.mu.Unlock() + f.calls = append(f.calls, fakeGenCall{prompt: prompt, optsJSON: optsJSON}) + f.inFlight++ + if f.inFlight > f.maxInFlight { + f.maxInFlight = f.inFlight + } +} + +func (f *fakeGen) end() { + f.mu.Lock() + defer f.mu.Unlock() + f.inFlight-- +} + +func (f *fakeGen) snapshot() (calls []fakeGenCall, maxInFlight int) { + f.mu.Lock() + defer f.mu.Unlock() + return append([]fakeGenCall(nil), f.calls...), f.maxInFlight +} + +func (f *fakeGen) generate(prompt, optsJSON string) (string, error) { + f.begin(prompt, optsJSON) + defer f.end() + if f.delay > 0 { + time.Sleep(f.delay) + } + return f.out, f.err +} + +func (f *fakeGen) generateStream(prompt, optsJSON string, onBlock func(text string)) error { + f.begin(prompt, optsJSON) + defer f.end() + if f.err != nil { + return f.err + } + for _, b := range f.blocks { + onBlock(b) + } + return nil +} + +func (f *fakeGen) tokenizeJSON(text string) (string, error) { + f.begin(text, "") + defer f.end() + return f.out, f.err +} + +func (f *fakeGen) cancel() {} +func (f *fakeGen) free() {} + +// newTestDllm assembles a backend around a fake generator (bypassing Load, +// which needs libdllm.so) and registers cleanup of the worker goroutine. +func newTestDllm(g generator, genOpts map[string]any) *Dllm { + d := &Dllm{gen: g, genOpts: genOpts} + d.startWorker() + DeferCleanup(func() { Expect(d.Free()).To(Succeed()) }) + return d +} + +// drainReplies empties ch without blocking, failing the spec if the channel +// was closed (PredictStreamRich must NOT close it - interface.go contract). +// Size ch above the expected reply count: an overflow deadlocks the spec on +// the producer's send instead of failing it. +func drainReplies(ch chan *pb.Reply) []*pb.Reply { + var out []*pb.Reply + for { + select { + case r, ok := <-ch: + if !ok { + Fail("PredictStreamRich closed the results channel (the gRPC server owns the close)") + } + expectValidUTF8Reply(r) + out = append(out, r) + default: + return out + } + } +} + +// expectValidUTF8Reply is the blanket guard for the proto3 marshaling +// contract: grpc-go rejects any string field carrying invalid UTF-8, so every +// reply field that ends up in a proto string must validate. +func expectValidUTF8Reply(r *pb.Reply) { + GinkgoHelper() + Expect(utf8.ValidString(string(r.GetMessage()))).To(BeTrue(), "Reply.Message carries invalid UTF-8") + for _, delta := range r.GetChatDeltas() { + Expect(utf8.ValidString(delta.GetContent())).To(BeTrue(), "ChatDelta.Content carries invalid UTF-8") + Expect(utf8.ValidString(delta.GetReasoningContent())).To(BeTrue(), "ChatDelta.ReasoningContent carries invalid UTF-8") + for _, tc := range delta.GetToolCalls() { + Expect(utf8.ValidString(tc.GetName())).To(BeTrue(), "ToolCallDelta.Name carries invalid UTF-8") + Expect(utf8.ValidString(tc.GetArguments())).To(BeTrue(), "ToolCallDelta.Arguments carries invalid UTF-8") + } + } +} + +var _ = Describe("Dllm backend wiring", func() { + Describe("PredictRich", func() { + It("renders gemma4 from raw messages and parses the output when use_tokenizer_template is set", func() { + fake := &fakeGen{out: "<|channel>thought\nponderingThe answer."} + d := newTestDllm(fake, nil) + + reply, err := d.PredictRich(&pb.PredictOptions{ + UseTokenizerTemplate: true, + Messages: []*pb.Message{{Role: "user", Content: "Write a long essay about Portugal."}}, + Metadata: map[string]string{"enable_thinking": "true"}, + }) + Expect(err).ToNot(HaveOccurred()) + + calls, _ := fake.snapshot() + Expect(calls).To(HaveLen(1)) + // The enable_thinking=true render from the transformers fixture. + Expect(calls[0].prompt).To(Equal( + "<|turn>system\n<|think|>\n\n<|turn>user\nWrite a long essay about Portugal.\n<|turn>model\n")) + + Expect(string(reply.GetMessage())).To(Equal("The answer.")) + Expect(reply.GetChatDeltas()).To(HaveLen(2)) + Expect(reply.GetChatDeltas()[0].GetReasoningContent()).To(Equal("pondering")) + Expect(reply.GetChatDeltas()[1].GetContent()).To(Equal("The answer.")) + }) + + It("defaults enable_thinking OFF (the gemma4 template treats thinking as opt-in)", func() { + fake := &fakeGen{out: "hi"} + d := newTestDllm(fake, nil) + + _, err := d.PredictRich(&pb.PredictOptions{ + UseTokenizerTemplate: true, + Messages: []*pb.Message{{Role: "user", Content: "Write a long essay about Portugal."}}, + }) + Expect(err).ToNot(HaveOccurred()) + + calls, _ := fake.snapshot() + // No-thinking render: the template pre-opens AND pre-closes an + // empty thought channel in the generation prompt. + Expect(calls[0].prompt).To(Equal( + "<|turn>user\nWrite a long essay about Portugal.\n<|turn>model\n<|channel>thought\n")) + }) + + It("passes the raw prompt verbatim and skips gemma4 parsing without use_tokenizer_template", func() { + // Marker-looking text must survive untouched: in raw-prompt mode + // the caller templates themselves and the Go-side extraction + // applies, so the backend must not interpret the output. + fake := &fakeGen{out: "<|channel>thought\nnot parsedtail"} + d := newTestDllm(fake, nil) + + reply, err := d.PredictRich(&pb.PredictOptions{Prompt: "my raw prompt"}) + Expect(err).ToNot(HaveOccurred()) + + calls, _ := fake.snapshot() + Expect(calls[0].prompt).To(Equal("my raw prompt")) + Expect(string(reply.GetMessage())).To(Equal(fake.out)) + Expect(reply.GetChatDeltas()).To(HaveLen(1)) + Expect(reply.GetChatDeltas()[0].GetContent()).To(Equal(fake.out)) + }) + + It("sanitizes invalid UTF-8 in the non-streaming output", func() { + // Byte-fallback tokens can decode to lone malformed bytes; the + // whole-output sanitize must replace them so proto3 marshaling of + // Message/ChatDeltas cannot fail. + fake := &fakeGen{out: "a\xe2b"} + d := newTestDllm(fake, nil) + + reply, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"}) + Expect(err).ToNot(HaveOccurred()) + expectValidUTF8Reply(reply) + Expect(string(reply.GetMessage())).To(Equal("a�b")) + Expect(reply.GetChatDeltas()[0].GetContent()).To(Equal("a�b")) + }) + + It("maps Tokens and Seed into the opts JSON on top of the model-level overrides", func() { + fake := &fakeGen{out: "x"} + d := newTestDllm(fake, map[string]any{"eb_t_min": 0.5, "kv_cache": "auto"}) + + _, err := d.PredictRich(&pb.PredictOptions{Prompt: "p", Tokens: 32, Seed: 7}) + Expect(err).ToNot(HaveOccurred()) + + calls, _ := fake.snapshot() + Expect(calls[0].optsJSON).To(MatchJSON(`{"n_predict":32,"seed":7,"eb_t_min":0.5,"kv_cache":"auto"}`)) + }) + + It("omits n_predict and seed when unset so the engine defaults apply", func() { + fake := &fakeGen{out: "x"} + d := newTestDllm(fake, nil) + + _, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"}) + Expect(err).ToNot(HaveOccurred()) + + calls, _ := fake.snapshot() + Expect(calls[0].optsJSON).To(MatchJSON(`{}`)) + }) + + It("surfaces generator errors", func() { + fake := &fakeGen{err: errors.New("boom")} + d := newTestDllm(fake, nil) + + _, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"}) + Expect(err).To(MatchError("boom")) + }) + + It("errors before generating when no model is loaded", func() { + d := &Dllm{} // no Load, no worker: must fail fast, not hang + _, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"}) + Expect(err).To(HaveOccurred()) + }) + + It("makes a concurrent Free wait for the in-flight request (both finish cleanly)", func() { + // server.go's Free has no locking of its own: the backend's genMu + // must hold Free back until the racing generate drains, instead of + // closing the jobs channel (panic) or freeing the C ctx under it. + fake := &fakeGen{out: "x", delay: 50 * time.Millisecond} + d := newTestDllm(fake, nil) + + predictDone := make(chan error, 1) + go func() { + defer GinkgoRecover() + _, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"}) + predictDone <- err + }() + // Wait until the fake generate is actually in flight (the read + // lock is held from before submit until PredictRich returns). + Eventually(func() int { + _, maxInFlight := fake.snapshot() + return maxInFlight + }).Should(Equal(1)) + + Expect(d.Free()).To(Succeed()) + // Free's write lock means the request finished before Free did. + var predictErr error + Eventually(predictDone).Should(Receive(&predictErr)) + Expect(predictErr).ToNot(HaveOccurred()) + }) + + It("returns model-not-loaded for requests after Free", func() { + fake := &fakeGen{out: "x"} + d := newTestDllm(fake, nil) + Expect(d.Free()).To(Succeed()) + + _, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"}) + Expect(err).To(MatchError(ContainSubstring("model not loaded"))) + }) + + It("serializes concurrent requests through the worker goroutine", func() { + // dllm_capi.h: one ctx = one concurrent generate. Two overlapping + // PredictRich calls must execute the C calls one at a time. + fake := &fakeGen{out: "x", delay: 30 * time.Millisecond} + d := newTestDllm(fake, nil) + + var wg sync.WaitGroup + for range 2 { + wg.Add(1) + go func() { + defer wg.Done() + defer GinkgoRecover() + _, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"}) + Expect(err).ToNot(HaveOccurred()) + }() + } + wg.Wait() + + calls, maxInFlight := fake.snapshot() + Expect(calls).To(HaveLen(2)) + Expect(maxInFlight).To(Equal(1), "generate calls overlapped despite the worker queue") + }) + }) + + Describe("PredictStreamRich", func() { + It("emits one reply per delta-producing block and leaves the channel open", func() { + // Blocks split mid-marker and mid-payload: the parser's holdback + // must keep marker fragments out of the emitted deltas. + fake := &fakeGen{blocks: []string{ + "<|channel>thou", // partial channel open: no deltas yet + "ght\nponder", // header completes, reasoning starts + "ingHi ", // reasoning ends, content starts + "therediscarded", // turn ends: trailing text dropped + }} + d := newTestDllm(fake, nil) + + ch := make(chan *pb.Reply, 16) + err := d.PredictStreamRich(&pb.PredictOptions{ + UseTokenizerTemplate: true, + Messages: []*pb.Message{{Role: "user", Content: "hi"}}, + }, ch) + Expect(err).ToNot(HaveOccurred()) + + replies := drainReplies(ch) + Expect(replies).To(HaveLen(3), "block 1 completes no delta and must not produce a reply") + + var content, reasoning string + for _, r := range replies { + for _, delta := range r.GetChatDeltas() { + content += delta.GetContent() + reasoning += delta.GetReasoningContent() + } + } + Expect(reasoning).To(Equal("pondering")) + Expect(content).To(Equal("Hi there")) + // Message mirrors each reply's content so legacy consumers see + // exactly the displayed tokens. + Expect(string(replies[1].GetMessage())).To(Equal("Hi ")) + Expect(string(replies[2].GetMessage())).To(Equal("there")) + }) + + It("streams raw blocks verbatim without use_tokenizer_template", func() { + fake := &fakeGen{blocks: []string{"abc", "", "<|channel>def"}} + d := newTestDllm(fake, nil) + + ch := make(chan *pb.Reply, 16) + err := d.PredictStreamRich(&pb.PredictOptions{Prompt: "raw"}, ch) + Expect(err).ToNot(HaveOccurred()) + + replies := drainReplies(ch) + Expect(replies).To(HaveLen(2), "empty blocks produce no reply") + Expect(string(replies[0].GetMessage())).To(Equal("abc")) + Expect(string(replies[1].GetMessage())).To(Equal("<|channel>def")) + Expect(replies[1].GetChatDeltas()).To(HaveLen(1)) + }) + + It("flushes parser holdback after the stream ends", func() { + // The unterminated partial marker ""}} + d := newTestDllm(fake, nil) + + ch := make(chan *pb.Reply, 16) + err := d.PredictStreamRich(&pb.PredictOptions{ + UseTokenizerTemplate: true, + Messages: []*pb.Message{{Role: "user", Content: "hi"}}, + }, ch) + Expect(err).ToNot(HaveOccurred()) + + var content string + for _, r := range drainReplies(ch) { + for _, delta := range r.GetChatDeltas() { + content += delta.GetContent() + } + } + Expect(content).To(Equal("caf€")) + }) + + It("replaces an incomplete sequence left at stream end with U+FFFD", func() { + // A byte-fallback token can leave a lone leading byte (0xE2) that + // no later block completes: the final flush must substitute it, + // never emit it raw and never drop into a marshal error. + fake := &fakeGen{blocks: []string{"ok\xe2"}} + d := newTestDllm(fake, nil) + + ch := make(chan *pb.Reply, 16) + err := d.PredictStreamRich(&pb.PredictOptions{Prompt: "raw"}, ch) + Expect(err).ToNot(HaveOccurred()) + + var content string + for _, r := range drainReplies(ch) { + content += string(r.GetMessage()) + } + Expect(content).To(Equal("ok�")) + }) + + It("surfaces generator errors without sending replies", func() { + fake := &fakeGen{err: errors.New("stream boom")} + d := newTestDllm(fake, nil) + + ch := make(chan *pb.Reply, 16) + err := d.PredictStreamRich(&pb.PredictOptions{Prompt: "p"}, ch) + Expect(err).To(MatchError("stream boom")) + Expect(drainReplies(ch)).To(BeEmpty()) + }) + + It("errors before generating when no model is loaded", func() { + d := &Dllm{} // no Load, no worker: must fail fast, not hang + ch := make(chan *pb.Reply, 1) + err := d.PredictStreamRich(&pb.PredictOptions{Prompt: "p"}, ch) + Expect(err).To(MatchError(ContainSubstring("model not loaded"))) + Expect(drainReplies(ch)).To(BeEmpty()) + }) + }) + + Describe("legacy Predict/PredictStream adapters", func() { + It("Predict returns the aggregated content string", func() { + fake := &fakeGen{out: "plain text"} + d := newTestDllm(fake, nil) + + out, err := d.Predict(&pb.PredictOptions{Prompt: "p"}) + Expect(err).ToNot(HaveOccurred()) + Expect(out).To(Equal("plain text")) + }) + + It("PredictStream forwards content strings and closes the channel (legacy ownership)", func() { + fake := &fakeGen{blocks: []string{"a", "b"}} + d := newTestDllm(fake, nil) + + ch := make(chan string, 16) + Expect(d.PredictStream(&pb.PredictOptions{Prompt: "p"}, ch)).To(Succeed()) + + var got []string + for s := range ch { // terminates only if the impl closed ch + got = append(got, s) + } + Expect(got).To(Equal([]string{"a", "b"})) + }) + }) + + Describe("TokenizeString", func() { + It("decodes the C-side JSON id array", func() { + fake := &fakeGen{out: "[2,18]"} + d := newTestDllm(fake, nil) + + resp, err := d.TokenizeString(&pb.PredictOptions{Prompt: "hello"}) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.Length).To(Equal(int32(2))) + Expect(resp.Tokens).To(Equal([]int32{2, 18})) + + calls, _ := fake.snapshot() + Expect(calls[0].prompt).To(Equal("hello")) + }) + + It("fails loud on a malformed id array", func() { + fake := &fakeGen{out: "not json"} + d := newTestDllm(fake, nil) + + _, err := d.TokenizeString(&pb.PredictOptions{Prompt: "hello"}) + Expect(err).To(HaveOccurred()) + }) + + It("errors before tokenizing when no model is loaded", func() { + d := &Dllm{} // no Load, no worker: must fail fast, not hang + _, err := d.TokenizeString(&pb.PredictOptions{Prompt: "hello"}) + Expect(err).To(MatchError(ContainSubstring("model not loaded"))) + }) + }) + + Describe("parseModelGenOpts", func() { + It("parses eb_*/blocks/kv_cache entries and types values by first successful parse", func() { + got := parseModelGenOpts([]string{ + "eb_max_steps:16", + "eb_t_min:0.25", + "kv_cache:auto", + "blocks:4", + "unrelated_key:1", // other layers' options: skipped + "malformed", // no colon: skipped + }) + Expect(got).To(Equal(map[string]any{ + "eb_max_steps": int64(16), + "eb_t_min": 0.25, + "kv_cache": "auto", + "blocks": int64(4), + })) + }) + + It("round-trips through buildOptsJSON (only flat scalars are produced)", func() { + got := parseModelGenOpts([]string{"eb_entropy_bound:0.8", "kv_cache:off"}) + out, err := buildOptsJSON(got) + Expect(err).ToNot(HaveOccurred()) + Expect(out).To(MatchJSON(`{"eb_entropy_bound":0.8,"kv_cache":"off"}`)) + }) + }) +}) + +// --------------------------------------------------------------------------- +// Gated backend round-trip against the real libdllm.so + tiny GGUF fixture. +// --------------------------------------------------------------------------- + +var _ = Describe("Dllm backend (real tiny model)", func() { + BeforeEach(func() { + if os.Getenv("DLLM_TEST_LIBRARY") == "" || os.Getenv("DLLM_TEST_TINY_MODEL") == "" { + Skip("set DLLM_TEST_LIBRARY and DLLM_TEST_TINY_MODEL to run the backend round-trip") + } + ensureLibLoaded() + Expect(libLoadErr).ToNot(HaveOccurred()) + }) + + It("round-trips Load, PredictRich, PredictStreamRich and TokenizeString", func() { + d := &Dllm{} + Expect(d.Load(&pb.ModelOptions{ModelFile: os.Getenv("DLLM_TEST_TINY_MODEL")})).To(Succeed()) + DeferCleanup(func() { Expect(d.Free()).To(Succeed()) }) + + // TokenizeString: tiny fixture vocab tokenizes "hello" to [2,18]. + resp, err := d.TokenizeString(&pb.PredictOptions{Prompt: "hello"}) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.Tokens).To(Equal([]int32{2, 18})) + Expect(resp.Length).To(Equal(int32(2))) + + req := &pb.PredictOptions{ + UseTokenizerTemplate: true, + Messages: []*pb.Message{{Role: "user", Content: "hello"}}, + Tokens: 16, + Seed: 7, + } + + // Non-streaming: the tiny random-weight model emits arbitrary vocab + // words; with no gemma4 markers in them everything is content. + reply, err := d.PredictRich(req) + Expect(err).ToNot(HaveOccurred()) + Expect(string(reply.GetMessage())).ToNot(BeEmpty()) + Expect(reply.GetChatDeltas()).ToNot(BeEmpty()) + + // Streaming: at least one reply, and the channel-ownership rule is + // honored (drainReplies fails the spec on a closed channel). + ch := make(chan *pb.Reply, 64) + Expect(d.PredictStreamRich(req, ch)).To(Succeed()) + replies := drainReplies(ch) + Expect(replies).ToNot(BeEmpty()) + var streamed string + for _, r := range replies { + streamed += string(r.GetMessage()) + } + Expect(streamed).ToNot(BeEmpty()) + }) +}) diff --git a/backend/go/dllm/main.go b/backend/go/dllm/main.go old mode 100644 new mode 100755