feat(dllm): rich gRPC backend with ChatDelta streaming

Implements PredictRich/PredictStreamRich (legacy methods delegate),
TokenizeString, and Load over the purego binding. A single worker
goroutine serializes all C calls per the dllm.cpp one-generate-per-ctx
contract (cancel is the documented exception); an RWMutex guards Free
against in-flight requests. Under use_tokenizer_template the gemma4
renderer and streaming parser own templating and ChatDelta extraction;
raw-prompt mode passes through verbatim. enable_thinking is opt-in via
request metadata (the gemma4 template treats thinking as opt-in).

Assisted-by: Claude Code (Fable 5)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2026-06-11 16:14:37 +00:00
parent 294c04ae2f
commit 99184809fa
4 changed files with 1176 additions and 22 deletions

35
backend/go/dllm/capi.go Normal file → Executable file
View File

@@ -16,15 +16,12 @@ package main
import (
"encoding/json"
"errors"
"fmt"
"sync"
"sync/atomic"
"unsafe"
"github.com/ebitengine/purego"
"github.com/mudler/LocalAI/pkg/grpc/base"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
)
// dllmABIVersion is the DLLM_CAPI_ABI_VERSION this binding was written
@@ -48,18 +45,6 @@ var (
cppCancel func(ctx uintptr)
)
// Dllm is the LocalAI gRPC backend over the dllm.cpp C-ABI. T1 ships only
// the binding scaffold; Load/PredictRich/PredictStreamRich (and the move to
// a dedicated dllm.go with the per-model worker goroutine) land in T4.
type Dllm struct {
base.Base
}
// Load is not wired yet: the binding smoke drives the C functions directly.
func (d *Dllm) Load(opts *pb.ModelOptions) error {
return errors.New("dllm: model loading not implemented yet (backend wiring lands in T4)")
}
// cAbiVersion returns the library's DLLM_CAPI_ABI_VERSION.
func cAbiVersion() int32 {
return cppAbiVersion()
@@ -218,6 +203,11 @@ func cCancel(h uintptr) {
// nested objects/arrays loudly; bools are rejected here too because the
// scanner has no concept of them. Fail loud rather than let an option be
// silently misread.
//
// CAVEAT: json.Marshal HTML-escapes <, > and & inside string values (e.g.
// "<" becomes the six-byte \u003c sequence). None of the known string-valued keys
// (kv_cache: auto|on|off) can contain those bytes today; if one ever does,
// switch to an Encoder with SetEscapeHTML(false) like gemma4JSONString.
func buildOptsJSON(opts map[string]any) (string, error) {
if len(opts) == 0 {
return "{}", nil
@@ -246,17 +236,18 @@ func buildOptsJSON(opts map[string]any) (string, error) {
// caller owns, or a callback argument only valid during the invocation);
// owning callers must free it via cppFreeString after the copy lands.
//
// The uintptr->unsafe.Pointer conversion below trips go vet's unsafeptr
// check, which can't distinguish a C-owned heap pointer from Go-managed
// memory. It is safe here: the pointer addresses C memory the Go GC neither
// tracks nor moves, and we dereference it immediately to copy the bytes out,
// the same pattern (and the same tolerated warning) as the parakeet-cpp and
// whisper backends.
// A direct unsafe.Pointer(cptr) conversion trips go vet's unsafeptr check,
// which can't distinguish a C-owned heap pointer from Go-managed memory (the
// parakeet-cpp and whisper backends tolerate that warning). Reinterpreting
// through &cptr below is equivalent at runtime and keeps plain `go vet`
// clean. It is safe either way: the pointer addresses C memory the Go GC
// neither tracks nor moves, and we dereference it immediately to copy the
// bytes out.
func goStringFromCPtr(cptr uintptr) string {
if cptr == 0 {
return ""
}
p := unsafe.Pointer(cptr) //nolint:govet // C-owned buffer, not Go-GC memory (see doc above)
p := *(*unsafe.Pointer)(unsafe.Pointer(&cptr)) // C-owned buffer, not Go-GC memory (see doc above)
n := 0
for *(*byte)(unsafe.Add(p, n)) != 0 {
n++

536
backend/go/dllm/dllm.go Executable file
View File

@@ -0,0 +1,536 @@
package main
// LocalAI gRPC backend for dllm.cpp (DiffusionGemma block-diffusion models).
//
// Wiring overview:
// - Load opens the GGUF via dllm_capi_load and starts the per-model worker
// goroutine that serializes every C call (see submit).
// - PredictRich / PredictStreamRich implement grpc.AIModelRich: when the
// request carries raw messages (use_tokenizer_template), the backend owns
// templating (RenderGemma4) and output parsing (Gemma4Parser) and replies
// with ChatDeltas, like the llama.cpp autoparser and the ds4 backend.
// - The legacy Predict / PredictStream methods delegate to the rich pair
// (cloud-proxy precedent); the gRPC server prefers the rich path anyway.
import (
"encoding/json"
"errors"
"fmt"
"strconv"
"strings"
"sync"
"unicode/utf8"
"github.com/mudler/LocalAI/pkg/grpc/base"
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/xlog"
)
// generator is the seam between the backend wiring and the dllm.cpp C-ABI:
// the real implementation (capiGenerator) wraps the cGenerate/cTokenizeJSON
// family, while tests substitute a fake to exercise prompt construction,
// parsing and serialization without libdllm.so.
type generator interface {
generate(prompt, optsJSON string) (string, error)
// generateStream invokes onBlock once per committed diffusion block, on
// the thread running the C call, before returning.
generateStream(prompt, optsJSON string, onBlock func(text string)) error
tokenizeJSON(text string) (string, error)
// cancel is the ONE entry point safe to call concurrently with an
// in-flight generate on the same ctx (dllm_capi.h: it only flips an
// atomic; everything else must be externally serialized per ctx).
cancel()
free()
}
// capiGenerator is the production generator over one dllm_ctx handle.
type capiGenerator struct {
h uintptr
}
func (g *capiGenerator) generate(prompt, optsJSON string) (string, error) {
return cGenerate(g.h, prompt, optsJSON)
}
func (g *capiGenerator) generateStream(prompt, optsJSON string, onBlock func(text string)) error {
// on_step (per-denoise-step canvas preview, dllm.cpp's --visual) is
// passed as nil for now: a future progress hook for the React UI can
// plumb it through without touching the C binding.
return cGenerateStream(g.h, prompt, optsJSON, onBlock, nil)
}
func (g *capiGenerator) tokenizeJSON(text string) (string, error) {
return cTokenizeJSON(g.h, text)
}
func (g *capiGenerator) cancel() {
cCancel(g.h)
}
func (g *capiGenerator) free() {
cFree(g.h)
}
// Dllm is the gRPC backend instance: one per loaded model (LocalAI starts
// one backend process per model).
type Dllm struct {
base.Base
gen generator
// genOpts holds the model-level generation overrides parsed from
// ModelOptions.Options at Load (eb_*, blocks, kv_cache). The C-ABI takes
// them per-generate, not per-load, so they are merged into every
// request's opts JSON (requestOptsJSON).
genOpts map[string]any
// jobs is the per-model worker queue. dllm_capi.h requires every entry
// point EXCEPT dllm_capi_cancel to be externally serialized per ctx (one
// ctx = one concurrent generate/tokenize; last_error is unsafe to read
// while a call is in flight). A single goroutine owning all C calls makes
// that contract structural instead of relying on lock discipline.
jobs chan func()
workerWG sync.WaitGroup
// genMu guards gen against Free racing in-flight requests: requests hold
// the read lock for their full duration (they stay concurrent with each
// other - the worker still serializes the C calls), Free takes the write
// lock so it can only run when no request is in flight.
genMu sync.RWMutex
}
func (d *Dllm) startWorker() {
d.jobs = make(chan func())
d.workerWG.Add(1)
go func() {
defer d.workerWG.Done()
for job := range d.jobs {
job()
}
}()
}
// submit runs job on the worker goroutine and waits for it to finish.
// Concurrent gRPC requests therefore queue up and execute one at a time
// against the single dllm_ctx.
func (d *Dllm) submit(job func()) {
done := make(chan struct{})
d.jobs <- func() {
defer close(done)
job()
}
<-done
}
// Load opens the GGUF and prepares the worker. Load-time engine parameters
// travel as the flat params JSON of dllm_capi_load; generation overrides
// from Options are stored for per-request opts JSON instead (the C-ABI has
// no per-load sampler state).
func (d *Dllm) Load(opts *pb.ModelOptions) error {
if d.gen != nil {
return errors.New("dllm: model already loaded")
}
params := map[string]any{
"n_gpu_layers": opts.GetNGPULayers(),
}
if opts.GetThreads() > 0 {
params["n_threads"] = opts.GetThreads()
}
if opts.GetContextSize() > 0 {
params["ctx_len"] = opts.GetContextSize()
}
paramsJSON, err := buildOptsJSON(params)
if err != nil {
return err
}
d.genOpts = parseModelGenOpts(opts.GetOptions())
h := cLoad(opts.GetModelFile(), paramsJSON)
if h == 0 {
// No ctx exists on load failure, so last_error(NULL) only carries the
// static NULL-ctx message; the real reason is on the backend's stderr.
return fmt.Errorf("dllm: load %q failed: %s (see backend log for details)",
opts.GetModelFile(), lastErrorOr(0, "unknown error"))
}
d.gen = &capiGenerator{h: h}
d.startWorker()
xlog.Info("dllm: model loaded", "model", opts.GetModelFile(), "params", paramsJSON, "gen_opts", d.genOpts)
return nil
}
// Free releases the dllm ctx and stops the worker. Safe when never loaded.
//
// The write lock is essential: the gRPC server (pkg/grpc/server.go, see the
// model-unload path around line 764) calls Free with no locking of its own,
// and base.Base provides none either. Without it a request racing Free would
// panic sending on the closed jobs channel - or worse, generate on a freed C
// ctx. Holding genMu until gen is nil also turns post-Free requests into a
// clean "model not loaded" error instead of a crash.
func (d *Dllm) Free() error {
d.genMu.Lock()
defer d.genMu.Unlock()
if d.gen == nil {
return nil
}
d.submit(d.gen.free)
close(d.jobs)
d.workerWG.Wait()
d.gen = nil
return nil
}
// Cancel requests cancellation of the in-flight generate. It deliberately
// bypasses the worker queue: dllm_capi_cancel is the one call the C-ABI
// allows from any goroutine mid-generate (it only flips an atomic).
//
// LIMITATION: nothing invokes this on client disconnect today. The gRPC
// server (pkg/grpc/server.go) does not hand the request/stream context to
// Predict/PredictStreamRich, so a dropped HTTP client cannot reach the
// backend until that plumbing exists; the method is here so future server
// wiring (or an admin RPC) has something to call. Note dllm_capi.h's
// cancel-reset race: each generate resets the flag on entry, so a caller
// racing a new generate should re-issue Cancel.
func (d *Dllm) Cancel() {
if d.gen != nil {
d.gen.cancel()
}
}
// dllmGenOptKeys are the ModelOptions.Options keys this backend forwards to
// the engine. Options is a shared free-form bag (other layers put their own
// entries there), so unknown keys are skipped with a warning, not an error.
var dllmGenOptKeys = map[string]bool{
"blocks": true,
"kv_cache": true, // "auto"|"on"|"off"; honored by the engine from P3
}
// parseModelGenOpts parses "key:value" Options entries into the flat scalar
// map merged into every generate's opts JSON. eb_* (Entropy-Bound sampler
// knobs) and the keys in dllmGenOptKeys are recognized; values are typed by
// first successful parse (int, then float, else string) to match the C
// scanner's number/string scalars.
func parseModelGenOpts(options []string) map[string]any {
out := map[string]any{}
for _, o := range options {
key, val, found := strings.Cut(o, ":")
if !found {
xlog.Warn("dllm: ignoring malformed option (want key:value)", "option", o)
continue
}
if !strings.HasPrefix(key, "eb_") && !dllmGenOptKeys[key] {
xlog.Debug("dllm: ignoring unrecognized option", "key", key)
continue
}
out[key] = parseScalarOpt(val)
}
return out
}
func parseScalarOpt(v string) any {
if iv, err := strconv.ParseInt(v, 10, 64); err == nil {
return iv
}
if fv, err := strconv.ParseFloat(v, 64); err == nil {
return fv
}
return v
}
// metadataEnableThinking reads the enable_thinking gate. Unlike ds4 (default
// ON, matching ds4-server), dllm defaults OFF: DiffusionGemma's chat
// template guards every thinking branch with `enable_thinking is defined and
// enable_thinking`, i.e. thinking is opt-in for this model family, and the
// no-thinking render pre-closes an empty thought channel that the OFF
// default must produce.
func metadataEnableThinking(opts *pb.PredictOptions) bool {
v := opts.GetMetadata()["enable_thinking"]
return v == "true" || v == "1"
}
// buildPrompt resolves the prompt for a request. With use_tokenizer_template
// and raw messages the backend owns templating (RenderGemma4) and the output
// is in the known gemma4 format, so parse=true. Without it the caller
// templated the prompt themselves (LocalAI's Go templates + PEG fallback, or
// a bare completion): the prompt passes through verbatim and the output is
// NOT gemma4-parsed - it is emitted as plain content and the Go side's
// extraction applies, as for any non-autoparsing backend.
func buildPrompt(opts *pb.PredictOptions) (prompt string, parse bool, err error) {
if opts.GetUseTokenizerTemplate() && len(opts.GetMessages()) > 0 {
prompt, err = RenderGemma4(opts.GetMessages(), opts.GetTools(), metadataEnableThinking(opts), true)
return prompt, true, err
}
return opts.GetPrompt(), false, nil
}
// requestOptsJSON merges the model-level overrides with the request's
// sampling fields into the flat opts JSON for one generate call.
func (d *Dllm) requestOptsJSON(opts *pb.PredictOptions) (string, error) {
m := make(map[string]any, len(d.genOpts)+2)
for k, v := range d.genOpts {
m[k] = v
}
if n := opts.GetTokens(); n > 0 {
// The engine rounds n_predict UP to a whole number of diffusion
// blocks (the canvas is denoised block-wise), so the completion may
// run slightly past the requested budget. Tokens==0 omits the key so
// the engine's GGUF-metadata default applies (the C-ABI documents
// per-key defaults; no hardcoded 256 like ds4's grpc-server).
m["n_predict"] = n
}
if s := opts.GetSeed(); s > 0 {
// The engine seeds mt19937 with explicit non-negative seeds. Seed<=0
// is omitted: proto3 cannot distinguish 0 from unset, and negative
// values conventionally mean "random" across LocalAI backends.
m["seed"] = s
}
return buildOptsJSON(m)
}
// prepareRequest is the shared prologue of the rich methods: resolve the
// prompt (and whether the output gets gemma4-parsed) and build the per-call
// opts JSON.
func (d *Dllm) prepareRequest(opts *pb.PredictOptions) (prompt string, parse bool, optsJSON string, err error) {
prompt, parse, err = buildPrompt(opts)
if err != nil {
return "", false, "", err
}
optsJSON, err = d.requestOptsJSON(opts)
if err != nil {
return "", false, "", err
}
return prompt, parse, optsJSON, nil
}
// sanitizeUTF8 makes s safe for a proto3 string field. Block-boundary
// detokenization and byte-fallback tokens can produce invalid UTF-8, and
// grpc-go refuses to marshal it ("string field contains invalid UTF-8"), so
// every string destined for a Reply/ChatDelta must pass through here (or
// through splitValidUTF8, which calls it). Lone malformed bytes are genuinely
// undecodable: replace with U+FFFD rather than crash the stream.
func sanitizeUTF8(s string) string {
if utf8.ValidString(s) {
return s
}
return strings.ToValidUTF8(s, "<22>")
}
// utf8SeqLen returns the declared sequence length of a UTF-8 leading byte
// (1 for bytes that can never lead a multi-byte sequence, so they are never
// held back and fall through to sanitizeUTF8's replacement).
func utf8SeqLen(b byte) int {
switch {
case b&0xE0 == 0xC0:
return 2
case b&0xF0 == 0xE0:
return 3
case b&0xF8 == 0xF0:
return 4
default:
return 1
}
}
// splitValidUTF8 prepends the previous block's carry to the new block and
// splits the result into text safe to emit now and a trailing INCOMPLETE
// UTF-8 sequence (at most utf8.UTFMax-1 bytes) to carry into the next block:
// the per-block detokenize can split a multi-byte character across block
// boundaries (llama.cpp's grpc-server holds back the same way). Only a
// suffix that can still become a valid rune is withheld; bytes that are
// already undecodable are replaced immediately so the carry stays bounded.
func splitValidUTF8(carry, block string) (emit, newCarry string) {
s := carry + block
cut := len(s)
for i := len(s) - 1; i >= 0 && len(s)-i < utf8.UTFMax; i-- {
b := s[i]
if b < utf8.RuneSelf {
break // ASCII: everything before the tail scan is complete
}
if !utf8.RuneStart(b) {
continue // continuation byte: keep looking for its leading byte
}
// Leading byte: hold the sequence back iff it declares more bytes
// than the stream has produced so far (it may complete next block).
if utf8SeqLen(b) > len(s)-i {
cut = i
}
break
}
return sanitizeUTF8(s[:cut]), s[cut:]
}
// PredictRich is the non-streaming inference path (grpc.AIModelRich).
// Returns one Reply whose Message is the aggregated assistant content and
// whose ChatDeltas carry the parsed content/reasoning/tool-call events.
func (d *Dllm) PredictRich(opts *pb.PredictOptions) (*pb.Reply, error) {
d.genMu.RLock()
defer d.genMu.RUnlock()
if d.gen == nil {
return nil, grpcerrors.ModelNotLoaded("dllm")
}
prompt, parse, optsJSON, err := d.prepareRequest(opts)
if err != nil {
return nil, err
}
var out string
var genErr error
d.submit(func() {
out, genErr = d.gen.generate(prompt, optsJSON)
})
if genErr != nil {
return nil, genErr
}
// Byte-fallback tokens can detokenize to invalid UTF-8; proto3 strings
// must be valid or grpc-go fails the whole reply at marshal time.
out = sanitizeUTF8(out)
if !parse {
// Raw-prompt mode: plain content, no gemma4 parsing (see buildPrompt).
return &pb.Reply{Message: []byte(out), ChatDeltas: []*pb.ChatDelta{{Content: out}}}, nil
}
// The prompt renders with add_generation_prompt; both thinking modes
// leave the model starting in content state (see the Gemma4Parser header
// comment), hence NewGemma4Parser(false).
parser := NewGemma4Parser(false)
if reply := replyFromDeltas(append(parser.Feed(out), parser.Close()...)); reply != nil {
return reply, nil
}
// Everything was markers (or out was empty): an empty but non-nil Reply.
return &pb.Reply{}, nil
}
// PredictStreamRich is the streaming counterpart (grpc.AIModelRich): one
// Reply per committed diffusion block that produced deltas. Per the
// interface contract the channel is only sent into here - the gRPC server
// closes it after this returns (opposite to legacy PredictStream).
func (d *Dllm) PredictStreamRich(opts *pb.PredictOptions, results chan<- *pb.Reply) error {
d.genMu.RLock()
defer d.genMu.RUnlock()
if d.gen == nil {
return grpcerrors.ModelNotLoaded("dllm")
}
prompt, parse, optsJSON, err := d.prepareRequest(opts)
if err != nil {
return err
}
var parser *Gemma4Parser
if parse {
parser = NewGemma4Parser(false)
}
// emit runs inside onBlock, i.e. on the thread driving the C generate.
// Sending on results can block on a slow consumer, but the server-side
// pump (pkg/grpc/server.go PredictStream) drains continuously and drops
// undeliverable sends, so this backpressure is brief and bounded - and
// pausing the diffusion loop under it is the desired behavior anyway.
emit := func(text string) {
if !parse {
if text != "" {
results <- &pb.Reply{Message: []byte(text), ChatDeltas: []*pb.ChatDelta{{Content: text}}}
}
return
}
deltas := parser.Feed(text)
if reply := replyFromDeltas(deltas); reply != nil {
results <- reply
}
}
// onBlock guards emit (and through it the parser) against invalid UTF-8:
// a multi-byte character split across block boundaries is held back until
// it completes (see splitValidUTF8), so proto3 marshaling never fails.
var carry string
onBlock := func(block string) {
var text string
text, carry = splitValidUTF8(carry, block)
emit(text)
}
var genErr error
d.submit(func() {
genErr = d.gen.generateStream(prompt, optsJSON, onBlock)
})
if genErr != nil {
return genErr
}
if carry != "" {
// The stream ended mid-sequence: the held-back bytes can no longer
// complete, so flush them through the U+FFFD last resort.
emit(sanitizeUTF8(carry))
}
if parse {
if reply := replyFromDeltas(parser.Close()); reply != nil {
results <- reply
}
}
return nil
}
// replyFromDeltas wraps one batch of parsed deltas into a streaming Reply,
// or nil when the batch is empty (markers consumed, nothing emitted yet).
// Message mirrors the batch's content text so legacy chan-string consumers
// see exactly the displayed tokens.
func replyFromDeltas(deltas []*pb.ChatDelta) *pb.Reply {
if len(deltas) == 0 {
return nil
}
var content strings.Builder
for _, delta := range deltas {
content.WriteString(delta.GetContent())
}
return &pb.Reply{Message: []byte(content.String()), ChatDeltas: deltas}
}
// Predict is the legacy (string, error) signature; the gRPC server prefers
// PredictRich, this exists for non-rich callers (cloud-proxy precedent).
func (d *Dllm) Predict(opts *pb.PredictOptions) (string, error) {
reply, err := d.PredictRich(opts)
if err != nil {
return "", err
}
return string(reply.GetMessage()), nil
}
// PredictStream is the legacy chan-string path: rich replies reduced to
// their content text. Note the inverted channel ownership - the LEGACY
// contract requires the impl to close the channel.
func (d *Dllm) PredictStream(opts *pb.PredictOptions, results chan string) error {
defer close(results)
richCh := make(chan *pb.Reply)
errCh := make(chan error, 1)
go func() {
errCh <- d.PredictStreamRich(opts, richCh)
close(richCh)
}()
for reply := range richCh {
if msg := reply.GetMessage(); len(msg) > 0 {
results <- string(msg)
}
}
return <-errCh
}
// TokenizeString tokenizes opts.Prompt via dllm_capi_tokenize_json (the C
// side prepends bos per the vocab) and decodes the returned id array.
func (d *Dllm) TokenizeString(opts *pb.PredictOptions) (pb.TokenizationResponse, error) {
d.genMu.RLock()
defer d.genMu.RUnlock()
if d.gen == nil {
return pb.TokenizationResponse{}, grpcerrors.ModelNotLoaded("dllm")
}
var out string
var tokErr error
d.submit(func() {
out, tokErr = d.gen.tokenizeJSON(opts.GetPrompt())
})
if tokErr != nil {
return pb.TokenizationResponse{}, tokErr
}
var tokens []int32
if err := json.Unmarshal([]byte(out), &tokens); err != nil {
return pb.TokenizationResponse{}, fmt.Errorf("dllm: decode tokenize result %q: %w", out, err)
}
return pb.TokenizationResponse{Length: int32(len(tokens)), Tokens: tokens}, nil
}

627
backend/go/dllm/dllm_test.go Normal file → Executable file
View File

@@ -1,13 +1,19 @@
package main
import (
"errors"
"os"
"runtime"
"sync"
"testing"
"time"
"unicode/utf8"
"unsafe"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
)
func TestDllm(t *testing.T) {
@@ -131,10 +137,59 @@ var _ = Describe("buildOptsJSON", func() {
})
})
var _ = Describe("splitValidUTF8", func() {
It("holds back a trailing incomplete sequence and completes it next block", func() {
emit, carry := splitValidUTF8("", "caf\xe2")
Expect(emit).To(Equal("caf"))
Expect(carry).To(Equal("\xe2"))
emit, carry = splitValidUTF8(carry, "\x82")
Expect(emit).To(BeEmpty())
Expect(carry).To(Equal("\xe2\x82"))
emit, carry = splitValidUTF8(carry, "\xac!")
Expect(emit).To(Equal("€!"))
Expect(carry).To(BeEmpty())
})
It("holds back up to 3 bytes of a 4-byte sequence", func() {
emit, carry := splitValidUTF8("", "x\xf0\x9f\x98") // 😀 missing its last byte
Expect(emit).To(Equal("x"))
Expect(carry).To(Equal("\xf0\x9f\x98"))
emit, carry = splitValidUTF8(carry, "\x80")
Expect(emit).To(Equal("😀"))
Expect(carry).To(BeEmpty())
})
It("replaces undecodable bytes immediately instead of carrying them", func() {
// A mid-string invalid byte can never complete: carrying it would let
// the carry grow unboundedly, so it is substituted on the spot.
emit, carry := splitValidUTF8("", "a\xe2bc")
Expect(emit).To(Equal("a<>bc"))
Expect(carry).To(BeEmpty())
// Orphan continuation bytes at the end have no leading byte to wait
// for either.
emit, carry = splitValidUTF8("", "a\x82")
Expect(emit).To(Equal("a<>"))
Expect(carry).To(BeEmpty())
})
It("passes pure ASCII and complete UTF-8 through untouched", func() {
emit, carry := splitValidUTF8("", "héllo €")
Expect(emit).To(Equal("héllo €"))
Expect(carry).To(BeEmpty())
})
})
var _ = Describe("goStringFromCPtr", func() {
It("copies a NUL-terminated buffer", func() {
buf := []byte("dllm\x00")
s := goStringFromCPtr(uintptr(unsafe.Pointer(&buf[0])))
// The uintptr round-trip hides buf from the GC's liveness analysis;
// keep it reachable until after the copy.
runtime.KeepAlive(buf)
Expect(s).To(Equal("dllm"))
})
@@ -142,3 +197,575 @@ var _ = Describe("goStringFromCPtr", func() {
Expect(goStringFromCPtr(0)).To(Equal(""))
})
})
// ---------------------------------------------------------------------------
// Backend wiring (T4): fake-generator specs, no libdllm.so required.
// ---------------------------------------------------------------------------
type fakeGenCall struct {
prompt string
optsJSON string
}
// fakeGen implements generator in-process. It records every call (prompt +
// opts JSON), tracks concurrent in-flight calls to prove worker
// serialization, and replays canned output (out for generate/tokenize,
// blocks for generateStream).
type fakeGen struct {
mu sync.Mutex
calls []fakeGenCall
inFlight int
maxInFlight int
out string
blocks []string
err error
delay time.Duration
}
func (f *fakeGen) begin(prompt, optsJSON string) {
f.mu.Lock()
defer f.mu.Unlock()
f.calls = append(f.calls, fakeGenCall{prompt: prompt, optsJSON: optsJSON})
f.inFlight++
if f.inFlight > f.maxInFlight {
f.maxInFlight = f.inFlight
}
}
func (f *fakeGen) end() {
f.mu.Lock()
defer f.mu.Unlock()
f.inFlight--
}
func (f *fakeGen) snapshot() (calls []fakeGenCall, maxInFlight int) {
f.mu.Lock()
defer f.mu.Unlock()
return append([]fakeGenCall(nil), f.calls...), f.maxInFlight
}
func (f *fakeGen) generate(prompt, optsJSON string) (string, error) {
f.begin(prompt, optsJSON)
defer f.end()
if f.delay > 0 {
time.Sleep(f.delay)
}
return f.out, f.err
}
func (f *fakeGen) generateStream(prompt, optsJSON string, onBlock func(text string)) error {
f.begin(prompt, optsJSON)
defer f.end()
if f.err != nil {
return f.err
}
for _, b := range f.blocks {
onBlock(b)
}
return nil
}
func (f *fakeGen) tokenizeJSON(text string) (string, error) {
f.begin(text, "")
defer f.end()
return f.out, f.err
}
func (f *fakeGen) cancel() {}
func (f *fakeGen) free() {}
// newTestDllm assembles a backend around a fake generator (bypassing Load,
// which needs libdllm.so) and registers cleanup of the worker goroutine.
func newTestDllm(g generator, genOpts map[string]any) *Dllm {
d := &Dllm{gen: g, genOpts: genOpts}
d.startWorker()
DeferCleanup(func() { Expect(d.Free()).To(Succeed()) })
return d
}
// drainReplies empties ch without blocking, failing the spec if the channel
// was closed (PredictStreamRich must NOT close it - interface.go contract).
// Size ch above the expected reply count: an overflow deadlocks the spec on
// the producer's send instead of failing it.
func drainReplies(ch chan *pb.Reply) []*pb.Reply {
var out []*pb.Reply
for {
select {
case r, ok := <-ch:
if !ok {
Fail("PredictStreamRich closed the results channel (the gRPC server owns the close)")
}
expectValidUTF8Reply(r)
out = append(out, r)
default:
return out
}
}
}
// expectValidUTF8Reply is the blanket guard for the proto3 marshaling
// contract: grpc-go rejects any string field carrying invalid UTF-8, so every
// reply field that ends up in a proto string must validate.
func expectValidUTF8Reply(r *pb.Reply) {
GinkgoHelper()
Expect(utf8.ValidString(string(r.GetMessage()))).To(BeTrue(), "Reply.Message carries invalid UTF-8")
for _, delta := range r.GetChatDeltas() {
Expect(utf8.ValidString(delta.GetContent())).To(BeTrue(), "ChatDelta.Content carries invalid UTF-8")
Expect(utf8.ValidString(delta.GetReasoningContent())).To(BeTrue(), "ChatDelta.ReasoningContent carries invalid UTF-8")
for _, tc := range delta.GetToolCalls() {
Expect(utf8.ValidString(tc.GetName())).To(BeTrue(), "ToolCallDelta.Name carries invalid UTF-8")
Expect(utf8.ValidString(tc.GetArguments())).To(BeTrue(), "ToolCallDelta.Arguments carries invalid UTF-8")
}
}
}
var _ = Describe("Dllm backend wiring", func() {
Describe("PredictRich", func() {
It("renders gemma4 from raw messages and parses the output when use_tokenizer_template is set", func() {
fake := &fakeGen{out: "<|channel>thought\npondering<channel|>The answer.<turn|>"}
d := newTestDllm(fake, nil)
reply, err := d.PredictRich(&pb.PredictOptions{
UseTokenizerTemplate: true,
Messages: []*pb.Message{{Role: "user", Content: "Write a long essay about Portugal."}},
Metadata: map[string]string{"enable_thinking": "true"},
})
Expect(err).ToNot(HaveOccurred())
calls, _ := fake.snapshot()
Expect(calls).To(HaveLen(1))
// The enable_thinking=true render from the transformers fixture.
Expect(calls[0].prompt).To(Equal(
"<|turn>system\n<|think|>\n<turn|>\n<|turn>user\nWrite a long essay about Portugal.<turn|>\n<|turn>model\n"))
Expect(string(reply.GetMessage())).To(Equal("The answer."))
Expect(reply.GetChatDeltas()).To(HaveLen(2))
Expect(reply.GetChatDeltas()[0].GetReasoningContent()).To(Equal("pondering"))
Expect(reply.GetChatDeltas()[1].GetContent()).To(Equal("The answer."))
})
It("defaults enable_thinking OFF (the gemma4 template treats thinking as opt-in)", func() {
fake := &fakeGen{out: "hi"}
d := newTestDllm(fake, nil)
_, err := d.PredictRich(&pb.PredictOptions{
UseTokenizerTemplate: true,
Messages: []*pb.Message{{Role: "user", Content: "Write a long essay about Portugal."}},
})
Expect(err).ToNot(HaveOccurred())
calls, _ := fake.snapshot()
// No-thinking render: the template pre-opens AND pre-closes an
// empty thought channel in the generation prompt.
Expect(calls[0].prompt).To(Equal(
"<|turn>user\nWrite a long essay about Portugal.<turn|>\n<|turn>model\n<|channel>thought\n<channel|>"))
})
It("passes the raw prompt verbatim and skips gemma4 parsing without use_tokenizer_template", func() {
// Marker-looking text must survive untouched: in raw-prompt mode
// the caller templates themselves and the Go-side extraction
// applies, so the backend must not interpret the output.
fake := &fakeGen{out: "<|channel>thought\nnot parsed<channel|>tail"}
d := newTestDllm(fake, nil)
reply, err := d.PredictRich(&pb.PredictOptions{Prompt: "my raw prompt"})
Expect(err).ToNot(HaveOccurred())
calls, _ := fake.snapshot()
Expect(calls[0].prompt).To(Equal("my raw prompt"))
Expect(string(reply.GetMessage())).To(Equal(fake.out))
Expect(reply.GetChatDeltas()).To(HaveLen(1))
Expect(reply.GetChatDeltas()[0].GetContent()).To(Equal(fake.out))
})
It("sanitizes invalid UTF-8 in the non-streaming output", func() {
// Byte-fallback tokens can decode to lone malformed bytes; the
// whole-output sanitize must replace them so proto3 marshaling of
// Message/ChatDeltas cannot fail.
fake := &fakeGen{out: "a\xe2b"}
d := newTestDllm(fake, nil)
reply, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
Expect(err).ToNot(HaveOccurred())
expectValidUTF8Reply(reply)
Expect(string(reply.GetMessage())).To(Equal("a<>b"))
Expect(reply.GetChatDeltas()[0].GetContent()).To(Equal("a<>b"))
})
It("maps Tokens and Seed into the opts JSON on top of the model-level overrides", func() {
fake := &fakeGen{out: "x"}
d := newTestDllm(fake, map[string]any{"eb_t_min": 0.5, "kv_cache": "auto"})
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p", Tokens: 32, Seed: 7})
Expect(err).ToNot(HaveOccurred())
calls, _ := fake.snapshot()
Expect(calls[0].optsJSON).To(MatchJSON(`{"n_predict":32,"seed":7,"eb_t_min":0.5,"kv_cache":"auto"}`))
})
It("omits n_predict and seed when unset so the engine defaults apply", func() {
fake := &fakeGen{out: "x"}
d := newTestDllm(fake, nil)
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
Expect(err).ToNot(HaveOccurred())
calls, _ := fake.snapshot()
Expect(calls[0].optsJSON).To(MatchJSON(`{}`))
})
It("surfaces generator errors", func() {
fake := &fakeGen{err: errors.New("boom")}
d := newTestDllm(fake, nil)
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
Expect(err).To(MatchError("boom"))
})
It("errors before generating when no model is loaded", func() {
d := &Dllm{} // no Load, no worker: must fail fast, not hang
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
Expect(err).To(HaveOccurred())
})
It("makes a concurrent Free wait for the in-flight request (both finish cleanly)", func() {
// server.go's Free has no locking of its own: the backend's genMu
// must hold Free back until the racing generate drains, instead of
// closing the jobs channel (panic) or freeing the C ctx under it.
fake := &fakeGen{out: "x", delay: 50 * time.Millisecond}
d := newTestDllm(fake, nil)
predictDone := make(chan error, 1)
go func() {
defer GinkgoRecover()
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
predictDone <- err
}()
// Wait until the fake generate is actually in flight (the read
// lock is held from before submit until PredictRich returns).
Eventually(func() int {
_, maxInFlight := fake.snapshot()
return maxInFlight
}).Should(Equal(1))
Expect(d.Free()).To(Succeed())
// Free's write lock means the request finished before Free did.
var predictErr error
Eventually(predictDone).Should(Receive(&predictErr))
Expect(predictErr).ToNot(HaveOccurred())
})
It("returns model-not-loaded for requests after Free", func() {
fake := &fakeGen{out: "x"}
d := newTestDllm(fake, nil)
Expect(d.Free()).To(Succeed())
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
Expect(err).To(MatchError(ContainSubstring("model not loaded")))
})
It("serializes concurrent requests through the worker goroutine", func() {
// dllm_capi.h: one ctx = one concurrent generate. Two overlapping
// PredictRich calls must execute the C calls one at a time.
fake := &fakeGen{out: "x", delay: 30 * time.Millisecond}
d := newTestDllm(fake, nil)
var wg sync.WaitGroup
for range 2 {
wg.Add(1)
go func() {
defer wg.Done()
defer GinkgoRecover()
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
Expect(err).ToNot(HaveOccurred())
}()
}
wg.Wait()
calls, maxInFlight := fake.snapshot()
Expect(calls).To(HaveLen(2))
Expect(maxInFlight).To(Equal(1), "generate calls overlapped despite the worker queue")
})
})
Describe("PredictStreamRich", func() {
It("emits one reply per delta-producing block and leaves the channel open", func() {
// Blocks split mid-marker and mid-payload: the parser's holdback
// must keep marker fragments out of the emitted deltas.
fake := &fakeGen{blocks: []string{
"<|channel>thou", // partial channel open: no deltas yet
"ght\nponder", // header completes, reasoning starts
"ing<channel|>Hi ", // reasoning ends, content starts
"there<turn|>discarded", // turn ends: trailing text dropped
}}
d := newTestDllm(fake, nil)
ch := make(chan *pb.Reply, 16)
err := d.PredictStreamRich(&pb.PredictOptions{
UseTokenizerTemplate: true,
Messages: []*pb.Message{{Role: "user", Content: "hi"}},
}, ch)
Expect(err).ToNot(HaveOccurred())
replies := drainReplies(ch)
Expect(replies).To(HaveLen(3), "block 1 completes no delta and must not produce a reply")
var content, reasoning string
for _, r := range replies {
for _, delta := range r.GetChatDeltas() {
content += delta.GetContent()
reasoning += delta.GetReasoningContent()
}
}
Expect(reasoning).To(Equal("pondering"))
Expect(content).To(Equal("Hi there"))
// Message mirrors each reply's content so legacy consumers see
// exactly the displayed tokens.
Expect(string(replies[1].GetMessage())).To(Equal("Hi "))
Expect(string(replies[2].GetMessage())).To(Equal("there"))
})
It("streams raw blocks verbatim without use_tokenizer_template", func() {
fake := &fakeGen{blocks: []string{"abc", "", "<|channel>def"}}
d := newTestDllm(fake, nil)
ch := make(chan *pb.Reply, 16)
err := d.PredictStreamRich(&pb.PredictOptions{Prompt: "raw"}, ch)
Expect(err).ToNot(HaveOccurred())
replies := drainReplies(ch)
Expect(replies).To(HaveLen(2), "empty blocks produce no reply")
Expect(string(replies[0].GetMessage())).To(Equal("abc"))
Expect(string(replies[1].GetMessage())).To(Equal("<|channel>def"))
Expect(replies[1].GetChatDeltas()).To(HaveLen(1))
})
It("flushes parser holdback after the stream ends", func() {
// The unterminated partial marker "<chan" is held back during the
// stream and must come out as content on the final flush.
fake := &fakeGen{blocks: []string{"tail<chan"}}
d := newTestDllm(fake, nil)
ch := make(chan *pb.Reply, 16)
err := d.PredictStreamRich(&pb.PredictOptions{
UseTokenizerTemplate: true,
Messages: []*pb.Message{{Role: "user", Content: "hi"}},
}, ch)
Expect(err).ToNot(HaveOccurred())
var content string
for _, r := range drainReplies(ch) {
content += string(r.GetMessage())
}
Expect(content).To(Equal("tail<chan"))
})
It("reassembles a multi-byte character split across block boundaries", func() {
// Per-block detokenize can split "€" (E2 82 AC) as E2 | 82 AC.
// Emitting the lone E2 would make grpc-go fail the marshal of the
// whole reply; the trailing incomplete sequence must be held back
// and completed by the next block.
fake := &fakeGen{blocks: []string{"caf\xe2", "\x82\xac ok"}}
d := newTestDllm(fake, nil)
ch := make(chan *pb.Reply, 16)
err := d.PredictStreamRich(&pb.PredictOptions{Prompt: "raw"}, ch)
Expect(err).ToNot(HaveOccurred())
var content string
for _, r := range drainReplies(ch) { // drain asserts ValidString per reply
content += string(r.GetMessage())
}
Expect(content).To(Equal("caf€ ok"))
})
It("reassembles a split multi-byte character in parsed (gemma4) mode too", func() {
fake := &fakeGen{blocks: []string{"caf\xe2", "\x82\xac<turn|>"}}
d := newTestDllm(fake, nil)
ch := make(chan *pb.Reply, 16)
err := d.PredictStreamRich(&pb.PredictOptions{
UseTokenizerTemplate: true,
Messages: []*pb.Message{{Role: "user", Content: "hi"}},
}, ch)
Expect(err).ToNot(HaveOccurred())
var content string
for _, r := range drainReplies(ch) {
for _, delta := range r.GetChatDeltas() {
content += delta.GetContent()
}
}
Expect(content).To(Equal("caf€"))
})
It("replaces an incomplete sequence left at stream end with U+FFFD", func() {
// A byte-fallback token can leave a lone leading byte (0xE2) that
// no later block completes: the final flush must substitute it,
// never emit it raw and never drop into a marshal error.
fake := &fakeGen{blocks: []string{"ok\xe2"}}
d := newTestDllm(fake, nil)
ch := make(chan *pb.Reply, 16)
err := d.PredictStreamRich(&pb.PredictOptions{Prompt: "raw"}, ch)
Expect(err).ToNot(HaveOccurred())
var content string
for _, r := range drainReplies(ch) {
content += string(r.GetMessage())
}
Expect(content).To(Equal("ok<6F>"))
})
It("surfaces generator errors without sending replies", func() {
fake := &fakeGen{err: errors.New("stream boom")}
d := newTestDllm(fake, nil)
ch := make(chan *pb.Reply, 16)
err := d.PredictStreamRich(&pb.PredictOptions{Prompt: "p"}, ch)
Expect(err).To(MatchError("stream boom"))
Expect(drainReplies(ch)).To(BeEmpty())
})
It("errors before generating when no model is loaded", func() {
d := &Dllm{} // no Load, no worker: must fail fast, not hang
ch := make(chan *pb.Reply, 1)
err := d.PredictStreamRich(&pb.PredictOptions{Prompt: "p"}, ch)
Expect(err).To(MatchError(ContainSubstring("model not loaded")))
Expect(drainReplies(ch)).To(BeEmpty())
})
})
Describe("legacy Predict/PredictStream adapters", func() {
It("Predict returns the aggregated content string", func() {
fake := &fakeGen{out: "plain text"}
d := newTestDllm(fake, nil)
out, err := d.Predict(&pb.PredictOptions{Prompt: "p"})
Expect(err).ToNot(HaveOccurred())
Expect(out).To(Equal("plain text"))
})
It("PredictStream forwards content strings and closes the channel (legacy ownership)", func() {
fake := &fakeGen{blocks: []string{"a", "b"}}
d := newTestDllm(fake, nil)
ch := make(chan string, 16)
Expect(d.PredictStream(&pb.PredictOptions{Prompt: "p"}, ch)).To(Succeed())
var got []string
for s := range ch { // terminates only if the impl closed ch
got = append(got, s)
}
Expect(got).To(Equal([]string{"a", "b"}))
})
})
Describe("TokenizeString", func() {
It("decodes the C-side JSON id array", func() {
fake := &fakeGen{out: "[2,18]"}
d := newTestDllm(fake, nil)
resp, err := d.TokenizeString(&pb.PredictOptions{Prompt: "hello"})
Expect(err).ToNot(HaveOccurred())
Expect(resp.Length).To(Equal(int32(2)))
Expect(resp.Tokens).To(Equal([]int32{2, 18}))
calls, _ := fake.snapshot()
Expect(calls[0].prompt).To(Equal("hello"))
})
It("fails loud on a malformed id array", func() {
fake := &fakeGen{out: "not json"}
d := newTestDllm(fake, nil)
_, err := d.TokenizeString(&pb.PredictOptions{Prompt: "hello"})
Expect(err).To(HaveOccurred())
})
It("errors before tokenizing when no model is loaded", func() {
d := &Dllm{} // no Load, no worker: must fail fast, not hang
_, err := d.TokenizeString(&pb.PredictOptions{Prompt: "hello"})
Expect(err).To(MatchError(ContainSubstring("model not loaded")))
})
})
Describe("parseModelGenOpts", func() {
It("parses eb_*/blocks/kv_cache entries and types values by first successful parse", func() {
got := parseModelGenOpts([]string{
"eb_max_steps:16",
"eb_t_min:0.25",
"kv_cache:auto",
"blocks:4",
"unrelated_key:1", // other layers' options: skipped
"malformed", // no colon: skipped
})
Expect(got).To(Equal(map[string]any{
"eb_max_steps": int64(16),
"eb_t_min": 0.25,
"kv_cache": "auto",
"blocks": int64(4),
}))
})
It("round-trips through buildOptsJSON (only flat scalars are produced)", func() {
got := parseModelGenOpts([]string{"eb_entropy_bound:0.8", "kv_cache:off"})
out, err := buildOptsJSON(got)
Expect(err).ToNot(HaveOccurred())
Expect(out).To(MatchJSON(`{"eb_entropy_bound":0.8,"kv_cache":"off"}`))
})
})
})
// ---------------------------------------------------------------------------
// Gated backend round-trip against the real libdllm.so + tiny GGUF fixture.
// ---------------------------------------------------------------------------
var _ = Describe("Dllm backend (real tiny model)", func() {
BeforeEach(func() {
if os.Getenv("DLLM_TEST_LIBRARY") == "" || os.Getenv("DLLM_TEST_TINY_MODEL") == "" {
Skip("set DLLM_TEST_LIBRARY and DLLM_TEST_TINY_MODEL to run the backend round-trip")
}
ensureLibLoaded()
Expect(libLoadErr).ToNot(HaveOccurred())
})
It("round-trips Load, PredictRich, PredictStreamRich and TokenizeString", func() {
d := &Dllm{}
Expect(d.Load(&pb.ModelOptions{ModelFile: os.Getenv("DLLM_TEST_TINY_MODEL")})).To(Succeed())
DeferCleanup(func() { Expect(d.Free()).To(Succeed()) })
// TokenizeString: tiny fixture vocab tokenizes "hello" to [2,18].
resp, err := d.TokenizeString(&pb.PredictOptions{Prompt: "hello"})
Expect(err).ToNot(HaveOccurred())
Expect(resp.Tokens).To(Equal([]int32{2, 18}))
Expect(resp.Length).To(Equal(int32(2)))
req := &pb.PredictOptions{
UseTokenizerTemplate: true,
Messages: []*pb.Message{{Role: "user", Content: "hello"}},
Tokens: 16,
Seed: 7,
}
// Non-streaming: the tiny random-weight model emits arbitrary vocab
// words; with no gemma4 markers in them everything is content.
reply, err := d.PredictRich(req)
Expect(err).ToNot(HaveOccurred())
Expect(string(reply.GetMessage())).ToNot(BeEmpty())
Expect(reply.GetChatDeltas()).ToNot(BeEmpty())
// Streaming: at least one reply, and the channel-ownership rule is
// honored (drainReplies fails the spec on a closed channel).
ch := make(chan *pb.Reply, 64)
Expect(d.PredictStreamRich(req, ch)).To(Succeed())
replies := drainReplies(ch)
Expect(replies).ToNot(BeEmpty())
var streamed string
for _, r := range replies {
streamed += string(r.GetMessage())
}
Expect(streamed).ToNot(BeEmpty())
})
})

0
backend/go/dllm/main.go Normal file → Executable file
View File