mirror of
https://github.com/ollama/ollama.git
synced 2026-01-19 12:57:56 -05:00
Compare commits
8 Commits
parth/decr
...
parth/move
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f30d01801d | ||
|
|
b08c7dad0a | ||
|
|
bc5ab5784b | ||
|
|
92a99e67c7 | ||
|
|
05cebf1f21 | ||
|
|
51a400ff0f | ||
|
|
a865b50d9a | ||
|
|
31f64183dc |
@@ -1,30 +1,26 @@
|
|||||||
package server
|
package harmony
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"maps"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"unicode"
|
"unicode"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
|
"github.com/ollama/ollama/template"
|
||||||
)
|
)
|
||||||
|
|
||||||
type harmonyParserState int
|
type harmonyParserState int
|
||||||
|
|
||||||
const (
|
func ShouldUseHarmony(modelFamily string, template *template.Template) bool {
|
||||||
harmonyParserState_LookingForMessageStart harmonyParserState = iota
|
if slices.Contains([]string{"gptoss", "gpt-oss"}, modelFamily) {
|
||||||
harmonyParserState_ParsingHeader
|
|
||||||
harmonyParserState_ParsingContent
|
|
||||||
)
|
|
||||||
|
|
||||||
func shouldUseHarmony(model Model) bool {
|
|
||||||
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
|
|
||||||
// heuristic to check whether the template expects to be parsed via harmony:
|
// heuristic to check whether the template expects to be parsed via harmony:
|
||||||
// search for harmony tags that are nearly always used
|
// search for harmony tags that are nearly always used
|
||||||
if model.Template.Contains("<|start|>") && model.Template.Contains("<|end|>") {
|
if template.Contains("<|start|>") && template.Contains("<|end|>") {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -32,6 +28,12 @@ func shouldUseHarmony(model Model) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
harmonyParserState_LookingForMessageStart harmonyParserState = iota
|
||||||
|
harmonyParserState_ParsingHeader
|
||||||
|
harmonyParserState_ParsingContent
|
||||||
|
)
|
||||||
|
|
||||||
func (s harmonyParserState) String() string {
|
func (s harmonyParserState) String() string {
|
||||||
switch s {
|
switch s {
|
||||||
// we're looking for the message start tag
|
// we're looking for the message start tag
|
||||||
@@ -89,17 +91,18 @@ func (s *HarmonyParser) AddImplicitStart() {
|
|||||||
s.acc.WriteString("<|start|>assistant")
|
s.acc.WriteString("<|start|>assistant")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *HarmonyParser) AddImplicitStartOrPrefill(lastMessage *api.Message) {
|
// AddImplicitStartOrPrefill adds content or thinking to the accumulator else adds start tag
|
||||||
if lastMessage != nil && lastMessage.Role == "assistant" {
|
func (s *HarmonyParser) AddImplicitStartOrPrefill(prefillContentOrThinking *bool) {
|
||||||
// handle prefilling conditions
|
if prefillContentOrThinking != nil {
|
||||||
if lastMessage.Content != "" {
|
if *prefillContentOrThinking {
|
||||||
s.acc.WriteString("<|start|>assistant<|channel|>final<|message|>")
|
s.acc.WriteString("<|start|>assistant<|channel|>final<|message|>")
|
||||||
return
|
return
|
||||||
} else if lastMessage.Thinking != "" {
|
} else {
|
||||||
s.acc.WriteString("<|start|>assistant<|channel|>analysis<|message|>")
|
s.acc.WriteString("<|start|>assistant<|channel|>analysis<|message|>")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s.AddImplicitStart()
|
s.AddImplicitStart()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -277,20 +280,20 @@ const (
|
|||||||
// This is a higher level interface that maps harmony concepts into ollama concepts
|
// This is a higher level interface that maps harmony concepts into ollama concepts
|
||||||
type HarmonyMessageHandler struct {
|
type HarmonyMessageHandler struct {
|
||||||
state harmonyMessageState
|
state harmonyMessageState
|
||||||
harmonyParser *HarmonyParser
|
HarmonyParser *HarmonyParser
|
||||||
functionNameMap *FunctionNameMap
|
FunctionNameMap *FunctionNameMap
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHarmonyMessageHandler creates a new message handler
|
// NewHarmonyMessageHandler creates a new message handler
|
||||||
func NewHarmonyMessageHandler() *HarmonyMessageHandler {
|
func NewHarmonyMessageHandler() *HarmonyMessageHandler {
|
||||||
return &HarmonyMessageHandler{
|
return &HarmonyMessageHandler{
|
||||||
state: harmonyMessageState_Normal,
|
state: harmonyMessageState_Normal,
|
||||||
harmonyParser: &HarmonyParser{
|
HarmonyParser: &HarmonyParser{
|
||||||
MessageStartTag: "<|start|>",
|
MessageStartTag: "<|start|>",
|
||||||
MessageEndTag: "<|end|>",
|
MessageEndTag: "<|end|>",
|
||||||
HeaderEndTag: "<|message|>",
|
HeaderEndTag: "<|message|>",
|
||||||
},
|
},
|
||||||
functionNameMap: NewFunctionNameMap(),
|
FunctionNameMap: NewFunctionNameMap(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -301,7 +304,7 @@ func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyTo
|
|||||||
thinkingSb := strings.Builder{}
|
thinkingSb := strings.Builder{}
|
||||||
toolContentSb := strings.Builder{}
|
toolContentSb := strings.Builder{}
|
||||||
|
|
||||||
events := h.harmonyParser.AddContent(content)
|
events := h.HarmonyParser.AddContent(content)
|
||||||
for _, event := range events {
|
for _, event := range events {
|
||||||
switch event := event.(type) {
|
switch event := event.(type) {
|
||||||
case HarmonyEventHeaderComplete:
|
case HarmonyEventHeaderComplete:
|
||||||
@@ -391,6 +394,38 @@ type FunctionNameMap struct {
|
|||||||
harmonyToUser map[string]string
|
harmonyToUser map[string]string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m FunctionNameMap) MarshalJSON() ([]byte, error) {
|
||||||
|
// necessary to avoid exposing map internals
|
||||||
|
type alias struct {
|
||||||
|
UserToHarmony map[string]string `json:"userToHarmony"`
|
||||||
|
HarmonyToUser map[string]string `json:"harmonyToUser"`
|
||||||
|
}
|
||||||
|
return json.Marshal(alias{
|
||||||
|
UserToHarmony: m.userToHarmony,
|
||||||
|
HarmonyToUser: m.harmonyToUser,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *FunctionNameMap) UnmarshalJSON(b []byte) error {
|
||||||
|
type alias struct {
|
||||||
|
UserToHarmony map[string]string `json:"userToHarmony"`
|
||||||
|
HarmonyToUser map[string]string `json:"harmonyToUser"`
|
||||||
|
}
|
||||||
|
var a alias
|
||||||
|
if err := json.Unmarshal(b, &a); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if m.userToHarmony == nil {
|
||||||
|
m.userToHarmony = make(map[string]string)
|
||||||
|
}
|
||||||
|
if m.harmonyToUser == nil {
|
||||||
|
m.harmonyToUser = make(map[string]string)
|
||||||
|
}
|
||||||
|
maps.Copy(m.userToHarmony, a.UserToHarmony)
|
||||||
|
maps.Copy(m.harmonyToUser, a.HarmonyToUser)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func NewFunctionNameMap() *FunctionNameMap {
|
func NewFunctionNameMap() *FunctionNameMap {
|
||||||
return &FunctionNameMap{
|
return &FunctionNameMap{
|
||||||
userToHarmony: make(map[string]string),
|
userToHarmony: make(map[string]string),
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package server
|
package harmony
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -2,10 +2,13 @@
|
|||||||
|
|
||||||
This directory contains integration tests to exercise Ollama end-to-end to verify behavior
|
This directory contains integration tests to exercise Ollama end-to-end to verify behavior
|
||||||
|
|
||||||
By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...`
|
By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...` Some tests require additional tags to enable to allow scoped testing to keep the duration reasonable. For example, testing a broad set of models requires `-tags=integration,models` and a longer timeout (~60m or more depending on the speed of your GPU.). To view the current set of tag combinations use `find integration -type f | xargs grep "go:build"`
|
||||||
|
|
||||||
|
|
||||||
The integration tests have 2 modes of operating.
|
The integration tests have 2 modes of operating.
|
||||||
|
|
||||||
1. By default, they will start the server on a random port, run the tests, and then shutdown the server.
|
1. By default, they will start the server on a random port, run the tests, and then shutdown the server.
|
||||||
2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote
|
2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote based on your `OLLAMA_HOST` environment variable
|
||||||
|
|
||||||
|
> [!IMPORTANT]
|
||||||
|
> Before running the tests locally without the "test existing" setting, compile ollama from the top of the source tree `go build .` in addition to GPU support with cmake if applicable on your platform. The integration tests expect to find an ollama binary at the top of the tree.
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ func TestContextExhaustion(t *testing.T) {
|
|||||||
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived"}, 120*time.Second, 10*time.Second)
|
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived"}, 120*time.Second, 10*time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send multiple requests with prior context and ensure the response is coherant and expected
|
// Send multiple generate requests with prior context and ensure the response is coherant and expected
|
||||||
func TestGenerateWithHistory(t *testing.T) {
|
func TestGenerateWithHistory(t *testing.T) {
|
||||||
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
|
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
|
||||||
req, resp := GenerateRequests()
|
req, resp := GenerateRequests()
|
||||||
@@ -111,5 +111,56 @@ func TestGenerateWithHistory(t *testing.T) {
|
|||||||
}(i)
|
}(i)
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send multiple chat requests with prior context and ensure the response is coherant and expected
|
||||||
|
func TestChatWithHistory(t *testing.T) {
|
||||||
|
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
|
||||||
|
req, resp := ChatRequests()
|
||||||
|
numParallel := 2
|
||||||
|
iterLimit := 2
|
||||||
|
|
||||||
|
softTimeout, hardTimeout := getTimeouts(t)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||||
|
defer cancel()
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// Get the server running (if applicable) warm the model up with a single initial empty request
|
||||||
|
slog.Info("loading", "model", modelOverride)
|
||||||
|
err := client.Generate(ctx,
|
||||||
|
&api.GenerateRequest{Model: modelOverride, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
|
||||||
|
func(response api.GenerateResponse) error { return nil },
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to load model %s: %s", modelOverride, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(numParallel)
|
||||||
|
for i := range numParallel {
|
||||||
|
go func(i int) {
|
||||||
|
defer wg.Done()
|
||||||
|
k := i % len(req)
|
||||||
|
req[k].Model = modelOverride
|
||||||
|
for j := 0; j < iterLimit; j++ {
|
||||||
|
if time.Now().Sub(started) > softTimeout {
|
||||||
|
slog.Info("exceeded soft timeout, winding down test")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
slog.Info("Starting", "thread", i, "iter", j)
|
||||||
|
// On slower GPUs it can take a while to process the concurrent requests
|
||||||
|
// so we allow a much longer initial timeout
|
||||||
|
assistant := DoChat(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second)
|
||||||
|
if assistant == nil {
|
||||||
|
t.Fatalf("didn't get an assistant response for context")
|
||||||
|
}
|
||||||
|
req[k].Messages = append(req[k].Messages,
|
||||||
|
*assistant,
|
||||||
|
api.Message{Role: "user", Content: "tell me more!"},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,6 +19,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestMaxQueue(t *testing.T) {
|
func TestMaxQueue(t *testing.T) {
|
||||||
|
t.Skip("this test needs to be re-evaluated to use a proper embedding model")
|
||||||
|
|
||||||
if os.Getenv("OLLAMA_TEST_EXISTING") != "" {
|
if os.Getenv("OLLAMA_TEST_EXISTING") != "" {
|
||||||
t.Skip("Max Queue test requires spawning a local server so we can adjust the queue size")
|
t.Skip("Max Queue test requires spawning a local server so we can adjust the queue size")
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -567,6 +567,76 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ChatRequests() ([]api.ChatRequest, [][]string) {
|
||||||
|
genReqs, results := GenerateRequests()
|
||||||
|
reqs := make([]api.ChatRequest, len(genReqs))
|
||||||
|
for i := range reqs {
|
||||||
|
reqs[i].Model = genReqs[i].Model
|
||||||
|
reqs[i].Stream = genReqs[i].Stream
|
||||||
|
reqs[i].KeepAlive = genReqs[i].KeepAlive
|
||||||
|
reqs[i].Messages = []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: genReqs[i].Prompt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return reqs, results
|
||||||
|
}
|
||||||
|
|
||||||
|
func DoChat(ctx context.Context, t *testing.T, client *api.Client, req api.ChatRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) *api.Message {
|
||||||
|
stallTimer := time.NewTimer(initialTimeout)
|
||||||
|
var buf bytes.Buffer
|
||||||
|
role := "assistant"
|
||||||
|
fn := func(response api.ChatResponse) error {
|
||||||
|
// fmt.Print(".")
|
||||||
|
role = response.Message.Role
|
||||||
|
buf.Write([]byte(response.Message.Content))
|
||||||
|
if !stallTimer.Reset(streamTimeout) {
|
||||||
|
return errors.New("stall was detected while streaming response, aborting")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
stream := true
|
||||||
|
req.Stream = &stream
|
||||||
|
done := make(chan int)
|
||||||
|
var genErr error
|
||||||
|
go func() {
|
||||||
|
genErr = client.Chat(ctx, &req, fn)
|
||||||
|
done <- 0
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-stallTimer.C:
|
||||||
|
if buf.Len() == 0 {
|
||||||
|
t.Errorf("generate never started. Timed out after :%s", initialTimeout.String())
|
||||||
|
} else {
|
||||||
|
t.Errorf("generate stalled. Response so far:%s", buf.String())
|
||||||
|
}
|
||||||
|
case <-done:
|
||||||
|
if genErr != nil && strings.Contains(genErr.Error(), "model requires more system memory") {
|
||||||
|
slog.Warn("model is too large for the target test system", "model", req.Model, "error", genErr)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
require.NoError(t, genErr, "failed with %s request Messages %s ", req.Model, req.Messages)
|
||||||
|
// Verify the response contains the expected data
|
||||||
|
response := buf.String()
|
||||||
|
atLeastOne := false
|
||||||
|
for _, resp := range anyResp {
|
||||||
|
if strings.Contains(strings.ToLower(response), resp) {
|
||||||
|
atLeastOne = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.True(t, atLeastOne, "%s: none of %v found in \"%s\" -- request was:%v", req.Model, anyResp, response, req.Messages)
|
||||||
|
slog.Info("test pass", "model", req.Model, "messages", req.Messages, "contains", anyResp, "response", response)
|
||||||
|
case <-ctx.Done():
|
||||||
|
t.Error("outer test context done while waiting for generate")
|
||||||
|
}
|
||||||
|
return &api.Message{Role: role, Content: buf.String()}
|
||||||
|
}
|
||||||
|
|
||||||
func skipUnderMinVRAM(t *testing.T, gb uint64) {
|
func skipUnderMinVRAM(t *testing.T, gb uint64) {
|
||||||
// TODO use info API in the future
|
// TODO use info API in the future
|
||||||
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
|
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ import (
|
|||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
"github.com/ollama/ollama/harmony"
|
||||||
"github.com/ollama/ollama/llama"
|
"github.com/ollama/ollama/llama"
|
||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
@@ -1331,7 +1332,9 @@ type CompletionRequest struct {
|
|||||||
Images []ImageData
|
Images []ImageData
|
||||||
Options *api.Options
|
Options *api.Options
|
||||||
|
|
||||||
Grammar string // set before sending the request to the subprocess
|
Grammar string // set before sending the request to the subprocess
|
||||||
|
FunctionNameMap *harmony.FunctionNameMap
|
||||||
|
PrefillContent *bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// DoneReason represents the reason why a completion response is done
|
// DoneReason represents the reason why a completion response is done
|
||||||
@@ -1358,13 +1361,15 @@ func (d DoneReason) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type CompletionResponse struct {
|
type CompletionResponse struct {
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
DoneReason DoneReason `json:"done_reason"`
|
Thinking string `json:"thinking"`
|
||||||
Done bool `json:"done"`
|
ToolCalls []api.ToolCall `json:"tool_calls"`
|
||||||
PromptEvalCount int `json:"prompt_eval_count"`
|
DoneReason DoneReason `json:"done_reason"`
|
||||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
|
Done bool `json:"done"`
|
||||||
EvalCount int `json:"eval_count"`
|
PromptEvalCount int `json:"prompt_eval_count"`
|
||||||
EvalDuration time.Duration `json:"eval_duration"`
|
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
|
||||||
|
EvalCount int `json:"eval_count"`
|
||||||
|
EvalDuration time.Duration `json:"eval_duration"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
||||||
@@ -1482,7 +1487,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||||||
return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
|
return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
|
||||||
}
|
}
|
||||||
switch {
|
switch {
|
||||||
case strings.TrimSpace(c.Content) == lastToken:
|
case lastToken != "" && (strings.TrimSpace(c.Content) == lastToken || strings.TrimSpace(c.Thinking) == lastToken):
|
||||||
tokenRepeat++
|
tokenRepeat++
|
||||||
default:
|
default:
|
||||||
lastToken = strings.TrimSpace(c.Content)
|
lastToken = strings.TrimSpace(c.Content)
|
||||||
@@ -1495,16 +1500,14 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.Content != "" {
|
|
||||||
fn(CompletionResponse{
|
|
||||||
Content: c.Content,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.Done {
|
if c.Done {
|
||||||
fn(c)
|
fn(c)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.Content != "" || c.Thinking != "" || len(c.ToolCalls) > 0 {
|
||||||
|
fn(c)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -400,6 +400,8 @@ type Tensor interface {
|
|||||||
Bytes() []byte
|
Bytes() []byte
|
||||||
Floats() []float32
|
Floats() []float32
|
||||||
|
|
||||||
|
BackendSetFromIntSlice(s []int32)
|
||||||
|
|
||||||
Neg(ctx Context) Tensor
|
Neg(ctx Context) Tensor
|
||||||
Add(ctx Context, t2 Tensor) Tensor
|
Add(ctx Context, t2 Tensor) Tensor
|
||||||
Sub(ctx Context, t2 Tensor) Tensor
|
Sub(ctx Context, t2 Tensor) Tensor
|
||||||
|
|||||||
@@ -82,6 +82,7 @@ type Backend struct {
|
|||||||
// to the name that is used by the model definition
|
// to the name that is used by the model definition
|
||||||
tensorLoadTargets map[string][]string
|
tensorLoadTargets map[string][]string
|
||||||
|
|
||||||
|
schedMu sync.Mutex // Only one Compute can run at a time
|
||||||
sched C.ggml_backend_sched_t
|
sched C.ggml_backend_sched_t
|
||||||
schedBackends []C.ggml_backend_t
|
schedBackends []C.ggml_backend_t
|
||||||
schedBufts []C.ggml_backend_buffer_type_t
|
schedBufts []C.ggml_backend_buffer_type_t
|
||||||
@@ -769,6 +770,8 @@ func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) Compute(tensors ...ml.Tensor) {
|
func (c *Context) Compute(tensors ...ml.Tensor) {
|
||||||
|
c.b.schedMu.Lock()
|
||||||
|
defer c.b.schedMu.Unlock()
|
||||||
if status := C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph); status != C.GGML_STATUS_SUCCESS {
|
if status := C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph); status != C.GGML_STATUS_SUCCESS {
|
||||||
panic(fmt.Errorf("error computing ggml graph: %v", status))
|
panic(fmt.Errorf("error computing ggml graph: %v", status))
|
||||||
}
|
}
|
||||||
@@ -1037,6 +1040,12 @@ func (t *Tensor) Floats() (data []float32) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Tensor) BackendSetFromIntSlice(s []int32) {
|
||||||
|
if len(s) > 0 {
|
||||||
|
C.ggml_backend_tensor_set(t.t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.t))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Tensor) DType() ml.DType {
|
func (t *Tensor) DType() ml.DType {
|
||||||
switch t.t._type {
|
switch t.t._type {
|
||||||
case C.GGML_TYPE_F32:
|
case C.GGML_TYPE_F32:
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ type MultimodalProcessor interface {
|
|||||||
// This function is also responsible for updating MultimodalHash for any Multimodal
|
// This function is also responsible for updating MultimodalHash for any Multimodal
|
||||||
// that is modified to ensure that there is a unique hash value that accurately
|
// that is modified to ensure that there is a unique hash value that accurately
|
||||||
// represents the contents.
|
// represents the contents.
|
||||||
PostTokenize([]input.Input) ([]input.Input, error)
|
PostTokenize([]*input.Input) ([]*input.Input, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Base implements the common fields and methods for all models
|
// Base implements the common fields and methods for all models
|
||||||
@@ -278,13 +278,13 @@ func canNil(t reflect.Type) bool {
|
|||||||
t.Kind() == reflect.Slice
|
t.Kind() == reflect.Slice
|
||||||
}
|
}
|
||||||
|
|
||||||
func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Tensor, error) {
|
func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Tensor, ml.Tensor, error) {
|
||||||
if len(batch.Positions) != len(batch.Sequences) {
|
if len(batch.Positions) != len(batch.Sequences) {
|
||||||
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
|
return nil, nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(batch.Positions) < 1 {
|
if len(batch.Positions) < 1 {
|
||||||
return nil, errors.New("batch size cannot be less than 1")
|
return nil, nil, errors.New("batch size cannot be less than 1")
|
||||||
}
|
}
|
||||||
|
|
||||||
batch.Inputs = ctx.Input().FromIntSlice(inputs, len(inputs))
|
batch.Inputs = ctx.Input().FromIntSlice(inputs, len(inputs))
|
||||||
@@ -293,16 +293,16 @@ func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Ten
|
|||||||
if cache != nil {
|
if cache != nil {
|
||||||
err := cache.StartForward(ctx, batch, false)
|
err := cache.StartForward(ctx, batch, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
t, err := m.Forward(ctx, batch)
|
t, err := m.Forward(ctx, batch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx.Forward(t).Compute(t)
|
ctx.Forward(t)
|
||||||
|
|
||||||
return t, nil
|
return batch.Inputs, t, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -112,8 +112,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
|||||||
return []input.Multimodal{{Tensor: visionOutputs}}, nil
|
return []input.Multimodal{{Tensor: visionOutputs}}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||||
var result []input.Input
|
var result []*input.Input
|
||||||
|
|
||||||
for _, inp := range inputs {
|
for _, inp := range inputs {
|
||||||
if len(inp.Multimodal) == 0 {
|
if len(inp.Multimodal) == 0 {
|
||||||
@@ -122,17 +122,17 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
|||||||
inputMultimodal := inp.Multimodal[0].Tensor
|
inputMultimodal := inp.Multimodal[0].Tensor
|
||||||
|
|
||||||
result = append(result,
|
result = append(result,
|
||||||
input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
|
&input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
|
||||||
input.Input{Token: 255999}, // "<start_of_image>""
|
&input.Input{Token: 255999}, // "<start_of_image>""
|
||||||
input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
|
&input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
|
||||||
)
|
)
|
||||||
|
|
||||||
// add image token placeholders
|
// add image token placeholders
|
||||||
result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
|
result = append(result, slices.Repeat([]*input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
|
||||||
|
|
||||||
result = append(result,
|
result = append(result,
|
||||||
input.Input{Token: 256000}, // <end_of_image>
|
&input.Input{Token: 256000}, // <end_of_image>
|
||||||
input.Input{Token: 108}, // "\n\n"
|
&input.Input{Token: 108}, // "\n\n"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -134,16 +134,16 @@ type separator struct {
|
|||||||
y bool
|
y bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||||
var result []input.Input
|
var result []*input.Input
|
||||||
for _, inp := range inputs {
|
for _, inp := range inputs {
|
||||||
if len(inp.Multimodal) == 0 {
|
if len(inp.Multimodal) == 0 {
|
||||||
result = append(result, inp)
|
result = append(result, inp)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
var imageInputs []input.Input
|
var imageInputs []*input.Input
|
||||||
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_start|>
|
imageInputs = append(imageInputs, &input.Input{Token: 200080}) // <|image_start|>
|
||||||
|
|
||||||
for i, mm := range inp.Multimodal {
|
for i, mm := range inp.Multimodal {
|
||||||
patchesPerChunk := mm.Tensor.Dim(1)
|
patchesPerChunk := mm.Tensor.Dim(1)
|
||||||
@@ -151,20 +151,20 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
|||||||
if i < len(inp.Multimodal)-1 {
|
if i < len(inp.Multimodal)-1 {
|
||||||
separator := mm.Data.(*separator)
|
separator := mm.Data.(*separator)
|
||||||
|
|
||||||
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
|
imageInputs = append(imageInputs, &input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
|
||||||
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
|
imageInputs = append(imageInputs, slices.Repeat([]*input.Input{{Token: 200092}}, patchesPerChunk-1)...)
|
||||||
|
|
||||||
if separator.x {
|
if separator.x {
|
||||||
imageInputs = append(imageInputs, input.Input{Token: 200084}) // <|tile_x_separator|>
|
imageInputs = append(imageInputs, &input.Input{Token: 200084}) // <|tile_x_separator|>
|
||||||
}
|
}
|
||||||
if separator.y {
|
if separator.y {
|
||||||
imageInputs = append(imageInputs, input.Input{Token: 200085}) // <|tile_y_separator|>
|
imageInputs = append(imageInputs, &input.Input{Token: 200085}) // <|tile_y_separator|>
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
imageInputs = append(imageInputs, input.Input{Token: 200090}) // <|image|>
|
imageInputs = append(imageInputs, &input.Input{Token: 200090}) // <|image|>
|
||||||
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
|
imageInputs = append(imageInputs, &input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
|
||||||
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
|
imageInputs = append(imageInputs, slices.Repeat([]*input.Input{{Token: 200092}}, patchesPerChunk-1)...)
|
||||||
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_end|>
|
imageInputs = append(imageInputs, &input.Input{Token: 200080}) // <|image_end|>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -133,22 +133,22 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
|||||||
// [IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_END]
|
// [IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_END]
|
||||||
// Each sequence of [IMG]...[IMG] is a set of patches of vision embeddings
|
// Each sequence of [IMG]...[IMG] is a set of patches of vision embeddings
|
||||||
// that can be processed together.
|
// that can be processed together.
|
||||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||||
var result []input.Input
|
var result []*input.Input
|
||||||
for _, inp := range inputs {
|
for _, inp := range inputs {
|
||||||
if len(inp.Multimodal) == 0 {
|
if len(inp.Multimodal) == 0 {
|
||||||
result = append(result, inp)
|
result = append(result, inp)
|
||||||
} else {
|
} else {
|
||||||
for i, row := range inp.Multimodal {
|
for i, row := range inp.Multimodal {
|
||||||
// [IMG]
|
// [IMG]
|
||||||
result = append(result, input.Input{Token: 10, Multimodal: []input.Multimodal{{Tensor: row.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: row.Tensor.Dim(1)})
|
result = append(result, &input.Input{Token: 10, Multimodal: []input.Multimodal{{Tensor: row.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: row.Tensor.Dim(1)})
|
||||||
result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.Tensor.Dim(1)-1)...)
|
result = append(result, slices.Repeat([]*input.Input{{Token: 10}}, row.Tensor.Dim(1)-1)...)
|
||||||
if i == len(inp.Multimodal)-1 {
|
if i == len(inp.Multimodal)-1 {
|
||||||
// [IMG_END]
|
// [IMG_END]
|
||||||
result = append(result, input.Input{Token: 13})
|
result = append(result, &input.Input{Token: 13})
|
||||||
} else {
|
} else {
|
||||||
// [IMG_BREAK]
|
// [IMG_BREAK]
|
||||||
result = append(result, input.Input{Token: 12})
|
result = append(result, &input.Input{Token: 12})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -90,7 +90,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
|||||||
return []input.Multimodal{{Tensor: projectedOutputs}}, nil
|
return []input.Multimodal{{Tensor: projectedOutputs}}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||||
for i := range inputs {
|
for i := range inputs {
|
||||||
if inputs[i].Multimodal != nil {
|
if inputs[i].Multimodal != nil {
|
||||||
inputs[i].Token = 128256 // <|image|>
|
inputs[i].Token = 128256 // <|image|>
|
||||||
|
|||||||
@@ -89,8 +89,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
|||||||
}
|
}
|
||||||
|
|
||||||
// PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass
|
// PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass
|
||||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||||
var result []input.Input
|
var result []*input.Input
|
||||||
|
|
||||||
var (
|
var (
|
||||||
imageToken int32 = 151655
|
imageToken int32 = 151655
|
||||||
@@ -112,16 +112,16 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
|||||||
return nil, fmt.Errorf("failed to encode image prompt: %w", err)
|
return nil, fmt.Errorf("failed to encode image prompt: %w", err)
|
||||||
}
|
}
|
||||||
for i := range pre {
|
for i := range pre {
|
||||||
result = append(result, input.Input{Token: pre[i]})
|
result = append(result, &input.Input{Token: pre[i]})
|
||||||
}
|
}
|
||||||
|
|
||||||
patchesPerChunk := inp.Multimodal[0].Tensor.Dim(1)
|
patchesPerChunk := inp.Multimodal[0].Tensor.Dim(1)
|
||||||
|
|
||||||
// First add the vision start token
|
// First add the vision start token
|
||||||
result = append(result, input.Input{Token: visionStartToken})
|
result = append(result, &input.Input{Token: visionStartToken})
|
||||||
|
|
||||||
// Add the image token with the multimodal tensor data at the first position
|
// Add the image token with the multimodal tensor data at the first position
|
||||||
result = append(result, input.Input{
|
result = append(result, &input.Input{
|
||||||
Token: imageToken,
|
Token: imageToken,
|
||||||
Multimodal: inp.Multimodal,
|
Multimodal: inp.Multimodal,
|
||||||
MultimodalHash: inp.MultimodalHash,
|
MultimodalHash: inp.MultimodalHash,
|
||||||
@@ -129,9 +129,9 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// Add the placeholder tokens for the remaining positions (tokensPerGrid-1)
|
// Add the placeholder tokens for the remaining positions (tokensPerGrid-1)
|
||||||
result = append(result, slices.Repeat([]input.Input{{Token: imageToken}}, patchesPerChunk-1)...)
|
result = append(result, slices.Repeat([]*input.Input{{Token: imageToken}}, patchesPerChunk-1)...)
|
||||||
|
|
||||||
result = append(result, input.Input{Token: visionEndToken})
|
result = append(result, &input.Input{Token: visionEndToken})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ type InputCacheSlot struct {
|
|||||||
Id int
|
Id int
|
||||||
|
|
||||||
// Inputs that are stored in the KV cache
|
// Inputs that are stored in the KV cache
|
||||||
Inputs []input.Input
|
Inputs []*input.Input
|
||||||
|
|
||||||
// is this cache actively being processed as part of a sequence?
|
// is this cache actively being processed as part of a sequence?
|
||||||
InUse bool
|
InUse bool
|
||||||
@@ -95,7 +95,7 @@ type InputCacheSlot struct {
|
|||||||
lastUsed time.Time
|
lastUsed time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []input.Input, error) {
|
func (c *InputCache) LoadCacheSlot(prompt []*input.Input) (*InputCacheSlot, []*input.Input, error) {
|
||||||
var slot *InputCacheSlot
|
var slot *InputCacheSlot
|
||||||
var numPast int32
|
var numPast int32
|
||||||
var err error
|
var err error
|
||||||
@@ -146,7 +146,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []inp
|
|||||||
return slot, prompt, nil
|
return slot, prompt, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
|
func (c *InputCache) findLongestCacheSlot(prompt []*input.Input) (*InputCacheSlot, int32, error) {
|
||||||
longest := int32(-1)
|
longest := int32(-1)
|
||||||
var longestSlot *InputCacheSlot
|
var longestSlot *InputCacheSlot
|
||||||
|
|
||||||
@@ -169,7 +169,7 @@ func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot
|
|||||||
return longestSlot, longest, nil
|
return longestSlot, longest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
|
func (c *InputCache) findBestCacheSlot(prompt []*input.Input) (*InputCacheSlot, int32, error) {
|
||||||
oldest := time.Now()
|
oldest := time.Now()
|
||||||
var oldestSlot *InputCacheSlot
|
var oldestSlot *InputCacheSlot
|
||||||
|
|
||||||
@@ -205,7 +205,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, i
|
|||||||
if longest > 0 && longestSlot != oldestSlot {
|
if longest > 0 && longestSlot != oldestSlot {
|
||||||
slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
|
slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
|
||||||
len(longestSlot.Inputs))
|
len(longestSlot.Inputs))
|
||||||
oldestSlot.Inputs = make([]input.Input, longest)
|
oldestSlot.Inputs = make([]*input.Input, longest)
|
||||||
copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
|
copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
|
||||||
if c.cache != nil {
|
if c.cache != nil {
|
||||||
c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
|
c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
|
||||||
@@ -215,7 +215,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, i
|
|||||||
return oldestSlot, longest, nil
|
return oldestSlot, longest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func countCommonPrefix(a []input.Input, b []input.Input) int32 {
|
func countCommonPrefix(a []*input.Input, b []*input.Input) int32 {
|
||||||
var count int32
|
var count int32
|
||||||
|
|
||||||
for i := range a {
|
for i := range a {
|
||||||
@@ -250,7 +250,7 @@ func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ErrReprocessInputs struct {
|
type ErrReprocessInputs struct {
|
||||||
Inputs []input.Input
|
Inputs []*input.Input
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *ErrReprocessInputs) Error() string {
|
func (e *ErrReprocessInputs) Error() string {
|
||||||
@@ -283,13 +283,13 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error {
|
|||||||
"id", slot.Id, "error", err)
|
"id", slot.Id, "error", err)
|
||||||
|
|
||||||
// Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
|
// Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
|
||||||
newInputs := make([]input.Input, numKeep+inputLen-(numKeep+discard))
|
newInputs := make([]*input.Input, numKeep+inputLen-(numKeep+discard))
|
||||||
copy(newInputs[:numKeep], slot.Inputs[:numKeep])
|
copy(newInputs[:numKeep], slot.Inputs[:numKeep])
|
||||||
copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
|
copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
|
||||||
|
|
||||||
// Reset the cache
|
// Reset the cache
|
||||||
_ = c.cache.Remove(slot.Id, 0, math.MaxInt32)
|
_ = c.cache.Remove(slot.Id, 0, math.MaxInt32)
|
||||||
slot.Inputs = []input.Input{}
|
slot.Inputs = []*input.Input{}
|
||||||
|
|
||||||
// Return error with inputs that need to be reprocessed
|
// Return error with inputs that need to be reprocessed
|
||||||
return &ErrReprocessInputs{Inputs: newInputs}
|
return &ErrReprocessInputs{Inputs: newInputs}
|
||||||
|
|||||||
@@ -13,50 +13,50 @@ import (
|
|||||||
func TestCountCommon(t *testing.T) {
|
func TestCountCommon(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
t1 []input.Input
|
t1 []*input.Input
|
||||||
t2 []input.Input
|
t2 []*input.Input
|
||||||
expected int32
|
expected int32
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Equal",
|
name: "Equal",
|
||||||
t1: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
t1: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||||
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||||
expected: 3,
|
expected: 3,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Prefix",
|
name: "Prefix",
|
||||||
t1: []input.Input{{Token: 1}},
|
t1: []*input.Input{{Token: 1}},
|
||||||
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||||
expected: 1,
|
expected: 1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Image Prefix",
|
name: "Image Prefix",
|
||||||
t1: []input.Input{{MultimodalHash: 1}},
|
t1: []*input.Input{{MultimodalHash: 1}},
|
||||||
t2: []input.Input{{MultimodalHash: 1}, {MultimodalHash: 2}, {MultimodalHash: 3}},
|
t2: []*input.Input{{MultimodalHash: 1}, {MultimodalHash: 2}, {MultimodalHash: 3}},
|
||||||
expected: 1,
|
expected: 1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Mixed",
|
name: "Mixed",
|
||||||
t1: []input.Input{{Token: 1}, {MultimodalHash: 1}},
|
t1: []*input.Input{{Token: 1}, {MultimodalHash: 1}},
|
||||||
t2: []input.Input{{Token: 1}, {MultimodalHash: 1}, {Token: 5}},
|
t2: []*input.Input{{Token: 1}, {MultimodalHash: 1}, {Token: 5}},
|
||||||
expected: 2,
|
expected: 2,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Mixed, Same Length",
|
name: "Mixed, Same Length",
|
||||||
t1: []input.Input{{Token: 1}, {MultimodalHash: 1}},
|
t1: []*input.Input{{Token: 1}, {MultimodalHash: 1}},
|
||||||
t2: []input.Input{{Token: 1}, {MultimodalHash: 2}},
|
t2: []*input.Input{{Token: 1}, {MultimodalHash: 2}},
|
||||||
expected: 1,
|
expected: 1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Empty",
|
name: "Empty",
|
||||||
t1: []input.Input{},
|
t1: []*input.Input{},
|
||||||
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||||
expected: 0,
|
expected: 0,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Both Empty",
|
name: "Both Empty",
|
||||||
t1: []input.Input{},
|
t1: []*input.Input{},
|
||||||
t2: []input.Input{},
|
t2: []*input.Input{},
|
||||||
expected: 0,
|
expected: 0,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -80,7 +80,7 @@ func TestFindCacheSlot(t *testing.T) {
|
|||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
cache InputCache
|
cache InputCache
|
||||||
prompt []input.Input
|
prompt []*input.Input
|
||||||
longest expected
|
longest expected
|
||||||
best expected
|
best expected
|
||||||
}{
|
}{
|
||||||
@@ -89,18 +89,18 @@ func TestFindCacheSlot(t *testing.T) {
|
|||||||
cache: InputCache{slots: []InputCacheSlot{
|
cache: InputCache{slots: []InputCacheSlot{
|
||||||
{
|
{
|
||||||
Id: 0,
|
Id: 0,
|
||||||
Inputs: []input.Input{},
|
Inputs: []*input.Input{},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Time{},
|
lastUsed: time.Time{},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Id: 1,
|
Id: 1,
|
||||||
Inputs: []input.Input{},
|
Inputs: []*input.Input{},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Time{},
|
lastUsed: time.Time{},
|
||||||
},
|
},
|
||||||
}},
|
}},
|
||||||
prompt: []input.Input{{Token: 1}},
|
prompt: []*input.Input{{Token: 1}},
|
||||||
longest: expected{result: 0, len: 0},
|
longest: expected{result: 0, len: 0},
|
||||||
best: expected{result: 0, len: 0},
|
best: expected{result: 0, len: 0},
|
||||||
},
|
},
|
||||||
@@ -109,18 +109,18 @@ func TestFindCacheSlot(t *testing.T) {
|
|||||||
cache: InputCache{slots: []InputCacheSlot{
|
cache: InputCache{slots: []InputCacheSlot{
|
||||||
{
|
{
|
||||||
Id: 0,
|
Id: 0,
|
||||||
Inputs: []input.Input{{Token: 1}},
|
Inputs: []*input.Input{{Token: 1}},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-time.Second),
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Id: 1,
|
Id: 1,
|
||||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-2 * time.Second),
|
lastUsed: time.Now().Add(-2 * time.Second),
|
||||||
},
|
},
|
||||||
}},
|
}},
|
||||||
prompt: []input.Input{{Token: 1}, {Token: 2}},
|
prompt: []*input.Input{{Token: 1}, {Token: 2}},
|
||||||
longest: expected{result: 1, len: 2},
|
longest: expected{result: 1, len: 2},
|
||||||
best: expected{result: 1, len: 2},
|
best: expected{result: 1, len: 2},
|
||||||
},
|
},
|
||||||
@@ -129,18 +129,18 @@ func TestFindCacheSlot(t *testing.T) {
|
|||||||
cache: InputCache{slots: []InputCacheSlot{
|
cache: InputCache{slots: []InputCacheSlot{
|
||||||
{
|
{
|
||||||
Id: 0,
|
Id: 0,
|
||||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-time.Second),
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Id: 1,
|
Id: 1,
|
||||||
Inputs: []input.Input{},
|
Inputs: []*input.Input{},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Time{},
|
lastUsed: time.Time{},
|
||||||
},
|
},
|
||||||
}},
|
}},
|
||||||
prompt: []input.Input{{Token: 2}},
|
prompt: []*input.Input{{Token: 2}},
|
||||||
longest: expected{result: 0, len: 0},
|
longest: expected{result: 0, len: 0},
|
||||||
best: expected{result: 1, len: 0},
|
best: expected{result: 1, len: 0},
|
||||||
},
|
},
|
||||||
@@ -150,19 +150,19 @@ func TestFindCacheSlot(t *testing.T) {
|
|||||||
slots: []InputCacheSlot{
|
slots: []InputCacheSlot{
|
||||||
{
|
{
|
||||||
Id: 0,
|
Id: 0,
|
||||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-time.Second),
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Id: 1,
|
Id: 1,
|
||||||
Inputs: []input.Input{},
|
Inputs: []*input.Input{},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Time{},
|
lastUsed: time.Time{},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
prompt: []input.Input{{Token: 1}},
|
prompt: []*input.Input{{Token: 1}},
|
||||||
longest: expected{result: 0, len: 1},
|
longest: expected{result: 0, len: 1},
|
||||||
best: expected{result: 1, len: 1},
|
best: expected{result: 1, len: 1},
|
||||||
},
|
},
|
||||||
@@ -171,18 +171,18 @@ func TestFindCacheSlot(t *testing.T) {
|
|||||||
cache: InputCache{slots: []InputCacheSlot{
|
cache: InputCache{slots: []InputCacheSlot{
|
||||||
{
|
{
|
||||||
Id: 0,
|
Id: 0,
|
||||||
Inputs: []input.Input{{Token: 1}},
|
Inputs: []*input.Input{{Token: 1}},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-time.Second),
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Id: 1,
|
Id: 1,
|
||||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-2 * time.Second),
|
lastUsed: time.Now().Add(-2 * time.Second),
|
||||||
},
|
},
|
||||||
}},
|
}},
|
||||||
prompt: []input.Input{{Token: 2}, {Token: 3}},
|
prompt: []*input.Input{{Token: 2}, {Token: 3}},
|
||||||
longest: expected{result: 0, len: 0},
|
longest: expected{result: 0, len: 0},
|
||||||
best: expected{result: 1, len: 0},
|
best: expected{result: 1, len: 0},
|
||||||
},
|
},
|
||||||
@@ -191,18 +191,18 @@ func TestFindCacheSlot(t *testing.T) {
|
|||||||
cache: InputCache{slots: []InputCacheSlot{
|
cache: InputCache{slots: []InputCacheSlot{
|
||||||
{
|
{
|
||||||
Id: 0,
|
Id: 0,
|
||||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||||
InUse: true,
|
InUse: true,
|
||||||
lastUsed: time.Now().Add(-time.Second),
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Id: 1,
|
Id: 1,
|
||||||
Inputs: []input.Input{{Token: 1}},
|
Inputs: []*input.Input{{Token: 1}},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-2 * time.Second),
|
lastUsed: time.Now().Add(-2 * time.Second),
|
||||||
},
|
},
|
||||||
}},
|
}},
|
||||||
prompt: []input.Input{{Token: 1}, {Token: 2}},
|
prompt: []*input.Input{{Token: 1}, {Token: 2}},
|
||||||
longest: expected{result: 1, len: 1},
|
longest: expected{result: 1, len: 1},
|
||||||
best: expected{result: 1, len: 2},
|
best: expected{result: 1, len: 2},
|
||||||
},
|
},
|
||||||
@@ -300,7 +300,7 @@ func TestLoadCacheSlot(t *testing.T) {
|
|||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
cache InputCache
|
cache InputCache
|
||||||
prompt []input.Input
|
prompt []*input.Input
|
||||||
wantErr bool
|
wantErr bool
|
||||||
expectedSlotId int
|
expectedSlotId int
|
||||||
expectedPrompt int // expected length of remaining prompt
|
expectedPrompt int // expected length of remaining prompt
|
||||||
@@ -312,19 +312,19 @@ func TestLoadCacheSlot(t *testing.T) {
|
|||||||
slots: []InputCacheSlot{
|
slots: []InputCacheSlot{
|
||||||
{
|
{
|
||||||
Id: 0,
|
Id: 0,
|
||||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-time.Second),
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Id: 1,
|
Id: 1,
|
||||||
Inputs: []input.Input{},
|
Inputs: []*input.Input{},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-2 * time.Second),
|
lastUsed: time.Now().Add(-2 * time.Second),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
expectedSlotId: 0,
|
expectedSlotId: 0,
|
||||||
expectedPrompt: 1, // Only token 3 remains
|
expectedPrompt: 1, // Only token 3 remains
|
||||||
@@ -336,19 +336,19 @@ func TestLoadCacheSlot(t *testing.T) {
|
|||||||
slots: []InputCacheSlot{
|
slots: []InputCacheSlot{
|
||||||
{
|
{
|
||||||
Id: 0,
|
Id: 0,
|
||||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-time.Second),
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Id: 1,
|
Id: 1,
|
||||||
Inputs: []input.Input{},
|
Inputs: []*input.Input{},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-2 * time.Second),
|
lastUsed: time.Now().Add(-2 * time.Second),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
expectedSlotId: 0,
|
expectedSlotId: 0,
|
||||||
expectedPrompt: 1, // Only token 3 remains
|
expectedPrompt: 1, // Only token 3 remains
|
||||||
@@ -360,13 +360,13 @@ func TestLoadCacheSlot(t *testing.T) {
|
|||||||
slots: []InputCacheSlot{
|
slots: []InputCacheSlot{
|
||||||
{
|
{
|
||||||
Id: 0,
|
Id: 0,
|
||||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-time.Second),
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
prompt: []input.Input{{Token: 1}, {Token: 2}},
|
prompt: []*input.Input{{Token: 1}, {Token: 2}},
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
expectedSlotId: 0,
|
expectedSlotId: 0,
|
||||||
expectedPrompt: 1, // Should leave 1 token for sampling
|
expectedPrompt: 1, // Should leave 1 token for sampling
|
||||||
@@ -378,13 +378,13 @@ func TestLoadCacheSlot(t *testing.T) {
|
|||||||
slots: []InputCacheSlot{
|
slots: []InputCacheSlot{
|
||||||
{
|
{
|
||||||
Id: 0,
|
Id: 0,
|
||||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||||
InUse: true,
|
InUse: true,
|
||||||
lastUsed: time.Now().Add(-time.Second),
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
expectedSlotId: -1,
|
expectedSlotId: -1,
|
||||||
expectedPrompt: -1,
|
expectedPrompt: -1,
|
||||||
@@ -452,7 +452,7 @@ func TestShiftCacheSlot(t *testing.T) {
|
|||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
numCtx int32
|
numCtx int32
|
||||||
inputs []input.Input
|
inputs []*input.Input
|
||||||
numKeep int32
|
numKeep int32
|
||||||
cacheErr bool
|
cacheErr bool
|
||||||
wantErr any
|
wantErr any
|
||||||
@@ -461,7 +461,7 @@ func TestShiftCacheSlot(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Normal shift",
|
name: "Normal shift",
|
||||||
numCtx: 10,
|
numCtx: 10,
|
||||||
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
inputs: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
||||||
numKeep: 2,
|
numKeep: 2,
|
||||||
cacheErr: false, // No error
|
cacheErr: false, // No error
|
||||||
wantErr: nil,
|
wantErr: nil,
|
||||||
@@ -470,7 +470,7 @@ func TestShiftCacheSlot(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Cache removal fails",
|
name: "Cache removal fails",
|
||||||
numCtx: 10,
|
numCtx: 10,
|
||||||
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
inputs: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
||||||
numKeep: 2,
|
numKeep: 2,
|
||||||
cacheErr: true,
|
cacheErr: true,
|
||||||
wantErr: &ErrReprocessInputs{},
|
wantErr: &ErrReprocessInputs{},
|
||||||
@@ -487,7 +487,7 @@ func TestShiftCacheSlot(t *testing.T) {
|
|||||||
}
|
}
|
||||||
slot := &InputCacheSlot{
|
slot := &InputCacheSlot{
|
||||||
Id: 123,
|
Id: 123,
|
||||||
Inputs: make([]input.Input, len(tt.inputs)),
|
Inputs: make([]*input.Input, len(tt.inputs)),
|
||||||
}
|
}
|
||||||
copy(slot.Inputs, tt.inputs)
|
copy(slot.Inputs, tt.inputs)
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"runtime/debug"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -28,6 +29,7 @@ import (
|
|||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
"github.com/ollama/ollama/harmony"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
@@ -51,10 +53,10 @@ type Sequence struct {
|
|||||||
iBatch int
|
iBatch int
|
||||||
|
|
||||||
// prompt inputs left to evaluate
|
// prompt inputs left to evaluate
|
||||||
inputs []input.Input
|
inputs []*input.Input
|
||||||
|
|
||||||
// inputs that have been added to a batch but not yet submitted to Forward
|
// inputs that have been added to a batch but not yet submitted to Forward
|
||||||
pendingInputs []input.Input
|
pendingInputs []*input.Input
|
||||||
|
|
||||||
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
||||||
pendingResponses []string
|
pendingResponses []string
|
||||||
@@ -86,6 +88,12 @@ type Sequence struct {
|
|||||||
// true if an embedding are to be returned instead of text generation
|
// true if an embedding are to be returned instead of text generation
|
||||||
embeddingOnly bool
|
embeddingOnly bool
|
||||||
|
|
||||||
|
// true if the sequence if finished and marked for removal on next pass
|
||||||
|
finished bool
|
||||||
|
|
||||||
|
// True if we have to skip this sequence to shift the cache
|
||||||
|
skipForShift bool
|
||||||
|
|
||||||
doneReason llm.DoneReason
|
doneReason llm.DoneReason
|
||||||
|
|
||||||
// Metrics
|
// Metrics
|
||||||
@@ -182,8 +190,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
|||||||
// inputs processes the prompt and images into a list of inputs
|
// inputs processes the prompt and images into a list of inputs
|
||||||
// by splitting the prompt on [img-<n>] tags, tokenizing text and
|
// by splitting the prompt on [img-<n>] tags, tokenizing text and
|
||||||
// decoding images
|
// decoding images
|
||||||
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, []ml.Context, multimodalStore, error) {
|
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input, []ml.Context, multimodalStore, error) {
|
||||||
var inputs []input.Input
|
var inputs []*input.Input
|
||||||
var ctxs []ml.Context
|
var ctxs []ml.Context
|
||||||
var mmStore multimodalStore
|
var mmStore multimodalStore
|
||||||
|
|
||||||
@@ -210,7 +218,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, t := range tokens {
|
for _, t := range tokens {
|
||||||
inputs = append(inputs, input.Input{Token: t})
|
inputs = append(inputs, &input.Input{Token: t})
|
||||||
}
|
}
|
||||||
|
|
||||||
// image - decode and store
|
// image - decode and store
|
||||||
@@ -243,7 +251,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
|
|||||||
|
|
||||||
mmStore.addMultimodal(imageEmbeddings)
|
mmStore.addMultimodal(imageEmbeddings)
|
||||||
|
|
||||||
inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
|
inputs = append(inputs, &input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
|
||||||
postTokenize = true
|
postTokenize = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -259,6 +267,27 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
|
|||||||
return inputs, ctxs, mmStore, nil
|
return inputs, ctxs, mmStore, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type batchState struct {
|
||||||
|
id int
|
||||||
|
ctx ml.Context
|
||||||
|
modelInput ml.Tensor
|
||||||
|
modelOutput ml.Tensor
|
||||||
|
batchInputs []*input.Input
|
||||||
|
batch input.Batch
|
||||||
|
seqs []*Sequence // full set of seqs at the time this batch was initiated
|
||||||
|
initSeqIdx int // The initial value for the set of sequences evaluated (s.nextSeq - 1)
|
||||||
|
|
||||||
|
// Signaled when this batches inputs are ready and compute can proceed
|
||||||
|
inputsReadyCh chan struct{}
|
||||||
|
|
||||||
|
// Signaling when Compute is about to begin on this batch, and
|
||||||
|
// seqs have been updated to prepare for the next batch
|
||||||
|
computeStartedCh chan struct{}
|
||||||
|
|
||||||
|
// Signaled when this batches outputs are complete and the next batch can proceed
|
||||||
|
outputsReadyCh chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
// modelPath is the location of the model to be loaded
|
// modelPath is the location of the model to be loaded
|
||||||
modelPath string
|
modelPath string
|
||||||
@@ -290,6 +319,16 @@ type Server struct {
|
|||||||
// TODO (jmorganca): make this n_batch
|
// TODO (jmorganca): make this n_batch
|
||||||
batchSize int
|
batchSize int
|
||||||
|
|
||||||
|
// Used to signal a hard failure during async processing which will panic the runner
|
||||||
|
hardErrCh chan error
|
||||||
|
|
||||||
|
// A prior batch that's still being processed
|
||||||
|
// only read or written by forwardBatch
|
||||||
|
pendingBatch *batchState
|
||||||
|
|
||||||
|
// Simple counter used only for trace logging batches
|
||||||
|
batchID int
|
||||||
|
|
||||||
// protects access to everything below this line
|
// protects access to everything below this line
|
||||||
// this is context state needed for decoding
|
// this is context state needed for decoding
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
@@ -350,45 +389,132 @@ func flushPending(seq *Sequence) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
|
func (s *Server) finishSequence(seqIndex int, reason llm.DoneReason) {
|
||||||
seq := s.seqs[seqIndex]
|
seq := s.seqs[seqIndex]
|
||||||
|
|
||||||
|
// finish could be called multiple times since we prepare 1 batch ahead
|
||||||
|
// and multiple scenarios can lead to finishing a sequence
|
||||||
|
// ensure only the first finish called is processed
|
||||||
|
if seq.finished {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
flushPending(seq)
|
flushPending(seq)
|
||||||
seq.doneReason = reason
|
seq.doneReason = reason
|
||||||
|
seq.finished = true
|
||||||
close(seq.responses)
|
close(seq.responses)
|
||||||
close(seq.embedding)
|
close(seq.embedding)
|
||||||
seq.cache.InUse = false
|
seq.cache.InUse = false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) removeFinishedSequence(seqIndex int) {
|
||||||
s.seqs[seqIndex] = nil
|
s.seqs[seqIndex] = nil
|
||||||
s.seqsSem.Release(1)
|
s.seqsSem.Release(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// track batch state between forwardBatch, computeBatch and predictForwardBatch
|
||||||
|
|
||||||
func (s *Server) run(ctx context.Context) {
|
func (s *Server) run(ctx context.Context) {
|
||||||
s.ready.Wait()
|
s.ready.Wait()
|
||||||
|
|
||||||
|
var bs *batchState
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
|
case err := <-s.hardErrCh:
|
||||||
|
panic(err)
|
||||||
default:
|
default:
|
||||||
err := s.processBatch()
|
var err error
|
||||||
|
bs, err = s.forwardBatch()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
if bs == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
go s.computeBatch(bs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) processBatch() error {
|
// forwardBatch will calculate a batch.
|
||||||
|
func (s *Server) forwardBatch() (*batchState, error) {
|
||||||
|
inputsReady := false
|
||||||
|
var inputsReadyCh chan struct{}
|
||||||
|
|
||||||
|
// If we have a pending batch still processing, wait until Compute has started
|
||||||
|
// before setting up the next batch so the seqs inputs are ready to receive their
|
||||||
|
// token values and we get the correct input pointers for the batchInputs
|
||||||
|
if s.pendingBatch != nil {
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch waiting for compute to start", "pendingBatch.id", s.pendingBatch.id)
|
||||||
|
<-s.pendingBatch.computeStartedCh
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch compute started, setting up next batch", "pendingBatch.id", s.pendingBatch.id, "id", s.batchID)
|
||||||
|
inputsReadyCh = s.pendingBatch.outputsReadyCh // Chain the ouputs from the pending batch to the next inputs batch
|
||||||
|
} else {
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch no pending batch detected", "batchID", s.batchID)
|
||||||
|
inputsReady = true // No pendingBatch, so the inputs will be ready in the seqs immediately
|
||||||
|
inputsReadyCh = make(chan struct{}, 1)
|
||||||
|
}
|
||||||
|
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
for s.allNil() {
|
for s.allNil() {
|
||||||
s.cond.Wait() // Wait until an item is added
|
s.cond.Wait() // Wait until an item is added
|
||||||
}
|
}
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
ctx := s.model.Backend().NewContext()
|
// If new sequences have been added with an active batch we delay preparing the next batch
|
||||||
defer ctx.Close()
|
// until Compute has finished
|
||||||
|
if s.pendingBatch != nil {
|
||||||
|
for seqIdx := range s.seqs {
|
||||||
|
if s.seqs[seqIdx] != s.pendingBatch.seqs[seqIdx] {
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch seqs changed, waiting for compute to finish to pick up new sequence(s)", "pendingBatch.id", s.pendingBatch.id)
|
||||||
|
s.mu.Unlock() // release the lock so computeBatch can finish up
|
||||||
|
<-s.pendingBatch.outputsReadyCh
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch pending batch outputs ready", "pendingBatch.id", s.pendingBatch.id)
|
||||||
|
s.mu.Lock()
|
||||||
|
inputsReady = true // pendingBatch completed, so the inputs are ready in the seqs
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Clear pending Batch - we'll set it if we have a batch with any inputs
|
||||||
|
s.pendingBatch = nil
|
||||||
|
|
||||||
var batchInputs []int32
|
// Remove any finished sequences before recording the active set of seqs in the batch
|
||||||
|
for seqIdx := range s.seqs {
|
||||||
|
seq := s.seqs[seqIdx]
|
||||||
|
if seq == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if seq.finished {
|
||||||
|
s.removeFinishedSequence(seqIdx)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
||||||
|
s.finishSequence(seqIdx, llm.DoneReasonLength)
|
||||||
|
s.removeFinishedSequence(seqIdx)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// next batch
|
||||||
|
nb := &batchState{
|
||||||
|
id: s.batchID,
|
||||||
|
initSeqIdx: s.nextSeq - 1,
|
||||||
|
seqs: make([]*Sequence, len(s.seqs)),
|
||||||
|
inputsReadyCh: inputsReadyCh,
|
||||||
|
computeStartedCh: make(chan struct{}, 1),
|
||||||
|
outputsReadyCh: make(chan struct{}, 1),
|
||||||
|
}
|
||||||
|
ctx := s.model.Backend().NewContext()
|
||||||
|
nb.ctx = ctx
|
||||||
|
|
||||||
|
// Record the sequences at the time we create the batch so we can detect if new sequences are added on the next pass
|
||||||
|
copy(nb.seqs, s.seqs)
|
||||||
|
|
||||||
|
// Prepare the seqs and batch, but defer the input token values as we may not be ready yet
|
||||||
|
var batchInputs []*input.Input
|
||||||
var batch input.Batch
|
var batch input.Batch
|
||||||
|
|
||||||
resumeSeq := -1
|
resumeSeq := -1
|
||||||
@@ -396,20 +522,13 @@ func (s *Server) processBatch() error {
|
|||||||
for range s.seqs {
|
for range s.seqs {
|
||||||
seqIdx = (seqIdx + 1) % len(s.seqs)
|
seqIdx = (seqIdx + 1) % len(s.seqs)
|
||||||
seq := s.seqs[seqIdx]
|
seq := s.seqs[seqIdx]
|
||||||
|
|
||||||
if seq == nil {
|
if seq == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// if past the num predict limit
|
|
||||||
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
|
||||||
s.removeSequence(seqIdx, llm.DoneReasonLength)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if !s.cache.enabled {
|
if !s.cache.enabled {
|
||||||
seq.inputs = append(seq.cache.Inputs, seq.inputs...)
|
seq.inputs = append(seq.cache.Inputs, seq.inputs...)
|
||||||
seq.cache.Inputs = []input.Input{}
|
seq.cache.Inputs = []*input.Input{}
|
||||||
}
|
}
|
||||||
|
|
||||||
batchSize := s.batchSize
|
batchSize := s.batchSize
|
||||||
@@ -449,18 +568,21 @@ func (s *Server) processBatch() error {
|
|||||||
// Prepend these inputs to the sequence's inputs queue for reprocessing
|
// Prepend these inputs to the sequence's inputs queue for reprocessing
|
||||||
seq.inputs = append(reprocess.Inputs, seq.inputs...)
|
seq.inputs = append(reprocess.Inputs, seq.inputs...)
|
||||||
// Skip this sequence but continue processing the rest
|
// Skip this sequence but continue processing the rest
|
||||||
|
seq.skipForShift = true // cleared in computeBatch below for the next batch
|
||||||
continue
|
continue
|
||||||
} else {
|
} else {
|
||||||
return err
|
ctx.Close()
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
batchInputs = append(batchInputs, inp.Token)
|
batchInputs = append(batchInputs, seq.inputs[i])
|
||||||
if inp.Multimodal != nil {
|
if inp.Multimodal != nil {
|
||||||
mm, err := seq.mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, false)
|
mm, err := seq.mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
ctx.Close()
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: mm})
|
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: mm})
|
||||||
}
|
}
|
||||||
@@ -468,10 +590,13 @@ func (s *Server) processBatch() error {
|
|||||||
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
||||||
batch.Sequences = append(batch.Sequences, seq.cache.Id)
|
batch.Sequences = append(batch.Sequences, seq.cache.Id)
|
||||||
|
|
||||||
|
// TODO BUG HERE!!!
|
||||||
|
// Somehow sometimes iBatch isn't set correctly
|
||||||
seq.iBatch = len(batch.Outputs)
|
seq.iBatch = len(batch.Outputs)
|
||||||
if i+1 == len(seq.inputs) {
|
if i+1 == len(seq.inputs) {
|
||||||
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
|
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
|
||||||
}
|
}
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
|
||||||
seq.pendingInputs = append(seq.pendingInputs, inp)
|
seq.pendingInputs = append(seq.pendingInputs, inp)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -485,36 +610,138 @@ func (s *Server) processBatch() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(batchInputs) == 0 {
|
if len(batchInputs) == 0 {
|
||||||
return nil
|
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch no batchInputs, going idle", "batchID", s.batchID)
|
||||||
|
ctx.Close()
|
||||||
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
s.batchID++
|
||||||
|
|
||||||
modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch)
|
var err error
|
||||||
|
// Actual batchInputs values will be injected into the modelInput tensor before calling Compute
|
||||||
|
nb.modelInput, nb.modelOutput, err = model.Forward(ctx, s.model, make([]int32, len(batchInputs)), batch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to decode batch: %w", err)
|
ctx.Close()
|
||||||
|
return nil, fmt.Errorf("failed to build graph: %w", err)
|
||||||
|
}
|
||||||
|
nb.batchInputs = batchInputs
|
||||||
|
nb.batch = batch
|
||||||
|
|
||||||
|
// computeBatch will close the context in the batch upon completion
|
||||||
|
s.pendingBatch = nb
|
||||||
|
|
||||||
|
if inputsReady {
|
||||||
|
nb.inputsReadyCh <- struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
logits := modelOutput.Floats()
|
return nb, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Async processing of the next batch
|
||||||
|
func (s *Server) computeBatch(bs *batchState) {
|
||||||
|
if bs == nil || bs.ctx == nil {
|
||||||
|
// Nothing to compute
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer bs.ctx.Close()
|
||||||
|
|
||||||
|
// Wait until inputs are ready
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: waiting for inputs to be ready", "batchID", bs.id)
|
||||||
|
<-bs.inputsReadyCh
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: inputs are ready", "batchID", bs.id)
|
||||||
|
|
||||||
|
// Once we complete, signal the next batch of inputs are ready
|
||||||
|
// This will unblock the next computeBatch, or forwardBatch if new seqs come in
|
||||||
|
defer func() {
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: outputs are ready", "batchID", bs.id)
|
||||||
|
bs.outputsReadyCh <- struct{}{}
|
||||||
|
}()
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
|
||||||
|
// Gather the actual input token values now that they're ready
|
||||||
|
batchInputs := make([]int32, len(bs.batchInputs))
|
||||||
|
for i := range batchInputs {
|
||||||
|
batchInputs[i] = bs.batchInputs[i].Token
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO the following logic could be run in a go routine to possibly speed up getting to Compute
|
||||||
|
|
||||||
|
// Now we run part of the decoding algorithm to adjust the seq.inputs with placeholder tokens
|
||||||
|
// so that forwardBatch can build a batchInputs set which will eventually contain the actual
|
||||||
|
// decoded tokens.
|
||||||
|
promptProcessing := make([]bool, len(s.seqs)) // track seq's we skip
|
||||||
|
nextBatchTokens := make([]*input.Input, len(s.seqs))
|
||||||
|
iBatches := make([]int, len(s.seqs)) // Record the iBatch values before releasing the lock
|
||||||
for i, seq := range s.seqs {
|
for i, seq := range s.seqs {
|
||||||
|
iBatches[i] = -1
|
||||||
if seq == nil {
|
if seq == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// Skip over any newly added sequences
|
||||||
|
if bs.seqs[i] == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// After calling Forward, pending inputs are now in the cache
|
// After calling Forward, pending inputs are now in the cache
|
||||||
if len(seq.pendingInputs) > 0 {
|
if len(seq.pendingInputs) > 0 {
|
||||||
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
|
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
|
||||||
seq.pendingInputs = []input.Input{}
|
seq.pendingInputs = []*input.Input{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// don't sample prompt processing
|
// don't sample prompt processing
|
||||||
if len(seq.inputs) != 0 {
|
if len(seq.inputs) != 0 {
|
||||||
if !s.cache.enabled {
|
if !s.cache.enabled {
|
||||||
return errors.New("caching disabled but unable to fit entire input in a batch")
|
s.hardErrCh <- fmt.Errorf("caching disabled but unable to fit entire input in a batch")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
// Record so we can skip during Decode
|
||||||
|
promptProcessing[i] = true
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
seq.numPredicted++
|
seq.numPredicted++
|
||||||
|
nextToken := &input.Input{Token: 0} // placeholder we'll fill in after Compute/Floats
|
||||||
|
seq.inputs = []*input.Input{nextToken}
|
||||||
|
nextBatchTokens[i] = nextToken
|
||||||
|
iBatches[i] = seq.iBatch
|
||||||
|
}
|
||||||
|
|
||||||
|
// At this point the seqs are ready for forwardBatch to move forward so unblock
|
||||||
|
s.mu.Unlock()
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: signaling computeStartedCh", "batchID", bs.id)
|
||||||
|
bs.computeStartedCh <- struct{}{}
|
||||||
|
|
||||||
|
bs.modelInput.BackendSetFromIntSlice(batchInputs)
|
||||||
|
bs.ctx.Compute(bs.modelOutput)
|
||||||
|
logits := bs.modelOutput.Floats()
|
||||||
|
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: logits ready", "batchID", bs.id)
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: decoding", "batchID", bs.id)
|
||||||
|
for i, seq := range s.seqs {
|
||||||
|
if seq == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Skip over any newly added sequences
|
||||||
|
if bs.seqs[i] == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Detect if the sequence we're processing has already been completed and replaced
|
||||||
|
// with a new sequence
|
||||||
|
if seq != bs.seqs[i] {
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: sequence replaced, discarding its results", "batchID", bs.id, "seqIdx", i)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// don't sample prompt processing
|
||||||
|
if promptProcessing[i] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if seq.numPredicted == 1 {
|
if seq.numPredicted == 1 {
|
||||||
seq.startGenerationTime = time.Now()
|
seq.startGenerationTime = time.Now()
|
||||||
}
|
}
|
||||||
@@ -522,35 +749,46 @@ func (s *Server) processBatch() error {
|
|||||||
// if done processing the prompt, generate an embedding and return
|
// if done processing the prompt, generate an embedding and return
|
||||||
if seq.embeddingOnly {
|
if seq.embeddingOnly {
|
||||||
// TODO(jessegross): Embedding support
|
// TODO(jessegross): Embedding support
|
||||||
slog.Warn("generation of embedding outputs not yet supported")
|
slog.Warn("generation of embedding outputs not yet supported", "id", bs.id, "seqIdx", i)
|
||||||
s.removeSequence(i, llm.DoneReasonStop)
|
s.finishSequence(i, llm.DoneReasonStop)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// sample a token
|
// sample a token
|
||||||
vocabSize := len(logits) / len(batch.Outputs)
|
vocabSize := len(logits) / len(bs.batch.Outputs)
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: vocab details", "batchID", bs.id, "seqIdx", i, "len(logits)", len(logits), "len(bs.batch.Outputs)", len(bs.batch.Outputs), "vocabSize", vocabSize, "seq.iBatch", seq.iBatch)
|
||||||
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
|
token, err := seq.sampler.Sample(logits[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to sample token: %w", err)
|
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
nextBatchTokens[i].Token = token
|
||||||
|
|
||||||
// if it's an end of sequence token, break
|
// if it's an end of sequence token, break
|
||||||
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
|
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
|
||||||
// TODO (jmorganca): we should send this back
|
// TODO (jmorganca): we should send this back
|
||||||
// as it's important for the /api/generate context
|
// as it's important for the /api/generate context
|
||||||
// seq.responses <- piece
|
// seq.responses <- piece
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: EOS", "batchID", bs.id, "seqIdx", i)
|
||||||
s.removeSequence(i, llm.DoneReasonStop)
|
s.finishSequence(i, llm.DoneReasonStop)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
|
piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
s.hardErrCh <- fmt.Errorf("failed to decode token: %w", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
seq.inputs = []input.Input{{Token: token}}
|
if nextBatchTokens[i] == nil {
|
||||||
|
slog.Error("batch corrupted", "id", bs.id, "batch", bs.batch, "seqIdx", i, "seq", seq)
|
||||||
|
s.hardErrCh <- fmt.Errorf("expected a single token during decode")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// fill in the final selected token value to replace the placeholder in the next batch
|
||||||
|
// nextBatchTokensWritten++
|
||||||
|
|
||||||
seq.pendingResponses = append(seq.pendingResponses, piece)
|
seq.pendingResponses = append(seq.pendingResponses, piece)
|
||||||
sequence := strings.Join(seq.pendingResponses, "")
|
sequence := strings.Join(seq.pendingResponses, "")
|
||||||
@@ -575,9 +813,10 @@ func (s *Server) processBatch() error {
|
|||||||
if tokenTruncated || origLen == newLen {
|
if tokenTruncated || origLen == newLen {
|
||||||
tokenLen--
|
tokenLen--
|
||||||
}
|
}
|
||||||
|
|
||||||
seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
|
seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
|
||||||
|
|
||||||
s.removeSequence(i, llm.DoneReasonStop)
|
s.finishSequence(i, llm.DoneReasonStop)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -590,11 +829,9 @@ func (s *Server) processBatch() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !flushPending(seq) {
|
if !flushPending(seq) {
|
||||||
s.removeSequence(i, llm.DoneReasonConnectionClosed)
|
s.finishSequence(i, llm.DoneReasonConnectionClosed)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -604,6 +841,15 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var harmonyMessageHandler *harmony.HarmonyMessageHandler
|
||||||
|
var harmonyToolParser *harmony.HarmonyToolCallAccumulator
|
||||||
|
if req.FunctionNameMap != nil {
|
||||||
|
harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
|
||||||
|
harmonyMessageHandler.FunctionNameMap = req.FunctionNameMap
|
||||||
|
harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(req.PrefillContent)
|
||||||
|
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
|
||||||
|
}
|
||||||
|
|
||||||
if req.Options == nil {
|
if req.Options == nil {
|
||||||
opts := api.DefaultOptions()
|
opts := api.DefaultOptions()
|
||||||
req.Options = &opts
|
req.Options = &opts
|
||||||
@@ -694,8 +940,16 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
case content, ok := <-seq.responses:
|
case content, ok := <-seq.responses:
|
||||||
if ok {
|
if ok {
|
||||||
|
var thinking string
|
||||||
|
if harmonyMessageHandler != nil {
|
||||||
|
var toolContent string
|
||||||
|
content, thinking, toolContent = harmonyMessageHandler.AddContent(content, harmonyToolParser)
|
||||||
|
harmonyToolParser.Add(toolContent)
|
||||||
|
}
|
||||||
|
|
||||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||||
Content: content,
|
Content: content,
|
||||||
|
Thinking: thinking,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||||
close(seq.quit)
|
close(seq.quit)
|
||||||
@@ -704,7 +958,29 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
} else {
|
} else {
|
||||||
|
var toolCalls []api.ToolCall
|
||||||
|
if harmonyMessageHandler != nil {
|
||||||
|
toolName, toolContent := harmonyToolParser.Drain()
|
||||||
|
if toolName != nil {
|
||||||
|
*toolName = strings.TrimPrefix(*toolName, "functions.")
|
||||||
|
*toolName = harmonyMessageHandler.FunctionNameMap.OriginalFromConverted(*toolName)
|
||||||
|
var args api.ToolCallFunctionArguments
|
||||||
|
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("failed to unmarshal tool call function arguments: %v", err), http.StatusInternalServerError)
|
||||||
|
close(seq.quit)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
toolCalls = append(toolCalls, api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: *toolName,
|
||||||
|
Arguments: args,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||||
|
ToolCalls: toolCalls,
|
||||||
Done: true,
|
Done: true,
|
||||||
DoneReason: seq.doneReason,
|
DoneReason: seq.doneReason,
|
||||||
PromptEvalCount: seq.numPromptInputs,
|
PromptEvalCount: seq.numPromptInputs,
|
||||||
@@ -736,7 +1012,10 @@ func (s *Server) reserveWorstCaseGraph() error {
|
|||||||
defer ctx.Close()
|
defer ctx.Close()
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
inputs := make([]input.Input, s.batchSize)
|
inputs := make([]*input.Input, s.batchSize)
|
||||||
|
for i := range inputs {
|
||||||
|
inputs[i] = &input.Input{}
|
||||||
|
}
|
||||||
mmStore := newMultimodalStore()
|
mmStore := newMultimodalStore()
|
||||||
|
|
||||||
// Multimodal strategy:
|
// Multimodal strategy:
|
||||||
@@ -778,8 +1057,11 @@ func (s *Server) reserveWorstCaseGraph() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(inputs) < s.batchSize {
|
if len(inputs) < s.batchSize {
|
||||||
newInputs := make([]input.Input, s.batchSize)
|
newInputs := make([]*input.Input, s.batchSize)
|
||||||
copy(newInputs, inputs)
|
copy(newInputs, inputs)
|
||||||
|
for i := len(inputs); i < s.batchSize; i++ {
|
||||||
|
newInputs[i] = &input.Input{}
|
||||||
|
}
|
||||||
inputs = newInputs
|
inputs = newInputs
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -842,6 +1124,7 @@ func (s *Server) allocModel(
|
|||||||
// Convert memory allocation panics to errors
|
// Convert memory allocation panics to errors
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
|
debug.PrintStack()
|
||||||
if err, ok := r.(error); ok {
|
if err, ok := r.(error); ok {
|
||||||
panicErr = err
|
panicErr = err
|
||||||
} else {
|
} else {
|
||||||
@@ -1011,6 +1294,7 @@ func Execute(args []string) error {
|
|||||||
server := &Server{
|
server := &Server{
|
||||||
modelPath: *mpath,
|
modelPath: *mpath,
|
||||||
status: llm.ServerStatusLaunched,
|
status: llm.ServerStatusLaunched,
|
||||||
|
hardErrCh: make(chan error, 1),
|
||||||
}
|
}
|
||||||
|
|
||||||
server.cond = sync.NewCond(&server.mu)
|
server.cond = sync.NewCond(&server.mu)
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ import (
|
|||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
"github.com/ollama/ollama/harmony"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
"github.com/ollama/ollama/openai"
|
"github.com/ollama/ollama/openai"
|
||||||
@@ -194,14 +195,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
useHarmony := shouldUseHarmony(*m) && !req.Raw
|
useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) && !req.Raw
|
||||||
var harmonyMessageHandler *HarmonyMessageHandler
|
|
||||||
var harmonyToolParser *HarmonyToolCallAccumulator
|
|
||||||
if useHarmony {
|
|
||||||
harmonyMessageHandler = NewHarmonyMessageHandler()
|
|
||||||
harmonyMessageHandler.harmonyParser.AddImplicitStart()
|
|
||||||
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate Think value: string values currently only allowed for gptoss models
|
// Validate Think value: string values currently only allowed for gptoss models
|
||||||
if req.Think != nil && req.Think.IsString() && !useHarmony {
|
if req.Think != nil && req.Think.IsString() && !useHarmony {
|
||||||
@@ -362,12 +356,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if useHarmony {
|
if !useHarmony && thinkingState != nil {
|
||||||
content, thinking, toolContent := harmonyMessageHandler.AddContent(cr.Content, harmonyToolParser)
|
|
||||||
res.Response = content
|
|
||||||
res.Thinking = thinking
|
|
||||||
harmonyToolParser.Add(toolContent)
|
|
||||||
} else if thinkingState != nil {
|
|
||||||
thinking, content := thinkingState.AddContent(cr.Content)
|
thinking, content := thinkingState.AddContent(cr.Content)
|
||||||
res.Thinking = thinking
|
res.Thinking = thinking
|
||||||
res.Response = content
|
res.Response = content
|
||||||
@@ -378,26 +367,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if cr.Done {
|
if cr.Done {
|
||||||
if useHarmony {
|
|
||||||
toolName, toolContent := harmonyToolParser.Drain()
|
|
||||||
if toolName != nil {
|
|
||||||
*toolName = strings.TrimPrefix(*toolName, "functions.")
|
|
||||||
var args api.ToolCallFunctionArguments
|
|
||||||
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
|
|
||||||
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
|
|
||||||
ch <- gin.H{"error": errStr}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
res.ToolCalls = append(res.ToolCalls, api.ToolCall{
|
|
||||||
Function: api.ToolCallFunction{
|
|
||||||
Name: *toolName,
|
|
||||||
Arguments: args,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
res.DoneReason = cr.DoneReason.String()
|
res.DoneReason = cr.DoneReason.String()
|
||||||
res.TotalDuration = time.Since(checkpointStart)
|
res.TotalDuration = time.Since(checkpointStart)
|
||||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||||
@@ -1603,27 +1572,36 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
msgs = filterThinkTags(msgs, m)
|
msgs = filterThinkTags(msgs, m)
|
||||||
|
|
||||||
var harmonyMessageHandler *HarmonyMessageHandler
|
useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template)
|
||||||
var harmonyToolParser *HarmonyToolCallAccumulator
|
|
||||||
|
|
||||||
useHarmony := shouldUseHarmony(*m)
|
|
||||||
|
|
||||||
processedTools := req.Tools
|
processedTools := req.Tools
|
||||||
|
var functionNameMap *harmony.FunctionNameMap
|
||||||
|
var prefillContentOrThinking *bool
|
||||||
if useHarmony {
|
if useHarmony {
|
||||||
harmonyMessageHandler = NewHarmonyMessageHandler()
|
functionNameMap = harmony.NewFunctionNameMap()
|
||||||
var lastMessage *api.Message
|
var lastMessage *api.Message
|
||||||
if len(msgs) > 0 {
|
if len(msgs) > 0 {
|
||||||
lastMessage = &msgs[len(msgs)-1]
|
lastMessage = &msgs[len(msgs)-1]
|
||||||
}
|
}
|
||||||
harmonyMessageHandler.harmonyParser.AddImplicitStartOrPrefill(lastMessage)
|
|
||||||
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
|
|
||||||
|
|
||||||
|
// prefill content or thinking flag if the last message is an assistant message
|
||||||
|
if lastMessage != nil && lastMessage.Role == "assistant" {
|
||||||
|
if lastMessage.Content != "" {
|
||||||
|
trueVal := true
|
||||||
|
// true sets content to be prefilled
|
||||||
|
prefillContentOrThinking = &trueVal
|
||||||
|
} else if lastMessage.Thinking != "" {
|
||||||
|
// false sets thinking to be prefilled
|
||||||
|
falseVal := false
|
||||||
|
prefillContentOrThinking = &falseVal
|
||||||
|
}
|
||||||
|
}
|
||||||
// make a copy of tools to pass to the chat prompt. Function names may be
|
// make a copy of tools to pass to the chat prompt. Function names may be
|
||||||
// renamed to be valid Harmony function names.
|
// renamed to be valid Harmony function names.
|
||||||
processedTools = make([]api.Tool, len(req.Tools))
|
processedTools = make([]api.Tool, len(req.Tools))
|
||||||
copy(processedTools, req.Tools)
|
copy(processedTools, req.Tools)
|
||||||
for i, tool := range processedTools {
|
for i, tool := range processedTools {
|
||||||
processedTools[i].Function.Name = harmonyMessageHandler.functionNameMap.ConvertAndAdd(tool.Function.Name)
|
processedTools[i].Function.Name = functionNameMap.ConvertAndAdd(tool.Function.Name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1672,15 +1650,17 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
defer close(ch)
|
defer close(ch)
|
||||||
|
|
||||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Images: images,
|
Images: images,
|
||||||
Format: req.Format,
|
Format: req.Format,
|
||||||
Options: opts,
|
Options: opts,
|
||||||
|
FunctionNameMap: functionNameMap,
|
||||||
|
PrefillContent: prefillContentOrThinking,
|
||||||
}, func(r llm.CompletionResponse) {
|
}, func(r llm.CompletionResponse) {
|
||||||
res := api.ChatResponse{
|
res := api.ChatResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
CreatedAt: time.Now().UTC(),
|
CreatedAt: time.Now().UTC(),
|
||||||
Message: api.Message{Role: "assistant", Content: r.Content},
|
Message: api.Message{Role: "assistant", Content: r.Content, Thinking: r.Thinking, ToolCalls: r.ToolCalls},
|
||||||
Done: r.Done,
|
Done: r.Done,
|
||||||
Metrics: api.Metrics{
|
Metrics: api.Metrics{
|
||||||
PromptEvalCount: r.PromptEvalCount,
|
PromptEvalCount: r.PromptEvalCount,
|
||||||
@@ -1696,31 +1676,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if useHarmony {
|
if useHarmony {
|
||||||
content, thinking, toolContent := harmonyMessageHandler.AddContent(r.Content, harmonyToolParser)
|
|
||||||
res.Message.Content = content
|
|
||||||
res.Message.Thinking = thinking
|
|
||||||
harmonyToolParser.Add(toolContent)
|
|
||||||
|
|
||||||
if r.Done {
|
|
||||||
toolName, toolContent := harmonyToolParser.Drain()
|
|
||||||
if toolName != nil {
|
|
||||||
*toolName = strings.TrimPrefix(*toolName, "functions.")
|
|
||||||
*toolName = harmonyMessageHandler.functionNameMap.OriginalFromConverted(*toolName)
|
|
||||||
var args api.ToolCallFunctionArguments
|
|
||||||
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
|
|
||||||
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
|
|
||||||
ch <- gin.H{"error": errStr}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
res.Message.ToolCalls = []api.ToolCall{{Function: api.ToolCallFunction{Name: *toolName, Arguments: args}}}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// only send messages with meaningful content (empty messages confuse clients)
|
// only send messages with meaningful content (empty messages confuse clients)
|
||||||
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || res.Done {
|
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || res.Done {
|
||||||
ch <- res
|
ch <- res
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user