diff --git a/backend/go/dllm/gemma4_parser.go b/backend/go/dllm/gemma4_parser.go new file mode 100755 index 000000000..cf381cea1 --- /dev/null +++ b/backend/go/dllm/gemma4_parser.go @@ -0,0 +1,562 @@ +// Gemma4 (DiffusionGemma) streaming output parser: raw model text, fed in +// arbitrary fragments (per committed diffusion block; a fragment can split +// anywhere, including mid-marker and mid-payload), is turned into +// pb.ChatDelta events (content / reasoning_content / tool_calls). +// +// Normative sources: +// - The chat template embedded at the top of gemma4_renderer.go ("tpl L" +// citations below refer to its numbered lines). The OUTPUT format mirrors +// what the template renders for assistant history: thought channels +// (<|channel>thought\n ... , tpl L240), tool calls +// (<|tool_call>call:name{...}, tpl L246-L257) and turn ends +// (, tpl L351). +// - vLLM PR #45163: vllm/tool_parsers/gemma4_tool_parser.py (marker +// handling, the call:name{...} argument grammar and its decoder, ported +// below) and vllm/reasoning/gemma4_reasoning_parser.py (channel markers, +// the "thought\n" role label, is_reasoning_end semantics). +// +// Initial state (derived from the generation prompt, tpl L356-L362, see +// RenderGemma4): +// - enable_thinking=false: the prompt ends with "<|turn>model\n" + +// "<|channel>thought\n" - an EMPTY thought channel, pre-opened +// AND pre-closed by the template. The model's output therefore starts in +// plain content. Use NewGemma4Parser(false). +// - enable_thinking=true: the prompt ends at "<|turn>model\n" and the model +// opens and closes its own thought channel in the OUTPUT +// ("<|channel>thought\n...reasoning...final answer", per the +// vLLM Gemma4ReasoningParser docstring). The parser still starts in +// content state - the channel markers in the output drive the switch. +// Use NewGemma4Parser(false) here too. +// - NewGemma4Parser(true) is for callers that pre-open the thought channel +// in the prompt themselves (appending "<|channel>thought\n" after the +// generation prompt to force thinking): the output then begins mid-thought +// and everything is reasoning until the first . +// +// State diagram (markers are consumed, never emitted): +// +// <|channel> \n (channel name dropped: the +// [content] --------------> [chan-header] ----> [thought] "thought\n" role +// ^ | (stray close: swallowed, label, stripped +// +-+ strip_thinking semantics, tpl L148-L158) like vLLM does) +// ^ +// +----------------------------------------- [thought] +// ^ | <|tool_call> (implicit +// +-------------- [tool-call] <-------------------+ reasoning end, vLLM +// | <|tool_call> ^ is_reasoning_end) +// +-------------------+ +// [content]/[thought] --- ---> [done] (everything after is dropped) +// +// Buffering rules: +// - content/thought states hold back at most len(longest marker)-1 bytes: +// the longest tail that is still a proper prefix of a watched marker. +// Content is otherwise emitted immediately (no unbounded buffering). +// - the tool-call state buffers the whole payload until . This +// is unbounded in principle but bounded in practice by the model's +// diffusion canvas, and is required because the call:name{...} payload +// only becomes decodable (and trustworthy) once complete - the same +// reason vLLM's parser accumulates before parsing. +// - Close() flushes whatever is still held: partial markers come out as +// content/reasoning (per the state that held them); an unterminated +// channel header or tool-call payload is re-emitted RAW (including its +// opening marker) as content - malformed output is never silently +// dropped (mirrors vLLM extract_tool_calls returning the raw text as +// content when its regex does not match). +// +// Streaming granularity DIVERGENCE from vLLM: vLLM re-parses the partial +// payload on every token and streams argument-JSON diffs (its `partial=True` +// decoder mode plus withholding logic exist only for that). Our fragments are +// whole committed diffusion blocks, so each completed tool call is emitted +// once, as a single ToolCallDelta carrying index + id + name + the full +// arguments JSON - exactly the shape backend/python/vllm/backend.py emits +// per call and pkg/functions.ToolCallsFromChatDeltas re-accumulates. +package main + +import ( + "encoding/json" + "regexp" + "strconv" + "strings" + + pb "github.com/mudler/LocalAI/pkg/grpc/proto" +) + +// gemma4CallRE is vLLM's tool_call_regex +// (`<\|tool_call>call:([\w\-\.]+)\{(.*?)\}`, DOTALL) anchored to +// a single already-extracted payload: name charset [\w\-.], braces mandatory. +var gemma4CallRE = regexp.MustCompile(`(?s)^call:([\w\-.]+)\{(.*)\}$`) + +type g4State int + +const ( + g4Content g4State = iota + g4ChanHeader + g4Thought + g4ToolCall + g4Done +) + +// Markers watched per emitting state. A stray outside a tool +// call is deliberately NOT watched: it passes through verbatim, consistent +// with the malformed-payload fallback re-emitting it as content. +var ( + gemma4ContentMarkers = []string{gemma4ChannelOpen, gemma4ChannelClose, gemma4ToolCallOpen, gemma4TurnEnd} + gemma4ThoughtMarkers = []string{gemma4ChannelClose, gemma4ToolCallOpen, gemma4TurnEnd} +) + +type Gemma4Parser struct { + state g4State + // held is the per-state carry-over between Feed calls: a partial marker + // (content/thought), a partial channel header (chan-header) or the + // payload accumulated so far (tool-call). + held string + toolIdx int +} + +// NewGemma4Parser returns a parser positioned per the initial-state rules in +// the header comment: startInThought=true only when the caller pre-opened a +// thought channel in the prompt. +func NewGemma4Parser(startInThought bool) *Gemma4Parser { + state := g4Content + if startInThought { + state = g4Thought + } + return &Gemma4Parser{state: state} +} + +// Feed consumes the next output fragment and returns the deltas it completes. +func (p *Gemma4Parser) Feed(text string) []*pb.ChatDelta { + if text == "" || p.state == g4Done { + return nil + } + pending := p.held + text + p.held = "" + var em g4Emitter + for pending != "" { + switch p.state { + case g4Content, g4Thought: + markers := gemma4ContentMarkers + if p.state == g4Thought { + markers = gemma4ThoughtMarkers + } + idx, marker := findEarliestGemma4Marker(pending, markers) + if idx == -1 { + hold := gemma4MarkerHoldback(pending, markers) + p.emitText(&em, pending[:len(pending)-hold]) + p.held = pending[len(pending)-hold:] + pending = "" + continue + } + p.emitText(&em, pending[:idx]) + pending = pending[idx+len(marker):] + switch marker { + case gemma4ChannelOpen: + p.state = g4ChanHeader + case gemma4ChannelClose: + // In thought: channel ends. In content: stray close, + // swallowed (strip_thinking keeps both sides, tpl L148-L158). + p.state = g4Content + case gemma4ToolCallOpen: + p.state = g4ToolCall + case gemma4TurnEnd: + p.state = g4Done + } + case g4ChanHeader: + // The channel header is "\n"; the template only ever writes + // "thought" (tpl L240/L360) and the label is structural, so it is + // dropped, not emitted (vLLM strips the same "thought\n" prefix). + nl := strings.IndexByte(pending, '\n') + if nl == -1 { + p.held = pending + pending = "" + continue + } + pending = pending[nl+1:] + p.state = g4Thought + case g4ToolCall: + end := strings.Index(pending, gemma4ToolCallClose) + if end == -1 { + p.held = pending + pending = "" + continue + } + p.emitToolCall(&em, pending[:end]) + pending = pending[end+len(gemma4ToolCallClose):] + p.state = g4Content + case g4Done: + pending = "" + } + } + return em.deltas +} + +// Close flushes held-back partials. Incomplete structures (open channel +// header, unterminated tool payload) are re-emitted raw as content rather +// than dropped. The parser is finished afterwards. +func (p *Gemma4Parser) Close() []*pb.ChatDelta { + var em g4Emitter + switch p.state { + case g4Content: + em.content(p.held) + case g4Thought: + em.reasoning(p.held) + case g4ChanHeader: + em.content(gemma4ChannelOpen + p.held) + case g4ToolCall: + em.content(gemma4ToolCallOpen + p.held) + case g4Done: + } + p.held = "" + p.state = g4Done + return em.deltas +} + +func (p *Gemma4Parser) emitText(em *g4Emitter, s string) { + if p.state == g4Thought { + em.reasoning(s) + return + } + em.content(s) +} + +// emitToolCall decodes one complete <|tool_call>... payload. On a +// payload that does not match call:name{...} the raw text (markers included) +// is emitted as content, mirroring vLLM's extract_tool_calls fallback. +func (p *Gemma4Parser) emitToolCall(em *g4Emitter, payload string) { + m := gemma4CallRE.FindStringSubmatch(payload) + if m == nil { + em.content(gemma4ToolCallOpen + payload + gemma4ToolCallClose) + return + } + // Index-based ids: deterministic (the split-invariance property relies + // on it) and matching the call_ convention of pkg/grpc/rich_test.go; + // core only needs ids to be non-empty and unique within the response. + em.tool(p.toolIdx, "call_"+strconv.Itoa(p.toolIdx), m[1], decodeGemma4Args(m[2], 0)) + p.toolIdx++ +} + +// g4Emitter collects ChatDeltas; empty text events are dropped. +type g4Emitter struct { + deltas []*pb.ChatDelta +} + +func (e *g4Emitter) content(s string) { + if s != "" { + e.deltas = append(e.deltas, &pb.ChatDelta{Content: s}) + } +} + +func (e *g4Emitter) reasoning(s string) { + if s != "" { + e.deltas = append(e.deltas, &pb.ChatDelta{ReasoningContent: s}) + } +} + +func (e *g4Emitter) tool(index int, id, name, argsJSON string) { + e.deltas = append(e.deltas, &pb.ChatDelta{ToolCalls: []*pb.ToolCallDelta{{ + Index: int32(index), + Id: id, + Name: name, + Arguments: argsJSON, + }}}) +} + +// findEarliestGemma4Marker returns the position and value of the first +// complete marker occurrence, or (-1, ""). +func findEarliestGemma4Marker(s string, markers []string) (int, string) { + best, bestMarker := -1, "" + for _, m := range markers { + if idx := strings.Index(s, m); idx >= 0 && (best == -1 || idx < best) { + best, bestMarker = idx, m + } + } + return best, bestMarker +} + +// gemma4MarkerHoldback returns the length of the longest suffix of s that is +// a proper prefix of a watched marker - the only bytes that may still grow +// into a marker and therefore must not be emitted yet (bounded by the +// longest marker, so content is never buffered unboundedly). +func gemma4MarkerHoldback(s string, markers []string) int { + maxHold := 0 + for _, m := range markers { + if len(m)-1 > maxHold { + maxHold = len(m) - 1 + } + } + if len(s) < maxHold { + maxHold = len(s) + } + for k := maxHold; k >= 1; k-- { + tail := s[len(s)-k:] + for _, m := range markers { + if strings.HasPrefix(m, tail) { + return k + } + } + } + return 0 +} + +// --------------------------------------------------------------------------- +// call:name{...} argument decoder +// +// Port of vLLM's _parse_gemma4_args / _parse_gemma4_array / +// _parse_gemma4_value (gemma4_tool_parser.py) in non-partial mode only: this +// parser decodes exclusively COMPLETE payloads (incomplete ones fall back to +// raw content at Close), so vLLM's partial-withholding machinery +// (trailing-dot floats, withheld bare tails) is intentionally not ported. +// +// Grammar (inverse of the renderer's formatGemma4Argument, tpl L118-L147): +// +// args := pair (',' pair)* +// pair := key ':' value (keys unquoted, up to the first ':') +// value := string | object | array | bare +// string := '<|"|>' ... '<|"|>' (no escapes; unterminated -> rest) +// object := '{' args '}' (delimited strings skipped when +// array := '[' value,* ']' counting braces/brackets) +// bare := true | false | null/none/nil | number | bare-string +// +// Output is a JSON object/array string with keys in payload order (Python +// dict insertion order), built with HTML escaping off so payload text +// survives byte-for-byte. +// --------------------------------------------------------------------------- + +func isGemma4Space(c byte) bool { return c == ' ' || c == '\n' || c == '\t' } + +// gemma4MaxArgsDepth caps the mutual recursion between decodeGemma4Args and +// decodeGemma4Array. Defense against model-generated deep nesting: a Go stack +// overflow is a fatal process kill, not a recoverable error, so past the cap +// a nested body gracefully degrades to a JSON string of its raw text. +const gemma4MaxArgsDepth = 100 + +// decodeGemma4Args decodes one args body (the text between the outer braces +// of call:name{...}) into a JSON object string. depth is the current nesting +// level (0 at the payload root); see gemma4MaxArgsDepth. +func decodeGemma4Args(s string, depth int) string { + if depth > gemma4MaxArgsDepth { + return gemma4JSONString(s) + } + var b strings.Builder + b.WriteString("{") + first := true + pair := func(key, val string) { + if !first { + b.WriteString(",") + } + first = false + b.WriteString(gemma4JSONString(key)) + b.WriteString(":") + b.WriteString(val) + } + i, n := 0, len(s) + for i < n { + for i < n && (isGemma4Space(s[i]) || s[i] == ',') { + i++ + } + if i >= n { + break + } + keyStart := i + for i < n && s[i] != ':' { + i++ + } + if i >= n { + break // no ':' -> trailing junk, dropped (vLLM does the same) + } + key := strings.TrimSpace(s[keyStart:i]) + i++ // skip ':' + for i < n && isGemma4Space(s[i]) { + i++ + } + if i >= n { + pair(key, `""`) // "key:" with nothing after -> empty string + break + } + switch { + case strings.HasPrefix(s[i:], gemma4StringDelim): + i += len(gemma4StringDelim) + if end := strings.Index(s[i:], gemma4StringDelim); end == -1 { + pair(key, gemma4JSONString(s[i:])) // unterminated -> take rest + i = n + } else { + pair(key, gemma4JSONString(s[i:i+end])) + i += end + len(gemma4StringDelim) + } + case s[i] == '{': + inner, next := scanGemma4Balanced(s, i, '{', '}') + pair(key, decodeGemma4Args(inner, depth+1)) + i = next + case s[i] == '[': + inner, next := scanGemma4Balanced(s, i, '[', ']') + pair(key, decodeGemma4Array(inner, depth+1)) + i = next + default: + valStart := i + for i < n && s[i] != ',' && s[i] != '}' && s[i] != ']' { + i++ + } + if i == valStart { + // No progress (value starts on a stray '}'/']'): abort on + // malformed input rather than loop, like vLLM. + i = n + continue + } + pair(key, decodeGemma4Bare(s[valStart:i])) + } + } + b.WriteString("}") + return b.String() +} + +// decodeGemma4Array decodes one array body (the text between '[' and ']') +// into a JSON array string. depth is the current nesting level; see +// gemma4MaxArgsDepth. +func decodeGemma4Array(s string, depth int) string { + if depth > gemma4MaxArgsDepth { + return gemma4JSONString(s) + } + var b strings.Builder + b.WriteString("[") + first := true + item := func(val string) { + if !first { + b.WriteString(",") + } + first = false + b.WriteString(val) + } + i, n := 0, len(s) + for i < n { + for i < n && (isGemma4Space(s[i]) || s[i] == ',') { + i++ + } + if i >= n { + break + } + switch { + case strings.HasPrefix(s[i:], gemma4StringDelim): + i += len(gemma4StringDelim) + if end := strings.Index(s[i:], gemma4StringDelim); end == -1 { + item(gemma4JSONString(s[i:])) + i = n + } else { + item(gemma4JSONString(s[i : i+end])) + i += end + len(gemma4StringDelim) + } + case s[i] == '{': + inner, next := scanGemma4Balanced(s, i, '{', '}') + item(decodeGemma4Args(inner, depth+1)) + i = next + case s[i] == '[': + inner, next := scanGemma4Balanced(s, i, '[', ']') + item(decodeGemma4Array(inner, depth+1)) + i = next + default: + valStart := i + for i < n && s[i] != ',' && s[i] != ']' { + i++ + } + if i == valStart { + i = n // no progress: abort on malformed input, like vLLM + continue + } + item(decodeGemma4Bare(s[valStart:i])) + } + } + b.WriteString("]") + return b.String() +} + +// scanGemma4Balanced scans a brace/bracket-balanced span starting at the +// opener s[start], skipping over <|"|>-delimited strings so structural +// characters inside them do not count (vLLM's depth scan). Returns the inner +// text and the index just past the closer; an unterminated span yields the +// rest of the string (the inner decoder still extracts what is there - this +// path is only reachable from genuinely malformed complete payloads). +func scanGemma4Balanced(s string, start int, open, close byte) (string, int) { + depth := 1 + i := start + 1 + innerStart := i + n := len(s) + for i < n && depth > 0 { + if strings.HasPrefix(s[i:], gemma4StringDelim) { + i += len(gemma4StringDelim) + if nd := strings.Index(s[i:], gemma4StringDelim); nd == -1 { + i = n + } else { + i += nd + len(gemma4StringDelim) + } + continue + } + switch s[i] { + case open: + depth++ + case close: + depth-- + } + i++ + } + if depth > 0 { + return s[innerStart:], n + } + return s[innerStart : i-1], i +} + +// decodeGemma4Bare maps an undelimited value to its JSON form: booleans, +// null aliases (null/none/nil, case-insensitive - the renderer writes +// Python None as "None", tpl L144-L145 via format_argument's else branch), +// numbers (vLLM's rule: a '.' tries float, otherwise int; anything that +// fails parses as a bare string). +func decodeGemma4Bare(raw string) string { + v := strings.TrimSpace(raw) + if v == "" { + return `""` + } + if v == "true" || v == "false" { + return v + } + switch strings.ToLower(v) { + case "null", "none", "nil": + return "null" + } + if strings.Contains(v, ".") { + if f, err := strconv.ParseFloat(v, 64); err == nil { + return formatGemma4Float(f) + } + } else if iv, err := strconv.ParseInt(v, 10, 64); err == nil { + return strconv.FormatInt(iv, 10) + } + return gemma4JSONString(v) +} + +// formatGemma4Float renders like Python's json.dumps(float): integral floats +// keep a ".0" suffix ("108." decodes to 108.0, not 108), so the arguments +// JSON matches what vLLM would have produced for the same payload. +func formatGemma4Float(f float64) string { + s := strconv.FormatFloat(f, 'g', -1, 64) + if !strings.ContainsAny(s, ".eE") { + s += ".0" + } + return s +} + +// gemma4JSONString encodes a JSON string WITHOUT HTML escaping (json.Marshal +// would escape the angle brackets in "
" to \u003c / \u003e sequences; +// payload text should survive +// byte-for-byte, like Python's json.dumps(ensure_ascii=False)). +func gemma4JSONString(s string) string { + var sb strings.Builder + enc := json.NewEncoder(&sb) + enc.SetEscapeHTML(false) + if err := enc.Encode(s); err != nil { + // Unreachable for plain strings; fall back to default escaping + // rather than emitting invalid JSON. + b, mErr := json.Marshal(s) + if mErr != nil { + return `""` + } + return string(b) + } + // Encode appends a trailing newline. + return strings.TrimSuffix(sb.String(), "\n") +} diff --git a/backend/go/dllm/gemma4_parser_test.go b/backend/go/dllm/gemma4_parser_test.go new file mode 100755 index 000000000..f3c243c02 --- /dev/null +++ b/backend/go/dllm/gemma4_parser_test.go @@ -0,0 +1,592 @@ +package main + +// Parser specs for Gemma4Parser (model output text -> pb.ChatDelta events). +// +// Fixture provenance: +// - Entries marked "vLLM: " are direct ports of the named test from +// vLLM PR #45163, tests/tool_parsers/test_gemma4_tool_parser.py (the +// authoritative test-suite for the gemma4 tool-call wire format). The +// streaming tests' chunk lists are reused verbatim as Feed fragments. +// - Decoder entries port the TestParseGemma4Args / TestParseGemma4Array +// classes from the same file (non-partial mode only; this parser never +// decodes partial payloads, see the divergence note in gemma4_parser.go). +// - Channel/turn-marker expectations come from the chat template embedded +// in gemma4_renderer.go (tpl L356-L362 generation prompt, L148-L158 +// strip_thinking) and vLLM's Gemma4ReasoningParser +// (vllm/reasoning/gemma4_reasoning_parser.py). + +import ( + "encoding/json" + "fmt" + "strings" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + pb "github.com/mudler/LocalAI/pkg/grpc/proto" +) + +// flatGemma4Tool is one accumulated tool call, mirroring how LocalAI core +// folds ToolCallDelta streams (pkg/functions/chat_deltas.go +// ToolCallsFromChatDeltas: name/id latch on first non-empty, arguments +// concatenate per index). Tests flatten through the same rules so they +// assert exactly what core will reconstruct. +type flatGemma4Tool struct { + id string + name string + args string +} + +func flattenGemma4Deltas(deltas []*pb.ChatDelta) (string, string, []flatGemma4Tool) { + var content, reasoning strings.Builder + byIndex := map[int32]*flatGemma4Tool{} + maxIdx := int32(-1) + for _, d := range deltas { + content.WriteString(d.GetContent()) + reasoning.WriteString(d.GetReasoningContent()) + for _, tc := range d.GetToolCalls() { + acc, ok := byIndex[tc.GetIndex()] + if !ok { + acc = &flatGemma4Tool{} + byIndex[tc.GetIndex()] = acc + } + if tc.GetName() != "" { + acc.name = tc.GetName() + } + if tc.GetId() != "" { + acc.id = tc.GetId() + } + acc.args += tc.GetArguments() + if tc.GetIndex() > maxIdx { + maxIdx = tc.GetIndex() + } + } + } + var tools []flatGemma4Tool + for i := int32(0); i <= maxIdx; i++ { + if acc, ok := byIndex[i]; ok { + tools = append(tools, *acc) + } + } + return content.String(), reasoning.String(), tools +} + +type wantGemma4Tool struct { + name string + argsJSON string // compared with MatchJSON (key order irrelevant) +} + +type parseGemma4Case struct { + startInThought bool + fragments []string + wantContent string + wantReasoning string + wantTools []wantGemma4Tool +} + +func parseGemma4Fragments(startInThought bool, fragments []string) []*pb.ChatDelta { + p := NewGemma4Parser(startInThought) + var all []*pb.ChatDelta + for _, f := range fragments { + all = append(all, p.Feed(f)...) + } + return append(all, p.Close()...) +} + +var _ = Describe("Gemma4Parser", func() { + DescribeTable("parses streamed gemma4 output into ChatDeltas", + func(c parseGemma4Case) { + content, reasoning, tools := flattenGemma4Deltas(parseGemma4Fragments(c.startInThought, c.fragments)) + Expect(content).To(Equal(c.wantContent)) + Expect(reasoning).To(Equal(c.wantReasoning)) + Expect(tools).To(HaveLen(len(c.wantTools))) + seenIDs := map[string]bool{} + for i, want := range c.wantTools { + Expect(tools[i].name).To(Equal(want.name), "tool %d name", i) + Expect(tools[i].args).To(MatchJSON(want.argsJSON), "tool %d arguments", i) + Expect(tools[i].id).ToNot(BeEmpty(), "tool %d id", i) + Expect(seenIDs).ToNot(HaveKey(tools[i].id), "tool %d id must be unique", i) + seenIDs[tools[i].id] = true + } + }, + + // --- (1) pure content ------------------------------------------------- + // vLLM: test_no_tool_calls + Entry("pure content, single fragment", parseGemma4Case{ + fragments: []string{"Hello, how can I help you today?"}, + wantContent: "Hello, how can I help you today?", + }), + + // --- (2) thought -> final transition ---------------------------------- + // enable_thinking render: prompt ends at <|turn>model\n and the model + // opens/closes its own thought channel in the OUTPUT (vLLM + // Gemma4ReasoningParser docstring; tpl L356-L362). The "thought\n" + // role label after <|channel> is structural and must be stripped + // (vLLM _THOUGHT_PREFIX handling). + Entry("thought channel then final content", parseGemma4Case{ + fragments: []string{"<|channel>thought\nLet me think about this.\nThe answer is 42."}, + wantReasoning: "Let me think about this.\n", + wantContent: "The answer is 42.", + }), + + // --- (3) startInThought both ways ------------------------------------- + Entry("startInThought=true routes initial text to reasoning until ", parseGemma4Case{ + startInThought: true, + fragments: []string{"I am thinking hard.Done."}, + wantReasoning: "I am thinking hard.", + wantContent: "Done.", + }), + // A stray with no open channel is swallowed, matching the + // template's strip_thinking (tpl L148-L158: the marker is dropped, + // text on both sides is kept). + Entry("startInThought=false keeps the same text as content, stray swallowed", parseGemma4Case{ + startInThought: false, + fragments: []string{"I am thinking hard.Done."}, + wantContent: "I am thinking hard.Done.", + }), + + // --- (4) one tool call, full payload type zoo -------------------------- + Entry("single tool call: strings, numbers, bools, null, nested object and array", parseGemma4Case{ + fragments: []string{`<|tool_call>call:complex_function{text:<|"|>with, comma and {braces}<|"|>,count:42,score:3.14,yes:true,no:false,nothing:null,obj:{inner:<|"|>v<|"|>,k:1},arr:[<|"|>a<|"|>,2,true]}`}, + wantTools: []wantGemma4Tool{{ + name: "complex_function", + argsJSON: `{"text":"with, comma and {braces}","count":42,"score":3.14,"yes":true,"no":false,"nothing":null,"obj":{"inner":"v","k":1},"arr":["a",2,true]}`, + }}, + }), + + // --- (5) payload split across 3 fragments ------------------------------ + Entry("tool-call payload split across three fragments", parseGemma4Case{ + fragments: []string{ + "<|tool_call>call:get_weather{loc", + `ation:<|"|>Paris, Fra`, + `nce<|"|>}`, + }, + wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Paris, France"}`}}, + }), + + // --- (6) marker split across fragments ---------------------------------- + Entry("tool-call open marker split across fragments", parseGemma4Case{ + fragments: []string{ + "<|tool_ca", + `ll>call:get_weather{location:<|"|>London<|"|>}`, + }, + wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"London"}`}}, + }), + Entry("channel open marker split across fragments", parseGemma4Case{ + fragments: []string{ + "<|chan", + "nel>thought\ndeep thoughtfinal", + }, + wantReasoning: "deep thought", + wantContent: "final", + }), + + // --- (7) trailing partial marker held, flushed by Close ----------------- + Entry("trailing partial marker is held back and flushed by Close", parseGemma4Case{ + fragments: []string{"Hello <|tool"}, + wantContent: "Hello <|tool", + }), + + // --- (8) malformed/incomplete payload -> content fallback --------------- + // vLLM: test_incomplete_tool_call (no end marker: the whole text stays + // content, never silently dropped). + Entry("incomplete tool payload at Close is emitted as raw content", parseGemma4Case{ + fragments: []string{`<|tool_call>call:get_weather{location:<|"|>London`}, + wantContent: `<|tool_call>call:get_weather{location:<|"|>London`, + }), + Entry("malformed complete payload is emitted as raw content, parsing continues", parseGemma4Case{ + fragments: []string{"<|tool_call>oops no call syntax done"}, + wantContent: "<|tool_call>oops no call syntax done", + }), + + // --- (9) ends the turn ------------------------------------------- + Entry("text after is ignored, including later fragments", parseGemma4Case{ + fragments: []string{ + "beforeafter", + `more <|tool_call>call:f{}`, + }, + wantContent: "before", + }), + Entry(" inside a thought channel ends the turn", parseGemma4Case{ + startInThought: true, + fragments: []string{"thinkingignored"}, + wantReasoning: "thinking", + }), + + // --- (10) ported vLLM non-streaming cases --------------------------------- + // vLLM: test_single_tool_call + Entry("vLLM: test_single_tool_call", parseGemma4Case{ + fragments: []string{`<|tool_call>call:get_weather{location:<|"|>London<|"|>}`}, + wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"London"}`}}, + }), + // vLLM: test_multiple_arguments + Entry("vLLM: test_multiple_arguments", parseGemma4Case{ + fragments: []string{`<|tool_call>call:get_weather{location:<|"|>San Francisco<|"|>,unit:<|"|>celsius<|"|>}`}, + wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"San Francisco","unit":"celsius"}`}}, + }), + // vLLM: test_text_before_tool_call. DIVERGENCE: vLLM's non-streaming + // extractor trims the content ("...you."); a streaming parser cannot + // retroactively trim already-emitted text, so the trailing space is + // kept (vLLM's own streaming path keeps it too, see + // test_streaming_text_before_tool_call which only checks a prefix). + Entry("vLLM: test_text_before_tool_call (streaming semantics: no trim)", parseGemma4Case{ + fragments: []string{`Let me check the weather for you. <|tool_call>call:get_weather{location:<|"|>Paris<|"|>}`}, + wantContent: "Let me check the weather for you. ", + wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Paris"}`}}, + }), + // vLLM: test_multiple_tool_calls (also covers case 11: multi-tool sequence) + Entry("vLLM: test_multiple_tool_calls", parseGemma4Case{ + fragments: []string{`<|tool_call>call:get_weather{location:<|"|>London<|"|>}<|tool_call>call:get_time{location:<|"|>London<|"|>}`}, + wantTools: []wantGemma4Tool{ + {name: "get_weather", argsJSON: `{"location":"London"}`}, + {name: "get_time", argsJSON: `{"location":"London"}`}, + }, + }), + // vLLM: test_nested_arguments + Entry("vLLM: test_nested_arguments", parseGemma4Case{ + fragments: []string{`<|tool_call>call:complex_function{nested:{inner:<|"|>value<|"|>},list:[<|"|>a<|"|>,<|"|>b<|"|>]}`}, + wantTools: []wantGemma4Tool{{name: "complex_function", argsJSON: `{"nested":{"inner":"value"},"list":["a","b"]}`}}, + }), + // vLLM: test_tool_call_with_number_and_boolean + Entry("vLLM: test_tool_call_with_number_and_boolean", parseGemma4Case{ + fragments: []string{`<|tool_call>call:set_status{is_active:true,count:42,score:3.14}`}, + wantTools: []wantGemma4Tool{{name: "set_status", argsJSON: `{"is_active":true,"count":42,"score":3.14}`}}, + }), + // vLLM: test_hyphenated_function_name + Entry("vLLM: test_hyphenated_function_name", parseGemma4Case{ + fragments: []string{`<|tool_call>call:get-weather{location:<|"|>London<|"|>}`}, + wantTools: []wantGemma4Tool{{name: "get-weather", argsJSON: `{"location":"London"}`}}, + }), + // vLLM: test_dotted_function_name + Entry("vLLM: test_dotted_function_name", parseGemma4Case{ + fragments: []string{`<|tool_call>call:weather.get{location:<|"|>London<|"|>}`}, + wantTools: []wantGemma4Tool{{name: "weather.get", argsJSON: `{"location":"London"}`}}, + }), + // vLLM: test_no_arguments + Entry("vLLM: test_no_arguments", parseGemma4Case{ + fragments: []string{"<|tool_call>call:get_status{}"}, + wantTools: []wantGemma4Tool{{name: "get_status", argsJSON: `{}`}}, + }), + + // --- ported vLLM streaming cases (chunk lists reused as fragments) -------- + // vLLM: test_basic_streaming_single_tool + Entry("vLLM: test_basic_streaming_single_tool", parseGemma4Case{ + fragments: []string{ + "<|tool_call>", + "call:get_weather{", + `location:<|"|>Paris`, + ", France", + `<|"|>}`, + "", + }, + wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Paris, France"}`}}, + }), + // vLLM: test_streaming_multi_arg + Entry("vLLM: test_streaming_multi_arg", parseGemma4Case{ + fragments: []string{ + "<|tool_call>", + "call:get_weather{", + `location:<|"|>Tokyo<|"|>,`, + `unit:<|"|>celsius<|"|>}`, + "", + }, + wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Tokyo","unit":"celsius"}`}}, + }), + // vLLM: test_streaming_text_before_tool_call + Entry("vLLM: test_streaming_text_before_tool_call", parseGemma4Case{ + fragments: []string{ + "Let me check ", + "the weather. ", + "<|tool_call>", + "call:get_weather{", + `location:<|"|>London<|"|>}`, + "", + }, + wantContent: "Let me check the weather. ", + wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"London"}`}}, + }), + // vLLM: test_streaming_numeric_args + Entry("vLLM: test_streaming_numeric_args", parseGemma4Case{ + fragments: []string{ + "<|tool_call>", + "call:set_config{", + "count:42,", + "active:true}", + "", + }, + wantTools: []wantGemma4Tool{{name: "set_config", argsJSON: `{"count":42,"active":true}`}}, + }), + // vLLM: test_streaming_boolean_split_across_chunks + Entry("vLLM: test_streaming_boolean_split_across_chunks", parseGemma4Case{ + fragments: []string{ + "<|tool_call>", + "call:search{input:{all:tru", + "e}}", + "", + }, + wantTools: []wantGemma4Tool{{name: "search", argsJSON: `{"input":{"all":true}}`}}, + }), + // vLLM: test_streaming_false_split_across_chunks + Entry("vLLM: test_streaming_false_split_across_chunks", parseGemma4Case{ + fragments: []string{ + "<|tool_call>", + "call:set{flag:fals", + "e}", + "", + }, + wantTools: []wantGemma4Tool{{name: "set", argsJSON: `{"flag":false}`}}, + }), + // vLLM: test_streaming_number_split_across_chunks + Entry("vLLM: test_streaming_number_split_across_chunks", parseGemma4Case{ + fragments: []string{ + "<|tool_call>", + "call:set{count:4", + "2}", + "", + }, + wantTools: []wantGemma4Tool{{name: "set", argsJSON: `{"count":42}`}}, + }), + // vLLM: test_streaming_empty_args + Entry("vLLM: test_streaming_empty_args", parseGemma4Case{ + fragments: []string{ + "<|tool_call>", + "call:get_status{}", + "", + }, + wantTools: []wantGemma4Tool{{name: "get_status", argsJSON: `{}`}}, + }), + // vLLM: test_streaming_split_delimiter_no_invalid_json (string + // delimiter <|"|> split across fragments must not leak fragments). + Entry("vLLM: test_streaming_split_delimiter_no_invalid_json", parseGemma4Case{ + fragments: []string{ + "<|tool_call>", + "call:todowrite{", + `content:<|"|>Buy milk<|`, + `"|>}`, + "", + }, + wantTools: []wantGemma4Tool{{name: "todowrite", argsJSON: `{"content":"Buy milk"}`}}, + }), + // vLLM: test_streaming_does_not_duplicate_plain_text_after_tool_call + Entry("vLLM: test_streaming_does_not_duplicate_plain_text_after_tool_call", parseGemma4Case{ + fragments: []string{ + "<|tool_call>", + "call:get_weather{", + `location:<|"|>Paris<|"|>}`, + "<", + "div>", + }, + wantContent: "
", + wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Paris"}`}}, + }), + // vLLM: test_streaming_html_argument_does_not_duplicate_tag_prefixes + Entry("vLLM: test_streaming_html_argument_does_not_duplicate_tag_prefixes", parseGemma4Case{ + fragments: []string{ + "<|tool_call>", + "call:write_file{", + `path:<|"|>index.html<|"|>,`, + `content:<|"|>` + "\n<", + `html lang="zh-CN">` + "\n<", + "head>\n <", + `meta charset="UTF-8">` + "\n <", + `meta name="viewport" content="width=device-width">` + "\n", + `<|"|>}`, + "", + }, + wantTools: []wantGemma4Tool{{ + name: "write_file", + argsJSON: `{"path":"index.html","content":"\n\n\n \n \n"}`, + }}, + }), + // vLLM: test_streaming_single_chunk_complete_tool_call + Entry("vLLM: test_streaming_single_chunk_complete_tool_call", parseGemma4Case{ + fragments: []string{`<|tool_call>call:name_a_color{color_hex:<|"|>00ff11<|"|>}`}, + wantTools: []wantGemma4Tool{{name: "name_a_color", argsJSON: `{"color_hex":"00ff11"}`}}, + }), + // vLLM: test_streaming_multi_chunk_batched_tool_calls (two complete + // calls in ONE fragment; both must come out with distinct indices) + Entry("vLLM: test_streaming_multi_chunk_batched_tool_calls", parseGemma4Case{ + fragments: []string{ + `<|tool_call>call:get_weather{location:<|"|>London<|"|>}` + + `<|tool_call>call:get_time{timezone:<|"|>GMT<|"|>}`, + }, + wantTools: []wantGemma4Tool{ + {name: "get_weather", argsJSON: `{"location":"London"}`}, + {name: "get_time", argsJSON: `{"timezone":"GMT"}`}, + }, + }), + // vLLM: test_streaming_trailing_bare_bool_not_duplicated + Entry("vLLM: test_streaming_trailing_bare_bool_not_duplicated", parseGemma4Case{ + fragments: []string{ + "<|tool_call>", + "call:Edit{", + `file_path:<|"|>src/env.py<|"|>,`, + `old_string:<|"|>old_val<|"|>,`, + `new_string:<|"|>new_val<|"|>,`, + "replace_all:", + "false}", + "", + }, + wantTools: []wantGemma4Tool{{ + name: "Edit", + argsJSON: `{"file_path":"src/env.py","old_string":"old_val","new_string":"new_val","replace_all":false}`, + }}, + }), + + // --- implicit reasoning end on <|tool_call> (vLLM is_reasoning_end: + // a tool_call token means reasoning is over) ----------------------------- + Entry("tool call inside an open thought channel ends the reasoning", parseGemma4Case{ + startInThought: true, + fragments: []string{`need the weather<|tool_call>call:get_weather{location:<|"|>Rome<|"|>}`}, + wantReasoning: "need the weather", + wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Rome"}`}}, + }), + + // --- (12) empty fragments are no-ops -------------------------------------- + Entry("empty fragments are no-ops", parseGemma4Case{ + fragments: []string{"", "Hello", "", "", " world", ""}, + wantContent: "Hello world", + }), + ) + + It("returns no deltas for an empty fragment and after Close", func() { + p := NewGemma4Parser(false) + Expect(p.Feed("")).To(BeEmpty()) + Expect(p.Feed("hi")).ToNot(BeEmpty()) + Expect(p.Close()).To(BeEmpty()) // nothing held back + // The parser is finished after Close: further input is dropped. + Expect(p.Feed("more")).To(BeEmpty()) + Expect(p.Close()).To(BeEmpty()) + }) + + It("generates index-based tool call ids (call_)", func() { + // Mirrors the index-based id convention of pkg/grpc/rich_test.go and + // keeps ids deterministic for the split-invariance property below. + deltas := parseGemma4Fragments(false, []string{ + `<|tool_call>call:a{}<|tool_call>call:b{}`, + }) + _, _, tools := flattenGemma4Deltas(deltas) + Expect(tools).To(HaveLen(2)) + Expect(tools[0].id).To(Equal("call_0")) + Expect(tools[1].id).To(Equal("call_1")) + }) + + // Property: for a fixed full output, EVERY 2-split position must yield + // exactly the same flattened result as the unsplit parse. This kills + // fragment-boundary bugs (mid-marker, mid-delimiter, mid-payload splits). + DescribeTable("2-split fragment invariance", + func(startInThought bool, full string) { + refContent, refReasoning, refTools := flattenGemma4Deltas( + parseGemma4Fragments(startInThought, []string{full})) + for i := 0; i <= len(full); i++ { + content, reasoning, tools := flattenGemma4Deltas( + parseGemma4Fragments(startInThought, []string{full[:i], full[i:]})) + Expect(content).To(Equal(refContent), fmt.Sprintf("content diverged at split %d", i)) + Expect(reasoning).To(Equal(refReasoning), fmt.Sprintf("reasoning diverged at split %d", i)) + Expect(tools).To(Equal(refTools), fmt.Sprintf("tool calls diverged at split %d", i)) + } + }, + Entry("thought + content + two tool calls + turn end", false, + "<|channel>thought\nPondering the request...\nSure - calling tools now. "+ + `<|tool_call>call:get_weather{location:<|"|>Paris, France<|"|>,unit:<|"|>celsius<|"|>,days:3,detailed:true}`+ + `<|tool_call>call:get_time{timezone:<|"|>Europe/Lisbon<|"|>,nested:{flag:false,vals:[1,2.5,<|"|>x<|"|>]}}`+ + "Done.ignored tail"), + Entry("startInThought + tool call + trailing partial marker", true, + `Deep thoughtfinal answer <|tool_call>call:noop{} trailing <|tool`), + Entry("malformed payload fallback", false, + `pre <|tool_call>not a call post`), + ) +}) + +// Decoder-level ports of vLLM's TestParseGemma4Args / TestParseGemma4Array +// (non-partial mode; the partial-withholding tests do not apply because this +// parser only ever decodes COMPLETE payloads, see gemma4_parser.go). +var _ = Describe("decodeGemma4Args", func() { + DescribeTable("decodes the gemma4 call syntax into JSON arguments", + func(in, wantJSON string) { + Expect(decodeGemma4Args(in, 0)).To(MatchJSON(wantJSON)) + }, + // vLLM: test_empty_string / test_whitespace_only + Entry("empty string", "", `{}`), + Entry("whitespace only", " ", `{}`), + // vLLM: test_single_string_value + Entry("single string value", `location:<|"|>Paris<|"|>`, `{"location":"Paris"}`), + // vLLM: test_string_value_with_comma + Entry("string value with comma", `location:<|"|>Paris, France<|"|>`, `{"location":"Paris, France"}`), + // vLLM: test_multiple_string_values + Entry("multiple string values", `location:<|"|>San Francisco<|"|>,unit:<|"|>celsius<|"|>`, `{"location":"San Francisco","unit":"celsius"}`), + // vLLM: test_integer_value / test_float_value + Entry("integer value", "count:42", `{"count":42}`), + Entry("float value", "score:3.14", `{"score":3.14}`), + // vLLM: test_boolean_true / test_boolean_false + Entry("boolean true", "flag:true", `{"flag":true}`), + Entry("boolean false", "flag:false", `{"flag":false}`), + // vLLM: test_null_value (bare null must become JSON null, not "null") + Entry("null value", "param:null", `{"param":null}`), + // vLLM: test_mixed_types + Entry("mixed types", `name:<|"|>test<|"|>,count:42,active:true,score:3.14`, + `{"name":"test","count":42,"active":true,"score":3.14}`), + // vLLM: test_nested_object + Entry("nested object", `nested:{inner:<|"|>value<|"|>}`, `{"nested":{"inner":"value"}}`), + // vLLM: test_array_of_strings + Entry("array of strings", `items:[<|"|>a<|"|>,<|"|>b<|"|>]`, `{"items":["a","b"]}`), + // vLLM: test_unterminated_string (take everything after the delimiter) + Entry("unterminated string", `key:<|"|>unterminated`, `{"key":"unterminated"}`), + // vLLM: test_empty_value (key with no value after colon) + Entry("empty value", "key:", `{"key":""}`), + // vLLM: test_trailing_dot_float_partial_withheld, non-partial branch + // (trailing-dot floats parse normally outside streaming). + Entry("trailing dot float, complete payload", "left:108.,right:22.8", `{"left":108.0,"right":22.8}`), + ) + + It("terminates and yields valid JSON on malformed input", func() { + // vLLM: test_malformed_partial_array (the assertion there is only + // "returns a dict without hanging"; ours is "valid JSON object"). + out := decodeGemma4Args(":[t:[]", 0) + var v map[string]any + Expect(json.Unmarshal([]byte(out), &v)).To(Succeed()) + }) + + It("degrades nesting beyond the recursion cap to a string value", func() { + // 200 levels of a:{a:{...a:1...}}. Without the depth cap the mutual + // recursion would grow the stack with the model's output; a Go stack + // overflow is a fatal process kill, so levels past gemma4MaxArgsDepth + // must gracefully fall back to the raw inner text as a JSON string. + const depth = 200 + body := strings.Repeat("a:{", depth-1) + "a:1" + strings.Repeat("}", depth-1) + out := decodeGemma4Args(body, 0) + var v map[string]any + Expect(json.Unmarshal([]byte(out), &v)).To(Succeed()) + levels := 0 + var cur any = v + for { + m, ok := cur.(map[string]any) + if !ok { + break + } + Expect(m).To(HaveKey("a")) + cur = m["a"] + levels++ + } + Expect(levels).To(Equal(gemma4MaxArgsDepth + 1)) + Expect(cur).To(BeAssignableToTypeOf("")) + Expect(cur).To(ContainSubstring("a:{")) + }) +}) + +var _ = Describe("decodeGemma4Array", func() { + DescribeTable("decodes gemma4 array bodies into JSON arrays", + func(in, wantJSON string) { + Expect(decodeGemma4Array(in, 0)).To(MatchJSON(wantJSON)) + }, + // vLLM: test_string_array / test_empty_array / test_bare_values + Entry("string array", `<|"|>a<|"|>,<|"|>b<|"|>`, `["a","b"]`), + Entry("empty array", "", `[]`), + Entry("bare values", "42,true,3.14", `[42,true,3.14]`), + // vLLM: test_string_element_with_closing_bracket (a ']' inside a + // delimited string must not close the array) + Entry("string element with closing bracket", `[<|"|>a]b<|"|>,<|"|>c<|"|>],<|"|>tail<|"|>`, `[["a]b","c"],"tail"]`), + // vLLM: test_stray_closing_bracket (no-progress abort, keep prefix) + Entry("stray closing bracket", "42,]trailing", `[42]`), + ) +}) diff --git a/backend/go/dllm/gemma4_renderer.go b/backend/go/dllm/gemma4_renderer.go new file mode 100755 index 000000000..868d98e4a --- /dev/null +++ b/backend/go/dllm/gemma4_renderer.go @@ -0,0 +1,1026 @@ +// Gemma4 (DiffusionGemma) chat template - NORMATIVE REFERENCE. +// +// The block comment below is the FULL `tokenizer.chat_template` extracted +// verbatim from diffusiongemma-26B-A4B-it-BF16.gguf via gguf-py's GGUFReader +// (17466 bytes, md5 8c34cf93c7a7815b3fdb300a009c4c17). Line numbers were +// added for citation only ("tpl L" throughout this file); the template +// text itself is untouched. RenderGemma4 replicates this template +// byte-for-byte (verified against jinja2 renders and the transformers +// fixtures in tests/models/diffusion_gemma/test_modeling_diffusion_gemma.py), +// with ONE deliberate exception: the leading `{{- bos_token -}}` is NOT +// emitted - see the BOS NOTE after the template. +// +/* + 1 {%- macro format_parameters(properties, required, filter_keys=false) -%} + 2 {%- set standard_keys = ['description', 'type', 'properties', 'required', 'nullable'] -%} + 3 {%- set ns = namespace(found_first=false) -%} + 4 {%- for key, value in properties | dictsort -%} + 5 {%- set add_comma = false -%} + 6 {%- if not filter_keys or key not in standard_keys -%} + 7 {%- if ns.found_first %},{% endif -%} + 8 {%- set ns.found_first = true -%} + 9 {{ key }}:{ + 10 {%- if value['description'] -%} + 11 description:<|"|>{{ value['description'] }}<|"|> + 12 {%- set add_comma = true -%} + 13 {%- endif -%} + 14 {%- if value['type'] | upper == 'STRING' -%} + 15 {%- if value['enum'] -%} + 16 {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + 17 enum:{{ format_argument(value['enum']) }} + 18 {%- endif -%} + 19 {%- elif value['type'] | upper == 'ARRAY' -%} + 20 {%- if value['items'] is mapping and value['items'] -%} + 21 {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + 22 items:{ + 23 {%- set ns_items = namespace(found_first=false) -%} + 24 {%- for item_key, item_value in value['items'] | dictsort -%} + 25 {%- if item_value is not none -%} + 26 {%- if ns_items.found_first %},{% endif -%} + 27 {%- set ns_items.found_first = true -%} + 28 {%- if item_key == 'properties' -%} + 29 properties:{ + 30 {%- if item_value is mapping -%} + 31 {{- format_parameters(item_value, value['items']['required'] | default([])) -}} + 32 {%- endif -%} + 33 } + 34 {%- elif item_key == 'required' -%} + 35 required:[ + 36 {%- for req_item in item_value -%} + 37 <|"|>{{- req_item -}}<|"|> + 38 {%- if not loop.last %},{% endif -%} + 39 {%- endfor -%} + 40 ] + 41 {%- elif item_key == 'type' -%} + 42 {%- if item_value is string -%} + 43 type:{{ format_argument(item_value | upper) }} + 44 {%- else -%} + 45 type:{{ format_argument(item_value | map('upper') | list) }} + 46 {%- endif -%} + 47 {%- else -%} + 48 {{ item_key }}:{{ format_argument(item_value) }} + 49 {%- endif -%} + 50 {%- endif -%} + 51 {%- endfor -%} + 52 } + 53 {%- endif -%} + 54 {%- endif -%} + 55 {%- if value['nullable'] %} + 56 {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + 57 nullable:true + 58 {%- endif -%} + 59 {%- if value['type'] | upper == 'OBJECT' -%} + 60 {%- if value['properties'] is defined and value['properties'] is mapping -%} + 61 {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + 62 properties:{ + 63 {{- format_parameters(value['properties'], value['required'] | default([])) -}} + 64 } + 65 {%- elif value is mapping -%} + 66 {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + 67 properties:{ + 68 {{- format_parameters(value, value['required'] | default([]), filter_keys=true) -}} + 69 } + 70 {%- endif -%} + 71 {%- if value['required'] -%} + 72 {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + 73 required:[ + 74 {%- for item in value['required'] | default([]) -%} + 75 <|"|>{{- item -}}<|"|> + 76 {%- if not loop.last %},{% endif -%} + 77 {%- endfor -%} + 78 ] + 79 {%- endif -%} + 80 {%- endif -%} + 81 {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + 82 type:<|"|>{{ value['type'] | upper }}<|"|>} + 83 {%- endif -%} + 84 {%- endfor -%} + 85 {%- endmacro -%} + 86 {%- macro format_function_declaration(tool_data) -%} + 87 declaration:{{- tool_data['function']['name'] -}}{description:<|"|>{{- tool_data['function']['description'] -}}<|"|> + 88 {%- set params = tool_data['function']['parameters'] -%} + 89 {%- if params -%} + 90 ,parameters:{ + 91 {%- if params['properties'] -%} + 92 properties:{ {{- format_parameters(params['properties'], params['required']) -}} }, + 93 {%- endif -%} + 94 {%- if params['required'] -%} + 95 required:[ + 96 {%- for item in params['required'] -%} + 97 <|"|>{{- item -}}<|"|> + 98 {{- ',' if not loop.last -}} + 99 {%- endfor -%} + 100 ], + 101 {%- endif -%} + 102 {%- if params['type'] -%} + 103 type:<|"|>{{- params['type'] | upper -}}<|"|>} + 104 {%- endif -%} + 105 {%- endif -%} + 106 {%- if 'response' in tool_data['function'] -%} + 107 {%- set response_declaration = tool_data['function']['response'] -%} + 108 ,response:{ + 109 {%- if response_declaration['description'] -%} + 110 description:<|"|>{{- response_declaration['description'] -}}<|"|>, + 111 {%- endif -%} + 112 {%- if response_declaration['type'] | upper == 'OBJECT' -%} + 113 type:<|"|>{{- response_declaration['type'] | upper -}}<|"|>} + 114 {%- endif -%} + 115 {%- endif -%} + 116 } + 117 {%- endmacro -%} + 118 {%- macro format_argument(argument, escape_keys=True) -%} + 119 {%- if argument is string -%} + 120 {{- '<|"|>' + argument + '<|"|>' -}} + 121 {%- elif argument is boolean -%} + 122 {{- 'true' if argument else 'false' -}} + 123 {%- elif argument is mapping -%} + 124 {{- '{' -}} + 125 {%- set ns = namespace(found_first=false) -%} + 126 {%- for key, value in argument | dictsort -%} + 127 {%- if ns.found_first %},{% endif -%} + 128 {%- set ns.found_first = true -%} + 129 {%- if escape_keys -%} + 130 {{- '<|"|>' + key + '<|"|>' -}} + 131 {%- else -%} + 132 {{- key -}} + 133 {%- endif -%} + 134 :{{- format_argument(value, escape_keys=escape_keys) -}} + 135 {%- endfor -%} + 136 {{- '}' -}} + 137 {%- elif argument is sequence -%} + 138 {{- '[' -}} + 139 {%- for item in argument -%} + 140 {{- format_argument(item, escape_keys=escape_keys) -}} + 141 {%- if not loop.last %},{% endif -%} + 142 {%- endfor -%} + 143 {{- ']' -}} + 144 {%- else -%} + 145 {{- argument -}} + 146 {%- endif -%} + 147 {%- endmacro -%} + 148 {%- macro strip_thinking(text) -%} + 149 {%- set ns = namespace(result='') -%} + 150 {%- for part in text.split('') -%} + 151 {%- if '<|channel>' in part -%} + 152 {%- set ns.result = ns.result + part.split('<|channel>')[0] -%} + 153 {%- else -%} + 154 {%- set ns.result = ns.result + part -%} + 155 {%- endif -%} + 156 {%- endfor -%} + 157 {{- ns.result | trim -}} + 158 {%- endmacro -%} + 159 + 160 {%- macro format_tool_response_block(tool_name, response) -%} + 161 {{- '<|tool_response>' -}} + 162 {%- if response is mapping -%} + 163 {{- 'response:' + tool_name + '{' -}} + 164 {%- for key, value in response | dictsort -%} + 165 {{- key -}}:{{- format_argument(value, escape_keys=False) -}} + 166 {%- if not loop.last %},{% endif -%} + 167 {%- endfor -%} + 168 {{- '}' -}} + 169 {%- else -%} + 170 {{- 'response:' + tool_name + '{value:' + format_argument(response, escape_keys=False) + '}' -}} + 171 {%- endif -%} + 172 {{- '' -}} + 173 {%- endmacro -%} + 174 + 175 {%- set ns = namespace(prev_message_type=None) -%} + 176 {%- set loop_messages = messages -%} + 177 {{- bos_token -}} + 178 {#- Handle System/Tool Definitions Block -#} + 179 {%- if (enable_thinking is defined and enable_thinking) or tools or messages[0]['role'] in ['system', 'developer'] -%} + 180 {{- '<|turn>system\n' -}} + 181 {#- Inject Thinking token at the very top of the FIRST system turn -#} + 182 {%- if enable_thinking is defined and enable_thinking -%} + 183 {{- '<|think|>\n' -}} + 184 {%- set ns.prev_message_type = 'think' -%} + 185 {%- endif -%} + 186 {%- if messages[0]['role'] in ['system', 'developer'] -%} + 187 {%- if messages[0]['content'] is string -%} + 188 {{- messages[0]['content'] | trim -}} + 189 {%- elif messages[0]['content'] is sequence -%} + 190 {%- for item in messages[0]['content'] -%} + 191 {{- item['text'] | trim + ' '-}} + 192 {%- endfor -%} + 193 {%- endif -%} + 194 {%- set loop_messages = messages[1:] -%} + 195 {%- endif -%} + 196 {%- if tools -%} + 197 {%- for tool in tools %} + 198 {{- '<|tool>' -}} + 199 {{- format_function_declaration(tool) | trim -}} + 200 {{- '' -}} + 201 {%- endfor %} + 202 {%- set ns.prev_message_type = 'tool' -%} + 203 {%- endif -%} + 204 {{- '\n' -}} + 205 {%- endif %} + 206 + 207 {#- Pre-scan: find last user message index for reasoning guard -#} + 208 {%- set ns_turn = namespace(last_user_idx=-1) -%} + 209 {%- for i in range(loop_messages | length) -%} + 210 {%- if loop_messages[i]['role'] == 'user' -%} + 211 {%- set ns_turn.last_user_idx = i -%} + 212 {%- endif -%} + 213 {%- endfor -%} + 214 + 215 {#- Loop through messages -#} + 216 {%- for message in loop_messages -%} + 217 {%- if message['role'] != 'tool' -%} + 218 {%- set ns.prev_message_type = None -%} + 219 {%- set role = 'model' if message['role'] == 'assistant' else message['role'] -%} + 220 {#- Detect continuation: suppress duplicate <|turn>model when previous non-tool message was also assistant -#} + 221 {%- set prev_nt = namespace(role=None, found=false) -%} + 222 {%- if loop.index0 > 0 -%} + 223 {%- for j in range(loop.index0 - 1, -1, -1) -%} + 224 {%- if not prev_nt.found -%} + 225 {%- if loop_messages[j]['role'] != 'tool' -%} + 226 {%- set prev_nt.role = loop_messages[j]['role'] -%} + 227 {%- set prev_nt.found = true -%} + 228 {%- endif -%} + 229 {%- endif -%} + 230 {%- endfor -%} + 231 {%- endif -%} + 232 {%- set continue_same_model_turn = (role == 'model' and prev_nt.role == 'assistant') -%} + 233 {%- if not continue_same_model_turn -%} + 234 {{- '<|turn>' + role + '\n' }} + 235 {%- endif -%} + 236 + 237 {#- Render reasoning/reasoning_content as thinking channel -#} + 238 {%- set thinking_text = message.get('reasoning') or message.get('reasoning_content') -%} + 239 {%- if thinking_text and loop.index0 > ns_turn.last_user_idx and message.get('tool_calls') -%} + 240 {{- '<|channel>thought\n' + thinking_text + '\n' -}} + 241 {%- endif -%} + 242 + 243 {%- if message['tool_calls'] -%} + 244 {%- for tool_call in message['tool_calls'] -%} + 245 {%- set function = tool_call['function'] -%} + 246 {{- '<|tool_call>call:' + function['name'] + '{' -}} + 247 {%- if function['arguments'] is mapping -%} + 248 {%- set ns_args = namespace(found_first=false) -%} + 249 {%- for key, value in function['arguments'] | dictsort -%} + 250 {%- if ns_args.found_first %},{% endif -%} + 251 {%- set ns_args.found_first = true -%} + 252 {{- key -}}:{{- format_argument(value, escape_keys=False) -}} + 253 {%- endfor -%} + 254 {%- elif function['arguments'] is string -%} + 255 {{- function['arguments'] -}} + 256 {%- endif -%} + 257 {{- '}' -}} + 258 {%- endfor -%} + 259 {%- set ns.prev_message_type = 'tool_call' -%} + 260 {%- endif -%} + 261 + 262 {%- set ns_tr_out = namespace(flag=false) -%} + 263 {%- if message.get('tool_responses') -%} + 264 {#- Legacy: tool_responses embedded on the assistant message (Google/Gemma native) -#} + 265 {%- for tool_response in message['tool_responses'] -%} + 266 {{- format_tool_response_block(tool_response['name'] | default('unknown'), tool_response['response']) -}} + 267 {%- set ns_tr_out.flag = true -%} + 268 {%- set ns.prev_message_type = 'tool_response' -%} + 269 {%- endfor -%} + 270 {%- elif message.get('tool_calls') -%} + 271 {#- OpenAI Chat Completions: forward-scan consecutive role:tool messages -#} + 272 {%- set ns_tool_scan = namespace(stopped=false) -%} + 273 {%- for k in range(loop.index0 + 1, loop_messages | length) -%} + 274 {%- if ns_tool_scan.stopped -%} + 275 {%- elif loop_messages[k]['role'] != 'tool' -%} + 276 {%- set ns_tool_scan.stopped = true -%} + 277 {%- else -%} + 278 {%- set follow = loop_messages[k] -%} + 279 {#- Resolve tool_call_id to function name -#} + 280 {%- set ns_tname = namespace(name=follow.get('name') | default('unknown')) -%} + 281 {%- for tc in message['tool_calls'] -%} + 282 {%- if tc.get('id') == follow.get('tool_call_id') -%} + 283 {%- set ns_tname.name = tc['function']['name'] -%} + 284 {%- endif -%} + 285 {%- endfor -%} + 286 {#- Handle content as string or content-parts array -#} + 287 {%- set tool_body = follow.get('content') -%} + 288 {%- if tool_body is string -%} + 289 {{- format_tool_response_block(ns_tname.name, tool_body) -}} + 290 {%- elif tool_body is sequence and tool_body is not string -%} + 291 {%- set ns_txt = namespace(s='') -%} + 292 {%- for part in tool_body -%} + 293 {%- if part.get('type') == 'text' -%} + 294 {%- set ns_txt.s = ns_txt.s + (part.get('text') | default('')) -%} + 295 {%- endif -%} + 296 {%- endfor -%} + 297 {{- format_tool_response_block(ns_tname.name, ns_txt.s) -}} + 298 {%- for part in tool_body -%} + 299 {%- if part.get('type') == 'image' -%} + 300 {{- '<|image|>' -}} + 301 {%- elif part.get('type') == 'audio' -%} + 302 {{- '<|audio|>' -}} + 303 {%- elif part.get('type') == 'video' -%} + 304 {{- '<|video|>' -}} + 305 {%- endif -%} + 306 {%- endfor -%} + 307 {%- else -%} + 308 {{- format_tool_response_block(ns_tname.name, tool_body) -}} + 309 {%- endif -%} + 310 {%- set ns_tr_out.flag = true -%} + 311 {%- set ns.prev_message_type = 'tool_response' -%} + 312 {%- endif -%} + 313 {%- endfor -%} + 314 {%- endif -%} + 315 + 316 {%- set captured_content -%} + 317 {%- if message['content'] is string -%} + 318 {%- if role == 'model' -%} + 319 {{- strip_thinking(message['content']) -}} + 320 {%- else -%} + 321 {{- message['content'] | trim -}} + 322 {%- endif -%} + 323 {%- elif message['content'] is sequence -%} + 324 {%- for item in message['content'] -%} + 325 {%- if item['type'] == 'text' -%} + 326 {%- if role == 'model' -%} + 327 {{- strip_thinking(item['text']) -}} + 328 {%- else -%} + 329 {{- item['text'] | trim -}} + 330 {%- endif -%} + 331 {%- elif item['type'] == 'image' -%} + 332 {{- '<|image|>' -}} + 333 {%- set ns.prev_message_type = 'image' -%} + 334 {%- elif item['type'] == 'audio' -%} + 335 {{- '<|audio|>' -}} + 336 {%- set ns.prev_message_type = 'audio' -%} + 337 {%- elif item['type'] == 'video' -%} + 338 {{- '<|video|>' -}} + 339 {%- set ns.prev_message_type = 'video' -%} + 340 {%- endif -%} + 341 {%- endfor -%} + 342 {%- endif -%} + 343 {%- endset -%} + 344 + 345 {{- captured_content -}} + 346 {%- set has_content = captured_content | trim | length > 0 -%} + 347 + 348 {%- if ns.prev_message_type == 'tool_call' and not ns_tr_out.flag -%} + 349 {{- '<|tool_response>' -}} + 350 {%- elif not (ns_tr_out.flag and not has_content) -%} + 351 {{- '\n' -}} + 352 {%- endif -%} + 353 {%- endif -%} + 354 {%- endfor -%} + 355 + 356 {%- if add_generation_prompt -%} + 357 {%- if ns.prev_message_type != 'tool_response' and ns.prev_message_type != 'tool_call' -%} + 358 {{- '<|turn>model\n' -}} + 359 {%- if not enable_thinking | default(false) -%} + 360 {{- '<|channel>thought\n' -}} + 361 {%- endif -%} + 362 {%- endif -%} + 363 {%- endif -%}*/ + +// Every rule below cites "tpl L" line numbers from the numbered template +// text above. +// +// BOS NOTE (tpl L177 `{{- bos_token -}}`): the template emits because +// HF's apply_chat_template is expected to produce the FULL token stream. Our +// renderer feeds dllm_capi_generate, whose run_generate tokenizes with +// prepend_bos = vocab.add_bos (dllm.cpp src/capi.cpp:230-231), and gemma4 +// GGUFs carry add_bos=true - the C side prepends BOS itself. A literal +// "" here would therefore double it, so RenderGemma4 NEVER emits it. + +package main + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "sort" + "strings" + + pb "github.com/mudler/LocalAI/pkg/grpc/proto" +) + +// Gemma4 marker vocabulary (special tokens referenced by the template). +const ( + gemma4StringDelim = `<|"|>` // string delimiter, tpl L119 etc. + gemma4TurnOpen = "<|turn>" // tpl L180/L234/L358 + // gemma4TurnEnd is the turn terminator as the MODEL emits it: the output + // parser (gemma4_parser.go) must trigger on the bare token, while the + // renderer appends the template's inter-turn newline (gemma4TurnClose). + gemma4TurnEnd = "" // tpl L204/L351 + gemma4TurnClose = gemma4TurnEnd + "\n" // tpl L204/L351 + gemma4ThinkToken = "<|think|>\n" // tpl L183 + gemma4ToolOpen = "<|tool>" // tpl L198 + gemma4ToolClose = "" // tpl L200 + gemma4ToolCallOpen = "<|tool_call>" // tpl L246 + gemma4ToolCallClose = "" // tpl L257 + gemma4ToolResponseOpen = "<|tool_response>" // tpl L161/L349 + gemma4ToolResponseClose = "" // tpl L172 + gemma4ChannelOpen = "<|channel>" // tpl L240/L360 + gemma4ChannelClose = "" // tpl L240/L360 + gemma4ThoughtChannel = gemma4ChannelOpen + "thought\n" +) + +// gemma4ToolCall is the wire shape LocalAI core puts into pb.Message.ToolCalls +// (core/schema/message.go ToolCall marshalled by Messages.ToProto): a JSON +// array of {"index":..,"id":..,"type":..,"function":{"name":..,"arguments":..}}. +type gemma4ToolCall struct { + ID string `json:"id"` + Function struct { + Name string `json:"name"` + // Arguments is a JSON-encoded string in the OpenAI wire format + // (schema.FunctionCall.Arguments is a string), but kept raw here so a + // template-native object also works. See renderGemma4ToolCallArgs. + Arguments json.RawMessage `json:"arguments"` + } `json:"function"` +} + +// RenderGemma4 renders an OpenAI-style message list (plus the request's tools +// JSON array) into the gemma4 prompt string, replicating the GGUF chat +// template above byte-for-byte - except for the leading (see BOS NOTE). +// +// enableThinking maps to the template's enable_thinking flag (ds4 convention: +// Metadata["enable_thinking"]); addGenerationPrompt to add_generation_prompt. +func RenderGemma4(msgs []*pb.Message, toolsJSON string, enableThinking bool, addGenerationPrompt bool) (string, error) { + // Fail loud on roles the template does not know about. The jinja would + // happily render any role as a generic turn; we reject instead so typos + // surface at the API boundary rather than as silent bad prompts. + for i, m := range msgs { + switch m.GetRole() { + case "system", "developer", "user", "assistant", "tool": + default: + return "", fmt.Errorf("dllm: gemma4 renderer: unknown role %q in message %d", m.GetRole(), i) + } + } + + tools, err := parseGemma4Tools(toolsJSON) + if err != nil { + return "", err + } + + var b strings.Builder + // ns.prev_message_type (tpl L175); "" stands for jinja None. + prev := "" + + // System/tool-definitions block (tpl L178-L205). + loopMsgs := msgs + firstIsSystem := len(msgs) > 0 && (msgs[0].GetRole() == "system" || msgs[0].GetRole() == "developer") + if enableThinking || len(tools) > 0 || firstIsSystem { + b.WriteString(gemma4TurnOpen + "system\n") // tpl L180 + if enableThinking { + // Thinking token at the very top of the first system turn, + // tpl L182-L185. NOTE: prev_message_type='think' (not used by + // the ending logic, mirrored for fidelity). + b.WriteString(gemma4ThinkToken) + prev = "think" + } + if firstIsSystem { + // First system/developer message is folded into this turn and + // consumed (loop_messages = messages[1:]), tpl L186-L195. + // pb.Message.Content is always a flattened string (core/schema/ + // message.go ToProto), so only the `is string` branch applies. + b.WriteString(strings.TrimSpace(msgs[0].GetContent())) + loopMsgs = msgs[1:] + } + if len(tools) > 0 { + // One <|tool>declaration:... block per tool, tpl L196-L203. + for _, t := range tools { + b.WriteString(gemma4ToolOpen) + b.WriteString(strings.TrimSpace(formatGemma4FunctionDeclaration(t))) + b.WriteString(gemma4ToolClose) + } + prev = "tool" + } + b.WriteString(gemma4TurnClose) // tpl L204 + } + + // Pre-scan: last user message index for the reasoning guard, tpl L207-L213. + lastUserIdx := -1 + for i, m := range loopMsgs { + if m.GetRole() == "user" { + lastUserIdx = i + } + } + + // Message loop, tpl L215-L354. role=tool messages are skipped here: they + // are rendered by the forward-scan from their assistant tool_calls turn. + // consumedTool tracks which of them a forward-scan actually rendered, so + // an orphan tool message (no preceding assistant tool_calls turn) fails + // loud below instead of vanishing from the prompt. + consumedTool := make([]bool, len(loopMsgs)) + for i, m := range loopMsgs { + if m.GetRole() == "tool" { + continue + } + prev = "" // tpl L218 + role := m.GetRole() + if role == "assistant" { + role = "model" // tpl L219 + } + + // Continuation: suppress duplicate <|turn>model when the previous + // non-tool message was also assistant, tpl L220-L235. + prevNonToolRole := "" + for j := i - 1; j >= 0; j-- { + if loopMsgs[j].GetRole() != "tool" { + prevNonToolRole = loopMsgs[j].GetRole() + break + } + } + if !(role == "model" && prevNonToolRole == "assistant") { + b.WriteString(gemma4TurnOpen + role + "\n") + } + + var toolCalls []gemma4ToolCall + if tc := m.GetToolCalls(); strings.TrimSpace(tc) != "" { + if err := json.Unmarshal([]byte(tc), &toolCalls); err != nil { + return "", fmt.Errorf("dllm: gemma4 renderer: message %d: invalid tool_calls JSON: %w", i, err) + } + } + + // reasoning_content renders as a thought channel ONLY on the + // tool-calling turn after the last user message, tpl L237-L241. + if rc := m.GetReasoningContent(); rc != "" && i > lastUserIdx && len(toolCalls) > 0 { + b.WriteString(gemma4ThoughtChannel + rc + "\n" + gemma4ChannelClose) + } + + // Tool calls: <|tool_call>call:name{args}, concatenated + // without separators, tpl L243-L260. + if len(toolCalls) > 0 { + for _, tc := range toolCalls { + b.WriteString(gemma4ToolCallOpen + "call:" + tc.Function.Name + "{") + b.WriteString(renderGemma4ToolCallArgs(tc.Function.Arguments)) + b.WriteString("}" + gemma4ToolCallClose) + } + prev = "tool_call" + } + + // Tool responses: pb has no legacy tool_responses field (tpl + // L263-L269 is unreachable through proto), so only the OpenAI + // forward-scan of consecutive role=tool messages applies, + // tpl L270-L313. + trOut := false + if len(toolCalls) > 0 { + for k := i + 1; k < len(loopMsgs); k++ { + if loopMsgs[k].GetRole() != "tool" { + break + } + follow := loopMsgs[k] + // Resolve tool_call_id to the function name; the message's + // own name (default 'unknown') is the fallback, tpl L278-L285. + name := follow.GetName() + if name == "" { + name = "unknown" + } + for _, tc := range toolCalls { + if tc.ID == follow.GetToolCallId() { + name = tc.Function.Name + } + } + // pb content is a flattened string: only the string body + // branch (tpl L287-L289) is reachable. + b.WriteString(formatGemma4ToolResponseBlock(name, follow.GetContent())) + consumedTool[k] = true + trOut = true + prev = "tool_response" + } + } + + // Captured content, tpl L316-L345. Model content gets thinking + // channels stripped (strip_thinking, tpl L148-L158); other roles are + // trimmed. pb content is a flattened string: the content-parts array + // branch (tpl L322-L342, incl. <|image|> markers) is unreachable. + var content string + if role == "model" { + content = stripGemma4Thinking(m.GetContent()) + } else { + content = strings.TrimSpace(m.GetContent()) + } + b.WriteString(content) + hasContent := strings.TrimSpace(content) != "" // tpl L346 + + // Turn ending, tpl L348-L353: a tool_calls turn with no rendered + // responses ends on an OPEN <|tool_response> (the runtime fills it); + // a turn whose only payload was tool responses stays open (no + // ); everything else closes the turn. + if prev == "tool_call" && !trOut { + b.WriteString(gemma4ToolResponseOpen) + } else if !(trOut && !hasContent) { + b.WriteString(gemma4TurnClose) + } + } + + // Fail loud on orphan role:tool messages no forward-scan consumed (e.g. a + // tool message with no preceding assistant tool_calls turn): the jinja + // would silently drop them from the prompt; we surface the bad request + // instead, same philosophy as the unknown-role check above. + for i, m := range loopMsgs { + if m.GetRole() == "tool" && !consumedTool[i] { + return "", fmt.Errorf("dllm: gemma4 renderer: orphan tool message %d: no preceding assistant tool_calls turn consumed it", i+(len(msgs)-len(loopMsgs))) + } + } + + // Generation prompt, tpl L356-L362: never reopened right after a + // tool_call/tool_response (the model continues its own open turn); the + // thought channel is pre-opened only when thinking is NOT enabled. + if addGenerationPrompt && prev != "tool_response" && prev != "tool_call" { + b.WriteString(gemma4TurnOpen + "model\n") + if !enableThinking { + b.WriteString(gemma4ThoughtChannel + gemma4ChannelClose) + } + } + return b.String(), nil +} + +// parseGemma4Tools decodes the request's OpenAI tools JSON array +// ([{"type":"function","function":{...}}]). Numbers are kept as json.Number +// so 42 / 3.5 / 1.0 render exactly as jinja renders the Python values. +// An empty/null/[] input is jinja-falsy (tpl L196 `{%- if tools -%}`). +func parseGemma4Tools(toolsJSON string) ([]map[string]any, error) { + s := strings.TrimSpace(toolsJSON) + if s == "" || s == "null" { + return nil, nil + } + v, err := decodeGemma4JSON([]byte(s)) + if err != nil { + return nil, fmt.Errorf("dllm: gemma4 renderer: invalid tools JSON: %w", err) + } + list, ok := v.([]any) + if !ok { + return nil, fmt.Errorf("dllm: gemma4 renderer: tools JSON is not an array") + } + tools := make([]map[string]any, 0, len(list)) + for i, e := range list { + m, ok := e.(map[string]any) + if !ok { + return nil, fmt.Errorf("dllm: gemma4 renderer: tools[%d] is not an object", i) + } + tools = append(tools, m) + } + return tools, nil +} + +// decodeGemma4JSON unmarshals with UseNumber so numeric literals survive +// verbatim ("1.0" stays "1.0", matching jinja's rendering of Python 1.0). +// Trailing non-whitespace after the first value is rejected: json.Decoder +// stops at the value boundary, and silently ignoring the rest would render +// a prompt from a prefix of what the caller sent. +func decodeGemma4JSON(data []byte) (any, error) { + dec := json.NewDecoder(bytes.NewReader(data)) + dec.UseNumber() + var v any + if err := dec.Decode(&v); err != nil { + return nil, err + } + if err := dec.Decode(new(any)); !errors.Is(err, io.EOF) { + return nil, fmt.Errorf("trailing data after JSON value") + } + return v, nil +} + +// renderGemma4ToolCallArgs renders the arguments between the braces of +// call:name{...}, tpl L247-L256: a mapping renders as dictsorted +// key:format_argument(value, escape_keys=False) pairs; a string renders +// verbatim; anything else renders nothing (mirroring the if/elif). +// +// DIVERGENCE NOTE: through pb the arguments arrive as a JSON-encoded string +// (OpenAI wire format; schema.FunctionCall.Arguments is a string). HF/vLLM +// parse that string into a dict before applying the template, so we do the +// same: a string that parses as a JSON object takes the mapping branch; only +// a non-object string falls back to the template's verbatim string branch. +// +// Also note: string values containing the literal <|"|> delimiter render +// unescaped (prompt-structure injection), byte-faithful to the jinja which +// has identical behavior. +func renderGemma4ToolCallArgs(raw json.RawMessage) string { + if len(bytes.TrimSpace(raw)) == 0 { + return "" + } + v, err := decodeGemma4JSON(raw) + if err != nil { + // Not JSON at all: treat like the template's string branch on the + // raw bytes (never drop caller data silently). + return string(raw) + } + if s, ok := v.(string); ok { + inner, err := decodeGemma4JSON([]byte(s)) + if err == nil { + if m, ok := inner.(map[string]any); ok { + v = m + } else { + return s // tpl L253-L254: string renders verbatim + } + } else { + return s + } + } + m, ok := v.(map[string]any) + if !ok { + return "" // tpl L247-L255: non-mapping, non-string renders nothing + } + parts := make([]string, 0, len(m)) + for _, k := range gemma4DictsortKeys(m) { + parts = append(parts, k+":"+formatGemma4Argument(m[k], false)) + } + return strings.Join(parts, ",") +} + +// formatGemma4Argument is the format_argument macro, tpl L118-L147: +// strings get <|"|> delimiters, booleans lower-case, mappings dictsorted +// {key:value} (keys delimited only when escape_keys), sequences [..], +// everything else verbatim (json.Number keeps its literal; null renders +// "None" exactly as jinja renders Python None). +func formatGemma4Argument(v any, escapeKeys bool) string { + switch a := v.(type) { + case string: + return gemma4StringDelim + a + gemma4StringDelim + case bool: + if a { + return "true" + } + return "false" + case map[string]any: + var b strings.Builder + b.WriteString("{") + for i, k := range gemma4DictsortKeys(a) { + if i > 0 { + b.WriteString(",") + } + if escapeKeys { + b.WriteString(gemma4StringDelim + k + gemma4StringDelim) + } else { + b.WriteString(k) + } + b.WriteString(":" + formatGemma4Argument(a[k], escapeKeys)) + } + b.WriteString("}") + return b.String() + case []any: + var b strings.Builder + b.WriteString("[") + for i, item := range a { + if i > 0 { + b.WriteString(",") + } + b.WriteString(formatGemma4Argument(item, escapeKeys)) + } + b.WriteString("]") + return b.String() + case json.Number: + return a.String() + case nil: + return "None" // jinja renders Python None as "None" + default: + return fmt.Sprint(a) + } +} + +// gemma4DictsortKeys mirrors jinja's dictsort default: case-insensitive sort +// by key. Distinct keys equal under lowering tie-break on the raw key for +// determinism (Go maps have no insertion order to preserve). +func gemma4DictsortKeys(m map[string]any) []string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + sort.Slice(keys, func(i, j int) bool { + li, lj := strings.ToLower(keys[i]), strings.ToLower(keys[j]) + if li != lj { + return li < lj + } + return keys[i] < keys[j] + }) + return keys +} + +// gemma4Lookup is jinja's value['key'] on a value of unknown type: missing +// keys and non-mapping receivers yield Undefined (nil here). +func gemma4Lookup(v any, key string) any { + if m, ok := v.(map[string]any); ok { + return m[key] + } + return nil +} + +// gemma4Truthy is jinja truthiness for the decoded JSON value set. +func gemma4Truthy(v any) bool { + switch a := v.(type) { + case nil: + return false + case bool: + return a + case string: + return a != "" + case json.Number: + f, err := a.Float64() + return err != nil || f != 0 + case map[string]any: + return len(a) > 0 + case []any: + return len(a) > 0 + default: + return true + } +} + +// gemma4Str renders a scalar the way `{{ value }}` would (Undefined -> ""). +func gemma4Str(v any) string { + switch a := v.(type) { + case nil: + return "" + case string: + return a + case json.Number: + return a.String() + case bool: + if a { + return "True" // Python bool repr; only reachable via odd schemas + } + return "False" + default: + return fmt.Sprint(a) + } +} + +// gemma4TypeUpper is `value['type'] | upper` (Undefined | upper -> ""). +func gemma4TypeUpper(v any) string { + return strings.ToUpper(gemma4Str(gemma4Lookup(v, "type"))) +} + +// gemma4QuoteJoin renders required-style lists: <|"|>item<|"|> joined by ',' +// (tpl L36-L41, L72-L78, L96-L101). +func gemma4QuoteJoin(list []any) string { + parts := make([]string, 0, len(list)) + for _, item := range list { + parts = append(parts, gemma4StringDelim+gemma4Str(item)+gemma4StringDelim) + } + return strings.Join(parts, ",") +} + +// formatGemma4FunctionDeclaration is the format_function_declaration macro, +// tpl L86-L117: declaration:name{description:<|"|>..<|"|>[,parameters:{..}] +// [,response:{..}]}. Brace placement (incl. the parameters block being closed +// by the type clause) is replicated exactly, quirks and all. +func formatGemma4FunctionDeclaration(tool map[string]any) string { + fn, _ := tool["function"].(map[string]any) + var b strings.Builder + b.WriteString("declaration:" + gemma4Str(gemma4Lookup(fn, "name"))) + b.WriteString("{description:" + gemma4StringDelim + gemma4Str(gemma4Lookup(fn, "description")) + gemma4StringDelim) + params := gemma4Lookup(fn, "parameters") + if gemma4Truthy(params) { // tpl L89 + b.WriteString(",parameters:{") + if props, ok := gemma4Lookup(params, "properties").(map[string]any); ok && gemma4Truthy(gemma4Lookup(params, "properties")) { // tpl L92 + required, _ := gemma4Lookup(params, "required").([]any) + b.WriteString("properties:{" + formatGemma4Parameters(props, required, false) + "},") + } + if required, ok := gemma4Lookup(params, "required").([]any); ok && len(required) > 0 { // tpl L95 + b.WriteString("required:[" + gemma4QuoteJoin(required) + "],") + } + if gemma4Truthy(gemma4Lookup(params, "type")) { // tpl L102: closes the parameters block + b.WriteString("type:" + gemma4StringDelim + gemma4TypeUpper(params) + gemma4StringDelim + "}") + } + } + if fn != nil { // tpl L106: `'response' in tool_data['function']` + if resp, present := fn["response"]; present { + b.WriteString(",response:{") + if gemma4Truthy(gemma4Lookup(resp, "description")) { + b.WriteString("description:" + gemma4StringDelim + gemma4Str(gemma4Lookup(resp, "description")) + gemma4StringDelim + ",") + } + if gemma4TypeUpper(resp) == "OBJECT" { // tpl L112: closes the response block + b.WriteString("type:" + gemma4StringDelim + gemma4TypeUpper(resp) + gemma4StringDelim + "}") + } + } + } + b.WriteString("}") + return b.String() +} + +// formatGemma4Parameters is the format_parameters macro, tpl L1-L85. Each +// property renders as key:{[description][,enum|items][,nullable][,properties] +// [,required],type:<|"|>TYPE<|"|>} with the comma threading of the macro's +// add_comma flag. +func formatGemma4Parameters(properties map[string]any, required []any, filterKeys bool) string { + _ = required // tpl L1: passed through by callers but never read here + standardKeys := map[string]bool{ // tpl L2 + "description": true, "type": true, "properties": true, "required": true, "nullable": true, + } + var b strings.Builder + foundFirst := false + for _, key := range gemma4DictsortKeys(properties) { + if filterKeys && standardKeys[key] { // tpl L6 + continue + } + value := properties[key] + if foundFirst { + b.WriteString(",") + } + foundFirst = true + b.WriteString(key + ":{") // tpl L9 + addComma := false + comma := func() { + if addComma { + b.WriteString(",") + } else { + addComma = true + } + } + typeUpper := gemma4TypeUpper(value) + + if gemma4Truthy(gemma4Lookup(value, "description")) { // tpl L10-L13 + b.WriteString("description:" + gemma4StringDelim + gemma4Str(gemma4Lookup(value, "description")) + gemma4StringDelim) + addComma = true + } + switch typeUpper { + case "STRING": // tpl L14-L19 + if enum := gemma4Lookup(value, "enum"); gemma4Truthy(enum) { + comma() + b.WriteString("enum:" + formatGemma4Argument(enum, true)) + } + case "ARRAY": // tpl L20-L55 + if items, ok := gemma4Lookup(value, "items").(map[string]any); ok && len(items) > 0 { + comma() + b.WriteString("items:{") + itemsFound := false + for _, itemKey := range gemma4DictsortKeys(items) { + itemValue := items[itemKey] + if itemValue == nil { // tpl L25: `is not none` + continue + } + if itemsFound { + b.WriteString(",") + } + itemsFound = true + switch itemKey { + case "properties": // tpl L29-L34 + b.WriteString("properties:{") + if m, ok := itemValue.(map[string]any); ok { + itemsRequired, _ := items["required"].([]any) + b.WriteString(formatGemma4Parameters(m, itemsRequired, false)) + } + b.WriteString("}") + case "required": // tpl L35-L41 + list, _ := itemValue.([]any) + b.WriteString("required:[" + gemma4QuoteJoin(list) + "]") + case "type": // tpl L42-L47 + if s, ok := itemValue.(string); ok { + b.WriteString("type:" + formatGemma4Argument(strings.ToUpper(s), true)) + } else if list, ok := itemValue.([]any); ok { + upped := make([]any, len(list)) + for li, lv := range list { + upped[li] = strings.ToUpper(gemma4Str(lv)) + } + b.WriteString("type:" + formatGemma4Argument(upped, true)) + } + default: // tpl L48-L49 + b.WriteString(itemKey + ":" + formatGemma4Argument(itemValue, true)) + } + } + b.WriteString("}") + } + } + if gemma4Truthy(gemma4Lookup(value, "nullable")) { // tpl L56-L59 + comma() + b.WriteString("nullable:true") + } + if typeUpper == "OBJECT" { // tpl L60-L80 + if props, ok := gemma4Lookup(value, "properties").(map[string]any); ok { // tpl L61: defined and mapping + comma() + req, _ := gemma4Lookup(value, "required").([]any) + b.WriteString("properties:{" + formatGemma4Parameters(props, req, false) + "}") + } else if vm, ok := value.(map[string]any); ok { // tpl L66 + comma() + req, _ := gemma4Lookup(value, "required").([]any) + b.WriteString("properties:{" + formatGemma4Parameters(vm, req, true) + "}") + } + if req, ok := gemma4Lookup(value, "required").([]any); ok && len(req) > 0 { // tpl L72 + comma() + b.WriteString("required:[" + gemma4QuoteJoin(req) + "]") + } + } + comma() // tpl L81-L82: type is always last and closes the property + b.WriteString("type:" + gemma4StringDelim + typeUpper + gemma4StringDelim + "}") + } + return b.String() +} + +// formatGemma4ToolResponseBlock is the format_tool_response_block macro, +// tpl L160-L173, restricted to the string-response branch: pb tool messages +// carry flattened string content, so the mapping branch is unreachable. +func formatGemma4ToolResponseBlock(toolName, response string) string { + return gemma4ToolResponseOpen + + "response:" + toolName + "{value:" + formatGemma4Argument(response, false) + "}" + + gemma4ToolResponseClose +} + +// stripGemma4Thinking is the strip_thinking macro, tpl L148-L158: split on +// , drop everything from <|channel> onward in each part, trim. +func stripGemma4Thinking(text string) string { + var b strings.Builder + for _, part := range strings.Split(text, gemma4ChannelClose) { + if idx := strings.Index(part, gemma4ChannelOpen); idx >= 0 { + b.WriteString(part[:idx]) + } else { + b.WriteString(part) + } + } + return strings.TrimSpace(b.String()) +} diff --git a/backend/go/dllm/gemma4_renderer_test.go b/backend/go/dllm/gemma4_renderer_test.go new file mode 100755 index 000000000..3600fbf7a --- /dev/null +++ b/backend/go/dllm/gemma4_renderer_test.go @@ -0,0 +1,347 @@ +package main + +// Renderer specs for RenderGemma4 against the canonical gemma4 chat template +// (see the normative template comment in gemma4_renderer.go). +// +// Fixture provenance: +// - "single user message" and "enable_thinking" are the EXACT expected +// decodes from transformers tests/models/diffusion_gemma/ +// test_modeling_diffusion_gemma.py (test_diffusion_gemma_chat_template +// and ..._with_thinking) with ONE difference: the transformers fixtures +// start with "" because apply_chat_template tokenizes the rendered +// text with add_bos. Our prompt goes through dllm_capi_generate, whose +// run_generate already tokenizes with prepend_bos = vocab.add_bos +// (dllm.cpp src/capi.cpp:230-231, true for gemma4), so the renderer must +// NOT emit a literal (it would double) and every expected string +// here drops that leading token. +// - All other expected strings were produced by rendering the verbatim +// GGUF template with jinja2 3.1.2 (bos_token="") and dropping the +// leading "" for the same reason. + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + pb "github.com/mudler/LocalAI/pkg/grpc/proto" +) + +// Two-function tools array used by the tool fixtures (OpenAI wire shape, as +// LocalAI passes it through PredictOptions.Tools). +const testToolsJSON = `[{"type":"function","function":{"name":"get_weather","description":"Get the current weather in a location.","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city name."},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}},{"type":"function","function":{"name":"get_time","description":"Get the current time in a timezone.","parameters":{"type":"object","properties":{"timezone":{"type":"string","description":"IANA timezone name."}},"required":["timezone"]}}}]` + +// The <|tool>... block the template renders for testToolsJSON inside +// the system turn (jinja2-verified). +const testToolsBlock = `<|tool>declaration:get_weather{description:<|"|>Get the current weather in a location.<|"|>,parameters:{properties:{location:{description:<|"|>The city name.<|"|>,type:<|"|>STRING<|"|>},unit:{enum:[<|"|>celsius<|"|>,<|"|>fahrenheit<|"|>],type:<|"|>STRING<|"|>}},required:[<|"|>location<|"|>],type:<|"|>OBJECT<|"|>}}<|tool>declaration:get_time{description:<|"|>Get the current time in a timezone.<|"|>,parameters:{properties:{timezone:{description:<|"|>IANA timezone name.<|"|>,type:<|"|>STRING<|"|>}},required:[<|"|>timezone<|"|>],type:<|"|>OBJECT<|"|>}}` + +// A single tool exercising the deep format_parameters branches: array items +// (string-typed and nested-array), nullable, enum+nullable, nested object +// properties/required, and a response declaration. +const complexToolsJSON = `[{"type":"function","function":{"name":"complex_tool","description":"A complex tool.","parameters":{"type":"object","properties":{"tags":{"type":"array","description":"Tags.","items":{"type":"string"}},"matrix":{"type":"array","items":{"type":"array","items":{"type":"number"}}},"opts":{"type":"object","description":"Options.","properties":{"depth":{"type":"integer","nullable":true}},"required":["depth"]},"mode":{"type":"string","enum":["a","b"],"nullable":true}},"required":["tags","opts"]},"response":{"description":"The result.","type":"object"}}}]` + +// jinja2-verified render of complexToolsJSON. Notable template quirks pinned +// here: nested array items go through format_argument with ESCAPED keys and +// an un-uppercased type (<|"|>type<|"|>:<|"|>number<|"|>), while direct item +// types are uppercased; properties dictsort case-insensitively. +const complexToolsBlock = `<|tool>declaration:complex_tool{description:<|"|>A complex tool.<|"|>,parameters:{properties:{matrix:{items:{items:{<|"|>type<|"|>:<|"|>number<|"|>},type:<|"|>ARRAY<|"|>},type:<|"|>ARRAY<|"|>},mode:{enum:[<|"|>a<|"|>,<|"|>b<|"|>],nullable:true,type:<|"|>STRING<|"|>},opts:{description:<|"|>Options.<|"|>,properties:{depth:{nullable:true,type:<|"|>INTEGER<|"|>}},required:[<|"|>depth<|"|>],type:<|"|>OBJECT<|"|>},tags:{description:<|"|>Tags.<|"|>,items:{type:<|"|>STRING<|"|>},type:<|"|>ARRAY<|"|>}},required:[<|"|>tags<|"|>,<|"|>opts<|"|>],type:<|"|>OBJECT<|"|>},response:{description:<|"|>The result.<|"|>,type:<|"|>OBJECT<|"|>}}` + +type renderGemma4Case struct { + msgs []*pb.Message + toolsJSON string + enableThinking bool + noGenerationPrompt bool // inverted so the zero value is the common case + expected string +} + +var _ = Describe("RenderGemma4", func() { + DescribeTable("renders the canonical gemma4 prompt", + func(c renderGemma4Case) { + out, err := RenderGemma4(c.msgs, c.toolsJSON, c.enableThinking, !c.noGenerationPrompt) + Expect(err).ToNot(HaveOccurred()) + Expect(out).To(Equal(c.expected)) + // The C-ABI generate prepends BOS itself: a literal + // anywhere in the rendered prompt would double-encode it. + Expect(out).ToNot(ContainSubstring("")) + }, + + // transformers fixture (test_diffusion_gemma_chat_template), sans : + // default thinking pre-opens an EMPTY thought channel in the + // generation prompt. + Entry("single user message, default (no thinking)", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "Write a long essay about Portugal."}, + }, + expected: "<|turn>user\nWrite a long essay about Portugal.\n<|turn>model\n<|channel>thought\n", + }), + + // transformers fixture (test_diffusion_gemma_chat_template_with_thinking), + // sans : a system turn carrying <|think|> and NO auto-opened + // thought channel. + Entry("enable_thinking=true", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "Write a long essay about Portugal."}, + }, + enableThinking: true, + expected: "<|turn>system\n<|think|>\n\n<|turn>user\nWrite a long essay about Portugal.\n<|turn>model\n", + }), + + Entry("multi-turn user/assistant/user", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "Hello, who are you?"}, + {Role: "assistant", Content: "I am Gemma, a helpful assistant."}, + {Role: "user", Content: "Tell me a joke."}, + }, + expected: "<|turn>user\nHello, who are you?\n<|turn>model\nI am Gemma, a helpful assistant.\n<|turn>user\nTell me a joke.\n<|turn>model\n<|channel>thought\n", + }), + + // tpl L178-L195: a leading system message is folded into the system + // turn (trimmed) and consumed from the loop. + Entry("system message folds into the system turn", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "system", Content: "You are a pirate."}, + {Role: "user", Content: "Hello!"}, + }, + expected: "<|turn>system\nYou are a pirate.\n<|turn>user\nHello!\n<|turn>model\n<|channel>thought\n", + }), + + // tpl L182-L185: <|think|> goes at the very top of the SAME system + // turn, before the system prompt text. + Entry("system message with enable_thinking shares the turn", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "system", Content: "You are a pirate."}, + {Role: "user", Content: "Hello!"}, + }, + enableThinking: true, + expected: "<|turn>system\n<|think|>\nYou are a pirate.\n<|turn>user\nHello!\n<|turn>model\n", + }), + + // tpl L196-L203: tool declarations render in the system turn, one + // <|tool>declaration:... block per tool, no separators. + Entry("tools array (two functions)", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "What is the weather in Tokyo?"}, + }, + toolsJSON: testToolsJSON, + expected: "<|turn>system\n" + testToolsBlock + "\n<|turn>user\nWhat is the weather in Tokyo?\n<|turn>model\n<|channel>thought\n", + }), + + // format_parameters deep branches (tpl L1-L85) + response declaration + // (tpl L106-L116). + Entry("complex tool schema (array items, nullable, nested object, response)", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "go"}, + }, + toolsJSON: complexToolsJSON, + expected: "<|turn>system\n" + complexToolsBlock + "\n<|turn>user\ngo\n<|turn>model\n<|channel>thought\n", + }), + + // tpl L243-L313: assistant tool_calls render as + // <|tool_call>call:name{args}; the following role=tool + // message renders inline as <|tool_response>response:name{value:..} + // ; the model turn stays OPEN (no , no new + // generation prompt) so the model continues after the response. + Entry("assistant tool_calls + role=tool result", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "What is the weather in Tokyo?"}, + {Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"call_1","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"Tokyo\",\"unit\":\"celsius\"}"}}]`}, + {Role: "tool", ToolCallId: "call_1", Content: "Sunny, 22 degrees celsius."}, + }, + toolsJSON: testToolsJSON, + expected: "<|turn>system\n" + testToolsBlock + "\n<|turn>user\nWhat is the weather in Tokyo?\n<|turn>model\n" + `<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>,unit:<|"|>celsius<|"|>}<|tool_response>response:get_weather{value:<|"|>Sunny, 22 degrees celsius.<|"|>}`, + }), + + // tpl L348-L349: a tool_calls turn with no rendered responses ends + // on an OPEN <|tool_response> marker for the runtime to fill, and + // add_generation_prompt adds nothing (tpl L357). + Entry("assistant tool_calls without a result leaves <|tool_response> open", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "What is the weather in Tokyo?"}, + {Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"call_1","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"Tokyo\",\"unit\":\"celsius\"}"}}]`}, + }, + toolsJSON: testToolsJSON, + expected: "<|turn>system\n" + testToolsBlock + "\n<|turn>user\nWhat is the weather in Tokyo?\n<|turn>model\n" + `<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>,unit:<|"|>celsius<|"|>}<|tool_response>`, + }), + + // tpl L237-L241: reasoning_content renders as a thought channel only + // on a tool-calling turn after the last user message. + Entry("reasoning_content with tool_calls renders the thought channel", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "weather?"}, + {Role: "assistant", Content: "", ReasoningContent: "I should call the tool", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"Tokyo\"}"}}]`}, + {Role: "tool", ToolCallId: "c1", Content: "Sunny"}, + }, + expected: "<|turn>user\nweather?\n<|turn>model\n<|channel>thought\nI should call the tool\n" + `<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>}<|tool_response>response:get_weather{value:<|"|>Sunny<|"|>}`, + }), + + // tpl L220-L235: the assistant answer following its own tool round + // continues the SAME model turn (no second <|turn>model). + Entry("tool round then final assistant answer then user", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "weather?"}, + {Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"Tokyo\"}"}}]`}, + {Role: "tool", ToolCallId: "c1", Content: "Sunny"}, + {Role: "assistant", Content: "It is sunny."}, + {Role: "user", Content: "thanks"}, + }, + expected: "<|turn>user\nweather?\n<|turn>model\n" + `<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>}<|tool_response>response:get_weather{value:<|"|>Sunny<|"|>}` + "It is sunny.\n<|turn>user\nthanks\n<|turn>model\n<|channel>thought\n", + }), + + // format_argument (tpl L118-L147): numbers keep their JSON literal, + // booleans lower-case, nested maps have unquoted dictsorted keys, + // arrays bracketed; top-level args are dictsorted case-insensitively. + Entry("tool_call argument types (number/bool/nested/array)", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "go"}, + {Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{\"count\":42,\"ratio\":3.5,\"flag\":true,\"off\":false,\"nested\":{\"x\":\"y\",\"n\":7},\"list\":[\"a\",1,true]}"}}]`}, + }, + expected: "<|turn>user\ngo\n<|turn>model\n" + `<|tool_call>call:f{count:42,flag:true,list:[<|"|>a<|"|>,1,true],nested:{n:7,x:<|"|>y<|"|>},off:false,ratio:3.5}<|tool_response>`, + }), + + // jinja dictsort is case-insensitive: alpha sorts before Beta. + Entry("tool_call argument dictsort is case-insensitive", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "go"}, + {Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{\"Beta\":1,\"alpha\":2}"}}]`}, + }, + expected: "<|turn>user\ngo\n<|turn>model\n<|tool_call>call:f{alpha:2,Beta:1}<|tool_response>", + }), + + // jinja renders Python None as "None" (round-trips through vLLM's + // parser, which lowers "none" back to null). + Entry("tool_call null argument renders as None", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "go"}, + {Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{\"maybe\":null}"}}]`}, + }, + expected: "<|turn>user\ngo\n<|turn>model\n<|tool_call>call:f{maybe:None}<|tool_response>", + }), + + Entry("tool_call empty arguments render empty braces", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "go"}, + {Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{}"}}]`}, + }, + expected: "<|turn>user\ngo\n<|turn>model\n<|tool_call>call:f{}<|tool_response>", + }), + + // tpl L253-L254: a non-object arguments string renders verbatim. + Entry("tool_call non-object string arguments render verbatim", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "go"}, + {Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"just text"}}]`}, + }, + expected: "<|turn>user\ngo\n<|turn>model\n<|tool_call>call:f{just text}<|tool_response>", + }), + + // tpl L278-L285: unmatched tool_call_id falls back to the tool + // message's own name. + Entry("tool result name falls back when tool_call_id does not match", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "go"}, + {Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{}"}}]`}, + {Role: "tool", ToolCallId: "OTHER", Name: "named_tool", Content: "out"}, + }, + expected: "<|turn>user\ngo\n<|turn>model\n" + `<|tool_call>call:f{}<|tool_response>response:named_tool{value:<|"|>out<|"|>}`, + }), + + // strip_thinking (tpl L148-L158): historical assistant content loses + // its <|channel>... spans. + Entry("assistant content thinking channels are stripped", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "hi"}, + {Role: "assistant", Content: "<|channel>thought\nsecret\nvisible answer"}, + {Role: "user", Content: "more"}, + }, + expected: "<|turn>user\nhi\n<|turn>model\nvisible answer\n<|turn>user\nmore\n<|turn>model\n<|channel>thought\n", + }), + + // tpl L220-L235: consecutive assistant messages suppress the second + // <|turn>model (continuation), but each still closes with . + Entry("consecutive assistant messages continue the model turn", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "hi"}, + {Role: "assistant", Content: "part one"}, + {Role: "assistant", Content: "part two"}, + {Role: "user", Content: "ok"}, + }, + expected: "<|turn>user\nhi\n<|turn>model\npart one\npart two\n<|turn>user\nok\n<|turn>model\n<|channel>thought\n", + }), + + Entry("add_generation_prompt=false renders no model turn", renderGemma4Case{ + msgs: []*pb.Message{ + {Role: "user", Content: "hi"}, + }, + noGenerationPrompt: true, + expected: "<|turn>user\nhi\n", + }), + ) + + Describe("error handling", func() { + It("fails loud on an unknown role", func() { + _, err := RenderGemma4([]*pb.Message{ + {Role: "narrator", Content: "Meanwhile..."}, + }, "", false, true) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring(`unknown role "narrator"`)) + }) + + It("fails on invalid tools JSON", func() { + _, err := RenderGemma4([]*pb.Message{ + {Role: "user", Content: "hi"}, + }, "{not json", false, true) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("tools JSON")) + }) + + It("fails on invalid tool_calls JSON", func() { + _, err := RenderGemma4([]*pb.Message{ + {Role: "user", Content: "hi"}, + {Role: "assistant", Content: "", ToolCalls: "{not json"}, + }, "", false, true) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("tool_calls JSON")) + }) + + It("fails on an orphan tool message, naming its index", func() { + // A role:tool message with no preceding assistant tool_calls turn + // would be silently dropped by the jinja; we fail loud instead. + _, err := RenderGemma4([]*pb.Message{ + {Role: "user", Content: "hi"}, + {Role: "tool", Content: `{"temp": 20}`, ToolCallId: "call_1"}, + }, "", false, true) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("orphan tool message 1")) + }) + + It("fails on trailing garbage after the tools JSON array", func() { + _, err := RenderGemma4([]*pb.Message{ + {Role: "user", Content: "hi"}, + }, "[] junk", false, true) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("tools JSON")) + }) + + It("fails when the tools JSON is not an array", func() { + _, err := RenderGemma4([]*pb.Message{ + {Role: "user", Content: "hi"}, + }, `{"type":"function"}`, false, true) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("tools JSON is not an array")) + }) + + It("fails when a tools array element is not an object", func() { + _, err := RenderGemma4([]*pb.Message{ + {Role: "user", Content: "hi"}, + }, `[42]`, false, true) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("tools[0] is not an object")) + }) + + It("rejects a nil message via the unknown-role check", func() { + // Pins current behavior: pb getters are nil-safe, so a nil message + // reads as role "" and trips the fail-loud unknown-role guard. + _, err := RenderGemma4([]*pb.Message{nil}, "", false, true) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring(`unknown role "" in message 0`)) + }) + }) +})