mirror of
https://github.com/ollama/ollama.git
synced 2026-01-19 04:51:17 -05:00
Compare commits
9 Commits
parth/decr
...
parth/gpt-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1fe7e07f63 | ||
|
|
40d3436cd1 | ||
|
|
5bc783b58e | ||
|
|
87714c1c39 | ||
|
|
f7ca3b7f7e | ||
|
|
72189c6d6e | ||
|
|
1d09e01431 | ||
|
|
eb7660d724 | ||
|
|
4a5bdd5f12 |
@@ -1,17 +1,32 @@
|
|||||||
package harmony
|
package harmony
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"maps"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"unicode"
|
"unicode"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
|
"github.com/ollama/ollama/template"
|
||||||
)
|
)
|
||||||
|
|
||||||
type harmonyParserState int
|
type harmonyParserState int
|
||||||
|
|
||||||
|
func ShouldUseHarmony(modelFamily string, template *template.Template) bool {
|
||||||
|
if slices.Contains([]string{"gptoss", "gpt-oss"}, modelFamily) {
|
||||||
|
// heuristic to check whether the template expects to be parsed via harmony:
|
||||||
|
// search for harmony tags that are nearly always used
|
||||||
|
if template.Contains("<|start|>") && template.Contains("<|end|>") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
harmonyParserState_LookingForMessageStart harmonyParserState = iota
|
harmonyParserState_LookingForMessageStart harmonyParserState = iota
|
||||||
harmonyParserState_ParsingHeader
|
harmonyParserState_ParsingHeader
|
||||||
@@ -33,12 +48,13 @@ func (s harmonyParserState) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type HarmonyParser struct {
|
type HarmonyParser struct {
|
||||||
state harmonyParserState
|
state harmonyParserState
|
||||||
MessageStartTag string
|
MessageStartTag string
|
||||||
MessageEndTag string
|
MessageEndTag string
|
||||||
HeaderEndTag string
|
HeaderEndTag string
|
||||||
acc strings.Builder
|
ConstrainAllowed bool
|
||||||
lifetimeAcc strings.Builder
|
acc strings.Builder
|
||||||
|
lifetimeAcc strings.Builder
|
||||||
}
|
}
|
||||||
|
|
||||||
type HarmonyEvent interface {
|
type HarmonyEvent interface {
|
||||||
@@ -75,17 +91,18 @@ func (s *HarmonyParser) AddImplicitStart() {
|
|||||||
s.acc.WriteString("<|start|>assistant")
|
s.acc.WriteString("<|start|>assistant")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *HarmonyParser) AddImplicitStartOrPrefill(lastMessage *api.Message) {
|
// AddImplicitStartOrPrefill adds content or thinking to the accumulator else adds start tag
|
||||||
if lastMessage != nil && lastMessage.Role == "assistant" {
|
func (s *HarmonyParser) AddImplicitStartOrPrefill(prefillContentOrThinking *bool) {
|
||||||
// handle prefilling conditions
|
if prefillContentOrThinking != nil {
|
||||||
if lastMessage.Content != "" {
|
if *prefillContentOrThinking {
|
||||||
s.acc.WriteString("<|start|>assistant<|channel|>final<|message|>")
|
s.acc.WriteString("<|start|>assistant<|channel|>final<|message|>")
|
||||||
return
|
return
|
||||||
} else if lastMessage.Thinking != "" {
|
} else {
|
||||||
s.acc.WriteString("<|start|>assistant<|channel|>analysis<|message|>")
|
s.acc.WriteString("<|start|>assistant<|channel|>analysis<|message|>")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s.AddImplicitStart()
|
s.AddImplicitStart()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -312,6 +329,7 @@ func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyTo
|
|||||||
}
|
}
|
||||||
case "final":
|
case "final":
|
||||||
h.state = harmonyMessageState_Normal
|
h.state = harmonyMessageState_Normal
|
||||||
|
h.HarmonyParser.ConstrainAllowed = true
|
||||||
}
|
}
|
||||||
case HarmonyEventContentEmitted:
|
case HarmonyEventContentEmitted:
|
||||||
logutil.Trace("harmony event content", "content", event.Content, "state", h.state)
|
logutil.Trace("harmony event content", "content", event.Content, "state", h.state)
|
||||||
@@ -377,6 +395,38 @@ type FunctionNameMap struct {
|
|||||||
harmonyToUser map[string]string
|
harmonyToUser map[string]string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m FunctionNameMap) MarshalJSON() ([]byte, error) {
|
||||||
|
// necessary to avoid exposing map internals
|
||||||
|
type alias struct {
|
||||||
|
UserToHarmony map[string]string `json:"userToHarmony"`
|
||||||
|
HarmonyToUser map[string]string `json:"harmonyToUser"`
|
||||||
|
}
|
||||||
|
return json.Marshal(alias{
|
||||||
|
UserToHarmony: m.userToHarmony,
|
||||||
|
HarmonyToUser: m.harmonyToUser,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *FunctionNameMap) UnmarshalJSON(b []byte) error {
|
||||||
|
type alias struct {
|
||||||
|
UserToHarmony map[string]string `json:"userToHarmony"`
|
||||||
|
HarmonyToUser map[string]string `json:"harmonyToUser"`
|
||||||
|
}
|
||||||
|
var a alias
|
||||||
|
if err := json.Unmarshal(b, &a); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if m.userToHarmony == nil {
|
||||||
|
m.userToHarmony = make(map[string]string)
|
||||||
|
}
|
||||||
|
if m.harmonyToUser == nil {
|
||||||
|
m.harmonyToUser = make(map[string]string)
|
||||||
|
}
|
||||||
|
maps.Copy(m.userToHarmony, a.UserToHarmony)
|
||||||
|
maps.Copy(m.harmonyToUser, a.HarmonyToUser)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func NewFunctionNameMap() *FunctionNameMap {
|
func NewFunctionNameMap() *FunctionNameMap {
|
||||||
return &FunctionNameMap{
|
return &FunctionNameMap{
|
||||||
userToHarmony: make(map[string]string),
|
userToHarmony: make(map[string]string),
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
package harmony
|
package harmony
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -535,3 +537,224 @@ func TestFunctionConvertAndAdd(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) {
|
||||||
|
t.Run("thinking_then_content_streams", func(t *testing.T) {
|
||||||
|
handler := NewHarmonyMessageHandler()
|
||||||
|
handler.HarmonyParser.AddImplicitStart()
|
||||||
|
tp := handler.CreateToolParser()
|
||||||
|
type step struct {
|
||||||
|
in string
|
||||||
|
wantContent string
|
||||||
|
wantThinking string
|
||||||
|
}
|
||||||
|
steps := []step{
|
||||||
|
{in: "<|channel|>analysis<|message|>Thinking...", wantThinking: "Thinking..."},
|
||||||
|
{in: "<|end|>", wantThinking: ""},
|
||||||
|
{in: "<|start|>assistant<|message|>Answer", wantContent: "Answer"},
|
||||||
|
{in: "<|end|>", wantContent: ""},
|
||||||
|
}
|
||||||
|
for i, s := range steps {
|
||||||
|
content, thinking, tool := handler.AddContent(s.in, tp)
|
||||||
|
if tool != "" {
|
||||||
|
tp.Add(tool)
|
||||||
|
}
|
||||||
|
if content != s.wantContent || thinking != s.wantThinking {
|
||||||
|
t.Fatalf("step %d: got (content=%q thinking=%q), want (content=%q thinking=%q)", i, content, thinking, s.wantContent, s.wantThinking)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("content_streams_as_it_arrives", func(t *testing.T) {
|
||||||
|
handler := NewHarmonyMessageHandler()
|
||||||
|
handler.HarmonyParser.AddImplicitStart()
|
||||||
|
tp := handler.CreateToolParser()
|
||||||
|
inputs := []string{
|
||||||
|
"<|start|>assistant<|message|>Hello",
|
||||||
|
", world",
|
||||||
|
"!<|end|>",
|
||||||
|
}
|
||||||
|
var got []string
|
||||||
|
for _, in := range inputs {
|
||||||
|
content, thinking, tool := handler.AddContent(in, tp)
|
||||||
|
if tool != "" {
|
||||||
|
tp.Add(tool)
|
||||||
|
}
|
||||||
|
if thinking != "" {
|
||||||
|
t.Fatalf("unexpected thinking %q", thinking)
|
||||||
|
}
|
||||||
|
if content != "" {
|
||||||
|
got = append(got, content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
want := []string{"Hello", ", world", "!"}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Fatalf("content pieces mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("thinking_streams_separately_from_content", func(t *testing.T) {
|
||||||
|
handler := NewHarmonyMessageHandler()
|
||||||
|
handler.HarmonyParser.AddImplicitStart()
|
||||||
|
tp := handler.CreateToolParser()
|
||||||
|
inputs := []string{
|
||||||
|
"<|channel|>analysis<|message|>Thinking...",
|
||||||
|
"<|end|>",
|
||||||
|
"<|start|>assistant<|message|>Answer",
|
||||||
|
"<|end|>",
|
||||||
|
}
|
||||||
|
var got []string
|
||||||
|
for _, in := range inputs {
|
||||||
|
content, thinking, tool := handler.AddContent(in, tp)
|
||||||
|
if tool != "" {
|
||||||
|
tp.Add(tool)
|
||||||
|
}
|
||||||
|
if thinking != "" {
|
||||||
|
got = append(got, thinking)
|
||||||
|
}
|
||||||
|
if content != "" {
|
||||||
|
got = append(got, content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
want := []string{"Thinking...", "Answer"}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Fatalf("content pieces mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("partial_tags_buffer_until_complete", func(t *testing.T) {
|
||||||
|
handler := NewHarmonyMessageHandler()
|
||||||
|
handler.HarmonyParser.AddImplicitStart()
|
||||||
|
tp := handler.CreateToolParser()
|
||||||
|
inputs := []string{
|
||||||
|
"<|chan",
|
||||||
|
"nel|>analysis<|mess",
|
||||||
|
"age|>Deep ",
|
||||||
|
"thought",
|
||||||
|
"<|end|>",
|
||||||
|
"<|start|>assistant<|message|>Done",
|
||||||
|
"<|end|>",
|
||||||
|
}
|
||||||
|
var thinkingPieces []string
|
||||||
|
var contentPieces []string
|
||||||
|
for _, in := range inputs {
|
||||||
|
content, thinking, tool := handler.AddContent(in, tp)
|
||||||
|
if tool != "" {
|
||||||
|
tp.Add(tool)
|
||||||
|
}
|
||||||
|
if thinking != "" {
|
||||||
|
thinkingPieces = append(thinkingPieces, thinking)
|
||||||
|
}
|
||||||
|
if content != "" {
|
||||||
|
contentPieces = append(contentPieces, content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if want := []string{"Deep ", "thought"}; !reflect.DeepEqual(thinkingPieces, want) {
|
||||||
|
t.Fatalf("thinking pieces mismatch: got %v want %v", thinkingPieces, want)
|
||||||
|
}
|
||||||
|
if want := []string{"Done"}; !reflect.DeepEqual(contentPieces, want) {
|
||||||
|
t.Fatalf("content pieces mismatch: got %v want %v", contentPieces, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("simple_assistant_after_analysis", func(t *testing.T) {
|
||||||
|
handler := NewHarmonyMessageHandler()
|
||||||
|
handler.HarmonyParser.AddImplicitStart()
|
||||||
|
tp := handler.CreateToolParser()
|
||||||
|
inputs := []string{
|
||||||
|
"<|channel|>analysis<|message|>Think",
|
||||||
|
"<|end|>",
|
||||||
|
"<|start|>assistant<|message|>Answer",
|
||||||
|
"<|end|>",
|
||||||
|
}
|
||||||
|
var contentSb, thinkingSb strings.Builder
|
||||||
|
for _, in := range inputs {
|
||||||
|
content, thinking, tool := handler.AddContent(in, tp)
|
||||||
|
if tool != "" {
|
||||||
|
tp.Add(tool)
|
||||||
|
}
|
||||||
|
contentSb.WriteString(content)
|
||||||
|
thinkingSb.WriteString(thinking)
|
||||||
|
}
|
||||||
|
if contentSb.String() != "Answer" {
|
||||||
|
t.Fatalf("content mismatch: got %q want %q", contentSb.String(), "Answer")
|
||||||
|
}
|
||||||
|
if thinkingSb.String() != "Think" {
|
||||||
|
t.Fatalf("thinking mismatch: got %q want %q", thinkingSb.String(), "Think")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("tool_call_parsed_and_returned_correctly", func(t *testing.T) {
|
||||||
|
handler := NewHarmonyMessageHandler()
|
||||||
|
handler.HarmonyParser.AddImplicitStart()
|
||||||
|
tp := handler.CreateToolParser()
|
||||||
|
inputs := []string{
|
||||||
|
"<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+2\"}<|end|>",
|
||||||
|
}
|
||||||
|
for _, in := range inputs {
|
||||||
|
content, thinking, tool := handler.AddContent(in, tp)
|
||||||
|
if content != "" || thinking != "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if tool != "" {
|
||||||
|
tp.Add(tool)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
name, args := tp.Drain()
|
||||||
|
if name == nil || *name != "functions.calculate" {
|
||||||
|
t.Fatalf("unexpected tool name: %v", name)
|
||||||
|
}
|
||||||
|
if got, want := args, "{\"expression\":\"2+2\"}"; got != want {
|
||||||
|
t.Fatalf("unexpected tool args: got %s want %s", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("tool_call_across_chunks", func(t *testing.T) {
|
||||||
|
handler := NewHarmonyMessageHandler()
|
||||||
|
handler.HarmonyParser.AddImplicitStart()
|
||||||
|
tp := handler.CreateToolParser()
|
||||||
|
inputs := []string{
|
||||||
|
"<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+",
|
||||||
|
"2\"}",
|
||||||
|
"<|end|>",
|
||||||
|
}
|
||||||
|
for _, in := range inputs {
|
||||||
|
content, thinking, tool := handler.AddContent(in, tp)
|
||||||
|
if content != "" || thinking != "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if tool != "" {
|
||||||
|
tp.Add(tool)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
name, args := tp.Drain()
|
||||||
|
if name == nil || *name != "functions.calculate" {
|
||||||
|
t.Fatalf("unexpected tool name: %v", name)
|
||||||
|
}
|
||||||
|
if got, want := args, "{\"expression\":\"2+2\"}"; got != want {
|
||||||
|
t.Fatalf("unexpected tool args: got %s want %s", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFunctionNameMapJSONRoundTrip(t *testing.T) {
|
||||||
|
m := NewFunctionNameMap()
|
||||||
|
gotConverted := m.ConvertAndAdd("get weather")
|
||||||
|
if gotConverted == "" {
|
||||||
|
t.Fatal("conversion returned empty")
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(m)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("marshal: %v", err)
|
||||||
|
}
|
||||||
|
var m2 FunctionNameMap
|
||||||
|
if err := json.Unmarshal(b, &m2); err != nil {
|
||||||
|
t.Fatalf("unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
if m2.userToHarmony["get weather"] != gotConverted {
|
||||||
|
t.Fatalf("userToHarmony lost: got %q want %q", m2.userToHarmony["get weather"], gotConverted)
|
||||||
|
}
|
||||||
|
if m2.harmonyToUser[gotConverted] != "get weather" {
|
||||||
|
t.Fatalf("harmonyToUser lost: got %q want %q", m2.harmonyToUser[gotConverted], "get weather")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1347,7 +1347,9 @@ type CompletionRequest struct {
|
|||||||
Images []ImageData
|
Images []ImageData
|
||||||
Options *api.Options
|
Options *api.Options
|
||||||
|
|
||||||
Grammar string // set before sending the request to the subprocess
|
Grammar string // set before sending the request to the subprocess
|
||||||
|
UseHarmony bool
|
||||||
|
PrefillContent *bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// DoneReason represents the reason why a completion response is done
|
// DoneReason represents the reason why a completion response is done
|
||||||
@@ -1374,13 +1376,15 @@ func (d DoneReason) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type CompletionResponse struct {
|
type CompletionResponse struct {
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
DoneReason DoneReason `json:"done_reason"`
|
Thinking string `json:"thinking"`
|
||||||
Done bool `json:"done"`
|
ToolCalls []api.ToolCall `json:"tool_calls"`
|
||||||
PromptEvalCount int `json:"prompt_eval_count"`
|
DoneReason DoneReason `json:"done_reason"`
|
||||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
|
Done bool `json:"done"`
|
||||||
EvalCount int `json:"eval_count"`
|
PromptEvalCount int `json:"prompt_eval_count"`
|
||||||
EvalDuration time.Duration `json:"eval_duration"`
|
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
|
||||||
|
EvalCount int `json:"eval_count"`
|
||||||
|
EvalDuration time.Duration `json:"eval_duration"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
||||||
@@ -1498,7 +1502,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||||||
return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
|
return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
|
||||||
}
|
}
|
||||||
switch {
|
switch {
|
||||||
case strings.TrimSpace(c.Content) == lastToken:
|
case lastToken != "" && (strings.TrimSpace(c.Content) == lastToken || strings.TrimSpace(c.Thinking) == lastToken):
|
||||||
tokenRepeat++
|
tokenRepeat++
|
||||||
default:
|
default:
|
||||||
lastToken = strings.TrimSpace(c.Content)
|
lastToken = strings.TrimSpace(c.Content)
|
||||||
@@ -1511,16 +1515,14 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.Content != "" {
|
|
||||||
fn(CompletionResponse{
|
|
||||||
Content: c.Content,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.Done {
|
if c.Done {
|
||||||
fn(c)
|
fn(c)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.Content != "" || c.Thinking != "" || len(c.ToolCalls) > 0 {
|
||||||
|
fn(c)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ import (
|
|||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
"github.com/ollama/ollama/harmony"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
@@ -773,6 +774,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var harmonyMessageHandler *harmony.HarmonyMessageHandler
|
||||||
|
var harmonyToolParser *harmony.HarmonyToolCallAccumulator
|
||||||
|
if req.UseHarmony {
|
||||||
|
harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
|
||||||
|
harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(req.PrefillContent)
|
||||||
|
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
|
||||||
|
}
|
||||||
|
|
||||||
if req.Options == nil {
|
if req.Options == nil {
|
||||||
opts := api.DefaultOptions()
|
opts := api.DefaultOptions()
|
||||||
req.Options = &opts
|
req.Options = &opts
|
||||||
@@ -805,7 +814,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
req.Options.TopP,
|
req.Options.TopP,
|
||||||
req.Options.MinP,
|
req.Options.MinP,
|
||||||
req.Options.Seed,
|
req.Options.Seed,
|
||||||
grammar,
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||||
@@ -856,6 +865,12 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(parthsareen): generalize grammar enablement on the fly for all thinking models
|
||||||
|
if harmonyMessageHandler == nil {
|
||||||
|
seq.sampler.SetGrammar(grammar)
|
||||||
|
}
|
||||||
|
|
||||||
|
grammarSet := false
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-r.Context().Done():
|
case <-r.Context().Done():
|
||||||
@@ -863,8 +878,20 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
case content, ok := <-seq.responses:
|
case content, ok := <-seq.responses:
|
||||||
if ok {
|
if ok {
|
||||||
|
var thinking string
|
||||||
|
if harmonyMessageHandler != nil {
|
||||||
|
var toolContent string
|
||||||
|
content, thinking, toolContent = harmonyMessageHandler.AddContent(content, harmonyToolParser)
|
||||||
|
harmonyToolParser.Add(toolContent)
|
||||||
|
if harmonyMessageHandler.HarmonyParser.ConstrainAllowed && !grammarSet {
|
||||||
|
seq.sampler.SetGrammar(grammar)
|
||||||
|
grammarSet = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||||
Content: content,
|
Content: content,
|
||||||
|
Thinking: thinking,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||||
close(seq.quit)
|
close(seq.quit)
|
||||||
@@ -873,7 +900,29 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
} else {
|
} else {
|
||||||
|
var toolCalls []api.ToolCall
|
||||||
|
if harmonyMessageHandler != nil {
|
||||||
|
// these tools still need to be transformed to the original function name
|
||||||
|
toolName, toolContent := harmonyToolParser.Drain()
|
||||||
|
if toolName != nil {
|
||||||
|
*toolName = strings.TrimPrefix(*toolName, "functions.")
|
||||||
|
var args api.ToolCallFunctionArguments
|
||||||
|
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("failed to unmarshal tool call function arguments: %v", err), http.StatusInternalServerError)
|
||||||
|
close(seq.quit)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
toolCalls = append(toolCalls, api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: *toolName,
|
||||||
|
Arguments: args,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||||
|
ToolCalls: toolCalls,
|
||||||
Done: true,
|
Done: true,
|
||||||
DoneReason: seq.doneReason,
|
DoneReason: seq.doneReason,
|
||||||
PromptEvalCount: seq.numPromptInputs,
|
PromptEvalCount: seq.numPromptInputs,
|
||||||
|
|||||||
@@ -25,6 +25,10 @@ type Sampler struct {
|
|||||||
grammar *GrammarSampler
|
grammar *GrammarSampler
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Sampler) SetGrammar(grammar *GrammarSampler) {
|
||||||
|
s.grammar = grammar
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Sampler) Sample(logits []float32) (int32, error) {
|
func (s *Sampler) Sample(logits []float32) (int32, error) {
|
||||||
if len(logits) == 0 {
|
if len(logits) == 0 {
|
||||||
return -1, errors.New("sample: no logits provided to sample")
|
return -1, errors.New("sample: no logits provided to sample")
|
||||||
|
|||||||
138
server/routes.go
138
server/routes.go
@@ -46,18 +46,6 @@ import (
|
|||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
func shouldUseHarmony(model *Model) bool {
|
|
||||||
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
|
|
||||||
// heuristic to check whether the template expects to be parsed via harmony:
|
|
||||||
// search for harmony tags that are nearly always used
|
|
||||||
if model.Template.Contains("<|start|>") && model.Template.Contains("<|end|>") {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func experimentEnabled(name string) bool {
|
func experimentEnabled(name string) bool {
|
||||||
return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name)
|
return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name)
|
||||||
}
|
}
|
||||||
@@ -207,13 +195,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
useHarmony := shouldUseHarmony(m) && !req.Raw
|
useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) && !req.Raw
|
||||||
var harmonyMessageHandler *harmony.HarmonyMessageHandler
|
var functionNameMap *harmony.FunctionNameMap
|
||||||
var harmonyToolParser *harmony.HarmonyToolCallAccumulator
|
|
||||||
if useHarmony {
|
if useHarmony {
|
||||||
harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
|
functionNameMap = harmony.NewFunctionNameMap()
|
||||||
harmonyMessageHandler.HarmonyParser.AddImplicitStart()
|
|
||||||
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate Think value: string values currently only allowed for gptoss models
|
// Validate Think value: string values currently only allowed for gptoss models
|
||||||
@@ -357,16 +343,19 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Images: images,
|
Images: images,
|
||||||
Format: req.Format,
|
Format: req.Format,
|
||||||
Options: opts,
|
Options: opts,
|
||||||
|
UseHarmony: useHarmony,
|
||||||
}, func(cr llm.CompletionResponse) {
|
}, func(cr llm.CompletionResponse) {
|
||||||
res := api.GenerateResponse{
|
res := api.GenerateResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
CreatedAt: time.Now().UTC(),
|
CreatedAt: time.Now().UTC(),
|
||||||
Response: cr.Content,
|
Response: cr.Content,
|
||||||
Done: cr.Done,
|
Done: cr.Done,
|
||||||
|
Thinking: cr.Thinking,
|
||||||
|
ToolCalls: cr.ToolCalls,
|
||||||
Metrics: api.Metrics{
|
Metrics: api.Metrics{
|
||||||
PromptEvalCount: cr.PromptEvalCount,
|
PromptEvalCount: cr.PromptEvalCount,
|
||||||
PromptEvalDuration: cr.PromptEvalDuration,
|
PromptEvalDuration: cr.PromptEvalDuration,
|
||||||
@@ -375,12 +364,22 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if res.Done {
|
||||||
|
res.DoneReason = cr.DoneReason.String()
|
||||||
|
res.TotalDuration = time.Since(checkpointStart)
|
||||||
|
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||||
|
}
|
||||||
|
|
||||||
if useHarmony {
|
if useHarmony {
|
||||||
content, thinking, toolContent := harmonyMessageHandler.AddContent(cr.Content, harmonyToolParser)
|
for i, tool := range res.ToolCalls {
|
||||||
res.Response = content
|
res.ToolCalls[i].Function.Name = functionNameMap.OriginalFromConverted(tool.Function.Name)
|
||||||
res.Thinking = thinking
|
}
|
||||||
harmonyToolParser.Add(toolContent)
|
if res.Response != "" || res.Thinking != "" || len(res.ToolCalls) > 0 || res.Done {
|
||||||
} else if thinkingState != nil {
|
ch <- res
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if thinkingState != nil {
|
||||||
thinking, content := thinkingState.AddContent(cr.Content)
|
thinking, content := thinkingState.AddContent(cr.Content)
|
||||||
res.Thinking = thinking
|
res.Thinking = thinking
|
||||||
res.Response = content
|
res.Response = content
|
||||||
@@ -391,30 +390,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if cr.Done {
|
if cr.Done {
|
||||||
if useHarmony {
|
|
||||||
toolName, toolContent := harmonyToolParser.Drain()
|
|
||||||
if toolName != nil {
|
|
||||||
*toolName = strings.TrimPrefix(*toolName, "functions.")
|
|
||||||
var args api.ToolCallFunctionArguments
|
|
||||||
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
|
|
||||||
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
|
|
||||||
ch <- gin.H{"error": errStr}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
res.ToolCalls = append(res.ToolCalls, api.ToolCall{
|
|
||||||
Function: api.ToolCallFunction{
|
|
||||||
Name: *toolName,
|
|
||||||
Arguments: args,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
res.DoneReason = cr.DoneReason.String()
|
|
||||||
res.TotalDuration = time.Since(checkpointStart)
|
|
||||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
|
||||||
|
|
||||||
if !req.Raw {
|
if !req.Raw {
|
||||||
tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
|
tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1616,27 +1591,36 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
msgs = filterThinkTags(msgs, m)
|
msgs = filterThinkTags(msgs, m)
|
||||||
|
|
||||||
var harmonyMessageHandler *harmony.HarmonyMessageHandler
|
useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template)
|
||||||
var harmonyToolParser *harmony.HarmonyToolCallAccumulator
|
|
||||||
|
|
||||||
useHarmony := shouldUseHarmony(m)
|
|
||||||
|
|
||||||
processedTools := req.Tools
|
processedTools := req.Tools
|
||||||
|
var functionNameMap *harmony.FunctionNameMap
|
||||||
|
var prefillContentOrThinking *bool
|
||||||
if useHarmony {
|
if useHarmony {
|
||||||
harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
|
functionNameMap = harmony.NewFunctionNameMap()
|
||||||
var lastMessage *api.Message
|
var lastMessage *api.Message
|
||||||
if len(msgs) > 0 {
|
if len(msgs) > 0 {
|
||||||
lastMessage = &msgs[len(msgs)-1]
|
lastMessage = &msgs[len(msgs)-1]
|
||||||
}
|
}
|
||||||
harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(lastMessage)
|
|
||||||
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
|
|
||||||
|
|
||||||
|
// prefill content or thinking flag if the last message is an assistant message
|
||||||
|
if lastMessage != nil && lastMessage.Role == "assistant" {
|
||||||
|
if lastMessage.Content != "" {
|
||||||
|
trueVal := true
|
||||||
|
// true sets content to be prefilled
|
||||||
|
prefillContentOrThinking = &trueVal
|
||||||
|
} else if lastMessage.Thinking != "" {
|
||||||
|
// false sets thinking to be prefilled
|
||||||
|
falseVal := false
|
||||||
|
prefillContentOrThinking = &falseVal
|
||||||
|
}
|
||||||
|
}
|
||||||
// make a copy of tools to pass to the chat prompt. Function names may be
|
// make a copy of tools to pass to the chat prompt. Function names may be
|
||||||
// renamed to be valid Harmony function names.
|
// renamed to be valid Harmony function names.
|
||||||
processedTools = make([]api.Tool, len(req.Tools))
|
processedTools = make([]api.Tool, len(req.Tools))
|
||||||
copy(processedTools, req.Tools)
|
copy(processedTools, req.Tools)
|
||||||
for i, tool := range processedTools {
|
for i, tool := range processedTools {
|
||||||
processedTools[i].Function.Name = harmonyMessageHandler.FunctionNameMap.ConvertAndAdd(tool.Function.Name)
|
processedTools[i].Function.Name = functionNameMap.ConvertAndAdd(tool.Function.Name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1689,15 +1673,17 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
defer close(ch)
|
defer close(ch)
|
||||||
|
|
||||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Images: images,
|
Images: images,
|
||||||
Format: req.Format,
|
Format: req.Format,
|
||||||
Options: opts,
|
Options: opts,
|
||||||
|
UseHarmony: useHarmony,
|
||||||
|
PrefillContent: prefillContentOrThinking,
|
||||||
}, func(r llm.CompletionResponse) {
|
}, func(r llm.CompletionResponse) {
|
||||||
res := api.ChatResponse{
|
res := api.ChatResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
CreatedAt: time.Now().UTC(),
|
CreatedAt: time.Now().UTC(),
|
||||||
Message: api.Message{Role: "assistant", Content: r.Content},
|
Message: api.Message{Role: "assistant", Content: r.Content, Thinking: r.Thinking, ToolCalls: r.ToolCalls},
|
||||||
Done: r.Done,
|
Done: r.Done,
|
||||||
Metrics: api.Metrics{
|
Metrics: api.Metrics{
|
||||||
PromptEvalCount: r.PromptEvalCount,
|
PromptEvalCount: r.PromptEvalCount,
|
||||||
@@ -1713,31 +1699,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if useHarmony {
|
if useHarmony {
|
||||||
content, thinking, toolContent := harmonyMessageHandler.AddContent(r.Content, harmonyToolParser)
|
|
||||||
res.Message.Content = content
|
|
||||||
res.Message.Thinking = thinking
|
|
||||||
harmonyToolParser.Add(toolContent)
|
|
||||||
|
|
||||||
if r.Done {
|
|
||||||
toolName, toolContent := harmonyToolParser.Drain()
|
|
||||||
if toolName != nil {
|
|
||||||
*toolName = strings.TrimPrefix(*toolName, "functions.")
|
|
||||||
*toolName = harmonyMessageHandler.FunctionNameMap.OriginalFromConverted(*toolName)
|
|
||||||
var args api.ToolCallFunctionArguments
|
|
||||||
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
|
|
||||||
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
|
|
||||||
ch <- gin.H{"error": errStr}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
res.Message.ToolCalls = []api.ToolCall{{Function: api.ToolCallFunction{Name: *toolName, Arguments: args}}}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// only send messages with meaningful content (empty messages confuse clients)
|
// only send messages with meaningful content (empty messages confuse clients)
|
||||||
|
for i, tool := range res.Message.ToolCalls {
|
||||||
|
res.Message.ToolCalls[i].Function.Name = functionNameMap.OriginalFromConverted(tool.Function.Name)
|
||||||
|
}
|
||||||
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || res.Done {
|
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || res.Done {
|
||||||
ch <- res
|
ch <- res
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -118,7 +117,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
|||||||
name: "content streams as it arrives",
|
name: "content streams as it arrives",
|
||||||
steps: []step{
|
steps: []step{
|
||||||
{
|
{
|
||||||
input: llm.CompletionResponse{Content: "<|message|>Hello", Done: false},
|
input: llm.CompletionResponse{Content: "Hello", Done: false},
|
||||||
wantContent: "Hello",
|
wantContent: "Hello",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -126,7 +125,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
|||||||
wantContent: ", world",
|
wantContent: ", world",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
input: llm.CompletionResponse{Content: "!<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
input: llm.CompletionResponse{Content: "!", Done: true, DoneReason: llm.DoneReasonStop},
|
||||||
wantContent: "!",
|
wantContent: "!",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -135,20 +134,15 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
|||||||
name: "thinking streams separately from content",
|
name: "thinking streams separately from content",
|
||||||
steps: []step{
|
steps: []step{
|
||||||
{
|
{
|
||||||
input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Thinking...", Done: false},
|
input: llm.CompletionResponse{Thinking: "Thinking...", Done: false},
|
||||||
wantThinking: "Thinking...",
|
wantThinking: "Thinking...",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
input: llm.CompletionResponse{Content: "<|end|>", Done: false},
|
input: llm.CompletionResponse{Content: "Answer", Done: false},
|
||||||
// No output expected - just closes the analysis message and resets state to normal
|
wantContent: "Answer",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Answer", Done: false},
|
input: llm.CompletionResponse{Done: true, DoneReason: llm.DoneReasonStop},
|
||||||
wantContent: "Answer", // After message end, state is reset to normal
|
|
||||||
},
|
|
||||||
{
|
|
||||||
input: llm.CompletionResponse{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
|
||||||
// No output expected - just closes the assistant message
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -156,24 +150,16 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
|||||||
name: "partial tags buffer until complete",
|
name: "partial tags buffer until complete",
|
||||||
steps: []step{
|
steps: []step{
|
||||||
{
|
{
|
||||||
input: llm.CompletionResponse{Content: "<|chan", Done: false},
|
input: llm.CompletionResponse{Thinking: "Deep ", Done: false},
|
||||||
// No output - partial tag
|
|
||||||
},
|
|
||||||
{
|
|
||||||
input: llm.CompletionResponse{Content: "nel|>analysis<|mess", Done: false},
|
|
||||||
// No output - still building tags
|
|
||||||
},
|
|
||||||
{
|
|
||||||
input: llm.CompletionResponse{Content: "age|>Deep ", Done: false},
|
|
||||||
wantThinking: "Deep ",
|
wantThinking: "Deep ",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
input: llm.CompletionResponse{Content: "thought<|end|>", Done: false},
|
input: llm.CompletionResponse{Thinking: "thought", Done: false},
|
||||||
wantThinking: "thought",
|
wantThinking: "thought",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Done<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
input: llm.CompletionResponse{Content: "Done", Done: true, DoneReason: llm.DoneReasonStop},
|
||||||
wantContent: "Done", // After message end, state is reset to normal
|
wantContent: "Done",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -181,7 +167,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
|||||||
name: "simple assistant after analysis",
|
name: "simple assistant after analysis",
|
||||||
steps: []step{
|
steps: []step{
|
||||||
{
|
{
|
||||||
input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Think<|end|><|start|>assistant<|message|>Answer<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
input: llm.CompletionResponse{Thinking: "Think", Content: "Answer", Done: true, DoneReason: llm.DoneReasonStop},
|
||||||
wantContent: "Answer",
|
wantContent: "Answer",
|
||||||
wantThinking: "Think",
|
wantThinking: "Think",
|
||||||
},
|
},
|
||||||
@@ -191,7 +177,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
|||||||
name: "tool call parsed and returned correctly",
|
name: "tool call parsed and returned correctly",
|
||||||
steps: []step{
|
steps: []step{
|
||||||
{
|
{
|
||||||
input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.get_weather<|message|>{\"location\":\"San Francisco\"}<|end|><|start|>assistant<|message|>The weather is sunny<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
input: llm.CompletionResponse{Content: "The weather is sunny", ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"location": "San Francisco"}}}}, Done: true, DoneReason: llm.DoneReasonStop},
|
||||||
wantContent: "The weather is sunny",
|
wantContent: "The weather is sunny",
|
||||||
wantToolCalls: []api.ToolCall{
|
wantToolCalls: []api.ToolCall{
|
||||||
{
|
{
|
||||||
@@ -210,15 +196,10 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
|||||||
name: "tool call with streaming JSON across chunks",
|
name: "tool call with streaming JSON across chunks",
|
||||||
steps: []step{
|
steps: []step{
|
||||||
{
|
{
|
||||||
input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.calculate<|message|>{\"expr", Done: false},
|
input: llm.CompletionResponse{Done: false},
|
||||||
// No output yet - incomplete JSON
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
input: llm.CompletionResponse{Content: "ession\":\"2+", Done: false},
|
input: llm.CompletionResponse{ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "calculate", Arguments: api.ToolCallFunctionArguments{"expression": "2+2"}}}}, Done: true},
|
||||||
// Still no output - incomplete JSON
|
|
||||||
},
|
|
||||||
{
|
|
||||||
input: llm.CompletionResponse{Content: "2\"}", Done: true},
|
|
||||||
wantToolCalls: []api.ToolCall{
|
wantToolCalls: []api.ToolCall{
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
@@ -400,9 +381,9 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
|
|||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
mockResponses := []llm.CompletionResponse{
|
mockResponses := []llm.CompletionResponse{
|
||||||
{Content: "<|message|>First ", Done: false},
|
{Content: "First ", Done: false},
|
||||||
{Content: "chunk ", Done: false},
|
{Content: "chunk ", Done: false},
|
||||||
{Content: "here<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
{Content: "here", Done: true, DoneReason: llm.DoneReasonStop},
|
||||||
}
|
}
|
||||||
|
|
||||||
mock := mockRunner{
|
mock := mockRunner{
|
||||||
@@ -507,189 +488,3 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
|
|||||||
t.Errorf("expected at least 2 content chunks for streaming, got %d", contentChunks)
|
t.Errorf("expected at least 2 content chunks for streaming, got %d", contentChunks)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestChatHarmonyParserStreaming(t *testing.T) {
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
|
|
||||||
type expectedChunk struct {
|
|
||||||
afterResponse int // Which mock response this chunk should appear after
|
|
||||||
content string // Expected content in this chunk
|
|
||||||
thinking string // Expected thinking in this chunk
|
|
||||||
}
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
mockResponses []llm.CompletionResponse
|
|
||||||
expectedChunks []expectedChunk
|
|
||||||
wantContent string
|
|
||||||
wantThinking string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "simple message without thinking",
|
|
||||||
mockResponses: []llm.CompletionResponse{
|
|
||||||
{Content: "<|start|>assistant<|message|>Hello, ", Done: false},
|
|
||||||
{Content: "how can I help?", Done: false},
|
|
||||||
{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
|
||||||
},
|
|
||||||
expectedChunks: []expectedChunk{
|
|
||||||
{afterResponse: 1, content: "Hello, "},
|
|
||||||
{afterResponse: 2, content: "how can I help?"},
|
|
||||||
},
|
|
||||||
wantContent: "Hello, how can I help?",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "message with analysis channel for thinking",
|
|
||||||
mockResponses: []llm.CompletionResponse{
|
|
||||||
{Content: "<|channel|>analysis<|message|>", Done: false},
|
|
||||||
{Content: "Let me think ", Done: false},
|
|
||||||
{Content: "about this problem...", Done: false},
|
|
||||||
{Content: "<|end|>", Done: false},
|
|
||||||
{Content: "<|start|>assistant<|message|>", Done: false},
|
|
||||||
{Content: "The answer ", Done: false},
|
|
||||||
{Content: "is 42", Done: false},
|
|
||||||
{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
|
||||||
},
|
|
||||||
expectedChunks: []expectedChunk{
|
|
||||||
{afterResponse: 2, thinking: "Let me think "},
|
|
||||||
{afterResponse: 3, thinking: "about this problem..."},
|
|
||||||
{afterResponse: 6, content: "The answer "},
|
|
||||||
{afterResponse: 7, content: "is 42"},
|
|
||||||
},
|
|
||||||
wantContent: "The answer is 42",
|
|
||||||
wantThinking: "Let me think about this problem...",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "streaming with partial tags across boundaries",
|
|
||||||
mockResponses: []llm.CompletionResponse{
|
|
||||||
{Content: "<|chan", Done: false},
|
|
||||||
{Content: "nel|>analy", Done: false},
|
|
||||||
{Content: "sis<|mess", Done: false},
|
|
||||||
{Content: "age|>Think", Done: false},
|
|
||||||
{Content: "ing deeply...<|end|>", Done: false},
|
|
||||||
{Content: "<|start|>assi", Done: false},
|
|
||||||
{Content: "stant<|message|>Result ", Done: false},
|
|
||||||
{Content: "computed<|e", Done: false},
|
|
||||||
{Content: "nd|>", Done: true, DoneReason: llm.DoneReasonStop},
|
|
||||||
},
|
|
||||||
expectedChunks: []expectedChunk{
|
|
||||||
{afterResponse: 4, thinking: "Think"},
|
|
||||||
{afterResponse: 5, thinking: "ing deeply..."},
|
|
||||||
{afterResponse: 7, content: "Result "},
|
|
||||||
{afterResponse: 8, content: "computed"},
|
|
||||||
},
|
|
||||||
wantContent: "Result computed",
|
|
||||||
wantThinking: "Thinking deeply...",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
// Channel to synchronize mock responses with chunk verification
|
|
||||||
responsesSent := make(chan int, len(tc.mockResponses))
|
|
||||||
|
|
||||||
mock := mockRunner{
|
|
||||||
CompletionFn: func(ctx context.Context, r llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
|
||||||
// Send mock responses one at a time, notifying when each is sent
|
|
||||||
for i, resp := range tc.mockResponses {
|
|
||||||
fn(resp)
|
|
||||||
responsesSent <- i + 1
|
|
||||||
}
|
|
||||||
close(responsesSent)
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
s := Server{
|
|
||||||
sched: &Scheduler{
|
|
||||||
pendingReqCh: make(chan *LlmRequest, 1),
|
|
||||||
finishedReqCh: make(chan *LlmRequest, 1),
|
|
||||||
expiredCh: make(chan *runnerRef, 1),
|
|
||||||
unloadedCh: make(chan any, 1),
|
|
||||||
loaded: make(map[string]*runnerRef),
|
|
||||||
newServerFn: newMockServer(&mock),
|
|
||||||
getGpuFn: discover.GetGPUInfo,
|
|
||||||
getCpuFn: discover.GetCPUInfo,
|
|
||||||
reschedDelay: 250 * time.Millisecond,
|
|
||||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
|
||||||
req.successCh <- &runnerRef{
|
|
||||||
llama: &mock,
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
go s.sched.Run(t.Context())
|
|
||||||
|
|
||||||
// Create a minimal model
|
|
||||||
_, digest := createHarmonyTestModel(t)
|
|
||||||
|
|
||||||
// Create model with passthrough template
|
|
||||||
stream := false
|
|
||||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
|
||||||
Model: "harmony-test",
|
|
||||||
Files: map[string]string{"file.gguf": digest},
|
|
||||||
Template: `<|start|><|end|>{{ with .Tools }}{{ end }}{{ .Prompt }}`,
|
|
||||||
Stream: &stream,
|
|
||||||
})
|
|
||||||
|
|
||||||
if w.Code != http.StatusOK {
|
|
||||||
t.Fatalf("failed to create model: %d", w.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test chat endpoint with streaming
|
|
||||||
streamTrue := true
|
|
||||||
w = createRequest(t, s.ChatHandler, api.ChatRequest{
|
|
||||||
Model: "harmony-test",
|
|
||||||
Messages: []api.Message{{Role: "user", Content: "Hello"}},
|
|
||||||
Stream: &streamTrue,
|
|
||||||
Tools: getTestTools(),
|
|
||||||
})
|
|
||||||
|
|
||||||
if w.Code != http.StatusOK {
|
|
||||||
t.Fatalf("chat request failed: %d - %s", w.Code, w.Body.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse streaming response
|
|
||||||
var chunks []api.ChatResponse
|
|
||||||
var content, thinking strings.Builder
|
|
||||||
|
|
||||||
decoder := json.NewDecoder(w.Body)
|
|
||||||
for decoder.More() {
|
|
||||||
var chunk api.ChatResponse
|
|
||||||
if err := decoder.Decode(&chunk); err != nil {
|
|
||||||
t.Fatalf("failed to decode chunk: %v", err)
|
|
||||||
}
|
|
||||||
chunks = append(chunks, chunk)
|
|
||||||
|
|
||||||
// Accumulate content and thinking from each chunk
|
|
||||||
content.WriteString(chunk.Message.Content)
|
|
||||||
thinking.WriteString(chunk.Message.Thinking)
|
|
||||||
|
|
||||||
// Debug output
|
|
||||||
t.Logf("Chunk %d: content=%q thinking=%q done=%v", len(chunks), chunk.Message.Content, chunk.Message.Thinking, chunk.Done)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify we got streaming chunks
|
|
||||||
if len(chunks) == 0 {
|
|
||||||
t.Fatal("expected streaming chunks, got none")
|
|
||||||
}
|
|
||||||
|
|
||||||
gotContent := content.String()
|
|
||||||
gotThinking := thinking.String()
|
|
||||||
|
|
||||||
if gotContent != tc.wantContent {
|
|
||||||
t.Errorf("content mismatch: got %q, want %q", gotContent, tc.wantContent)
|
|
||||||
}
|
|
||||||
if gotThinking != tc.wantThinking {
|
|
||||||
t.Errorf("thinking mismatch: got %q, want %q", gotThinking, tc.wantThinking)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify last chunk has done=true
|
|
||||||
lastChunk := chunks[len(chunks)-1]
|
|
||||||
if !lastChunk.Done {
|
|
||||||
t.Error("expected last chunk to have done=true")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
Reference in New Issue
Block a user