mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-11 18:27:32 -04:00
feat(dllm): rich gRPC backend with ChatDelta streaming
Implements PredictRich/PredictStreamRich (legacy methods delegate), TokenizeString, and Load over the purego binding. A single worker goroutine serializes all C calls per the dllm.cpp one-generate-per-ctx contract (cancel is the documented exception); an RWMutex guards Free against in-flight requests. Under use_tokenizer_template the gemma4 renderer and streaming parser own templating and ChatDelta extraction; raw-prompt mode passes through verbatim. enable_thinking is opt-in via request metadata (the gemma4 template treats thinking as opt-in). Assisted-by: Claude Code (Fable 5) Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
35
backend/go/dllm/capi.go
Normal file → Executable file
35
backend/go/dllm/capi.go
Normal file → Executable file
@@ -16,15 +16,12 @@ package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
// dllmABIVersion is the DLLM_CAPI_ABI_VERSION this binding was written
|
||||
@@ -48,18 +45,6 @@ var (
|
||||
cppCancel func(ctx uintptr)
|
||||
)
|
||||
|
||||
// Dllm is the LocalAI gRPC backend over the dllm.cpp C-ABI. T1 ships only
|
||||
// the binding scaffold; Load/PredictRich/PredictStreamRich (and the move to
|
||||
// a dedicated dllm.go with the per-model worker goroutine) land in T4.
|
||||
type Dllm struct {
|
||||
base.Base
|
||||
}
|
||||
|
||||
// Load is not wired yet: the binding smoke drives the C functions directly.
|
||||
func (d *Dllm) Load(opts *pb.ModelOptions) error {
|
||||
return errors.New("dllm: model loading not implemented yet (backend wiring lands in T4)")
|
||||
}
|
||||
|
||||
// cAbiVersion returns the library's DLLM_CAPI_ABI_VERSION.
|
||||
func cAbiVersion() int32 {
|
||||
return cppAbiVersion()
|
||||
@@ -218,6 +203,11 @@ func cCancel(h uintptr) {
|
||||
// nested objects/arrays loudly; bools are rejected here too because the
|
||||
// scanner has no concept of them. Fail loud rather than let an option be
|
||||
// silently misread.
|
||||
//
|
||||
// CAVEAT: json.Marshal HTML-escapes <, > and & inside string values (e.g.
|
||||
// "<" becomes the six-byte \u003c sequence). None of the known string-valued keys
|
||||
// (kv_cache: auto|on|off) can contain those bytes today; if one ever does,
|
||||
// switch to an Encoder with SetEscapeHTML(false) like gemma4JSONString.
|
||||
func buildOptsJSON(opts map[string]any) (string, error) {
|
||||
if len(opts) == 0 {
|
||||
return "{}", nil
|
||||
@@ -246,17 +236,18 @@ func buildOptsJSON(opts map[string]any) (string, error) {
|
||||
// caller owns, or a callback argument only valid during the invocation);
|
||||
// owning callers must free it via cppFreeString after the copy lands.
|
||||
//
|
||||
// The uintptr->unsafe.Pointer conversion below trips go vet's unsafeptr
|
||||
// check, which can't distinguish a C-owned heap pointer from Go-managed
|
||||
// memory. It is safe here: the pointer addresses C memory the Go GC neither
|
||||
// tracks nor moves, and we dereference it immediately to copy the bytes out,
|
||||
// the same pattern (and the same tolerated warning) as the parakeet-cpp and
|
||||
// whisper backends.
|
||||
// A direct unsafe.Pointer(cptr) conversion trips go vet's unsafeptr check,
|
||||
// which can't distinguish a C-owned heap pointer from Go-managed memory (the
|
||||
// parakeet-cpp and whisper backends tolerate that warning). Reinterpreting
|
||||
// through &cptr below is equivalent at runtime and keeps plain `go vet`
|
||||
// clean. It is safe either way: the pointer addresses C memory the Go GC
|
||||
// neither tracks nor moves, and we dereference it immediately to copy the
|
||||
// bytes out.
|
||||
func goStringFromCPtr(cptr uintptr) string {
|
||||
if cptr == 0 {
|
||||
return ""
|
||||
}
|
||||
p := unsafe.Pointer(cptr) //nolint:govet // C-owned buffer, not Go-GC memory (see doc above)
|
||||
p := *(*unsafe.Pointer)(unsafe.Pointer(&cptr)) // C-owned buffer, not Go-GC memory (see doc above)
|
||||
n := 0
|
||||
for *(*byte)(unsafe.Add(p, n)) != 0 {
|
||||
n++
|
||||
|
||||
536
backend/go/dllm/dllm.go
Executable file
536
backend/go/dllm/dllm.go
Executable file
@@ -0,0 +1,536 @@
|
||||
package main
|
||||
|
||||
// LocalAI gRPC backend for dllm.cpp (DiffusionGemma block-diffusion models).
|
||||
//
|
||||
// Wiring overview:
|
||||
// - Load opens the GGUF via dllm_capi_load and starts the per-model worker
|
||||
// goroutine that serializes every C call (see submit).
|
||||
// - PredictRich / PredictStreamRich implement grpc.AIModelRich: when the
|
||||
// request carries raw messages (use_tokenizer_template), the backend owns
|
||||
// templating (RenderGemma4) and output parsing (Gemma4Parser) and replies
|
||||
// with ChatDeltas, like the llama.cpp autoparser and the ds4 backend.
|
||||
// - The legacy Predict / PredictStream methods delegate to the rich pair
|
||||
// (cloud-proxy precedent); the gRPC server prefers the rich path anyway.
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// generator is the seam between the backend wiring and the dllm.cpp C-ABI:
|
||||
// the real implementation (capiGenerator) wraps the cGenerate/cTokenizeJSON
|
||||
// family, while tests substitute a fake to exercise prompt construction,
|
||||
// parsing and serialization without libdllm.so.
|
||||
type generator interface {
|
||||
generate(prompt, optsJSON string) (string, error)
|
||||
// generateStream invokes onBlock once per committed diffusion block, on
|
||||
// the thread running the C call, before returning.
|
||||
generateStream(prompt, optsJSON string, onBlock func(text string)) error
|
||||
tokenizeJSON(text string) (string, error)
|
||||
// cancel is the ONE entry point safe to call concurrently with an
|
||||
// in-flight generate on the same ctx (dllm_capi.h: it only flips an
|
||||
// atomic; everything else must be externally serialized per ctx).
|
||||
cancel()
|
||||
free()
|
||||
}
|
||||
|
||||
// capiGenerator is the production generator over one dllm_ctx handle.
|
||||
type capiGenerator struct {
|
||||
h uintptr
|
||||
}
|
||||
|
||||
func (g *capiGenerator) generate(prompt, optsJSON string) (string, error) {
|
||||
return cGenerate(g.h, prompt, optsJSON)
|
||||
}
|
||||
|
||||
func (g *capiGenerator) generateStream(prompt, optsJSON string, onBlock func(text string)) error {
|
||||
// on_step (per-denoise-step canvas preview, dllm.cpp's --visual) is
|
||||
// passed as nil for now: a future progress hook for the React UI can
|
||||
// plumb it through without touching the C binding.
|
||||
return cGenerateStream(g.h, prompt, optsJSON, onBlock, nil)
|
||||
}
|
||||
|
||||
func (g *capiGenerator) tokenizeJSON(text string) (string, error) {
|
||||
return cTokenizeJSON(g.h, text)
|
||||
}
|
||||
|
||||
func (g *capiGenerator) cancel() {
|
||||
cCancel(g.h)
|
||||
}
|
||||
|
||||
func (g *capiGenerator) free() {
|
||||
cFree(g.h)
|
||||
}
|
||||
|
||||
// Dllm is the gRPC backend instance: one per loaded model (LocalAI starts
|
||||
// one backend process per model).
|
||||
type Dllm struct {
|
||||
base.Base
|
||||
|
||||
gen generator
|
||||
// genOpts holds the model-level generation overrides parsed from
|
||||
// ModelOptions.Options at Load (eb_*, blocks, kv_cache). The C-ABI takes
|
||||
// them per-generate, not per-load, so they are merged into every
|
||||
// request's opts JSON (requestOptsJSON).
|
||||
genOpts map[string]any
|
||||
|
||||
// jobs is the per-model worker queue. dllm_capi.h requires every entry
|
||||
// point EXCEPT dllm_capi_cancel to be externally serialized per ctx (one
|
||||
// ctx = one concurrent generate/tokenize; last_error is unsafe to read
|
||||
// while a call is in flight). A single goroutine owning all C calls makes
|
||||
// that contract structural instead of relying on lock discipline.
|
||||
jobs chan func()
|
||||
workerWG sync.WaitGroup
|
||||
|
||||
// genMu guards gen against Free racing in-flight requests: requests hold
|
||||
// the read lock for their full duration (they stay concurrent with each
|
||||
// other - the worker still serializes the C calls), Free takes the write
|
||||
// lock so it can only run when no request is in flight.
|
||||
genMu sync.RWMutex
|
||||
}
|
||||
|
||||
func (d *Dllm) startWorker() {
|
||||
d.jobs = make(chan func())
|
||||
d.workerWG.Add(1)
|
||||
go func() {
|
||||
defer d.workerWG.Done()
|
||||
for job := range d.jobs {
|
||||
job()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// submit runs job on the worker goroutine and waits for it to finish.
|
||||
// Concurrent gRPC requests therefore queue up and execute one at a time
|
||||
// against the single dllm_ctx.
|
||||
func (d *Dllm) submit(job func()) {
|
||||
done := make(chan struct{})
|
||||
d.jobs <- func() {
|
||||
defer close(done)
|
||||
job()
|
||||
}
|
||||
<-done
|
||||
}
|
||||
|
||||
// Load opens the GGUF and prepares the worker. Load-time engine parameters
|
||||
// travel as the flat params JSON of dllm_capi_load; generation overrides
|
||||
// from Options are stored for per-request opts JSON instead (the C-ABI has
|
||||
// no per-load sampler state).
|
||||
func (d *Dllm) Load(opts *pb.ModelOptions) error {
|
||||
if d.gen != nil {
|
||||
return errors.New("dllm: model already loaded")
|
||||
}
|
||||
|
||||
params := map[string]any{
|
||||
"n_gpu_layers": opts.GetNGPULayers(),
|
||||
}
|
||||
if opts.GetThreads() > 0 {
|
||||
params["n_threads"] = opts.GetThreads()
|
||||
}
|
||||
if opts.GetContextSize() > 0 {
|
||||
params["ctx_len"] = opts.GetContextSize()
|
||||
}
|
||||
paramsJSON, err := buildOptsJSON(params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d.genOpts = parseModelGenOpts(opts.GetOptions())
|
||||
|
||||
h := cLoad(opts.GetModelFile(), paramsJSON)
|
||||
if h == 0 {
|
||||
// No ctx exists on load failure, so last_error(NULL) only carries the
|
||||
// static NULL-ctx message; the real reason is on the backend's stderr.
|
||||
return fmt.Errorf("dllm: load %q failed: %s (see backend log for details)",
|
||||
opts.GetModelFile(), lastErrorOr(0, "unknown error"))
|
||||
}
|
||||
d.gen = &capiGenerator{h: h}
|
||||
d.startWorker()
|
||||
xlog.Info("dllm: model loaded", "model", opts.GetModelFile(), "params", paramsJSON, "gen_opts", d.genOpts)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Free releases the dllm ctx and stops the worker. Safe when never loaded.
|
||||
//
|
||||
// The write lock is essential: the gRPC server (pkg/grpc/server.go, see the
|
||||
// model-unload path around line 764) calls Free with no locking of its own,
|
||||
// and base.Base provides none either. Without it a request racing Free would
|
||||
// panic sending on the closed jobs channel - or worse, generate on a freed C
|
||||
// ctx. Holding genMu until gen is nil also turns post-Free requests into a
|
||||
// clean "model not loaded" error instead of a crash.
|
||||
func (d *Dllm) Free() error {
|
||||
d.genMu.Lock()
|
||||
defer d.genMu.Unlock()
|
||||
if d.gen == nil {
|
||||
return nil
|
||||
}
|
||||
d.submit(d.gen.free)
|
||||
close(d.jobs)
|
||||
d.workerWG.Wait()
|
||||
d.gen = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Cancel requests cancellation of the in-flight generate. It deliberately
|
||||
// bypasses the worker queue: dllm_capi_cancel is the one call the C-ABI
|
||||
// allows from any goroutine mid-generate (it only flips an atomic).
|
||||
//
|
||||
// LIMITATION: nothing invokes this on client disconnect today. The gRPC
|
||||
// server (pkg/grpc/server.go) does not hand the request/stream context to
|
||||
// Predict/PredictStreamRich, so a dropped HTTP client cannot reach the
|
||||
// backend until that plumbing exists; the method is here so future server
|
||||
// wiring (or an admin RPC) has something to call. Note dllm_capi.h's
|
||||
// cancel-reset race: each generate resets the flag on entry, so a caller
|
||||
// racing a new generate should re-issue Cancel.
|
||||
func (d *Dllm) Cancel() {
|
||||
if d.gen != nil {
|
||||
d.gen.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// dllmGenOptKeys are the ModelOptions.Options keys this backend forwards to
|
||||
// the engine. Options is a shared free-form bag (other layers put their own
|
||||
// entries there), so unknown keys are skipped with a warning, not an error.
|
||||
var dllmGenOptKeys = map[string]bool{
|
||||
"blocks": true,
|
||||
"kv_cache": true, // "auto"|"on"|"off"; honored by the engine from P3
|
||||
}
|
||||
|
||||
// parseModelGenOpts parses "key:value" Options entries into the flat scalar
|
||||
// map merged into every generate's opts JSON. eb_* (Entropy-Bound sampler
|
||||
// knobs) and the keys in dllmGenOptKeys are recognized; values are typed by
|
||||
// first successful parse (int, then float, else string) to match the C
|
||||
// scanner's number/string scalars.
|
||||
func parseModelGenOpts(options []string) map[string]any {
|
||||
out := map[string]any{}
|
||||
for _, o := range options {
|
||||
key, val, found := strings.Cut(o, ":")
|
||||
if !found {
|
||||
xlog.Warn("dllm: ignoring malformed option (want key:value)", "option", o)
|
||||
continue
|
||||
}
|
||||
if !strings.HasPrefix(key, "eb_") && !dllmGenOptKeys[key] {
|
||||
xlog.Debug("dllm: ignoring unrecognized option", "key", key)
|
||||
continue
|
||||
}
|
||||
out[key] = parseScalarOpt(val)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func parseScalarOpt(v string) any {
|
||||
if iv, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
return iv
|
||||
}
|
||||
if fv, err := strconv.ParseFloat(v, 64); err == nil {
|
||||
return fv
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// metadataEnableThinking reads the enable_thinking gate. Unlike ds4 (default
|
||||
// ON, matching ds4-server), dllm defaults OFF: DiffusionGemma's chat
|
||||
// template guards every thinking branch with `enable_thinking is defined and
|
||||
// enable_thinking`, i.e. thinking is opt-in for this model family, and the
|
||||
// no-thinking render pre-closes an empty thought channel that the OFF
|
||||
// default must produce.
|
||||
func metadataEnableThinking(opts *pb.PredictOptions) bool {
|
||||
v := opts.GetMetadata()["enable_thinking"]
|
||||
return v == "true" || v == "1"
|
||||
}
|
||||
|
||||
// buildPrompt resolves the prompt for a request. With use_tokenizer_template
|
||||
// and raw messages the backend owns templating (RenderGemma4) and the output
|
||||
// is in the known gemma4 format, so parse=true. Without it the caller
|
||||
// templated the prompt themselves (LocalAI's Go templates + PEG fallback, or
|
||||
// a bare completion): the prompt passes through verbatim and the output is
|
||||
// NOT gemma4-parsed - it is emitted as plain content and the Go side's
|
||||
// extraction applies, as for any non-autoparsing backend.
|
||||
func buildPrompt(opts *pb.PredictOptions) (prompt string, parse bool, err error) {
|
||||
if opts.GetUseTokenizerTemplate() && len(opts.GetMessages()) > 0 {
|
||||
prompt, err = RenderGemma4(opts.GetMessages(), opts.GetTools(), metadataEnableThinking(opts), true)
|
||||
return prompt, true, err
|
||||
}
|
||||
return opts.GetPrompt(), false, nil
|
||||
}
|
||||
|
||||
// requestOptsJSON merges the model-level overrides with the request's
|
||||
// sampling fields into the flat opts JSON for one generate call.
|
||||
func (d *Dllm) requestOptsJSON(opts *pb.PredictOptions) (string, error) {
|
||||
m := make(map[string]any, len(d.genOpts)+2)
|
||||
for k, v := range d.genOpts {
|
||||
m[k] = v
|
||||
}
|
||||
if n := opts.GetTokens(); n > 0 {
|
||||
// The engine rounds n_predict UP to a whole number of diffusion
|
||||
// blocks (the canvas is denoised block-wise), so the completion may
|
||||
// run slightly past the requested budget. Tokens==0 omits the key so
|
||||
// the engine's GGUF-metadata default applies (the C-ABI documents
|
||||
// per-key defaults; no hardcoded 256 like ds4's grpc-server).
|
||||
m["n_predict"] = n
|
||||
}
|
||||
if s := opts.GetSeed(); s > 0 {
|
||||
// The engine seeds mt19937 with explicit non-negative seeds. Seed<=0
|
||||
// is omitted: proto3 cannot distinguish 0 from unset, and negative
|
||||
// values conventionally mean "random" across LocalAI backends.
|
||||
m["seed"] = s
|
||||
}
|
||||
return buildOptsJSON(m)
|
||||
}
|
||||
|
||||
// prepareRequest is the shared prologue of the rich methods: resolve the
|
||||
// prompt (and whether the output gets gemma4-parsed) and build the per-call
|
||||
// opts JSON.
|
||||
func (d *Dllm) prepareRequest(opts *pb.PredictOptions) (prompt string, parse bool, optsJSON string, err error) {
|
||||
prompt, parse, err = buildPrompt(opts)
|
||||
if err != nil {
|
||||
return "", false, "", err
|
||||
}
|
||||
optsJSON, err = d.requestOptsJSON(opts)
|
||||
if err != nil {
|
||||
return "", false, "", err
|
||||
}
|
||||
return prompt, parse, optsJSON, nil
|
||||
}
|
||||
|
||||
// sanitizeUTF8 makes s safe for a proto3 string field. Block-boundary
|
||||
// detokenization and byte-fallback tokens can produce invalid UTF-8, and
|
||||
// grpc-go refuses to marshal it ("string field contains invalid UTF-8"), so
|
||||
// every string destined for a Reply/ChatDelta must pass through here (or
|
||||
// through splitValidUTF8, which calls it). Lone malformed bytes are genuinely
|
||||
// undecodable: replace with U+FFFD rather than crash the stream.
|
||||
func sanitizeUTF8(s string) string {
|
||||
if utf8.ValidString(s) {
|
||||
return s
|
||||
}
|
||||
return strings.ToValidUTF8(s, "<22>")
|
||||
}
|
||||
|
||||
// utf8SeqLen returns the declared sequence length of a UTF-8 leading byte
|
||||
// (1 for bytes that can never lead a multi-byte sequence, so they are never
|
||||
// held back and fall through to sanitizeUTF8's replacement).
|
||||
func utf8SeqLen(b byte) int {
|
||||
switch {
|
||||
case b&0xE0 == 0xC0:
|
||||
return 2
|
||||
case b&0xF0 == 0xE0:
|
||||
return 3
|
||||
case b&0xF8 == 0xF0:
|
||||
return 4
|
||||
default:
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
// splitValidUTF8 prepends the previous block's carry to the new block and
|
||||
// splits the result into text safe to emit now and a trailing INCOMPLETE
|
||||
// UTF-8 sequence (at most utf8.UTFMax-1 bytes) to carry into the next block:
|
||||
// the per-block detokenize can split a multi-byte character across block
|
||||
// boundaries (llama.cpp's grpc-server holds back the same way). Only a
|
||||
// suffix that can still become a valid rune is withheld; bytes that are
|
||||
// already undecodable are replaced immediately so the carry stays bounded.
|
||||
func splitValidUTF8(carry, block string) (emit, newCarry string) {
|
||||
s := carry + block
|
||||
cut := len(s)
|
||||
for i := len(s) - 1; i >= 0 && len(s)-i < utf8.UTFMax; i-- {
|
||||
b := s[i]
|
||||
if b < utf8.RuneSelf {
|
||||
break // ASCII: everything before the tail scan is complete
|
||||
}
|
||||
if !utf8.RuneStart(b) {
|
||||
continue // continuation byte: keep looking for its leading byte
|
||||
}
|
||||
// Leading byte: hold the sequence back iff it declares more bytes
|
||||
// than the stream has produced so far (it may complete next block).
|
||||
if utf8SeqLen(b) > len(s)-i {
|
||||
cut = i
|
||||
}
|
||||
break
|
||||
}
|
||||
return sanitizeUTF8(s[:cut]), s[cut:]
|
||||
}
|
||||
|
||||
// PredictRich is the non-streaming inference path (grpc.AIModelRich).
|
||||
// Returns one Reply whose Message is the aggregated assistant content and
|
||||
// whose ChatDeltas carry the parsed content/reasoning/tool-call events.
|
||||
func (d *Dllm) PredictRich(opts *pb.PredictOptions) (*pb.Reply, error) {
|
||||
d.genMu.RLock()
|
||||
defer d.genMu.RUnlock()
|
||||
if d.gen == nil {
|
||||
return nil, grpcerrors.ModelNotLoaded("dllm")
|
||||
}
|
||||
prompt, parse, optsJSON, err := d.prepareRequest(opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var out string
|
||||
var genErr error
|
||||
d.submit(func() {
|
||||
out, genErr = d.gen.generate(prompt, optsJSON)
|
||||
})
|
||||
if genErr != nil {
|
||||
return nil, genErr
|
||||
}
|
||||
// Byte-fallback tokens can detokenize to invalid UTF-8; proto3 strings
|
||||
// must be valid or grpc-go fails the whole reply at marshal time.
|
||||
out = sanitizeUTF8(out)
|
||||
|
||||
if !parse {
|
||||
// Raw-prompt mode: plain content, no gemma4 parsing (see buildPrompt).
|
||||
return &pb.Reply{Message: []byte(out), ChatDeltas: []*pb.ChatDelta{{Content: out}}}, nil
|
||||
}
|
||||
|
||||
// The prompt renders with add_generation_prompt; both thinking modes
|
||||
// leave the model starting in content state (see the Gemma4Parser header
|
||||
// comment), hence NewGemma4Parser(false).
|
||||
parser := NewGemma4Parser(false)
|
||||
if reply := replyFromDeltas(append(parser.Feed(out), parser.Close()...)); reply != nil {
|
||||
return reply, nil
|
||||
}
|
||||
// Everything was markers (or out was empty): an empty but non-nil Reply.
|
||||
return &pb.Reply{}, nil
|
||||
}
|
||||
|
||||
// PredictStreamRich is the streaming counterpart (grpc.AIModelRich): one
|
||||
// Reply per committed diffusion block that produced deltas. Per the
|
||||
// interface contract the channel is only sent into here - the gRPC server
|
||||
// closes it after this returns (opposite to legacy PredictStream).
|
||||
func (d *Dllm) PredictStreamRich(opts *pb.PredictOptions, results chan<- *pb.Reply) error {
|
||||
d.genMu.RLock()
|
||||
defer d.genMu.RUnlock()
|
||||
if d.gen == nil {
|
||||
return grpcerrors.ModelNotLoaded("dllm")
|
||||
}
|
||||
prompt, parse, optsJSON, err := d.prepareRequest(opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var parser *Gemma4Parser
|
||||
if parse {
|
||||
parser = NewGemma4Parser(false)
|
||||
}
|
||||
// emit runs inside onBlock, i.e. on the thread driving the C generate.
|
||||
// Sending on results can block on a slow consumer, but the server-side
|
||||
// pump (pkg/grpc/server.go PredictStream) drains continuously and drops
|
||||
// undeliverable sends, so this backpressure is brief and bounded - and
|
||||
// pausing the diffusion loop under it is the desired behavior anyway.
|
||||
emit := func(text string) {
|
||||
if !parse {
|
||||
if text != "" {
|
||||
results <- &pb.Reply{Message: []byte(text), ChatDeltas: []*pb.ChatDelta{{Content: text}}}
|
||||
}
|
||||
return
|
||||
}
|
||||
deltas := parser.Feed(text)
|
||||
if reply := replyFromDeltas(deltas); reply != nil {
|
||||
results <- reply
|
||||
}
|
||||
}
|
||||
// onBlock guards emit (and through it the parser) against invalid UTF-8:
|
||||
// a multi-byte character split across block boundaries is held back until
|
||||
// it completes (see splitValidUTF8), so proto3 marshaling never fails.
|
||||
var carry string
|
||||
onBlock := func(block string) {
|
||||
var text string
|
||||
text, carry = splitValidUTF8(carry, block)
|
||||
emit(text)
|
||||
}
|
||||
|
||||
var genErr error
|
||||
d.submit(func() {
|
||||
genErr = d.gen.generateStream(prompt, optsJSON, onBlock)
|
||||
})
|
||||
if genErr != nil {
|
||||
return genErr
|
||||
}
|
||||
if carry != "" {
|
||||
// The stream ended mid-sequence: the held-back bytes can no longer
|
||||
// complete, so flush them through the U+FFFD last resort.
|
||||
emit(sanitizeUTF8(carry))
|
||||
}
|
||||
if parse {
|
||||
if reply := replyFromDeltas(parser.Close()); reply != nil {
|
||||
results <- reply
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// replyFromDeltas wraps one batch of parsed deltas into a streaming Reply,
|
||||
// or nil when the batch is empty (markers consumed, nothing emitted yet).
|
||||
// Message mirrors the batch's content text so legacy chan-string consumers
|
||||
// see exactly the displayed tokens.
|
||||
func replyFromDeltas(deltas []*pb.ChatDelta) *pb.Reply {
|
||||
if len(deltas) == 0 {
|
||||
return nil
|
||||
}
|
||||
var content strings.Builder
|
||||
for _, delta := range deltas {
|
||||
content.WriteString(delta.GetContent())
|
||||
}
|
||||
return &pb.Reply{Message: []byte(content.String()), ChatDeltas: deltas}
|
||||
}
|
||||
|
||||
// Predict is the legacy (string, error) signature; the gRPC server prefers
|
||||
// PredictRich, this exists for non-rich callers (cloud-proxy precedent).
|
||||
func (d *Dllm) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
reply, err := d.PredictRich(opts)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(reply.GetMessage()), nil
|
||||
}
|
||||
|
||||
// PredictStream is the legacy chan-string path: rich replies reduced to
|
||||
// their content text. Note the inverted channel ownership - the LEGACY
|
||||
// contract requires the impl to close the channel.
|
||||
func (d *Dllm) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
defer close(results)
|
||||
richCh := make(chan *pb.Reply)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- d.PredictStreamRich(opts, richCh)
|
||||
close(richCh)
|
||||
}()
|
||||
for reply := range richCh {
|
||||
if msg := reply.GetMessage(); len(msg) > 0 {
|
||||
results <- string(msg)
|
||||
}
|
||||
}
|
||||
return <-errCh
|
||||
}
|
||||
|
||||
// TokenizeString tokenizes opts.Prompt via dllm_capi_tokenize_json (the C
|
||||
// side prepends bos per the vocab) and decodes the returned id array.
|
||||
func (d *Dllm) TokenizeString(opts *pb.PredictOptions) (pb.TokenizationResponse, error) {
|
||||
d.genMu.RLock()
|
||||
defer d.genMu.RUnlock()
|
||||
if d.gen == nil {
|
||||
return pb.TokenizationResponse{}, grpcerrors.ModelNotLoaded("dllm")
|
||||
}
|
||||
var out string
|
||||
var tokErr error
|
||||
d.submit(func() {
|
||||
out, tokErr = d.gen.tokenizeJSON(opts.GetPrompt())
|
||||
})
|
||||
if tokErr != nil {
|
||||
return pb.TokenizationResponse{}, tokErr
|
||||
}
|
||||
var tokens []int32
|
||||
if err := json.Unmarshal([]byte(out), &tokens); err != nil {
|
||||
return pb.TokenizationResponse{}, fmt.Errorf("dllm: decode tokenize result %q: %w", out, err)
|
||||
}
|
||||
return pb.TokenizationResponse{Length: int32(len(tokens)), Tokens: tokens}, nil
|
||||
}
|
||||
627
backend/go/dllm/dllm_test.go
Normal file → Executable file
627
backend/go/dllm/dllm_test.go
Normal file → Executable file
@@ -1,13 +1,19 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
"unsafe"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
func TestDllm(t *testing.T) {
|
||||
@@ -131,10 +137,59 @@ var _ = Describe("buildOptsJSON", func() {
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("splitValidUTF8", func() {
|
||||
It("holds back a trailing incomplete sequence and completes it next block", func() {
|
||||
emit, carry := splitValidUTF8("", "caf\xe2")
|
||||
Expect(emit).To(Equal("caf"))
|
||||
Expect(carry).To(Equal("\xe2"))
|
||||
|
||||
emit, carry = splitValidUTF8(carry, "\x82")
|
||||
Expect(emit).To(BeEmpty())
|
||||
Expect(carry).To(Equal("\xe2\x82"))
|
||||
|
||||
emit, carry = splitValidUTF8(carry, "\xac!")
|
||||
Expect(emit).To(Equal("€!"))
|
||||
Expect(carry).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("holds back up to 3 bytes of a 4-byte sequence", func() {
|
||||
emit, carry := splitValidUTF8("", "x\xf0\x9f\x98") // 😀 missing its last byte
|
||||
Expect(emit).To(Equal("x"))
|
||||
Expect(carry).To(Equal("\xf0\x9f\x98"))
|
||||
|
||||
emit, carry = splitValidUTF8(carry, "\x80")
|
||||
Expect(emit).To(Equal("😀"))
|
||||
Expect(carry).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("replaces undecodable bytes immediately instead of carrying them", func() {
|
||||
// A mid-string invalid byte can never complete: carrying it would let
|
||||
// the carry grow unboundedly, so it is substituted on the spot.
|
||||
emit, carry := splitValidUTF8("", "a\xe2bc")
|
||||
Expect(emit).To(Equal("a<>bc"))
|
||||
Expect(carry).To(BeEmpty())
|
||||
|
||||
// Orphan continuation bytes at the end have no leading byte to wait
|
||||
// for either.
|
||||
emit, carry = splitValidUTF8("", "a\x82")
|
||||
Expect(emit).To(Equal("a<>"))
|
||||
Expect(carry).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("passes pure ASCII and complete UTF-8 through untouched", func() {
|
||||
emit, carry := splitValidUTF8("", "héllo €")
|
||||
Expect(emit).To(Equal("héllo €"))
|
||||
Expect(carry).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("goStringFromCPtr", func() {
|
||||
It("copies a NUL-terminated buffer", func() {
|
||||
buf := []byte("dllm\x00")
|
||||
s := goStringFromCPtr(uintptr(unsafe.Pointer(&buf[0])))
|
||||
// The uintptr round-trip hides buf from the GC's liveness analysis;
|
||||
// keep it reachable until after the copy.
|
||||
runtime.KeepAlive(buf)
|
||||
Expect(s).To(Equal("dllm"))
|
||||
})
|
||||
|
||||
@@ -142,3 +197,575 @@ var _ = Describe("goStringFromCPtr", func() {
|
||||
Expect(goStringFromCPtr(0)).To(Equal(""))
|
||||
})
|
||||
})
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Backend wiring (T4): fake-generator specs, no libdllm.so required.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type fakeGenCall struct {
|
||||
prompt string
|
||||
optsJSON string
|
||||
}
|
||||
|
||||
// fakeGen implements generator in-process. It records every call (prompt +
|
||||
// opts JSON), tracks concurrent in-flight calls to prove worker
|
||||
// serialization, and replays canned output (out for generate/tokenize,
|
||||
// blocks for generateStream).
|
||||
type fakeGen struct {
|
||||
mu sync.Mutex
|
||||
calls []fakeGenCall
|
||||
inFlight int
|
||||
maxInFlight int
|
||||
|
||||
out string
|
||||
blocks []string
|
||||
err error
|
||||
delay time.Duration
|
||||
}
|
||||
|
||||
func (f *fakeGen) begin(prompt, optsJSON string) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.calls = append(f.calls, fakeGenCall{prompt: prompt, optsJSON: optsJSON})
|
||||
f.inFlight++
|
||||
if f.inFlight > f.maxInFlight {
|
||||
f.maxInFlight = f.inFlight
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeGen) end() {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.inFlight--
|
||||
}
|
||||
|
||||
func (f *fakeGen) snapshot() (calls []fakeGenCall, maxInFlight int) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
return append([]fakeGenCall(nil), f.calls...), f.maxInFlight
|
||||
}
|
||||
|
||||
func (f *fakeGen) generate(prompt, optsJSON string) (string, error) {
|
||||
f.begin(prompt, optsJSON)
|
||||
defer f.end()
|
||||
if f.delay > 0 {
|
||||
time.Sleep(f.delay)
|
||||
}
|
||||
return f.out, f.err
|
||||
}
|
||||
|
||||
func (f *fakeGen) generateStream(prompt, optsJSON string, onBlock func(text string)) error {
|
||||
f.begin(prompt, optsJSON)
|
||||
defer f.end()
|
||||
if f.err != nil {
|
||||
return f.err
|
||||
}
|
||||
for _, b := range f.blocks {
|
||||
onBlock(b)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeGen) tokenizeJSON(text string) (string, error) {
|
||||
f.begin(text, "")
|
||||
defer f.end()
|
||||
return f.out, f.err
|
||||
}
|
||||
|
||||
func (f *fakeGen) cancel() {}
|
||||
func (f *fakeGen) free() {}
|
||||
|
||||
// newTestDllm assembles a backend around a fake generator (bypassing Load,
|
||||
// which needs libdllm.so) and registers cleanup of the worker goroutine.
|
||||
func newTestDllm(g generator, genOpts map[string]any) *Dllm {
|
||||
d := &Dllm{gen: g, genOpts: genOpts}
|
||||
d.startWorker()
|
||||
DeferCleanup(func() { Expect(d.Free()).To(Succeed()) })
|
||||
return d
|
||||
}
|
||||
|
||||
// drainReplies empties ch without blocking, failing the spec if the channel
|
||||
// was closed (PredictStreamRich must NOT close it - interface.go contract).
|
||||
// Size ch above the expected reply count: an overflow deadlocks the spec on
|
||||
// the producer's send instead of failing it.
|
||||
func drainReplies(ch chan *pb.Reply) []*pb.Reply {
|
||||
var out []*pb.Reply
|
||||
for {
|
||||
select {
|
||||
case r, ok := <-ch:
|
||||
if !ok {
|
||||
Fail("PredictStreamRich closed the results channel (the gRPC server owns the close)")
|
||||
}
|
||||
expectValidUTF8Reply(r)
|
||||
out = append(out, r)
|
||||
default:
|
||||
return out
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// expectValidUTF8Reply is the blanket guard for the proto3 marshaling
|
||||
// contract: grpc-go rejects any string field carrying invalid UTF-8, so every
|
||||
// reply field that ends up in a proto string must validate.
|
||||
func expectValidUTF8Reply(r *pb.Reply) {
|
||||
GinkgoHelper()
|
||||
Expect(utf8.ValidString(string(r.GetMessage()))).To(BeTrue(), "Reply.Message carries invalid UTF-8")
|
||||
for _, delta := range r.GetChatDeltas() {
|
||||
Expect(utf8.ValidString(delta.GetContent())).To(BeTrue(), "ChatDelta.Content carries invalid UTF-8")
|
||||
Expect(utf8.ValidString(delta.GetReasoningContent())).To(BeTrue(), "ChatDelta.ReasoningContent carries invalid UTF-8")
|
||||
for _, tc := range delta.GetToolCalls() {
|
||||
Expect(utf8.ValidString(tc.GetName())).To(BeTrue(), "ToolCallDelta.Name carries invalid UTF-8")
|
||||
Expect(utf8.ValidString(tc.GetArguments())).To(BeTrue(), "ToolCallDelta.Arguments carries invalid UTF-8")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var _ = Describe("Dllm backend wiring", func() {
|
||||
Describe("PredictRich", func() {
|
||||
It("renders gemma4 from raw messages and parses the output when use_tokenizer_template is set", func() {
|
||||
fake := &fakeGen{out: "<|channel>thought\npondering<channel|>The answer.<turn|>"}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
reply, err := d.PredictRich(&pb.PredictOptions{
|
||||
UseTokenizerTemplate: true,
|
||||
Messages: []*pb.Message{{Role: "user", Content: "Write a long essay about Portugal."}},
|
||||
Metadata: map[string]string{"enable_thinking": "true"},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
calls, _ := fake.snapshot()
|
||||
Expect(calls).To(HaveLen(1))
|
||||
// The enable_thinking=true render from the transformers fixture.
|
||||
Expect(calls[0].prompt).To(Equal(
|
||||
"<|turn>system\n<|think|>\n<turn|>\n<|turn>user\nWrite a long essay about Portugal.<turn|>\n<|turn>model\n"))
|
||||
|
||||
Expect(string(reply.GetMessage())).To(Equal("The answer."))
|
||||
Expect(reply.GetChatDeltas()).To(HaveLen(2))
|
||||
Expect(reply.GetChatDeltas()[0].GetReasoningContent()).To(Equal("pondering"))
|
||||
Expect(reply.GetChatDeltas()[1].GetContent()).To(Equal("The answer."))
|
||||
})
|
||||
|
||||
It("defaults enable_thinking OFF (the gemma4 template treats thinking as opt-in)", func() {
|
||||
fake := &fakeGen{out: "hi"}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
_, err := d.PredictRich(&pb.PredictOptions{
|
||||
UseTokenizerTemplate: true,
|
||||
Messages: []*pb.Message{{Role: "user", Content: "Write a long essay about Portugal."}},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
calls, _ := fake.snapshot()
|
||||
// No-thinking render: the template pre-opens AND pre-closes an
|
||||
// empty thought channel in the generation prompt.
|
||||
Expect(calls[0].prompt).To(Equal(
|
||||
"<|turn>user\nWrite a long essay about Portugal.<turn|>\n<|turn>model\n<|channel>thought\n<channel|>"))
|
||||
})
|
||||
|
||||
It("passes the raw prompt verbatim and skips gemma4 parsing without use_tokenizer_template", func() {
|
||||
// Marker-looking text must survive untouched: in raw-prompt mode
|
||||
// the caller templates themselves and the Go-side extraction
|
||||
// applies, so the backend must not interpret the output.
|
||||
fake := &fakeGen{out: "<|channel>thought\nnot parsed<channel|>tail"}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
reply, err := d.PredictRich(&pb.PredictOptions{Prompt: "my raw prompt"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
calls, _ := fake.snapshot()
|
||||
Expect(calls[0].prompt).To(Equal("my raw prompt"))
|
||||
Expect(string(reply.GetMessage())).To(Equal(fake.out))
|
||||
Expect(reply.GetChatDeltas()).To(HaveLen(1))
|
||||
Expect(reply.GetChatDeltas()[0].GetContent()).To(Equal(fake.out))
|
||||
})
|
||||
|
||||
It("sanitizes invalid UTF-8 in the non-streaming output", func() {
|
||||
// Byte-fallback tokens can decode to lone malformed bytes; the
|
||||
// whole-output sanitize must replace them so proto3 marshaling of
|
||||
// Message/ChatDeltas cannot fail.
|
||||
fake := &fakeGen{out: "a\xe2b"}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
reply, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expectValidUTF8Reply(reply)
|
||||
Expect(string(reply.GetMessage())).To(Equal("a<>b"))
|
||||
Expect(reply.GetChatDeltas()[0].GetContent()).To(Equal("a<>b"))
|
||||
})
|
||||
|
||||
It("maps Tokens and Seed into the opts JSON on top of the model-level overrides", func() {
|
||||
fake := &fakeGen{out: "x"}
|
||||
d := newTestDllm(fake, map[string]any{"eb_t_min": 0.5, "kv_cache": "auto"})
|
||||
|
||||
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p", Tokens: 32, Seed: 7})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
calls, _ := fake.snapshot()
|
||||
Expect(calls[0].optsJSON).To(MatchJSON(`{"n_predict":32,"seed":7,"eb_t_min":0.5,"kv_cache":"auto"}`))
|
||||
})
|
||||
|
||||
It("omits n_predict and seed when unset so the engine defaults apply", func() {
|
||||
fake := &fakeGen{out: "x"}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
calls, _ := fake.snapshot()
|
||||
Expect(calls[0].optsJSON).To(MatchJSON(`{}`))
|
||||
})
|
||||
|
||||
It("surfaces generator errors", func() {
|
||||
fake := &fakeGen{err: errors.New("boom")}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
|
||||
Expect(err).To(MatchError("boom"))
|
||||
})
|
||||
|
||||
It("errors before generating when no model is loaded", func() {
|
||||
d := &Dllm{} // no Load, no worker: must fail fast, not hang
|
||||
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("makes a concurrent Free wait for the in-flight request (both finish cleanly)", func() {
|
||||
// server.go's Free has no locking of its own: the backend's genMu
|
||||
// must hold Free back until the racing generate drains, instead of
|
||||
// closing the jobs channel (panic) or freeing the C ctx under it.
|
||||
fake := &fakeGen{out: "x", delay: 50 * time.Millisecond}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
predictDone := make(chan error, 1)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
|
||||
predictDone <- err
|
||||
}()
|
||||
// Wait until the fake generate is actually in flight (the read
|
||||
// lock is held from before submit until PredictRich returns).
|
||||
Eventually(func() int {
|
||||
_, maxInFlight := fake.snapshot()
|
||||
return maxInFlight
|
||||
}).Should(Equal(1))
|
||||
|
||||
Expect(d.Free()).To(Succeed())
|
||||
// Free's write lock means the request finished before Free did.
|
||||
var predictErr error
|
||||
Eventually(predictDone).Should(Receive(&predictErr))
|
||||
Expect(predictErr).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("returns model-not-loaded for requests after Free", func() {
|
||||
fake := &fakeGen{out: "x"}
|
||||
d := newTestDllm(fake, nil)
|
||||
Expect(d.Free()).To(Succeed())
|
||||
|
||||
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
|
||||
Expect(err).To(MatchError(ContainSubstring("model not loaded")))
|
||||
})
|
||||
|
||||
It("serializes concurrent requests through the worker goroutine", func() {
|
||||
// dllm_capi.h: one ctx = one concurrent generate. Two overlapping
|
||||
// PredictRich calls must execute the C calls one at a time.
|
||||
fake := &fakeGen{out: "x", delay: 30 * time.Millisecond}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for range 2 {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer GinkgoRecover()
|
||||
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
calls, maxInFlight := fake.snapshot()
|
||||
Expect(calls).To(HaveLen(2))
|
||||
Expect(maxInFlight).To(Equal(1), "generate calls overlapped despite the worker queue")
|
||||
})
|
||||
})
|
||||
|
||||
Describe("PredictStreamRich", func() {
|
||||
It("emits one reply per delta-producing block and leaves the channel open", func() {
|
||||
// Blocks split mid-marker and mid-payload: the parser's holdback
|
||||
// must keep marker fragments out of the emitted deltas.
|
||||
fake := &fakeGen{blocks: []string{
|
||||
"<|channel>thou", // partial channel open: no deltas yet
|
||||
"ght\nponder", // header completes, reasoning starts
|
||||
"ing<channel|>Hi ", // reasoning ends, content starts
|
||||
"there<turn|>discarded", // turn ends: trailing text dropped
|
||||
}}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
ch := make(chan *pb.Reply, 16)
|
||||
err := d.PredictStreamRich(&pb.PredictOptions{
|
||||
UseTokenizerTemplate: true,
|
||||
Messages: []*pb.Message{{Role: "user", Content: "hi"}},
|
||||
}, ch)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
replies := drainReplies(ch)
|
||||
Expect(replies).To(HaveLen(3), "block 1 completes no delta and must not produce a reply")
|
||||
|
||||
var content, reasoning string
|
||||
for _, r := range replies {
|
||||
for _, delta := range r.GetChatDeltas() {
|
||||
content += delta.GetContent()
|
||||
reasoning += delta.GetReasoningContent()
|
||||
}
|
||||
}
|
||||
Expect(reasoning).To(Equal("pondering"))
|
||||
Expect(content).To(Equal("Hi there"))
|
||||
// Message mirrors each reply's content so legacy consumers see
|
||||
// exactly the displayed tokens.
|
||||
Expect(string(replies[1].GetMessage())).To(Equal("Hi "))
|
||||
Expect(string(replies[2].GetMessage())).To(Equal("there"))
|
||||
})
|
||||
|
||||
It("streams raw blocks verbatim without use_tokenizer_template", func() {
|
||||
fake := &fakeGen{blocks: []string{"abc", "", "<|channel>def"}}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
ch := make(chan *pb.Reply, 16)
|
||||
err := d.PredictStreamRich(&pb.PredictOptions{Prompt: "raw"}, ch)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
replies := drainReplies(ch)
|
||||
Expect(replies).To(HaveLen(2), "empty blocks produce no reply")
|
||||
Expect(string(replies[0].GetMessage())).To(Equal("abc"))
|
||||
Expect(string(replies[1].GetMessage())).To(Equal("<|channel>def"))
|
||||
Expect(replies[1].GetChatDeltas()).To(HaveLen(1))
|
||||
})
|
||||
|
||||
It("flushes parser holdback after the stream ends", func() {
|
||||
// The unterminated partial marker "<chan" is held back during the
|
||||
// stream and must come out as content on the final flush.
|
||||
fake := &fakeGen{blocks: []string{"tail<chan"}}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
ch := make(chan *pb.Reply, 16)
|
||||
err := d.PredictStreamRich(&pb.PredictOptions{
|
||||
UseTokenizerTemplate: true,
|
||||
Messages: []*pb.Message{{Role: "user", Content: "hi"}},
|
||||
}, ch)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var content string
|
||||
for _, r := range drainReplies(ch) {
|
||||
content += string(r.GetMessage())
|
||||
}
|
||||
Expect(content).To(Equal("tail<chan"))
|
||||
})
|
||||
|
||||
It("reassembles a multi-byte character split across block boundaries", func() {
|
||||
// Per-block detokenize can split "€" (E2 82 AC) as E2 | 82 AC.
|
||||
// Emitting the lone E2 would make grpc-go fail the marshal of the
|
||||
// whole reply; the trailing incomplete sequence must be held back
|
||||
// and completed by the next block.
|
||||
fake := &fakeGen{blocks: []string{"caf\xe2", "\x82\xac ok"}}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
ch := make(chan *pb.Reply, 16)
|
||||
err := d.PredictStreamRich(&pb.PredictOptions{Prompt: "raw"}, ch)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var content string
|
||||
for _, r := range drainReplies(ch) { // drain asserts ValidString per reply
|
||||
content += string(r.GetMessage())
|
||||
}
|
||||
Expect(content).To(Equal("caf€ ok"))
|
||||
})
|
||||
|
||||
It("reassembles a split multi-byte character in parsed (gemma4) mode too", func() {
|
||||
fake := &fakeGen{blocks: []string{"caf\xe2", "\x82\xac<turn|>"}}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
ch := make(chan *pb.Reply, 16)
|
||||
err := d.PredictStreamRich(&pb.PredictOptions{
|
||||
UseTokenizerTemplate: true,
|
||||
Messages: []*pb.Message{{Role: "user", Content: "hi"}},
|
||||
}, ch)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var content string
|
||||
for _, r := range drainReplies(ch) {
|
||||
for _, delta := range r.GetChatDeltas() {
|
||||
content += delta.GetContent()
|
||||
}
|
||||
}
|
||||
Expect(content).To(Equal("caf€"))
|
||||
})
|
||||
|
||||
It("replaces an incomplete sequence left at stream end with U+FFFD", func() {
|
||||
// A byte-fallback token can leave a lone leading byte (0xE2) that
|
||||
// no later block completes: the final flush must substitute it,
|
||||
// never emit it raw and never drop into a marshal error.
|
||||
fake := &fakeGen{blocks: []string{"ok\xe2"}}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
ch := make(chan *pb.Reply, 16)
|
||||
err := d.PredictStreamRich(&pb.PredictOptions{Prompt: "raw"}, ch)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var content string
|
||||
for _, r := range drainReplies(ch) {
|
||||
content += string(r.GetMessage())
|
||||
}
|
||||
Expect(content).To(Equal("ok<6F>"))
|
||||
})
|
||||
|
||||
It("surfaces generator errors without sending replies", func() {
|
||||
fake := &fakeGen{err: errors.New("stream boom")}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
ch := make(chan *pb.Reply, 16)
|
||||
err := d.PredictStreamRich(&pb.PredictOptions{Prompt: "p"}, ch)
|
||||
Expect(err).To(MatchError("stream boom"))
|
||||
Expect(drainReplies(ch)).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("errors before generating when no model is loaded", func() {
|
||||
d := &Dllm{} // no Load, no worker: must fail fast, not hang
|
||||
ch := make(chan *pb.Reply, 1)
|
||||
err := d.PredictStreamRich(&pb.PredictOptions{Prompt: "p"}, ch)
|
||||
Expect(err).To(MatchError(ContainSubstring("model not loaded")))
|
||||
Expect(drainReplies(ch)).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("legacy Predict/PredictStream adapters", func() {
|
||||
It("Predict returns the aggregated content string", func() {
|
||||
fake := &fakeGen{out: "plain text"}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
out, err := d.Predict(&pb.PredictOptions{Prompt: "p"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(out).To(Equal("plain text"))
|
||||
})
|
||||
|
||||
It("PredictStream forwards content strings and closes the channel (legacy ownership)", func() {
|
||||
fake := &fakeGen{blocks: []string{"a", "b"}}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
ch := make(chan string, 16)
|
||||
Expect(d.PredictStream(&pb.PredictOptions{Prompt: "p"}, ch)).To(Succeed())
|
||||
|
||||
var got []string
|
||||
for s := range ch { // terminates only if the impl closed ch
|
||||
got = append(got, s)
|
||||
}
|
||||
Expect(got).To(Equal([]string{"a", "b"}))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("TokenizeString", func() {
|
||||
It("decodes the C-side JSON id array", func() {
|
||||
fake := &fakeGen{out: "[2,18]"}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
resp, err := d.TokenizeString(&pb.PredictOptions{Prompt: "hello"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.Length).To(Equal(int32(2)))
|
||||
Expect(resp.Tokens).To(Equal([]int32{2, 18}))
|
||||
|
||||
calls, _ := fake.snapshot()
|
||||
Expect(calls[0].prompt).To(Equal("hello"))
|
||||
})
|
||||
|
||||
It("fails loud on a malformed id array", func() {
|
||||
fake := &fakeGen{out: "not json"}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
_, err := d.TokenizeString(&pb.PredictOptions{Prompt: "hello"})
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("errors before tokenizing when no model is loaded", func() {
|
||||
d := &Dllm{} // no Load, no worker: must fail fast, not hang
|
||||
_, err := d.TokenizeString(&pb.PredictOptions{Prompt: "hello"})
|
||||
Expect(err).To(MatchError(ContainSubstring("model not loaded")))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("parseModelGenOpts", func() {
|
||||
It("parses eb_*/blocks/kv_cache entries and types values by first successful parse", func() {
|
||||
got := parseModelGenOpts([]string{
|
||||
"eb_max_steps:16",
|
||||
"eb_t_min:0.25",
|
||||
"kv_cache:auto",
|
||||
"blocks:4",
|
||||
"unrelated_key:1", // other layers' options: skipped
|
||||
"malformed", // no colon: skipped
|
||||
})
|
||||
Expect(got).To(Equal(map[string]any{
|
||||
"eb_max_steps": int64(16),
|
||||
"eb_t_min": 0.25,
|
||||
"kv_cache": "auto",
|
||||
"blocks": int64(4),
|
||||
}))
|
||||
})
|
||||
|
||||
It("round-trips through buildOptsJSON (only flat scalars are produced)", func() {
|
||||
got := parseModelGenOpts([]string{"eb_entropy_bound:0.8", "kv_cache:off"})
|
||||
out, err := buildOptsJSON(got)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(out).To(MatchJSON(`{"eb_entropy_bound":0.8,"kv_cache":"off"}`))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Gated backend round-trip against the real libdllm.so + tiny GGUF fixture.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
var _ = Describe("Dllm backend (real tiny model)", func() {
|
||||
BeforeEach(func() {
|
||||
if os.Getenv("DLLM_TEST_LIBRARY") == "" || os.Getenv("DLLM_TEST_TINY_MODEL") == "" {
|
||||
Skip("set DLLM_TEST_LIBRARY and DLLM_TEST_TINY_MODEL to run the backend round-trip")
|
||||
}
|
||||
ensureLibLoaded()
|
||||
Expect(libLoadErr).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("round-trips Load, PredictRich, PredictStreamRich and TokenizeString", func() {
|
||||
d := &Dllm{}
|
||||
Expect(d.Load(&pb.ModelOptions{ModelFile: os.Getenv("DLLM_TEST_TINY_MODEL")})).To(Succeed())
|
||||
DeferCleanup(func() { Expect(d.Free()).To(Succeed()) })
|
||||
|
||||
// TokenizeString: tiny fixture vocab tokenizes "hello" to [2,18].
|
||||
resp, err := d.TokenizeString(&pb.PredictOptions{Prompt: "hello"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.Tokens).To(Equal([]int32{2, 18}))
|
||||
Expect(resp.Length).To(Equal(int32(2)))
|
||||
|
||||
req := &pb.PredictOptions{
|
||||
UseTokenizerTemplate: true,
|
||||
Messages: []*pb.Message{{Role: "user", Content: "hello"}},
|
||||
Tokens: 16,
|
||||
Seed: 7,
|
||||
}
|
||||
|
||||
// Non-streaming: the tiny random-weight model emits arbitrary vocab
|
||||
// words; with no gemma4 markers in them everything is content.
|
||||
reply, err := d.PredictRich(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(string(reply.GetMessage())).ToNot(BeEmpty())
|
||||
Expect(reply.GetChatDeltas()).ToNot(BeEmpty())
|
||||
|
||||
// Streaming: at least one reply, and the channel-ownership rule is
|
||||
// honored (drainReplies fails the spec on a closed channel).
|
||||
ch := make(chan *pb.Reply, 64)
|
||||
Expect(d.PredictStreamRich(req, ch)).To(Succeed())
|
||||
replies := drainReplies(ch)
|
||||
Expect(replies).ToNot(BeEmpty())
|
||||
var streamed string
|
||||
for _, r := range replies {
|
||||
streamed += string(r.GetMessage())
|
||||
}
|
||||
Expect(streamed).ToNot(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
0
backend/go/dllm/main.go
Normal file → Executable file
0
backend/go/dllm/main.go
Normal file → Executable file
Reference in New Issue
Block a user