mirror of
https://github.com/ollama/ollama.git
synced 2026-01-02 04:29:51 -05:00
Compare commits
8 Commits
implement-
...
parth/move
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f30d01801d | ||
|
|
b08c7dad0a | ||
|
|
bc5ab5784b | ||
|
|
92a99e67c7 | ||
|
|
05cebf1f21 | ||
|
|
51a400ff0f | ||
|
|
a865b50d9a | ||
|
|
31f64183dc |
@@ -1,30 +1,26 @@
|
||||
package server
|
||||
package harmony
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"slices"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/template"
|
||||
)
|
||||
|
||||
type harmonyParserState int
|
||||
|
||||
const (
|
||||
harmonyParserState_LookingForMessageStart harmonyParserState = iota
|
||||
harmonyParserState_ParsingHeader
|
||||
harmonyParserState_ParsingContent
|
||||
)
|
||||
|
||||
func shouldUseHarmony(model Model) bool {
|
||||
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
|
||||
func ShouldUseHarmony(modelFamily string, template *template.Template) bool {
|
||||
if slices.Contains([]string{"gptoss", "gpt-oss"}, modelFamily) {
|
||||
// heuristic to check whether the template expects to be parsed via harmony:
|
||||
// search for harmony tags that are nearly always used
|
||||
if model.Template.Contains("<|start|>") && model.Template.Contains("<|end|>") {
|
||||
if template.Contains("<|start|>") && template.Contains("<|end|>") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -32,6 +28,12 @@ func shouldUseHarmony(model Model) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
const (
|
||||
harmonyParserState_LookingForMessageStart harmonyParserState = iota
|
||||
harmonyParserState_ParsingHeader
|
||||
harmonyParserState_ParsingContent
|
||||
)
|
||||
|
||||
func (s harmonyParserState) String() string {
|
||||
switch s {
|
||||
// we're looking for the message start tag
|
||||
@@ -89,17 +91,18 @@ func (s *HarmonyParser) AddImplicitStart() {
|
||||
s.acc.WriteString("<|start|>assistant")
|
||||
}
|
||||
|
||||
func (s *HarmonyParser) AddImplicitStartOrPrefill(lastMessage *api.Message) {
|
||||
if lastMessage != nil && lastMessage.Role == "assistant" {
|
||||
// handle prefilling conditions
|
||||
if lastMessage.Content != "" {
|
||||
// AddImplicitStartOrPrefill adds content or thinking to the accumulator else adds start tag
|
||||
func (s *HarmonyParser) AddImplicitStartOrPrefill(prefillContentOrThinking *bool) {
|
||||
if prefillContentOrThinking != nil {
|
||||
if *prefillContentOrThinking {
|
||||
s.acc.WriteString("<|start|>assistant<|channel|>final<|message|>")
|
||||
return
|
||||
} else if lastMessage.Thinking != "" {
|
||||
} else {
|
||||
s.acc.WriteString("<|start|>assistant<|channel|>analysis<|message|>")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
s.AddImplicitStart()
|
||||
}
|
||||
|
||||
@@ -277,20 +280,20 @@ const (
|
||||
// This is a higher level interface that maps harmony concepts into ollama concepts
|
||||
type HarmonyMessageHandler struct {
|
||||
state harmonyMessageState
|
||||
harmonyParser *HarmonyParser
|
||||
functionNameMap *FunctionNameMap
|
||||
HarmonyParser *HarmonyParser
|
||||
FunctionNameMap *FunctionNameMap
|
||||
}
|
||||
|
||||
// NewHarmonyMessageHandler creates a new message handler
|
||||
func NewHarmonyMessageHandler() *HarmonyMessageHandler {
|
||||
return &HarmonyMessageHandler{
|
||||
state: harmonyMessageState_Normal,
|
||||
harmonyParser: &HarmonyParser{
|
||||
HarmonyParser: &HarmonyParser{
|
||||
MessageStartTag: "<|start|>",
|
||||
MessageEndTag: "<|end|>",
|
||||
HeaderEndTag: "<|message|>",
|
||||
},
|
||||
functionNameMap: NewFunctionNameMap(),
|
||||
FunctionNameMap: NewFunctionNameMap(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -301,7 +304,7 @@ func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyTo
|
||||
thinkingSb := strings.Builder{}
|
||||
toolContentSb := strings.Builder{}
|
||||
|
||||
events := h.harmonyParser.AddContent(content)
|
||||
events := h.HarmonyParser.AddContent(content)
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case HarmonyEventHeaderComplete:
|
||||
@@ -391,6 +394,38 @@ type FunctionNameMap struct {
|
||||
harmonyToUser map[string]string
|
||||
}
|
||||
|
||||
func (m FunctionNameMap) MarshalJSON() ([]byte, error) {
|
||||
// necessary to avoid exposing map internals
|
||||
type alias struct {
|
||||
UserToHarmony map[string]string `json:"userToHarmony"`
|
||||
HarmonyToUser map[string]string `json:"harmonyToUser"`
|
||||
}
|
||||
return json.Marshal(alias{
|
||||
UserToHarmony: m.userToHarmony,
|
||||
HarmonyToUser: m.harmonyToUser,
|
||||
})
|
||||
}
|
||||
|
||||
func (m *FunctionNameMap) UnmarshalJSON(b []byte) error {
|
||||
type alias struct {
|
||||
UserToHarmony map[string]string `json:"userToHarmony"`
|
||||
HarmonyToUser map[string]string `json:"harmonyToUser"`
|
||||
}
|
||||
var a alias
|
||||
if err := json.Unmarshal(b, &a); err != nil {
|
||||
return err
|
||||
}
|
||||
if m.userToHarmony == nil {
|
||||
m.userToHarmony = make(map[string]string)
|
||||
}
|
||||
if m.harmonyToUser == nil {
|
||||
m.harmonyToUser = make(map[string]string)
|
||||
}
|
||||
maps.Copy(m.userToHarmony, a.UserToHarmony)
|
||||
maps.Copy(m.harmonyToUser, a.HarmonyToUser)
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewFunctionNameMap() *FunctionNameMap {
|
||||
return &FunctionNameMap{
|
||||
userToHarmony: make(map[string]string),
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package harmony
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -2,10 +2,13 @@
|
||||
|
||||
This directory contains integration tests to exercise Ollama end-to-end to verify behavior
|
||||
|
||||
By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...`
|
||||
By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...` Some tests require additional tags to enable to allow scoped testing to keep the duration reasonable. For example, testing a broad set of models requires `-tags=integration,models` and a longer timeout (~60m or more depending on the speed of your GPU.). To view the current set of tag combinations use `find integration -type f | xargs grep "go:build"`
|
||||
|
||||
|
||||
The integration tests have 2 modes of operating.
|
||||
|
||||
1. By default, they will start the server on a random port, run the tests, and then shutdown the server.
|
||||
2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote
|
||||
2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote based on your `OLLAMA_HOST` environment variable
|
||||
|
||||
> [!IMPORTANT]
|
||||
> Before running the tests locally without the "test existing" setting, compile ollama from the top of the source tree `go build .` in addition to GPU support with cmake if applicable on your platform. The integration tests expect to find an ollama binary at the top of the tree.
|
||||
|
||||
@@ -66,7 +66,7 @@ func TestContextExhaustion(t *testing.T) {
|
||||
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived"}, 120*time.Second, 10*time.Second)
|
||||
}
|
||||
|
||||
// Send multiple requests with prior context and ensure the response is coherant and expected
|
||||
// Send multiple generate requests with prior context and ensure the response is coherant and expected
|
||||
func TestGenerateWithHistory(t *testing.T) {
|
||||
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
|
||||
req, resp := GenerateRequests()
|
||||
@@ -111,5 +111,56 @@ func TestGenerateWithHistory(t *testing.T) {
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
}
|
||||
|
||||
// Send multiple chat requests with prior context and ensure the response is coherant and expected
|
||||
func TestChatWithHistory(t *testing.T) {
|
||||
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
|
||||
req, resp := ChatRequests()
|
||||
numParallel := 2
|
||||
iterLimit := 2
|
||||
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
// Get the server running (if applicable) warm the model up with a single initial empty request
|
||||
slog.Info("loading", "model", modelOverride)
|
||||
err := client.Generate(ctx,
|
||||
&api.GenerateRequest{Model: modelOverride, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
|
||||
func(response api.GenerateResponse) error { return nil },
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load model %s: %s", modelOverride, err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numParallel)
|
||||
for i := range numParallel {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
k := i % len(req)
|
||||
req[k].Model = modelOverride
|
||||
for j := 0; j < iterLimit; j++ {
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
slog.Info("exceeded soft timeout, winding down test")
|
||||
return
|
||||
}
|
||||
slog.Info("Starting", "thread", i, "iter", j)
|
||||
// On slower GPUs it can take a while to process the concurrent requests
|
||||
// so we allow a much longer initial timeout
|
||||
assistant := DoChat(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second)
|
||||
if assistant == nil {
|
||||
t.Fatalf("didn't get an assistant response for context")
|
||||
}
|
||||
req[k].Messages = append(req[k].Messages,
|
||||
*assistant,
|
||||
api.Message{Role: "user", Content: "tell me more!"},
|
||||
)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
@@ -19,6 +19,8 @@ import (
|
||||
)
|
||||
|
||||
func TestMaxQueue(t *testing.T) {
|
||||
t.Skip("this test needs to be re-evaluated to use a proper embedding model")
|
||||
|
||||
if os.Getenv("OLLAMA_TEST_EXISTING") != "" {
|
||||
t.Skip("Max Queue test requires spawning a local server so we can adjust the queue size")
|
||||
return
|
||||
|
||||
@@ -567,6 +567,76 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
||||
}
|
||||
}
|
||||
|
||||
func ChatRequests() ([]api.ChatRequest, [][]string) {
|
||||
genReqs, results := GenerateRequests()
|
||||
reqs := make([]api.ChatRequest, len(genReqs))
|
||||
for i := range reqs {
|
||||
reqs[i].Model = genReqs[i].Model
|
||||
reqs[i].Stream = genReqs[i].Stream
|
||||
reqs[i].KeepAlive = genReqs[i].KeepAlive
|
||||
reqs[i].Messages = []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: genReqs[i].Prompt,
|
||||
},
|
||||
}
|
||||
}
|
||||
return reqs, results
|
||||
}
|
||||
|
||||
func DoChat(ctx context.Context, t *testing.T, client *api.Client, req api.ChatRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) *api.Message {
|
||||
stallTimer := time.NewTimer(initialTimeout)
|
||||
var buf bytes.Buffer
|
||||
role := "assistant"
|
||||
fn := func(response api.ChatResponse) error {
|
||||
// fmt.Print(".")
|
||||
role = response.Message.Role
|
||||
buf.Write([]byte(response.Message.Content))
|
||||
if !stallTimer.Reset(streamTimeout) {
|
||||
return errors.New("stall was detected while streaming response, aborting")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
stream := true
|
||||
req.Stream = &stream
|
||||
done := make(chan int)
|
||||
var genErr error
|
||||
go func() {
|
||||
genErr = client.Chat(ctx, &req, fn)
|
||||
done <- 0
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-stallTimer.C:
|
||||
if buf.Len() == 0 {
|
||||
t.Errorf("generate never started. Timed out after :%s", initialTimeout.String())
|
||||
} else {
|
||||
t.Errorf("generate stalled. Response so far:%s", buf.String())
|
||||
}
|
||||
case <-done:
|
||||
if genErr != nil && strings.Contains(genErr.Error(), "model requires more system memory") {
|
||||
slog.Warn("model is too large for the target test system", "model", req.Model, "error", genErr)
|
||||
return nil
|
||||
}
|
||||
require.NoError(t, genErr, "failed with %s request Messages %s ", req.Model, req.Messages)
|
||||
// Verify the response contains the expected data
|
||||
response := buf.String()
|
||||
atLeastOne := false
|
||||
for _, resp := range anyResp {
|
||||
if strings.Contains(strings.ToLower(response), resp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, atLeastOne, "%s: none of %v found in \"%s\" -- request was:%v", req.Model, anyResp, response, req.Messages)
|
||||
slog.Info("test pass", "model", req.Model, "messages", req.Messages, "contains", anyResp, "response", response)
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for generate")
|
||||
}
|
||||
return &api.Message{Role: role, Content: buf.String()}
|
||||
}
|
||||
|
||||
func skipUnderMinVRAM(t *testing.T, gb uint64) {
|
||||
// TODO use info API in the future
|
||||
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
|
||||
|
||||
@@ -31,6 +31,7 @@ import (
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/harmony"
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
@@ -1331,7 +1332,9 @@ type CompletionRequest struct {
|
||||
Images []ImageData
|
||||
Options *api.Options
|
||||
|
||||
Grammar string // set before sending the request to the subprocess
|
||||
Grammar string // set before sending the request to the subprocess
|
||||
FunctionNameMap *harmony.FunctionNameMap
|
||||
PrefillContent *bool
|
||||
}
|
||||
|
||||
// DoneReason represents the reason why a completion response is done
|
||||
@@ -1358,13 +1361,15 @@ func (d DoneReason) String() string {
|
||||
}
|
||||
|
||||
type CompletionResponse struct {
|
||||
Content string `json:"content"`
|
||||
DoneReason DoneReason `json:"done_reason"`
|
||||
Done bool `json:"done"`
|
||||
PromptEvalCount int `json:"prompt_eval_count"`
|
||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
|
||||
EvalCount int `json:"eval_count"`
|
||||
EvalDuration time.Duration `json:"eval_duration"`
|
||||
Content string `json:"content"`
|
||||
Thinking string `json:"thinking"`
|
||||
ToolCalls []api.ToolCall `json:"tool_calls"`
|
||||
DoneReason DoneReason `json:"done_reason"`
|
||||
Done bool `json:"done"`
|
||||
PromptEvalCount int `json:"prompt_eval_count"`
|
||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
|
||||
EvalCount int `json:"eval_count"`
|
||||
EvalDuration time.Duration `json:"eval_duration"`
|
||||
}
|
||||
|
||||
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
||||
@@ -1482,7 +1487,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
|
||||
}
|
||||
switch {
|
||||
case strings.TrimSpace(c.Content) == lastToken:
|
||||
case lastToken != "" && (strings.TrimSpace(c.Content) == lastToken || strings.TrimSpace(c.Thinking) == lastToken):
|
||||
tokenRepeat++
|
||||
default:
|
||||
lastToken = strings.TrimSpace(c.Content)
|
||||
@@ -1495,16 +1500,14 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
if c.Content != "" {
|
||||
fn(CompletionResponse{
|
||||
Content: c.Content,
|
||||
})
|
||||
}
|
||||
|
||||
if c.Done {
|
||||
fn(c)
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.Content != "" || c.Thinking != "" || len(c.ToolCalls) > 0 {
|
||||
fn(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -400,6 +400,8 @@ type Tensor interface {
|
||||
Bytes() []byte
|
||||
Floats() []float32
|
||||
|
||||
BackendSetFromIntSlice(s []int32)
|
||||
|
||||
Neg(ctx Context) Tensor
|
||||
Add(ctx Context, t2 Tensor) Tensor
|
||||
Sub(ctx Context, t2 Tensor) Tensor
|
||||
|
||||
@@ -82,6 +82,7 @@ type Backend struct {
|
||||
// to the name that is used by the model definition
|
||||
tensorLoadTargets map[string][]string
|
||||
|
||||
schedMu sync.Mutex // Only one Compute can run at a time
|
||||
sched C.ggml_backend_sched_t
|
||||
schedBackends []C.ggml_backend_t
|
||||
schedBufts []C.ggml_backend_buffer_type_t
|
||||
@@ -769,6 +770,8 @@ func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
|
||||
}
|
||||
|
||||
func (c *Context) Compute(tensors ...ml.Tensor) {
|
||||
c.b.schedMu.Lock()
|
||||
defer c.b.schedMu.Unlock()
|
||||
if status := C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph); status != C.GGML_STATUS_SUCCESS {
|
||||
panic(fmt.Errorf("error computing ggml graph: %v", status))
|
||||
}
|
||||
@@ -1037,6 +1040,12 @@ func (t *Tensor) Floats() (data []float32) {
|
||||
return
|
||||
}
|
||||
|
||||
func (t *Tensor) BackendSetFromIntSlice(s []int32) {
|
||||
if len(s) > 0 {
|
||||
C.ggml_backend_tensor_set(t.t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.t))
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) DType() ml.DType {
|
||||
switch t.t._type {
|
||||
case C.GGML_TYPE_F32:
|
||||
|
||||
@@ -64,7 +64,7 @@ type MultimodalProcessor interface {
|
||||
// This function is also responsible for updating MultimodalHash for any Multimodal
|
||||
// that is modified to ensure that there is a unique hash value that accurately
|
||||
// represents the contents.
|
||||
PostTokenize([]input.Input) ([]input.Input, error)
|
||||
PostTokenize([]*input.Input) ([]*input.Input, error)
|
||||
}
|
||||
|
||||
// Base implements the common fields and methods for all models
|
||||
@@ -278,13 +278,13 @@ func canNil(t reflect.Type) bool {
|
||||
t.Kind() == reflect.Slice
|
||||
}
|
||||
|
||||
func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Tensor, error) {
|
||||
func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Tensor, ml.Tensor, error) {
|
||||
if len(batch.Positions) != len(batch.Sequences) {
|
||||
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
|
||||
return nil, nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
|
||||
}
|
||||
|
||||
if len(batch.Positions) < 1 {
|
||||
return nil, errors.New("batch size cannot be less than 1")
|
||||
return nil, nil, errors.New("batch size cannot be less than 1")
|
||||
}
|
||||
|
||||
batch.Inputs = ctx.Input().FromIntSlice(inputs, len(inputs))
|
||||
@@ -293,16 +293,16 @@ func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Ten
|
||||
if cache != nil {
|
||||
err := cache.StartForward(ctx, batch, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
t, err := m.Forward(ctx, batch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
ctx.Forward(t).Compute(t)
|
||||
ctx.Forward(t)
|
||||
|
||||
return t, nil
|
||||
return batch.Inputs, t, nil
|
||||
}
|
||||
|
||||
@@ -112,8 +112,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||
return []input.Multimodal{{Tensor: visionOutputs}}, nil
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
var result []input.Input
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
var result []*input.Input
|
||||
|
||||
for _, inp := range inputs {
|
||||
if len(inp.Multimodal) == 0 {
|
||||
@@ -122,17 +122,17 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
inputMultimodal := inp.Multimodal[0].Tensor
|
||||
|
||||
result = append(result,
|
||||
input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
|
||||
input.Input{Token: 255999}, // "<start_of_image>""
|
||||
input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
|
||||
&input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
|
||||
&input.Input{Token: 255999}, // "<start_of_image>""
|
||||
&input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
|
||||
)
|
||||
|
||||
// add image token placeholders
|
||||
result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
|
||||
result = append(result, slices.Repeat([]*input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
|
||||
|
||||
result = append(result,
|
||||
input.Input{Token: 256000}, // <end_of_image>
|
||||
input.Input{Token: 108}, // "\n\n"
|
||||
&input.Input{Token: 256000}, // <end_of_image>
|
||||
&input.Input{Token: 108}, // "\n\n"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -134,16 +134,16 @@ type separator struct {
|
||||
y bool
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
var result []input.Input
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
var result []*input.Input
|
||||
for _, inp := range inputs {
|
||||
if len(inp.Multimodal) == 0 {
|
||||
result = append(result, inp)
|
||||
continue
|
||||
}
|
||||
|
||||
var imageInputs []input.Input
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_start|>
|
||||
var imageInputs []*input.Input
|
||||
imageInputs = append(imageInputs, &input.Input{Token: 200080}) // <|image_start|>
|
||||
|
||||
for i, mm := range inp.Multimodal {
|
||||
patchesPerChunk := mm.Tensor.Dim(1)
|
||||
@@ -151,20 +151,20 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
if i < len(inp.Multimodal)-1 {
|
||||
separator := mm.Data.(*separator)
|
||||
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
|
||||
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
|
||||
imageInputs = append(imageInputs, &input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
|
||||
imageInputs = append(imageInputs, slices.Repeat([]*input.Input{{Token: 200092}}, patchesPerChunk-1)...)
|
||||
|
||||
if separator.x {
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200084}) // <|tile_x_separator|>
|
||||
imageInputs = append(imageInputs, &input.Input{Token: 200084}) // <|tile_x_separator|>
|
||||
}
|
||||
if separator.y {
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200085}) // <|tile_y_separator|>
|
||||
imageInputs = append(imageInputs, &input.Input{Token: 200085}) // <|tile_y_separator|>
|
||||
}
|
||||
} else {
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200090}) // <|image|>
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
|
||||
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_end|>
|
||||
imageInputs = append(imageInputs, &input.Input{Token: 200090}) // <|image|>
|
||||
imageInputs = append(imageInputs, &input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
|
||||
imageInputs = append(imageInputs, slices.Repeat([]*input.Input{{Token: 200092}}, patchesPerChunk-1)...)
|
||||
imageInputs = append(imageInputs, &input.Input{Token: 200080}) // <|image_end|>
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -133,22 +133,22 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||
// [IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_END]
|
||||
// Each sequence of [IMG]...[IMG] is a set of patches of vision embeddings
|
||||
// that can be processed together.
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
var result []input.Input
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
var result []*input.Input
|
||||
for _, inp := range inputs {
|
||||
if len(inp.Multimodal) == 0 {
|
||||
result = append(result, inp)
|
||||
} else {
|
||||
for i, row := range inp.Multimodal {
|
||||
// [IMG]
|
||||
result = append(result, input.Input{Token: 10, Multimodal: []input.Multimodal{{Tensor: row.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: row.Tensor.Dim(1)})
|
||||
result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.Tensor.Dim(1)-1)...)
|
||||
result = append(result, &input.Input{Token: 10, Multimodal: []input.Multimodal{{Tensor: row.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: row.Tensor.Dim(1)})
|
||||
result = append(result, slices.Repeat([]*input.Input{{Token: 10}}, row.Tensor.Dim(1)-1)...)
|
||||
if i == len(inp.Multimodal)-1 {
|
||||
// [IMG_END]
|
||||
result = append(result, input.Input{Token: 13})
|
||||
result = append(result, &input.Input{Token: 13})
|
||||
} else {
|
||||
// [IMG_BREAK]
|
||||
result = append(result, input.Input{Token: 12})
|
||||
result = append(result, &input.Input{Token: 12})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,7 +90,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||
return []input.Multimodal{{Tensor: projectedOutputs}}, nil
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
for i := range inputs {
|
||||
if inputs[i].Multimodal != nil {
|
||||
inputs[i].Token = 128256 // <|image|>
|
||||
|
||||
@@ -89,8 +89,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||
}
|
||||
|
||||
// PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
var result []input.Input
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
var result []*input.Input
|
||||
|
||||
var (
|
||||
imageToken int32 = 151655
|
||||
@@ -112,16 +112,16 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
return nil, fmt.Errorf("failed to encode image prompt: %w", err)
|
||||
}
|
||||
for i := range pre {
|
||||
result = append(result, input.Input{Token: pre[i]})
|
||||
result = append(result, &input.Input{Token: pre[i]})
|
||||
}
|
||||
|
||||
patchesPerChunk := inp.Multimodal[0].Tensor.Dim(1)
|
||||
|
||||
// First add the vision start token
|
||||
result = append(result, input.Input{Token: visionStartToken})
|
||||
result = append(result, &input.Input{Token: visionStartToken})
|
||||
|
||||
// Add the image token with the multimodal tensor data at the first position
|
||||
result = append(result, input.Input{
|
||||
result = append(result, &input.Input{
|
||||
Token: imageToken,
|
||||
Multimodal: inp.Multimodal,
|
||||
MultimodalHash: inp.MultimodalHash,
|
||||
@@ -129,9 +129,9 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
})
|
||||
|
||||
// Add the placeholder tokens for the remaining positions (tokensPerGrid-1)
|
||||
result = append(result, slices.Repeat([]input.Input{{Token: imageToken}}, patchesPerChunk-1)...)
|
||||
result = append(result, slices.Repeat([]*input.Input{{Token: imageToken}}, patchesPerChunk-1)...)
|
||||
|
||||
result = append(result, input.Input{Token: visionEndToken})
|
||||
result = append(result, &input.Input{Token: visionEndToken})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -86,7 +86,7 @@ type InputCacheSlot struct {
|
||||
Id int
|
||||
|
||||
// Inputs that are stored in the KV cache
|
||||
Inputs []input.Input
|
||||
Inputs []*input.Input
|
||||
|
||||
// is this cache actively being processed as part of a sequence?
|
||||
InUse bool
|
||||
@@ -95,7 +95,7 @@ type InputCacheSlot struct {
|
||||
lastUsed time.Time
|
||||
}
|
||||
|
||||
func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []input.Input, error) {
|
||||
func (c *InputCache) LoadCacheSlot(prompt []*input.Input) (*InputCacheSlot, []*input.Input, error) {
|
||||
var slot *InputCacheSlot
|
||||
var numPast int32
|
||||
var err error
|
||||
@@ -146,7 +146,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []inp
|
||||
return slot, prompt, nil
|
||||
}
|
||||
|
||||
func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
|
||||
func (c *InputCache) findLongestCacheSlot(prompt []*input.Input) (*InputCacheSlot, int32, error) {
|
||||
longest := int32(-1)
|
||||
var longestSlot *InputCacheSlot
|
||||
|
||||
@@ -169,7 +169,7 @@ func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot
|
||||
return longestSlot, longest, nil
|
||||
}
|
||||
|
||||
func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
|
||||
func (c *InputCache) findBestCacheSlot(prompt []*input.Input) (*InputCacheSlot, int32, error) {
|
||||
oldest := time.Now()
|
||||
var oldestSlot *InputCacheSlot
|
||||
|
||||
@@ -205,7 +205,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, i
|
||||
if longest > 0 && longestSlot != oldestSlot {
|
||||
slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
|
||||
len(longestSlot.Inputs))
|
||||
oldestSlot.Inputs = make([]input.Input, longest)
|
||||
oldestSlot.Inputs = make([]*input.Input, longest)
|
||||
copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
|
||||
if c.cache != nil {
|
||||
c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
|
||||
@@ -215,7 +215,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, i
|
||||
return oldestSlot, longest, nil
|
||||
}
|
||||
|
||||
func countCommonPrefix(a []input.Input, b []input.Input) int32 {
|
||||
func countCommonPrefix(a []*input.Input, b []*input.Input) int32 {
|
||||
var count int32
|
||||
|
||||
for i := range a {
|
||||
@@ -250,7 +250,7 @@ func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
|
||||
}
|
||||
|
||||
type ErrReprocessInputs struct {
|
||||
Inputs []input.Input
|
||||
Inputs []*input.Input
|
||||
}
|
||||
|
||||
func (e *ErrReprocessInputs) Error() string {
|
||||
@@ -283,13 +283,13 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error {
|
||||
"id", slot.Id, "error", err)
|
||||
|
||||
// Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
|
||||
newInputs := make([]input.Input, numKeep+inputLen-(numKeep+discard))
|
||||
newInputs := make([]*input.Input, numKeep+inputLen-(numKeep+discard))
|
||||
copy(newInputs[:numKeep], slot.Inputs[:numKeep])
|
||||
copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
|
||||
|
||||
// Reset the cache
|
||||
_ = c.cache.Remove(slot.Id, 0, math.MaxInt32)
|
||||
slot.Inputs = []input.Input{}
|
||||
slot.Inputs = []*input.Input{}
|
||||
|
||||
// Return error with inputs that need to be reprocessed
|
||||
return &ErrReprocessInputs{Inputs: newInputs}
|
||||
|
||||
@@ -13,50 +13,50 @@ import (
|
||||
func TestCountCommon(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
t1 []input.Input
|
||||
t2 []input.Input
|
||||
t1 []*input.Input
|
||||
t2 []*input.Input
|
||||
expected int32
|
||||
}{
|
||||
{
|
||||
name: "Equal",
|
||||
t1: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
t1: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
expected: 3,
|
||||
},
|
||||
{
|
||||
name: "Prefix",
|
||||
t1: []input.Input{{Token: 1}},
|
||||
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
t1: []*input.Input{{Token: 1}},
|
||||
t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "Image Prefix",
|
||||
t1: []input.Input{{MultimodalHash: 1}},
|
||||
t2: []input.Input{{MultimodalHash: 1}, {MultimodalHash: 2}, {MultimodalHash: 3}},
|
||||
t1: []*input.Input{{MultimodalHash: 1}},
|
||||
t2: []*input.Input{{MultimodalHash: 1}, {MultimodalHash: 2}, {MultimodalHash: 3}},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "Mixed",
|
||||
t1: []input.Input{{Token: 1}, {MultimodalHash: 1}},
|
||||
t2: []input.Input{{Token: 1}, {MultimodalHash: 1}, {Token: 5}},
|
||||
t1: []*input.Input{{Token: 1}, {MultimodalHash: 1}},
|
||||
t2: []*input.Input{{Token: 1}, {MultimodalHash: 1}, {Token: 5}},
|
||||
expected: 2,
|
||||
},
|
||||
{
|
||||
name: "Mixed, Same Length",
|
||||
t1: []input.Input{{Token: 1}, {MultimodalHash: 1}},
|
||||
t2: []input.Input{{Token: 1}, {MultimodalHash: 2}},
|
||||
t1: []*input.Input{{Token: 1}, {MultimodalHash: 1}},
|
||||
t2: []*input.Input{{Token: 1}, {MultimodalHash: 2}},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "Empty",
|
||||
t1: []input.Input{},
|
||||
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
t1: []*input.Input{},
|
||||
t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "Both Empty",
|
||||
t1: []input.Input{},
|
||||
t2: []input.Input{},
|
||||
t1: []*input.Input{},
|
||||
t2: []*input.Input{},
|
||||
expected: 0,
|
||||
},
|
||||
}
|
||||
@@ -80,7 +80,7 @@ func TestFindCacheSlot(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cache InputCache
|
||||
prompt []input.Input
|
||||
prompt []*input.Input
|
||||
longest expected
|
||||
best expected
|
||||
}{
|
||||
@@ -89,18 +89,18 @@ func TestFindCacheSlot(t *testing.T) {
|
||||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{},
|
||||
Inputs: []*input.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Time{},
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input.Input{},
|
||||
Inputs: []*input.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Time{},
|
||||
},
|
||||
}},
|
||||
prompt: []input.Input{{Token: 1}},
|
||||
prompt: []*input.Input{{Token: 1}},
|
||||
longest: expected{result: 0, len: 0},
|
||||
best: expected{result: 0, len: 0},
|
||||
},
|
||||
@@ -109,18 +109,18 @@ func TestFindCacheSlot(t *testing.T) {
|
||||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}},
|
||||
Inputs: []*input.Input{{Token: 1}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
}},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}},
|
||||
prompt: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
longest: expected{result: 1, len: 2},
|
||||
best: expected{result: 1, len: 2},
|
||||
},
|
||||
@@ -129,18 +129,18 @@ func TestFindCacheSlot(t *testing.T) {
|
||||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input.Input{},
|
||||
Inputs: []*input.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Time{},
|
||||
},
|
||||
}},
|
||||
prompt: []input.Input{{Token: 2}},
|
||||
prompt: []*input.Input{{Token: 2}},
|
||||
longest: expected{result: 0, len: 0},
|
||||
best: expected{result: 1, len: 0},
|
||||
},
|
||||
@@ -150,19 +150,19 @@ func TestFindCacheSlot(t *testing.T) {
|
||||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input.Input{},
|
||||
Inputs: []*input.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Time{},
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []input.Input{{Token: 1}},
|
||||
prompt: []*input.Input{{Token: 1}},
|
||||
longest: expected{result: 0, len: 1},
|
||||
best: expected{result: 1, len: 1},
|
||||
},
|
||||
@@ -171,18 +171,18 @@ func TestFindCacheSlot(t *testing.T) {
|
||||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}},
|
||||
Inputs: []*input.Input{{Token: 1}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
}},
|
||||
prompt: []input.Input{{Token: 2}, {Token: 3}},
|
||||
prompt: []*input.Input{{Token: 2}, {Token: 3}},
|
||||
longest: expected{result: 0, len: 0},
|
||||
best: expected{result: 1, len: 0},
|
||||
},
|
||||
@@ -191,18 +191,18 @@ func TestFindCacheSlot(t *testing.T) {
|
||||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: true,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input.Input{{Token: 1}},
|
||||
Inputs: []*input.Input{{Token: 1}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
}},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}},
|
||||
prompt: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
longest: expected{result: 1, len: 1},
|
||||
best: expected{result: 1, len: 2},
|
||||
},
|
||||
@@ -300,7 +300,7 @@ func TestLoadCacheSlot(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cache InputCache
|
||||
prompt []input.Input
|
||||
prompt []*input.Input
|
||||
wantErr bool
|
||||
expectedSlotId int
|
||||
expectedPrompt int // expected length of remaining prompt
|
||||
@@ -312,19 +312,19 @@ func TestLoadCacheSlot(t *testing.T) {
|
||||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input.Input{},
|
||||
Inputs: []*input.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
wantErr: false,
|
||||
expectedSlotId: 0,
|
||||
expectedPrompt: 1, // Only token 3 remains
|
||||
@@ -336,19 +336,19 @@ func TestLoadCacheSlot(t *testing.T) {
|
||||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input.Input{},
|
||||
Inputs: []*input.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
wantErr: false,
|
||||
expectedSlotId: 0,
|
||||
expectedPrompt: 1, // Only token 3 remains
|
||||
@@ -360,13 +360,13 @@ func TestLoadCacheSlot(t *testing.T) {
|
||||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}},
|
||||
prompt: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
wantErr: false,
|
||||
expectedSlotId: 0,
|
||||
expectedPrompt: 1, // Should leave 1 token for sampling
|
||||
@@ -378,13 +378,13 @@ func TestLoadCacheSlot(t *testing.T) {
|
||||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: true,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
wantErr: true,
|
||||
expectedSlotId: -1,
|
||||
expectedPrompt: -1,
|
||||
@@ -452,7 +452,7 @@ func TestShiftCacheSlot(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
numCtx int32
|
||||
inputs []input.Input
|
||||
inputs []*input.Input
|
||||
numKeep int32
|
||||
cacheErr bool
|
||||
wantErr any
|
||||
@@ -461,7 +461,7 @@ func TestShiftCacheSlot(t *testing.T) {
|
||||
{
|
||||
name: "Normal shift",
|
||||
numCtx: 10,
|
||||
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
||||
inputs: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
||||
numKeep: 2,
|
||||
cacheErr: false, // No error
|
||||
wantErr: nil,
|
||||
@@ -470,7 +470,7 @@ func TestShiftCacheSlot(t *testing.T) {
|
||||
{
|
||||
name: "Cache removal fails",
|
||||
numCtx: 10,
|
||||
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
||||
inputs: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
||||
numKeep: 2,
|
||||
cacheErr: true,
|
||||
wantErr: &ErrReprocessInputs{},
|
||||
@@ -487,7 +487,7 @@ func TestShiftCacheSlot(t *testing.T) {
|
||||
}
|
||||
slot := &InputCacheSlot{
|
||||
Id: 123,
|
||||
Inputs: make([]input.Input, len(tt.inputs)),
|
||||
Inputs: make([]*input.Input, len(tt.inputs)),
|
||||
}
|
||||
copy(slot.Inputs, tt.inputs)
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"reflect"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -28,6 +29,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/harmony"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
@@ -51,10 +53,10 @@ type Sequence struct {
|
||||
iBatch int
|
||||
|
||||
// prompt inputs left to evaluate
|
||||
inputs []input.Input
|
||||
inputs []*input.Input
|
||||
|
||||
// inputs that have been added to a batch but not yet submitted to Forward
|
||||
pendingInputs []input.Input
|
||||
pendingInputs []*input.Input
|
||||
|
||||
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
||||
pendingResponses []string
|
||||
@@ -86,6 +88,12 @@ type Sequence struct {
|
||||
// true if an embedding are to be returned instead of text generation
|
||||
embeddingOnly bool
|
||||
|
||||
// true if the sequence if finished and marked for removal on next pass
|
||||
finished bool
|
||||
|
||||
// True if we have to skip this sequence to shift the cache
|
||||
skipForShift bool
|
||||
|
||||
doneReason llm.DoneReason
|
||||
|
||||
// Metrics
|
||||
@@ -182,8 +190,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
// inputs processes the prompt and images into a list of inputs
|
||||
// by splitting the prompt on [img-<n>] tags, tokenizing text and
|
||||
// decoding images
|
||||
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, []ml.Context, multimodalStore, error) {
|
||||
var inputs []input.Input
|
||||
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input, []ml.Context, multimodalStore, error) {
|
||||
var inputs []*input.Input
|
||||
var ctxs []ml.Context
|
||||
var mmStore multimodalStore
|
||||
|
||||
@@ -210,7 +218,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
|
||||
}
|
||||
|
||||
for _, t := range tokens {
|
||||
inputs = append(inputs, input.Input{Token: t})
|
||||
inputs = append(inputs, &input.Input{Token: t})
|
||||
}
|
||||
|
||||
// image - decode and store
|
||||
@@ -243,7 +251,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
|
||||
|
||||
mmStore.addMultimodal(imageEmbeddings)
|
||||
|
||||
inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
|
||||
inputs = append(inputs, &input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
|
||||
postTokenize = true
|
||||
}
|
||||
}
|
||||
@@ -259,6 +267,27 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
|
||||
return inputs, ctxs, mmStore, nil
|
||||
}
|
||||
|
||||
type batchState struct {
|
||||
id int
|
||||
ctx ml.Context
|
||||
modelInput ml.Tensor
|
||||
modelOutput ml.Tensor
|
||||
batchInputs []*input.Input
|
||||
batch input.Batch
|
||||
seqs []*Sequence // full set of seqs at the time this batch was initiated
|
||||
initSeqIdx int // The initial value for the set of sequences evaluated (s.nextSeq - 1)
|
||||
|
||||
// Signaled when this batches inputs are ready and compute can proceed
|
||||
inputsReadyCh chan struct{}
|
||||
|
||||
// Signaling when Compute is about to begin on this batch, and
|
||||
// seqs have been updated to prepare for the next batch
|
||||
computeStartedCh chan struct{}
|
||||
|
||||
// Signaled when this batches outputs are complete and the next batch can proceed
|
||||
outputsReadyCh chan struct{}
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
// modelPath is the location of the model to be loaded
|
||||
modelPath string
|
||||
@@ -290,6 +319,16 @@ type Server struct {
|
||||
// TODO (jmorganca): make this n_batch
|
||||
batchSize int
|
||||
|
||||
// Used to signal a hard failure during async processing which will panic the runner
|
||||
hardErrCh chan error
|
||||
|
||||
// A prior batch that's still being processed
|
||||
// only read or written by forwardBatch
|
||||
pendingBatch *batchState
|
||||
|
||||
// Simple counter used only for trace logging batches
|
||||
batchID int
|
||||
|
||||
// protects access to everything below this line
|
||||
// this is context state needed for decoding
|
||||
mu sync.Mutex
|
||||
@@ -350,45 +389,132 @@ func flushPending(seq *Sequence) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
|
||||
func (s *Server) finishSequence(seqIndex int, reason llm.DoneReason) {
|
||||
seq := s.seqs[seqIndex]
|
||||
|
||||
// finish could be called multiple times since we prepare 1 batch ahead
|
||||
// and multiple scenarios can lead to finishing a sequence
|
||||
// ensure only the first finish called is processed
|
||||
if seq.finished {
|
||||
return
|
||||
}
|
||||
|
||||
flushPending(seq)
|
||||
seq.doneReason = reason
|
||||
seq.finished = true
|
||||
close(seq.responses)
|
||||
close(seq.embedding)
|
||||
seq.cache.InUse = false
|
||||
}
|
||||
|
||||
func (s *Server) removeFinishedSequence(seqIndex int) {
|
||||
s.seqs[seqIndex] = nil
|
||||
s.seqsSem.Release(1)
|
||||
}
|
||||
|
||||
// track batch state between forwardBatch, computeBatch and predictForwardBatch
|
||||
|
||||
func (s *Server) run(ctx context.Context) {
|
||||
s.ready.Wait()
|
||||
|
||||
var bs *batchState
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case err := <-s.hardErrCh:
|
||||
panic(err)
|
||||
default:
|
||||
err := s.processBatch()
|
||||
var err error
|
||||
bs, err = s.forwardBatch()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if bs == nil {
|
||||
continue
|
||||
}
|
||||
go s.computeBatch(bs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) processBatch() error {
|
||||
// forwardBatch will calculate a batch.
|
||||
func (s *Server) forwardBatch() (*batchState, error) {
|
||||
inputsReady := false
|
||||
var inputsReadyCh chan struct{}
|
||||
|
||||
// If we have a pending batch still processing, wait until Compute has started
|
||||
// before setting up the next batch so the seqs inputs are ready to receive their
|
||||
// token values and we get the correct input pointers for the batchInputs
|
||||
if s.pendingBatch != nil {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch waiting for compute to start", "pendingBatch.id", s.pendingBatch.id)
|
||||
<-s.pendingBatch.computeStartedCh
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch compute started, setting up next batch", "pendingBatch.id", s.pendingBatch.id, "id", s.batchID)
|
||||
inputsReadyCh = s.pendingBatch.outputsReadyCh // Chain the ouputs from the pending batch to the next inputs batch
|
||||
} else {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch no pending batch detected", "batchID", s.batchID)
|
||||
inputsReady = true // No pendingBatch, so the inputs will be ready in the seqs immediately
|
||||
inputsReadyCh = make(chan struct{}, 1)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
for s.allNil() {
|
||||
s.cond.Wait() // Wait until an item is added
|
||||
}
|
||||
defer s.mu.Unlock()
|
||||
|
||||
ctx := s.model.Backend().NewContext()
|
||||
defer ctx.Close()
|
||||
// If new sequences have been added with an active batch we delay preparing the next batch
|
||||
// until Compute has finished
|
||||
if s.pendingBatch != nil {
|
||||
for seqIdx := range s.seqs {
|
||||
if s.seqs[seqIdx] != s.pendingBatch.seqs[seqIdx] {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch seqs changed, waiting for compute to finish to pick up new sequence(s)", "pendingBatch.id", s.pendingBatch.id)
|
||||
s.mu.Unlock() // release the lock so computeBatch can finish up
|
||||
<-s.pendingBatch.outputsReadyCh
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch pending batch outputs ready", "pendingBatch.id", s.pendingBatch.id)
|
||||
s.mu.Lock()
|
||||
inputsReady = true // pendingBatch completed, so the inputs are ready in the seqs
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
// Clear pending Batch - we'll set it if we have a batch with any inputs
|
||||
s.pendingBatch = nil
|
||||
|
||||
var batchInputs []int32
|
||||
// Remove any finished sequences before recording the active set of seqs in the batch
|
||||
for seqIdx := range s.seqs {
|
||||
seq := s.seqs[seqIdx]
|
||||
if seq == nil {
|
||||
continue
|
||||
}
|
||||
if seq.finished {
|
||||
s.removeFinishedSequence(seqIdx)
|
||||
continue
|
||||
}
|
||||
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
||||
s.finishSequence(seqIdx, llm.DoneReasonLength)
|
||||
s.removeFinishedSequence(seqIdx)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// next batch
|
||||
nb := &batchState{
|
||||
id: s.batchID,
|
||||
initSeqIdx: s.nextSeq - 1,
|
||||
seqs: make([]*Sequence, len(s.seqs)),
|
||||
inputsReadyCh: inputsReadyCh,
|
||||
computeStartedCh: make(chan struct{}, 1),
|
||||
outputsReadyCh: make(chan struct{}, 1),
|
||||
}
|
||||
ctx := s.model.Backend().NewContext()
|
||||
nb.ctx = ctx
|
||||
|
||||
// Record the sequences at the time we create the batch so we can detect if new sequences are added on the next pass
|
||||
copy(nb.seqs, s.seqs)
|
||||
|
||||
// Prepare the seqs and batch, but defer the input token values as we may not be ready yet
|
||||
var batchInputs []*input.Input
|
||||
var batch input.Batch
|
||||
|
||||
resumeSeq := -1
|
||||
@@ -396,20 +522,13 @@ func (s *Server) processBatch() error {
|
||||
for range s.seqs {
|
||||
seqIdx = (seqIdx + 1) % len(s.seqs)
|
||||
seq := s.seqs[seqIdx]
|
||||
|
||||
if seq == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// if past the num predict limit
|
||||
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
||||
s.removeSequence(seqIdx, llm.DoneReasonLength)
|
||||
continue
|
||||
}
|
||||
|
||||
if !s.cache.enabled {
|
||||
seq.inputs = append(seq.cache.Inputs, seq.inputs...)
|
||||
seq.cache.Inputs = []input.Input{}
|
||||
seq.cache.Inputs = []*input.Input{}
|
||||
}
|
||||
|
||||
batchSize := s.batchSize
|
||||
@@ -449,18 +568,21 @@ func (s *Server) processBatch() error {
|
||||
// Prepend these inputs to the sequence's inputs queue for reprocessing
|
||||
seq.inputs = append(reprocess.Inputs, seq.inputs...)
|
||||
// Skip this sequence but continue processing the rest
|
||||
seq.skipForShift = true // cleared in computeBatch below for the next batch
|
||||
continue
|
||||
} else {
|
||||
return err
|
||||
ctx.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
batchInputs = append(batchInputs, inp.Token)
|
||||
batchInputs = append(batchInputs, seq.inputs[i])
|
||||
if inp.Multimodal != nil {
|
||||
mm, err := seq.mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, false)
|
||||
if err != nil {
|
||||
return err
|
||||
ctx.Close()
|
||||
return nil, err
|
||||
}
|
||||
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: mm})
|
||||
}
|
||||
@@ -468,10 +590,13 @@ func (s *Server) processBatch() error {
|
||||
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
||||
batch.Sequences = append(batch.Sequences, seq.cache.Id)
|
||||
|
||||
// TODO BUG HERE!!!
|
||||
// Somehow sometimes iBatch isn't set correctly
|
||||
seq.iBatch = len(batch.Outputs)
|
||||
if i+1 == len(seq.inputs) {
|
||||
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
|
||||
}
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
|
||||
seq.pendingInputs = append(seq.pendingInputs, inp)
|
||||
}
|
||||
|
||||
@@ -485,36 +610,138 @@ func (s *Server) processBatch() error {
|
||||
}
|
||||
|
||||
if len(batchInputs) == 0 {
|
||||
return nil
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch no batchInputs, going idle", "batchID", s.batchID)
|
||||
ctx.Close()
|
||||
return nil, nil
|
||||
}
|
||||
s.batchID++
|
||||
|
||||
modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch)
|
||||
var err error
|
||||
// Actual batchInputs values will be injected into the modelInput tensor before calling Compute
|
||||
nb.modelInput, nb.modelOutput, err = model.Forward(ctx, s.model, make([]int32, len(batchInputs)), batch)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode batch: %w", err)
|
||||
ctx.Close()
|
||||
return nil, fmt.Errorf("failed to build graph: %w", err)
|
||||
}
|
||||
nb.batchInputs = batchInputs
|
||||
nb.batch = batch
|
||||
|
||||
// computeBatch will close the context in the batch upon completion
|
||||
s.pendingBatch = nb
|
||||
|
||||
if inputsReady {
|
||||
nb.inputsReadyCh <- struct{}{}
|
||||
}
|
||||
|
||||
logits := modelOutput.Floats()
|
||||
return nb, nil
|
||||
}
|
||||
|
||||
// Async processing of the next batch
|
||||
func (s *Server) computeBatch(bs *batchState) {
|
||||
if bs == nil || bs.ctx == nil {
|
||||
// Nothing to compute
|
||||
return
|
||||
}
|
||||
defer bs.ctx.Close()
|
||||
|
||||
// Wait until inputs are ready
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: waiting for inputs to be ready", "batchID", bs.id)
|
||||
<-bs.inputsReadyCh
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: inputs are ready", "batchID", bs.id)
|
||||
|
||||
// Once we complete, signal the next batch of inputs are ready
|
||||
// This will unblock the next computeBatch, or forwardBatch if new seqs come in
|
||||
defer func() {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: outputs are ready", "batchID", bs.id)
|
||||
bs.outputsReadyCh <- struct{}{}
|
||||
}()
|
||||
|
||||
s.mu.Lock()
|
||||
|
||||
// Gather the actual input token values now that they're ready
|
||||
batchInputs := make([]int32, len(bs.batchInputs))
|
||||
for i := range batchInputs {
|
||||
batchInputs[i] = bs.batchInputs[i].Token
|
||||
}
|
||||
|
||||
// TODO the following logic could be run in a go routine to possibly speed up getting to Compute
|
||||
|
||||
// Now we run part of the decoding algorithm to adjust the seq.inputs with placeholder tokens
|
||||
// so that forwardBatch can build a batchInputs set which will eventually contain the actual
|
||||
// decoded tokens.
|
||||
promptProcessing := make([]bool, len(s.seqs)) // track seq's we skip
|
||||
nextBatchTokens := make([]*input.Input, len(s.seqs))
|
||||
iBatches := make([]int, len(s.seqs)) // Record the iBatch values before releasing the lock
|
||||
for i, seq := range s.seqs {
|
||||
iBatches[i] = -1
|
||||
if seq == nil {
|
||||
continue
|
||||
}
|
||||
// Skip over any newly added sequences
|
||||
if bs.seqs[i] == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// After calling Forward, pending inputs are now in the cache
|
||||
if len(seq.pendingInputs) > 0 {
|
||||
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
|
||||
seq.pendingInputs = []input.Input{}
|
||||
seq.pendingInputs = []*input.Input{}
|
||||
}
|
||||
|
||||
// don't sample prompt processing
|
||||
if len(seq.inputs) != 0 {
|
||||
if !s.cache.enabled {
|
||||
return errors.New("caching disabled but unable to fit entire input in a batch")
|
||||
s.hardErrCh <- fmt.Errorf("caching disabled but unable to fit entire input in a batch")
|
||||
return
|
||||
}
|
||||
// Record so we can skip during Decode
|
||||
promptProcessing[i] = true
|
||||
continue
|
||||
}
|
||||
|
||||
seq.numPredicted++
|
||||
nextToken := &input.Input{Token: 0} // placeholder we'll fill in after Compute/Floats
|
||||
seq.inputs = []*input.Input{nextToken}
|
||||
nextBatchTokens[i] = nextToken
|
||||
iBatches[i] = seq.iBatch
|
||||
}
|
||||
|
||||
// At this point the seqs are ready for forwardBatch to move forward so unblock
|
||||
s.mu.Unlock()
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: signaling computeStartedCh", "batchID", bs.id)
|
||||
bs.computeStartedCh <- struct{}{}
|
||||
|
||||
bs.modelInput.BackendSetFromIntSlice(batchInputs)
|
||||
bs.ctx.Compute(bs.modelOutput)
|
||||
logits := bs.modelOutput.Floats()
|
||||
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: logits ready", "batchID", bs.id)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: decoding", "batchID", bs.id)
|
||||
for i, seq := range s.seqs {
|
||||
if seq == nil {
|
||||
continue
|
||||
}
|
||||
// Skip over any newly added sequences
|
||||
if bs.seqs[i] == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Detect if the sequence we're processing has already been completed and replaced
|
||||
// with a new sequence
|
||||
if seq != bs.seqs[i] {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: sequence replaced, discarding its results", "batchID", bs.id, "seqIdx", i)
|
||||
continue
|
||||
}
|
||||
|
||||
// don't sample prompt processing
|
||||
if promptProcessing[i] {
|
||||
continue
|
||||
}
|
||||
|
||||
if seq.numPredicted == 1 {
|
||||
seq.startGenerationTime = time.Now()
|
||||
}
|
||||
@@ -522,35 +749,46 @@ func (s *Server) processBatch() error {
|
||||
// if done processing the prompt, generate an embedding and return
|
||||
if seq.embeddingOnly {
|
||||
// TODO(jessegross): Embedding support
|
||||
slog.Warn("generation of embedding outputs not yet supported")
|
||||
s.removeSequence(i, llm.DoneReasonStop)
|
||||
slog.Warn("generation of embedding outputs not yet supported", "id", bs.id, "seqIdx", i)
|
||||
s.finishSequence(i, llm.DoneReasonStop)
|
||||
continue
|
||||
}
|
||||
|
||||
// sample a token
|
||||
vocabSize := len(logits) / len(batch.Outputs)
|
||||
|
||||
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
|
||||
vocabSize := len(logits) / len(bs.batch.Outputs)
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: vocab details", "batchID", bs.id, "seqIdx", i, "len(logits)", len(logits), "len(bs.batch.Outputs)", len(bs.batch.Outputs), "vocabSize", vocabSize, "seq.iBatch", seq.iBatch)
|
||||
token, err := seq.sampler.Sample(logits[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sample token: %w", err)
|
||||
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
nextBatchTokens[i].Token = token
|
||||
|
||||
// if it's an end of sequence token, break
|
||||
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
|
||||
// TODO (jmorganca): we should send this back
|
||||
// as it's important for the /api/generate context
|
||||
// seq.responses <- piece
|
||||
|
||||
s.removeSequence(i, llm.DoneReasonStop)
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: EOS", "batchID", bs.id, "seqIdx", i)
|
||||
s.finishSequence(i, llm.DoneReasonStop)
|
||||
continue
|
||||
}
|
||||
|
||||
piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
|
||||
if err != nil {
|
||||
return err
|
||||
s.hardErrCh <- fmt.Errorf("failed to decode token: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
seq.inputs = []input.Input{{Token: token}}
|
||||
if nextBatchTokens[i] == nil {
|
||||
slog.Error("batch corrupted", "id", bs.id, "batch", bs.batch, "seqIdx", i, "seq", seq)
|
||||
s.hardErrCh <- fmt.Errorf("expected a single token during decode")
|
||||
return
|
||||
}
|
||||
|
||||
// fill in the final selected token value to replace the placeholder in the next batch
|
||||
// nextBatchTokensWritten++
|
||||
|
||||
seq.pendingResponses = append(seq.pendingResponses, piece)
|
||||
sequence := strings.Join(seq.pendingResponses, "")
|
||||
@@ -575,9 +813,10 @@ func (s *Server) processBatch() error {
|
||||
if tokenTruncated || origLen == newLen {
|
||||
tokenLen--
|
||||
}
|
||||
|
||||
seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
|
||||
|
||||
s.removeSequence(i, llm.DoneReasonStop)
|
||||
s.finishSequence(i, llm.DoneReasonStop)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -590,11 +829,9 @@ func (s *Server) processBatch() error {
|
||||
}
|
||||
|
||||
if !flushPending(seq) {
|
||||
s.removeSequence(i, llm.DoneReasonConnectionClosed)
|
||||
s.finishSequence(i, llm.DoneReasonConnectionClosed)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -604,6 +841,15 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
var harmonyMessageHandler *harmony.HarmonyMessageHandler
|
||||
var harmonyToolParser *harmony.HarmonyToolCallAccumulator
|
||||
if req.FunctionNameMap != nil {
|
||||
harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
|
||||
harmonyMessageHandler.FunctionNameMap = req.FunctionNameMap
|
||||
harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(req.PrefillContent)
|
||||
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
|
||||
}
|
||||
|
||||
if req.Options == nil {
|
||||
opts := api.DefaultOptions()
|
||||
req.Options = &opts
|
||||
@@ -694,8 +940,16 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
case content, ok := <-seq.responses:
|
||||
if ok {
|
||||
var thinking string
|
||||
if harmonyMessageHandler != nil {
|
||||
var toolContent string
|
||||
content, thinking, toolContent = harmonyMessageHandler.AddContent(content, harmonyToolParser)
|
||||
harmonyToolParser.Add(toolContent)
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||
Content: content,
|
||||
Content: content,
|
||||
Thinking: thinking,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
close(seq.quit)
|
||||
@@ -704,7 +958,29 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
flusher.Flush()
|
||||
} else {
|
||||
var toolCalls []api.ToolCall
|
||||
if harmonyMessageHandler != nil {
|
||||
toolName, toolContent := harmonyToolParser.Drain()
|
||||
if toolName != nil {
|
||||
*toolName = strings.TrimPrefix(*toolName, "functions.")
|
||||
*toolName = harmonyMessageHandler.FunctionNameMap.OriginalFromConverted(*toolName)
|
||||
var args api.ToolCallFunctionArguments
|
||||
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to unmarshal tool call function arguments: %v", err), http.StatusInternalServerError)
|
||||
close(seq.quit)
|
||||
return
|
||||
}
|
||||
toolCalls = append(toolCalls, api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: *toolName,
|
||||
Arguments: args,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||
ToolCalls: toolCalls,
|
||||
Done: true,
|
||||
DoneReason: seq.doneReason,
|
||||
PromptEvalCount: seq.numPromptInputs,
|
||||
@@ -736,7 +1012,10 @@ func (s *Server) reserveWorstCaseGraph() error {
|
||||
defer ctx.Close()
|
||||
|
||||
var err error
|
||||
inputs := make([]input.Input, s.batchSize)
|
||||
inputs := make([]*input.Input, s.batchSize)
|
||||
for i := range inputs {
|
||||
inputs[i] = &input.Input{}
|
||||
}
|
||||
mmStore := newMultimodalStore()
|
||||
|
||||
// Multimodal strategy:
|
||||
@@ -778,8 +1057,11 @@ func (s *Server) reserveWorstCaseGraph() error {
|
||||
}
|
||||
|
||||
if len(inputs) < s.batchSize {
|
||||
newInputs := make([]input.Input, s.batchSize)
|
||||
newInputs := make([]*input.Input, s.batchSize)
|
||||
copy(newInputs, inputs)
|
||||
for i := len(inputs); i < s.batchSize; i++ {
|
||||
newInputs[i] = &input.Input{}
|
||||
}
|
||||
inputs = newInputs
|
||||
}
|
||||
}
|
||||
@@ -842,6 +1124,7 @@ func (s *Server) allocModel(
|
||||
// Convert memory allocation panics to errors
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
debug.PrintStack()
|
||||
if err, ok := r.(error); ok {
|
||||
panicErr = err
|
||||
} else {
|
||||
@@ -1011,6 +1294,7 @@ func Execute(args []string) error {
|
||||
server := &Server{
|
||||
modelPath: *mpath,
|
||||
status: llm.ServerStatusLaunched,
|
||||
hardErrCh: make(chan error, 1),
|
||||
}
|
||||
|
||||
server.cond = sync.NewCond(&server.mu)
|
||||
|
||||
@@ -32,6 +32,7 @@ import (
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/harmony"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/openai"
|
||||
@@ -194,14 +195,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
useHarmony := shouldUseHarmony(*m) && !req.Raw
|
||||
var harmonyMessageHandler *HarmonyMessageHandler
|
||||
var harmonyToolParser *HarmonyToolCallAccumulator
|
||||
if useHarmony {
|
||||
harmonyMessageHandler = NewHarmonyMessageHandler()
|
||||
harmonyMessageHandler.harmonyParser.AddImplicitStart()
|
||||
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
|
||||
}
|
||||
useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) && !req.Raw
|
||||
|
||||
// Validate Think value: string values currently only allowed for gptoss models
|
||||
if req.Think != nil && req.Think.IsString() && !useHarmony {
|
||||
@@ -362,12 +356,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
},
|
||||
}
|
||||
|
||||
if useHarmony {
|
||||
content, thinking, toolContent := harmonyMessageHandler.AddContent(cr.Content, harmonyToolParser)
|
||||
res.Response = content
|
||||
res.Thinking = thinking
|
||||
harmonyToolParser.Add(toolContent)
|
||||
} else if thinkingState != nil {
|
||||
if !useHarmony && thinkingState != nil {
|
||||
thinking, content := thinkingState.AddContent(cr.Content)
|
||||
res.Thinking = thinking
|
||||
res.Response = content
|
||||
@@ -378,26 +367,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
if cr.Done {
|
||||
if useHarmony {
|
||||
toolName, toolContent := harmonyToolParser.Drain()
|
||||
if toolName != nil {
|
||||
*toolName = strings.TrimPrefix(*toolName, "functions.")
|
||||
var args api.ToolCallFunctionArguments
|
||||
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
|
||||
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
|
||||
ch <- gin.H{"error": errStr}
|
||||
return
|
||||
}
|
||||
|
||||
res.ToolCalls = append(res.ToolCalls, api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: *toolName,
|
||||
Arguments: args,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
res.DoneReason = cr.DoneReason.String()
|
||||
res.TotalDuration = time.Since(checkpointStart)
|
||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
@@ -1603,27 +1572,36 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
msgs = filterThinkTags(msgs, m)
|
||||
|
||||
var harmonyMessageHandler *HarmonyMessageHandler
|
||||
var harmonyToolParser *HarmonyToolCallAccumulator
|
||||
|
||||
useHarmony := shouldUseHarmony(*m)
|
||||
useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template)
|
||||
|
||||
processedTools := req.Tools
|
||||
var functionNameMap *harmony.FunctionNameMap
|
||||
var prefillContentOrThinking *bool
|
||||
if useHarmony {
|
||||
harmonyMessageHandler = NewHarmonyMessageHandler()
|
||||
functionNameMap = harmony.NewFunctionNameMap()
|
||||
var lastMessage *api.Message
|
||||
if len(msgs) > 0 {
|
||||
lastMessage = &msgs[len(msgs)-1]
|
||||
}
|
||||
harmonyMessageHandler.harmonyParser.AddImplicitStartOrPrefill(lastMessage)
|
||||
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
|
||||
|
||||
// prefill content or thinking flag if the last message is an assistant message
|
||||
if lastMessage != nil && lastMessage.Role == "assistant" {
|
||||
if lastMessage.Content != "" {
|
||||
trueVal := true
|
||||
// true sets content to be prefilled
|
||||
prefillContentOrThinking = &trueVal
|
||||
} else if lastMessage.Thinking != "" {
|
||||
// false sets thinking to be prefilled
|
||||
falseVal := false
|
||||
prefillContentOrThinking = &falseVal
|
||||
}
|
||||
}
|
||||
// make a copy of tools to pass to the chat prompt. Function names may be
|
||||
// renamed to be valid Harmony function names.
|
||||
processedTools = make([]api.Tool, len(req.Tools))
|
||||
copy(processedTools, req.Tools)
|
||||
for i, tool := range processedTools {
|
||||
processedTools[i].Function.Name = harmonyMessageHandler.functionNameMap.ConvertAndAdd(tool.Function.Name)
|
||||
processedTools[i].Function.Name = functionNameMap.ConvertAndAdd(tool.Function.Name)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1672,15 +1650,17 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
defer close(ch)
|
||||
|
||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
FunctionNameMap: functionNameMap,
|
||||
PrefillContent: prefillContentOrThinking,
|
||||
}, func(r llm.CompletionResponse) {
|
||||
res := api.ChatResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Message: api.Message{Role: "assistant", Content: r.Content},
|
||||
Message: api.Message{Role: "assistant", Content: r.Content, Thinking: r.Thinking, ToolCalls: r.ToolCalls},
|
||||
Done: r.Done,
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: r.PromptEvalCount,
|
||||
@@ -1696,31 +1676,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
if useHarmony {
|
||||
content, thinking, toolContent := harmonyMessageHandler.AddContent(r.Content, harmonyToolParser)
|
||||
res.Message.Content = content
|
||||
res.Message.Thinking = thinking
|
||||
harmonyToolParser.Add(toolContent)
|
||||
|
||||
if r.Done {
|
||||
toolName, toolContent := harmonyToolParser.Drain()
|
||||
if toolName != nil {
|
||||
*toolName = strings.TrimPrefix(*toolName, "functions.")
|
||||
*toolName = harmonyMessageHandler.functionNameMap.OriginalFromConverted(*toolName)
|
||||
var args api.ToolCallFunctionArguments
|
||||
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
|
||||
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
|
||||
ch <- gin.H{"error": errStr}
|
||||
return
|
||||
}
|
||||
res.Message.ToolCalls = []api.ToolCall{{Function: api.ToolCallFunction{Name: *toolName, Arguments: args}}}
|
||||
}
|
||||
}
|
||||
|
||||
// only send messages with meaningful content (empty messages confuse clients)
|
||||
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || res.Done {
|
||||
ch <- res
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user