mirror of
https://github.com/mudler/LocalAI.git
synced 2026-03-31 21:25:59 -04:00
150 lines
4.2 KiB
Go
150 lines
4.2 KiB
Go
package openai
|
|
|
|
import (
|
|
"encoding/json"
|
|
"strings"
|
|
|
|
"github.com/mudler/LocalAI/core/backend"
|
|
"github.com/mudler/LocalAI/core/config"
|
|
|
|
"github.com/mudler/LocalAI/core/schema"
|
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
|
model "github.com/mudler/LocalAI/pkg/model"
|
|
"github.com/mudler/xlog"
|
|
)
|
|
|
|
func ComputeChoices(
|
|
req *schema.OpenAIRequest,
|
|
predInput string,
|
|
config *config.ModelConfig,
|
|
bcl *config.ModelConfigLoader,
|
|
o *config.ApplicationConfig,
|
|
loader *model.ModelLoader,
|
|
cb func(string, *[]schema.Choice),
|
|
tokenCallback func(string, backend.TokenUsage) bool,
|
|
shouldRetry ...func(int) bool,
|
|
) ([]schema.Choice, backend.TokenUsage, []*pb.ChatDelta, error) {
|
|
n := req.N // number of completions to return
|
|
result := []schema.Choice{}
|
|
|
|
if n == 0 {
|
|
n = 1
|
|
}
|
|
|
|
// Extract the optional shouldRetry callback
|
|
var shouldRetryFn func(int) bool
|
|
if len(shouldRetry) > 0 {
|
|
shouldRetryFn = shouldRetry[0]
|
|
}
|
|
|
|
images := []string{}
|
|
for _, m := range req.Messages {
|
|
images = append(images, m.StringImages...)
|
|
}
|
|
videos := []string{}
|
|
for _, m := range req.Messages {
|
|
videos = append(videos, m.StringVideos...)
|
|
}
|
|
audios := []string{}
|
|
for _, m := range req.Messages {
|
|
audios = append(audios, m.StringAudios...)
|
|
}
|
|
|
|
// Serialize tools and tool_choice to JSON strings
|
|
toolsJSON := ""
|
|
if len(req.Tools) > 0 {
|
|
toolsBytes, err := json.Marshal(req.Tools)
|
|
if err == nil {
|
|
toolsJSON = string(toolsBytes)
|
|
}
|
|
}
|
|
toolChoiceJSON := ""
|
|
if req.ToolsChoice != nil {
|
|
toolChoiceBytes, err := json.Marshal(req.ToolsChoice)
|
|
if err == nil {
|
|
toolChoiceJSON = string(toolChoiceBytes)
|
|
}
|
|
}
|
|
|
|
// Extract logprobs from request
|
|
// According to OpenAI API: logprobs is boolean, top_logprobs (0-20) controls how many top tokens per position
|
|
var logprobs *int
|
|
var topLogprobs *int
|
|
if req.Logprobs.IsEnabled() {
|
|
// If logprobs is enabled, use top_logprobs if provided, otherwise default to 1
|
|
if req.TopLogprobs != nil {
|
|
topLogprobs = req.TopLogprobs
|
|
// For backend compatibility, set logprobs to the top_logprobs value
|
|
logprobs = req.TopLogprobs
|
|
} else {
|
|
// Default to 1 if logprobs is true but top_logprobs not specified
|
|
val := 1
|
|
logprobs = &val
|
|
topLogprobs = &val
|
|
}
|
|
}
|
|
|
|
// Extract logit_bias from request
|
|
// According to OpenAI API: logit_bias is a map of token IDs (as strings) to bias values (-100 to 100)
|
|
var logitBias map[string]float64
|
|
if len(req.LogitBias) > 0 {
|
|
logitBias = req.LogitBias
|
|
}
|
|
|
|
// get the model function to call for the result
|
|
predFunc, err := backend.ModelInferenceFunc(
|
|
req.Context, predInput, req.Messages, images, videos, audios, loader, config, bcl, o, tokenCallback, toolsJSON, toolChoiceJSON, logprobs, topLogprobs, logitBias, req.Metadata)
|
|
if err != nil {
|
|
return result, backend.TokenUsage{}, nil, err
|
|
}
|
|
|
|
tokenUsage := backend.TokenUsage{}
|
|
var allChatDeltas []*pb.ChatDelta
|
|
|
|
const maxRetries = 5
|
|
|
|
for i := 0; i < n; i++ {
|
|
var prediction backend.LLMResponse
|
|
|
|
for attempt := 0; attempt <= maxRetries; attempt++ {
|
|
p, err := predFunc()
|
|
if err != nil {
|
|
return result, backend.TokenUsage{}, nil, err
|
|
}
|
|
prediction = p
|
|
|
|
// Built-in: retry on truly empty response (no tokens at all)
|
|
if strings.TrimSpace(prediction.Response) == "" && attempt < maxRetries {
|
|
xlog.Warn("Backend returned empty response, retrying",
|
|
"attempt", attempt+1, "maxRetries", maxRetries)
|
|
continue
|
|
}
|
|
|
|
tokenUsage.Prompt = prediction.Usage.Prompt
|
|
tokenUsage.Completion = prediction.Usage.Completion
|
|
tokenUsage.TimingPromptProcessing = prediction.Usage.TimingPromptProcessing
|
|
tokenUsage.TimingTokenGeneration = prediction.Usage.TimingTokenGeneration
|
|
|
|
allChatDeltas = prediction.ChatDeltas
|
|
|
|
finetunedResponse := backend.Finetune(*config, predInput, prediction.Response)
|
|
cb(finetunedResponse, &result)
|
|
|
|
// Caller-driven retry (tool parsing, reasoning-only, etc.)
|
|
if shouldRetryFn != nil && shouldRetryFn(attempt) && attempt < maxRetries {
|
|
// Caller has already reset its state inside shouldRetry
|
|
result = result[:0]
|
|
allChatDeltas = nil
|
|
continue
|
|
}
|
|
break
|
|
}
|
|
|
|
// Add logprobs to the last choice if present
|
|
if prediction.Logprobs != nil && len(result) > 0 {
|
|
result[len(result)-1].Logprobs = prediction.Logprobs
|
|
}
|
|
}
|
|
return result, tokenUsage, allChatDeltas, err
|
|
}
|