mirror of
https://github.com/mudler/LocalAI.git
synced 2026-03-31 21:25:59 -04:00
feat: inferencing default, automatic tool parsing fallback and wire min_p (#9092)
* feat: wire min_p Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat: inferencing defaults Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * chore(refactor): re-use iterative parser Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * chore: generate automatically inference defaults from unsloth Instead of trying to re-invent the wheel and maintain here the inference defaults, prefer to consume unsloth ones, and contribute there as necessary. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * chore: apply defaults also to models installed via gallery Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * chore: be consistent and apply fallback to all endpoint Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
8036d22ec6
commit
031a36c995
48
.github/workflows/bump-inference-defaults.yml
vendored
Normal file
48
.github/workflows/bump-inference-defaults.yml
vendored
Normal file
@@ -0,0 +1,48 @@
|
||||
name: Bump inference defaults
|
||||
|
||||
on:
|
||||
schedule:
|
||||
# Run daily at 06:00 UTC
|
||||
- cron: '0 6 * * *'
|
||||
workflow_dispatch: # Allow manual trigger
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
bump:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
|
||||
- name: Re-fetch inference defaults
|
||||
run: make generate-force
|
||||
|
||||
- name: Check for changes
|
||||
id: diff
|
||||
run: |
|
||||
if git diff --quiet core/config/inference_defaults.json; then
|
||||
echo "changed=false" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
- name: Create Pull Request
|
||||
if: steps.diff.outputs.changed == 'true'
|
||||
uses: peter-evans/create-pull-request@v7
|
||||
with:
|
||||
commit-message: "chore: bump inference defaults from unsloth"
|
||||
title: "chore: bump inference defaults from unsloth"
|
||||
body: |
|
||||
Auto-generated update of `core/config/inference_defaults.json` from
|
||||
[unsloth's inference_defaults.json](https://github.com/unslothai/unsloth/blob/main/studio/backend/assets/configs/inference_defaults.json).
|
||||
|
||||
This PR was created automatically by the `bump-inference-defaults` workflow.
|
||||
branch: chore/bump-inference-defaults
|
||||
delete-branch: true
|
||||
labels: automated
|
||||
12
Makefile
12
Makefile
@@ -107,7 +107,7 @@ core/http/react-ui/dist: react-ui
|
||||
|
||||
## Build:
|
||||
|
||||
build: protogen-go install-go-tools core/http/react-ui/dist ## Build the project
|
||||
build: protogen-go generate install-go-tools core/http/react-ui/dist ## Build the project
|
||||
$(info ${GREEN}I local-ai build info:${RESET})
|
||||
$(info ${GREEN}I BUILD_TYPE: ${YELLOW}$(BUILD_TYPE)${RESET})
|
||||
$(info ${GREEN}I GO_TAGS: ${YELLOW}$(GO_TAGS)${RESET})
|
||||
@@ -398,6 +398,16 @@ protogen-go: protoc install-go-tools
|
||||
./protoc --experimental_allow_proto3_optional -Ibackend/ --go_out=pkg/grpc/proto/ --go_opt=paths=source_relative --go-grpc_out=pkg/grpc/proto/ --go-grpc_opt=paths=source_relative \
|
||||
backend/backend.proto
|
||||
|
||||
core/config/inference_defaults.json: ## Fetch inference defaults from unsloth (only if missing)
|
||||
$(GOCMD) generate ./core/config/...
|
||||
|
||||
.PHONY: generate
|
||||
generate: core/config/inference_defaults.json ## Ensure inference defaults exist
|
||||
|
||||
.PHONY: generate-force
|
||||
generate-force: ## Re-fetch inference defaults from unsloth (always)
|
||||
$(GOCMD) generate ./core/config/...
|
||||
|
||||
.PHONY: protogen-go-clean
|
||||
protogen-go-clean:
|
||||
$(RM) pkg/grpc/proto/backend.pb.go pkg/grpc/proto/backend_grpc.pb.go
|
||||
|
||||
@@ -178,6 +178,7 @@ message PredictOptions {
|
||||
int32 Logprobs = 50; // Number of top logprobs to return (maps to OpenAI logprobs parameter)
|
||||
int32 TopLogprobs = 51; // Number of top logprobs to return per token (maps to OpenAI top_logprobs parameter)
|
||||
map<string, string> Metadata = 52; // Generic per-request metadata (e.g., enable_thinking)
|
||||
float MinP = 53; // Minimum probability sampling threshold (0.0 = disabled)
|
||||
}
|
||||
|
||||
// ToolCallDelta represents an incremental tool call update from the C++ parser.
|
||||
|
||||
@@ -136,6 +136,7 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, const
|
||||
data["mirostat_eta"] = predict->mirostateta();
|
||||
data["n_keep"] = predict->nkeep();
|
||||
data["seed"] = predict->seed();
|
||||
data["min_p"] = predict->minp();
|
||||
|
||||
|
||||
std::string grammar_str = predict->grammar();
|
||||
|
||||
@@ -252,6 +252,7 @@ func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions
|
||||
TopP: float32(*c.TopP),
|
||||
NDraft: c.NDraft,
|
||||
TopK: int32(*c.TopK),
|
||||
MinP: float32(*c.MinP),
|
||||
Tokens: int32(*c.Maxtokens),
|
||||
Threads: int32(*c.Threads),
|
||||
PromptCacheAll: c.PromptCacheAll,
|
||||
|
||||
30
core/config/gen_inference_defaults/README.md
Normal file
30
core/config/gen_inference_defaults/README.md
Normal file
@@ -0,0 +1,30 @@
|
||||
# gen_inference_defaults
|
||||
|
||||
This tool fetches per-model-family inference parameter defaults from [unsloth's inference_defaults.json](https://github.com/unslothai/unsloth/blob/main/studio/backend/assets/configs/inference_defaults.json), validates the data, remaps field names to LocalAI conventions, and writes `core/config/inference_defaults.json`.
|
||||
|
||||
## What it does
|
||||
|
||||
1. Fetches the latest `inference_defaults.json` from unsloth's repo
|
||||
2. Validates that every entry has required fields (`temperature`, `top_p`, `top_k`)
|
||||
3. Validates that every pattern references an existing family
|
||||
4. Warns if pattern ordering would cause shorter prefixes to shadow longer ones
|
||||
5. Remaps `repetition_penalty` → `repeat_penalty` (LocalAI naming)
|
||||
6. Filters to allowed fields only: `temperature`, `top_p`, `top_k`, `min_p`, `repeat_penalty`, `presence_penalty`
|
||||
7. Writes the validated JSON to `core/config/inference_defaults.json`
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
# Only regenerate if the file is missing (runs during make build)
|
||||
make generate
|
||||
|
||||
# Force re-fetch from unsloth
|
||||
make generate-force
|
||||
|
||||
# Or directly via go generate
|
||||
go generate ./core/config/...
|
||||
```
|
||||
|
||||
## Automation
|
||||
|
||||
The GitHub Actions workflow `.github/workflows/bump-inference-defaults.yml` runs `make generate-force` daily and opens a PR if the upstream data changed.
|
||||
222
core/config/gen_inference_defaults/main.go
Normal file
222
core/config/gen_inference_defaults/main.go
Normal file
@@ -0,0 +1,222 @@
|
||||
// gen_inference_defaults fetches unsloth's inference_defaults.json,
|
||||
// validates its structure, remaps field names to LocalAI conventions,
|
||||
// and writes the result to core/config/inference_defaults.json.
|
||||
//
|
||||
// Run via: go generate ./core/config/
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
unslothURL = "https://raw.githubusercontent.com/unslothai/unsloth/main/studio/backend/assets/configs/inference_defaults.json"
|
||||
outputFile = "inference_defaults.json"
|
||||
)
|
||||
|
||||
// unslothDefaults mirrors the upstream JSON structure
|
||||
type unslothDefaults struct {
|
||||
Comment string `json:"_comment"`
|
||||
Families map[string]map[string]float64 `json:"families"`
|
||||
Patterns []string `json:"patterns"`
|
||||
}
|
||||
|
||||
// localAIDefaults is our output structure
|
||||
type localAIDefaults struct {
|
||||
Comment string `json:"_comment"`
|
||||
Families map[string]map[string]float64 `json:"families"`
|
||||
Patterns []string `json:"patterns"`
|
||||
}
|
||||
|
||||
// requiredFields are the fields every family entry must have
|
||||
var requiredFields = []string{"temperature", "top_p", "top_k"}
|
||||
|
||||
// fieldRemap maps unsloth field names to LocalAI field names
|
||||
var fieldRemap = map[string]string{
|
||||
"repetition_penalty": "repeat_penalty",
|
||||
}
|
||||
|
||||
// allowedFields are the only fields we keep (after remapping)
|
||||
var allowedFields = map[string]bool{
|
||||
"temperature": true,
|
||||
"top_p": true,
|
||||
"top_k": true,
|
||||
"min_p": true,
|
||||
"repeat_penalty": true,
|
||||
"presence_penalty": true,
|
||||
}
|
||||
|
||||
func main() {
|
||||
fmt.Fprintf(os.Stderr, "Fetching %s ...\n", unslothURL)
|
||||
|
||||
resp, err := http.Get(unslothURL)
|
||||
if err != nil {
|
||||
fatal("fetch failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
fatal("fetch returned HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
fatal("read body: %v", err)
|
||||
}
|
||||
|
||||
var upstream unslothDefaults
|
||||
if err := json.Unmarshal(body, &upstream); err != nil {
|
||||
fatal("parse upstream JSON: %v", err)
|
||||
}
|
||||
|
||||
// Validate structure
|
||||
if len(upstream.Families) == 0 {
|
||||
fatal("upstream has no families")
|
||||
}
|
||||
if len(upstream.Patterns) == 0 {
|
||||
fatal("upstream has no patterns")
|
||||
}
|
||||
|
||||
// Validate every pattern references a family
|
||||
for _, p := range upstream.Patterns {
|
||||
if _, ok := upstream.Families[p]; !ok {
|
||||
fatal("pattern %q has no corresponding family entry", p)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate every family has required fields and remap field names
|
||||
output := localAIDefaults{
|
||||
Comment: "Auto-generated from unsloth inference_defaults.json. DO NOT EDIT. Run go generate ./core/config/ to update.",
|
||||
Families: make(map[string]map[string]float64, len(upstream.Families)),
|
||||
Patterns: upstream.Patterns,
|
||||
}
|
||||
|
||||
// Sort family names for deterministic output
|
||||
familyNames := make([]string, 0, len(upstream.Families))
|
||||
for name := range upstream.Families {
|
||||
familyNames = append(familyNames, name)
|
||||
}
|
||||
sort.Strings(familyNames)
|
||||
|
||||
for _, name := range familyNames {
|
||||
params := upstream.Families[name]
|
||||
|
||||
// Check required fields
|
||||
for _, req := range requiredFields {
|
||||
found := false
|
||||
for k := range params {
|
||||
mapped := k
|
||||
if m, ok := fieldRemap[k]; ok {
|
||||
mapped = m
|
||||
}
|
||||
if mapped == req || k == req {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
fatal("family %q missing required field %q", name, req)
|
||||
}
|
||||
}
|
||||
|
||||
// Remap and filter fields
|
||||
remapped := make(map[string]float64)
|
||||
for k, v := range params {
|
||||
if newName, ok := fieldRemap[k]; ok {
|
||||
k = newName
|
||||
}
|
||||
if allowedFields[k] {
|
||||
remapped[k] = v
|
||||
}
|
||||
}
|
||||
output.Families[name] = remapped
|
||||
}
|
||||
|
||||
// Validate patterns are ordered longest-match-first within same prefix groups
|
||||
validatePatternOrder(output.Patterns)
|
||||
|
||||
// Marshal with ordered keys for readability
|
||||
data, err := marshalOrdered(output)
|
||||
if err != nil {
|
||||
fatal("marshal output: %v", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(outputFile, data, 0644); err != nil {
|
||||
fatal("write %s: %v", outputFile, err)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Written %s (%d families, %d patterns)\n",
|
||||
outputFile, len(output.Families), len(output.Patterns))
|
||||
}
|
||||
|
||||
// validatePatternOrder warns if a shorter pattern appears before a longer one
|
||||
// that it's a prefix of (e.g., "qwen3" before "qwen3.5")
|
||||
func validatePatternOrder(patterns []string) {
|
||||
for i, p := range patterns {
|
||||
for j := i + 1; j < len(patterns); j++ {
|
||||
if strings.HasPrefix(patterns[j], p) {
|
||||
fmt.Fprintf(os.Stderr, "WARNING: pattern %q at index %d is a prefix of %q at index %d — longer match should come first\n",
|
||||
p, i, patterns[j], j)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// marshalOrdered produces JSON with families in pattern order for readability
|
||||
func marshalOrdered(d localAIDefaults) ([]byte, error) {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("{\n")
|
||||
sb.WriteString(fmt.Sprintf(" %q: %q,\n", "_comment", d.Comment))
|
||||
sb.WriteString(" \"families\": {\n")
|
||||
|
||||
// Write families in pattern order, then any remaining not in patterns
|
||||
written := make(map[string]bool)
|
||||
allFamilies := make([]string, 0, len(d.Families))
|
||||
for _, p := range d.Patterns {
|
||||
if _, ok := d.Families[p]; ok && !written[p] {
|
||||
allFamilies = append(allFamilies, p)
|
||||
written[p] = true
|
||||
}
|
||||
}
|
||||
for name := range d.Families {
|
||||
if !written[name] {
|
||||
allFamilies = append(allFamilies, name)
|
||||
}
|
||||
}
|
||||
|
||||
for i, name := range allFamilies {
|
||||
params := d.Families[name]
|
||||
paramJSON, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
comma := ","
|
||||
if i == len(allFamilies)-1 {
|
||||
comma = ""
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf(" %q: %s%s\n", name, paramJSON, comma))
|
||||
}
|
||||
|
||||
sb.WriteString(" },\n")
|
||||
|
||||
// Patterns array
|
||||
patternsJSON, err := json.Marshal(d.Patterns)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf(" \"patterns\": %s\n", patternsJSON))
|
||||
sb.WriteString("}\n")
|
||||
|
||||
return []byte(sb.String()), nil
|
||||
}
|
||||
|
||||
func fatal(format string, args ...any) {
|
||||
fmt.Fprintf(os.Stderr, "gen_inference_defaults: "+format+"\n", args...)
|
||||
os.Exit(1)
|
||||
}
|
||||
@@ -76,6 +76,8 @@ func guessGGUFFromFile(cfg *ModelConfig, f *gguf.GGUFFile, defaultCtx int) {
|
||||
cfg.Options = append(cfg.Options, "use_jinja:true")
|
||||
cfg.KnownUsecaseStrings = append(cfg.KnownUsecaseStrings, "FLAG_CHAT")
|
||||
|
||||
// Apply per-model-family inference parameter defaults (temperature, top_p, etc.)
|
||||
ApplyInferenceDefaults(cfg, f.Metadata().Name)
|
||||
}
|
||||
|
||||
// DetectThinkingSupportFromBackend calls the ModelMetadata gRPC method to detect
|
||||
|
||||
128
core/config/inference_defaults.go
Normal file
128
core/config/inference_defaults.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package config
|
||||
|
||||
//go:generate go run ./gen_inference_defaults/
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
//go:embed inference_defaults.json
|
||||
var inferenceDefaultsJSON []byte
|
||||
|
||||
// inferenceDefaults holds the parsed inference defaults data
|
||||
type inferenceDefaults struct {
|
||||
Families map[string]map[string]float64 `json:"families"`
|
||||
Patterns []string `json:"patterns"`
|
||||
}
|
||||
|
||||
var defaultsData *inferenceDefaults
|
||||
|
||||
func init() {
|
||||
defaultsData = &inferenceDefaults{}
|
||||
if err := json.Unmarshal(inferenceDefaultsJSON, defaultsData); err != nil {
|
||||
xlog.Warn("failed to parse inference_defaults.json", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// normalizeModelID lowercases, strips org prefix (before /), and removes .gguf extension
|
||||
func normalizeModelID(modelID string) string {
|
||||
modelID = strings.ToLower(modelID)
|
||||
|
||||
// Strip org prefix (e.g., "unsloth/Qwen3.5-9B-GGUF" -> "qwen3.5-9b-gguf")
|
||||
if idx := strings.LastIndex(modelID, "/"); idx >= 0 {
|
||||
modelID = modelID[idx+1:]
|
||||
}
|
||||
|
||||
// Strip .gguf extension
|
||||
modelID = strings.TrimSuffix(modelID, ".gguf")
|
||||
|
||||
// Replace underscores with hyphens for matching
|
||||
modelID = strings.ReplaceAll(modelID, "_", "-")
|
||||
|
||||
return modelID
|
||||
}
|
||||
|
||||
// MatchModelFamily returns the inference defaults for the best-matching model family.
|
||||
// Patterns are checked in order (longest-match-first as defined in the JSON).
|
||||
// Returns nil if no family matches.
|
||||
func MatchModelFamily(modelID string) map[string]float64 {
|
||||
if defaultsData == nil || len(defaultsData.Patterns) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
normalized := normalizeModelID(modelID)
|
||||
|
||||
for _, pattern := range defaultsData.Patterns {
|
||||
if strings.Contains(normalized, pattern) {
|
||||
if family, ok := defaultsData.Families[pattern]; ok {
|
||||
return family
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApplyInferenceDefaults sets recommended inference parameters on cfg based on modelIDs.
|
||||
// Tries each modelID in order; the first match wins.
|
||||
// Only fills in parameters that are not already set (nil pointers or zero values).
|
||||
func ApplyInferenceDefaults(cfg *ModelConfig, modelIDs ...string) {
|
||||
var family map[string]float64
|
||||
var matchedID string
|
||||
for _, id := range modelIDs {
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
if f := MatchModelFamily(id); f != nil {
|
||||
family = f
|
||||
matchedID = id
|
||||
break
|
||||
}
|
||||
}
|
||||
if family == nil {
|
||||
return
|
||||
}
|
||||
|
||||
xlog.Debug("[inference_defaults] applying defaults for model", "modelID", matchedID, "family", family)
|
||||
|
||||
if cfg.Temperature == nil {
|
||||
if v, ok := family["temperature"]; ok {
|
||||
cfg.Temperature = &v
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.TopP == nil {
|
||||
if v, ok := family["top_p"]; ok {
|
||||
cfg.TopP = &v
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.TopK == nil {
|
||||
if v, ok := family["top_k"]; ok {
|
||||
intV := int(v)
|
||||
cfg.TopK = &intV
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.MinP == nil {
|
||||
if v, ok := family["min_p"]; ok {
|
||||
cfg.MinP = &v
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.RepeatPenalty == 0 {
|
||||
if v, ok := family["repeat_penalty"]; ok {
|
||||
cfg.RepeatPenalty = v
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.PresencePenalty == 0 {
|
||||
if v, ok := family["presence_penalty"]; ok {
|
||||
cfg.PresencePenalty = v
|
||||
}
|
||||
}
|
||||
}
|
||||
57
core/config/inference_defaults.json
Normal file
57
core/config/inference_defaults.json
Normal file
@@ -0,0 +1,57 @@
|
||||
{
|
||||
"_comment": "Auto-generated from unsloth inference_defaults.json. DO NOT EDIT. Run go generate ./core/config/ to update.",
|
||||
"families": {
|
||||
"qwen3.5": {"min_p":0,"presence_penalty":1.5,"repeat_penalty":1,"temperature":0.7,"top_k":20,"top_p":0.8},
|
||||
"qwen3-coder": {"min_p":0,"repeat_penalty":1,"temperature":0.7,"top_k":20,"top_p":0.8},
|
||||
"qwen3-next": {"min_p":0,"repeat_penalty":1,"temperature":0.7,"top_k":20,"top_p":0.8},
|
||||
"qwen3-vl": {"min_p":0,"repeat_penalty":1,"temperature":0.7,"top_k":20,"top_p":0.8},
|
||||
"qwen3": {"min_p":0,"repeat_penalty":1,"temperature":0.6,"top_k":20,"top_p":0.95},
|
||||
"qwen2.5-coder": {"min_p":0.1,"repeat_penalty":1,"temperature":1.5,"top_k":-1,"top_p":0.95},
|
||||
"qwen2.5-vl": {"min_p":0.1,"repeat_penalty":1,"temperature":1.5,"top_k":-1,"top_p":0.95},
|
||||
"qwen2.5-omni": {"min_p":0,"repeat_penalty":1,"temperature":0.7,"top_k":20,"top_p":0.8},
|
||||
"qwen2.5-math": {"min_p":0,"repeat_penalty":1,"temperature":0.7,"top_k":20,"top_p":0.8},
|
||||
"qwen2.5": {"min_p":0,"repeat_penalty":1,"temperature":0.7,"top_k":20,"top_p":0.8},
|
||||
"qwen2-vl": {"min_p":0.1,"repeat_penalty":1,"temperature":1.5,"top_k":-1,"top_p":0.95},
|
||||
"qwen2": {"min_p":0,"repeat_penalty":1,"temperature":0.7,"top_k":20,"top_p":0.8},
|
||||
"qwq": {"min_p":0,"repeat_penalty":1,"temperature":0.6,"top_k":40,"top_p":0.95},
|
||||
"gemma-3n": {"min_p":0,"repeat_penalty":1,"temperature":1,"top_k":64,"top_p":0.95},
|
||||
"gemma-3": {"min_p":0,"repeat_penalty":1,"temperature":1,"top_k":64,"top_p":0.95},
|
||||
"medgemma": {"min_p":0,"repeat_penalty":1,"temperature":1,"top_k":64,"top_p":0.95},
|
||||
"gemma-2": {"min_p":0,"repeat_penalty":1,"temperature":1,"top_k":64,"top_p":0.95},
|
||||
"llama-4": {"min_p":0.01,"repeat_penalty":1,"temperature":1,"top_k":-1,"top_p":0.9},
|
||||
"llama-3.3": {"min_p":0.1,"repeat_penalty":1,"temperature":1.5,"top_k":-1,"top_p":0.95},
|
||||
"llama-3.2": {"min_p":0.1,"repeat_penalty":1,"temperature":1.5,"top_k":-1,"top_p":0.95},
|
||||
"llama-3.1": {"min_p":0.1,"repeat_penalty":1,"temperature":1.5,"top_k":-1,"top_p":0.95},
|
||||
"llama-3": {"min_p":0.1,"repeat_penalty":1,"temperature":1.5,"top_k":-1,"top_p":0.95},
|
||||
"phi-4": {"min_p":0,"repeat_penalty":1,"temperature":0.8,"top_k":-1,"top_p":0.95},
|
||||
"phi-3": {"min_p":0.01,"repeat_penalty":1,"temperature":0.7,"top_k":-1,"top_p":0.9},
|
||||
"mistral-nemo": {"min_p":0.01,"repeat_penalty":1,"temperature":0.7,"top_k":-1,"top_p":0.95},
|
||||
"mistral-small": {"min_p":0.01,"repeat_penalty":1,"temperature":0.15,"top_k":-1,"top_p":0.95},
|
||||
"mistral-large": {"min_p":0.01,"repeat_penalty":1,"temperature":0.7,"top_k":-1,"top_p":0.95},
|
||||
"magistral": {"min_p":0.01,"repeat_penalty":1,"temperature":0.7,"top_k":-1,"top_p":0.95},
|
||||
"ministral": {"min_p":0.01,"repeat_penalty":1,"temperature":0.15,"top_k":-1,"top_p":0.95},
|
||||
"devstral": {"min_p":0.01,"repeat_penalty":1,"temperature":0.7,"top_k":-1,"top_p":0.95},
|
||||
"pixtral": {"min_p":0.1,"repeat_penalty":1,"temperature":1.5,"top_k":-1,"top_p":0.95},
|
||||
"deepseek-r1": {"min_p":0.01,"repeat_penalty":1,"temperature":0.6,"top_k":-1,"top_p":0.95},
|
||||
"deepseek-v3": {"min_p":0.01,"repeat_penalty":1,"temperature":0.6,"top_k":-1,"top_p":0.95},
|
||||
"deepseek-ocr": {"min_p":0.01,"repeat_penalty":1,"temperature":0,"top_k":-1,"top_p":0.95},
|
||||
"glm-5": {"min_p":0.01,"repeat_penalty":1,"temperature":1,"top_k":-1,"top_p":0.95},
|
||||
"glm-4": {"min_p":0.01,"repeat_penalty":1,"temperature":1,"top_k":-1,"top_p":0.95},
|
||||
"nemotron": {"min_p":0.01,"repeat_penalty":1,"temperature":1,"top_k":-1,"top_p":1},
|
||||
"minimax-m2.5": {"min_p":0.01,"repeat_penalty":1,"temperature":1,"top_k":40,"top_p":0.95},
|
||||
"minimax": {"min_p":0.01,"repeat_penalty":1,"temperature":1,"top_k":40,"top_p":0.95},
|
||||
"gpt-oss": {"min_p":0.01,"repeat_penalty":1,"temperature":1,"top_k":0,"top_p":1},
|
||||
"granite-4": {"min_p":0.01,"repeat_penalty":1,"temperature":0,"top_k":0,"top_p":1},
|
||||
"kimi-k2": {"min_p":0.01,"repeat_penalty":1,"temperature":0.6,"top_k":-1,"top_p":0.95},
|
||||
"kimi": {"min_p":0.01,"repeat_penalty":1,"temperature":0.6,"top_k":-1,"top_p":0.95},
|
||||
"lfm2": {"min_p":0.15,"repeat_penalty":1.05,"temperature":0.1,"top_k":50,"top_p":0.1},
|
||||
"smollm": {"min_p":0.01,"repeat_penalty":1,"temperature":0.7,"top_k":-1,"top_p":0.95},
|
||||
"olmo": {"min_p":0.01,"repeat_penalty":1,"temperature":0.7,"top_k":-1,"top_p":0.95},
|
||||
"falcon": {"min_p":0.01,"repeat_penalty":1,"temperature":0.7,"top_k":-1,"top_p":0.95},
|
||||
"ernie": {"min_p":0.01,"repeat_penalty":1,"temperature":0.7,"top_k":-1,"top_p":0.95},
|
||||
"seed": {"min_p":0.01,"repeat_penalty":1,"temperature":0.7,"top_k":-1,"top_p":0.95},
|
||||
"grok": {"min_p":0.01,"repeat_penalty":1,"temperature":1,"top_k":-1,"top_p":0.95},
|
||||
"mimo": {"min_p":0.01,"repeat_penalty":1,"temperature":0.7,"top_k":-1,"top_p":0.95}
|
||||
},
|
||||
"patterns": ["qwen3.5","qwen3-coder","qwen3-next","qwen3-vl","qwen3","qwen2.5-coder","qwen2.5-vl","qwen2.5-omni","qwen2.5-math","qwen2.5","qwen2-vl","qwen2","qwq","gemma-3n","gemma-3","medgemma","gemma-2","llama-4","llama-3.3","llama-3.2","llama-3.1","llama-3","phi-4","phi-3","mistral-nemo","mistral-small","mistral-large","magistral","ministral","devstral","pixtral","deepseek-r1","deepseek-v3","deepseek-ocr","glm-5","glm-4","nemotron","minimax-m2.5","minimax","gpt-oss","granite-4","kimi-k2","kimi","lfm2","smollm","olmo","falcon","ernie","seed","grok","mimo"]
|
||||
}
|
||||
154
core/config/inference_defaults_test.go
Normal file
154
core/config/inference_defaults_test.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package config_test
|
||||
|
||||
import (
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("InferenceDefaults", func() {
|
||||
Describe("MatchModelFamily", func() {
|
||||
It("matches qwen3.5 and not qwen3 for Qwen3.5 model", func() {
|
||||
family := config.MatchModelFamily("unsloth/Qwen3.5-9B-GGUF")
|
||||
Expect(family).ToNot(BeNil())
|
||||
Expect(family["temperature"]).To(Equal(0.7))
|
||||
Expect(family["top_p"]).To(Equal(0.8))
|
||||
Expect(family["top_k"]).To(Equal(float64(20)))
|
||||
Expect(family["presence_penalty"]).To(Equal(1.5))
|
||||
})
|
||||
|
||||
It("matches llama-3.3 and not llama-3 for Llama-3.3 model", func() {
|
||||
family := config.MatchModelFamily("meta-llama/Llama-3.3-70B")
|
||||
Expect(family).ToNot(BeNil())
|
||||
Expect(family["temperature"]).To(Equal(1.5))
|
||||
})
|
||||
|
||||
It("is case insensitive", func() {
|
||||
family := config.MatchModelFamily("QWEN3.5-9B")
|
||||
Expect(family).ToNot(BeNil())
|
||||
Expect(family["temperature"]).To(Equal(0.7))
|
||||
})
|
||||
|
||||
It("strips org prefix", func() {
|
||||
family := config.MatchModelFamily("someorg/deepseek-r1-7b.gguf")
|
||||
Expect(family).ToNot(BeNil())
|
||||
Expect(family["temperature"]).To(Equal(0.6))
|
||||
})
|
||||
|
||||
It("strips .gguf extension", func() {
|
||||
family := config.MatchModelFamily("gemma-3-4b-q4_k_m.gguf")
|
||||
Expect(family).ToNot(BeNil())
|
||||
Expect(family["temperature"]).To(Equal(1.0))
|
||||
Expect(family["top_k"]).To(Equal(float64(64)))
|
||||
})
|
||||
|
||||
It("returns nil for unknown model", func() {
|
||||
family := config.MatchModelFamily("my-custom-model-v1")
|
||||
Expect(family).To(BeNil())
|
||||
})
|
||||
|
||||
It("matches qwen3-coder before qwen3", func() {
|
||||
family := config.MatchModelFamily("Qwen3-Coder-8B")
|
||||
Expect(family).ToNot(BeNil())
|
||||
Expect(family["temperature"]).To(Equal(0.7))
|
||||
Expect(family["top_p"]).To(Equal(0.8))
|
||||
})
|
||||
|
||||
It("matches deepseek-v3", func() {
|
||||
family := config.MatchModelFamily("deepseek-v3-base")
|
||||
Expect(family).ToNot(BeNil())
|
||||
Expect(family["temperature"]).To(Equal(0.6))
|
||||
})
|
||||
|
||||
It("matches lfm2 with non-standard params", func() {
|
||||
family := config.MatchModelFamily("lfm2-7b")
|
||||
Expect(family).ToNot(BeNil())
|
||||
Expect(family["temperature"]).To(Equal(0.1))
|
||||
Expect(family["top_p"]).To(Equal(0.1))
|
||||
Expect(family["min_p"]).To(Equal(0.15))
|
||||
Expect(family["repeat_penalty"]).To(Equal(1.05))
|
||||
})
|
||||
|
||||
It("includes min_p for llama-3.3", func() {
|
||||
family := config.MatchModelFamily("llama-3.3-70b")
|
||||
Expect(family).ToNot(BeNil())
|
||||
Expect(family["min_p"]).To(Equal(0.1))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("ApplyInferenceDefaults", func() {
|
||||
It("fills nil fields from defaults", func() {
|
||||
cfg := &config.ModelConfig{}
|
||||
config.ApplyInferenceDefaults(cfg, "gemma-3-8b")
|
||||
|
||||
Expect(cfg.Temperature).ToNot(BeNil())
|
||||
Expect(*cfg.Temperature).To(Equal(1.0))
|
||||
Expect(cfg.TopP).ToNot(BeNil())
|
||||
Expect(*cfg.TopP).To(Equal(0.95))
|
||||
Expect(cfg.TopK).ToNot(BeNil())
|
||||
Expect(*cfg.TopK).To(Equal(64))
|
||||
Expect(cfg.MinP).ToNot(BeNil())
|
||||
Expect(*cfg.MinP).To(Equal(0.0))
|
||||
Expect(cfg.RepeatPenalty).To(Equal(1.0))
|
||||
})
|
||||
|
||||
It("fills min_p with non-zero value", func() {
|
||||
cfg := &config.ModelConfig{}
|
||||
config.ApplyInferenceDefaults(cfg, "llama-3.3-8b")
|
||||
|
||||
Expect(cfg.MinP).ToNot(BeNil())
|
||||
Expect(*cfg.MinP).To(Equal(0.1))
|
||||
})
|
||||
|
||||
It("preserves non-nil fields", func() {
|
||||
temp := 0.5
|
||||
topK := 10
|
||||
cfg := &config.ModelConfig{
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
Temperature: &temp,
|
||||
TopK: &topK,
|
||||
},
|
||||
}
|
||||
config.ApplyInferenceDefaults(cfg, "gemma-3-8b")
|
||||
|
||||
Expect(*cfg.Temperature).To(Equal(0.5))
|
||||
Expect(*cfg.TopK).To(Equal(10))
|
||||
// TopP should be filled since it was nil
|
||||
Expect(cfg.TopP).ToNot(BeNil())
|
||||
Expect(*cfg.TopP).To(Equal(0.95))
|
||||
})
|
||||
|
||||
It("preserves non-zero repeat penalty", func() {
|
||||
cfg := &config.ModelConfig{
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
RepeatPenalty: 1.2,
|
||||
},
|
||||
}
|
||||
config.ApplyInferenceDefaults(cfg, "gemma-3-8b")
|
||||
Expect(cfg.RepeatPenalty).To(Equal(1.2))
|
||||
})
|
||||
|
||||
It("preserves non-nil min_p", func() {
|
||||
minP := 0.05
|
||||
cfg := &config.ModelConfig{
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
MinP: &minP,
|
||||
},
|
||||
}
|
||||
config.ApplyInferenceDefaults(cfg, "llama-3.3-8b")
|
||||
Expect(*cfg.MinP).To(Equal(0.05))
|
||||
})
|
||||
|
||||
It("does nothing for unknown model", func() {
|
||||
cfg := &config.ModelConfig{}
|
||||
config.ApplyInferenceDefaults(cfg, "my-custom-model")
|
||||
|
||||
Expect(cfg.Temperature).To(BeNil())
|
||||
Expect(cfg.TopP).To(BeNil())
|
||||
Expect(cfg.TopK).To(BeNil())
|
||||
Expect(cfg.MinP).To(BeNil())
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -375,9 +375,15 @@ func (cfg *ModelConfig) SetDefaults(opts ...ConfigLoaderOption) {
|
||||
threads := lo.threads
|
||||
f16 := lo.f16
|
||||
debug := lo.debug
|
||||
|
||||
// Apply model-family-specific inference defaults before generic fallbacks.
|
||||
// This ensures gallery-installed and runtime-loaded models get optimal parameters.
|
||||
ApplyInferenceDefaults(cfg, cfg.Name, cfg.Model)
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/75cd4c77292034ecec587ecb401366f57338f7c0/common/sampling.h#L22
|
||||
defaultTopP := 0.95
|
||||
defaultTopK := 40
|
||||
defaultMinP := 0.0
|
||||
defaultTemp := 0.9
|
||||
// https://github.com/mudler/LocalAI/issues/2780
|
||||
defaultMirostat := 0
|
||||
@@ -400,6 +406,10 @@ func (cfg *ModelConfig) SetDefaults(opts ...ConfigLoaderOption) {
|
||||
cfg.TopK = &defaultTopK
|
||||
}
|
||||
|
||||
if cfg.MinP == nil {
|
||||
cfg.MinP = &defaultMinP
|
||||
}
|
||||
|
||||
if cfg.TypicalP == nil {
|
||||
cfg.TypicalP = &defaultTypicalP
|
||||
}
|
||||
|
||||
@@ -114,6 +114,7 @@ func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error)
|
||||
GrammarConfig: functions.GrammarConfig{
|
||||
NoGrammar: true,
|
||||
},
|
||||
AutomaticToolParsingFallback: true,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -249,6 +250,9 @@ func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply per-model-family inference parameter defaults
|
||||
config.ApplyInferenceDefaults(&modelConfig, details.URI)
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
|
||||
@@ -81,6 +81,9 @@ func (i *MLXImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
},
|
||||
}
|
||||
|
||||
// Apply per-model-family inference parameter defaults
|
||||
config.ApplyInferenceDefaults(&modelConfig, details.URI)
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
|
||||
@@ -97,6 +97,9 @@ func (i *TransformersImporter) Import(details Details) (gallery.ModelConfig, err
|
||||
modelConfig.ModelType = modelType
|
||||
modelConfig.Quantization = quantization
|
||||
|
||||
// Apply per-model-family inference parameter defaults
|
||||
config.ApplyInferenceDefaults(&modelConfig, details.URI)
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
|
||||
@@ -85,6 +85,9 @@ func (i *VLLMImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
},
|
||||
}
|
||||
|
||||
// Apply per-model-family inference parameter defaults
|
||||
config.ApplyInferenceDefaults(&modelConfig, details.URI)
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
|
||||
@@ -264,6 +264,49 @@ func InstallModel(ctx context.Context, systemState *system.SystemState, nameOver
|
||||
return nil, fmt.Errorf("failed to unmarshal updated config YAML: %v", err)
|
||||
}
|
||||
|
||||
// Apply model-family-specific inference defaults so they are persisted in the config YAML.
|
||||
// Apply to the typed struct for validation, and merge into configMap for serialization
|
||||
// (configMap preserves unknown fields that ModelConfig would drop).
|
||||
lconfig.ApplyInferenceDefaults(&modelConfig, name, modelConfig.Model)
|
||||
|
||||
// Merge inference defaults into configMap so they are persisted without losing unknown fields.
|
||||
if modelConfig.Temperature != nil {
|
||||
if _, exists := configMap["temperature"]; !exists {
|
||||
configMap["temperature"] = *modelConfig.Temperature
|
||||
}
|
||||
}
|
||||
if modelConfig.TopP != nil {
|
||||
if _, exists := configMap["top_p"]; !exists {
|
||||
configMap["top_p"] = *modelConfig.TopP
|
||||
}
|
||||
}
|
||||
if modelConfig.TopK != nil {
|
||||
if _, exists := configMap["top_k"]; !exists {
|
||||
configMap["top_k"] = *modelConfig.TopK
|
||||
}
|
||||
}
|
||||
if modelConfig.MinP != nil {
|
||||
if _, exists := configMap["min_p"]; !exists {
|
||||
configMap["min_p"] = *modelConfig.MinP
|
||||
}
|
||||
}
|
||||
if modelConfig.RepeatPenalty != 0 {
|
||||
if _, exists := configMap["repeat_penalty"]; !exists {
|
||||
configMap["repeat_penalty"] = modelConfig.RepeatPenalty
|
||||
}
|
||||
}
|
||||
if modelConfig.PresencePenalty != 0 {
|
||||
if _, exists := configMap["presence_penalty"]; !exists {
|
||||
configMap["presence_penalty"] = modelConfig.PresencePenalty
|
||||
}
|
||||
}
|
||||
|
||||
// Re-marshal from configMap to preserve unknown fields
|
||||
updatedConfigYAML, err = yaml.Marshal(configMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal config with inference defaults: %v", err)
|
||||
}
|
||||
|
||||
if valid, err := modelConfig.Validate(); !valid {
|
||||
return nil, fmt.Errorf("failed to validate updated config YAML: %v", err)
|
||||
}
|
||||
|
||||
@@ -306,6 +306,35 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic
|
||||
if textContent != "" {
|
||||
contentBlocks = append([]schema.AnthropicContentBlock{{Type: "text", Text: textContent}}, contentBlocks...)
|
||||
}
|
||||
} else if !shouldUseFn && cfg.FunctionsConfig.AutomaticToolParsingFallback && result != "" {
|
||||
// Automatic tool parsing fallback: no tools in request but model emitted tool call markup
|
||||
parsed := functions.ParseFunctionCall(result, cfg.FunctionsConfig)
|
||||
if len(parsed) > 0 {
|
||||
stopReason = "tool_use"
|
||||
stripped := functions.StripToolCallMarkup(result)
|
||||
if stripped != "" {
|
||||
contentBlocks = append(contentBlocks, schema.AnthropicContentBlock{Type: "text", Text: stripped})
|
||||
}
|
||||
for i, fc := range parsed {
|
||||
var inputArgs map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(fc.Arguments), &inputArgs); err != nil {
|
||||
inputArgs = map[string]interface{}{"raw": fc.Arguments}
|
||||
}
|
||||
toolCallID := fc.ID
|
||||
if toolCallID == "" {
|
||||
toolCallID = fmt.Sprintf("toolu_%s_%d", id, i)
|
||||
}
|
||||
contentBlocks = append(contentBlocks, schema.AnthropicContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: toolCallID,
|
||||
Name: fc.Name,
|
||||
Input: inputArgs,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
stopReason = "end_turn"
|
||||
contentBlocks = []schema.AnthropicContentBlock{{Type: "text", Text: result}}
|
||||
}
|
||||
} else {
|
||||
stopReason = "end_turn"
|
||||
contentBlocks = []schema.AnthropicContentBlock{
|
||||
@@ -522,6 +551,51 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
}
|
||||
}
|
||||
|
||||
// Automatic tool parsing fallback for streaming: when no tools were requested
|
||||
// but the model emitted tool call markup, parse and emit as tool_use blocks.
|
||||
if !shouldUseFn && cfg.FunctionsConfig.AutomaticToolParsingFallback && accumulatedContent != "" && toolCallsEmitted == 0 {
|
||||
parsed := functions.ParseFunctionCall(accumulatedContent, cfg.FunctionsConfig)
|
||||
if len(parsed) > 0 {
|
||||
// Close the text content block
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: currentBlockIndex,
|
||||
})
|
||||
currentBlockIndex++
|
||||
inToolCall = true
|
||||
|
||||
for i, fc := range parsed {
|
||||
toolCallID := fc.ID
|
||||
if toolCallID == "" {
|
||||
toolCallID = fmt.Sprintf("toolu_%s_%d", id, i)
|
||||
}
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_start",
|
||||
Index: currentBlockIndex,
|
||||
ContentBlock: &schema.AnthropicContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: toolCallID,
|
||||
Name: fc.Name,
|
||||
},
|
||||
})
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: currentBlockIndex,
|
||||
Delta: &schema.AnthropicStreamDelta{
|
||||
Type: "input_json_delta",
|
||||
PartialJSON: fc.Arguments,
|
||||
},
|
||||
})
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: currentBlockIndex,
|
||||
})
|
||||
currentBlockIndex++
|
||||
toolCallsEmitted++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// No MCP tools to execute, close stream
|
||||
if !inToolCall {
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
|
||||
@@ -751,8 +751,8 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
collectedToolCalls = mergeToolCallDeltas(collectedToolCalls, ev.Choices[0].Delta.ToolCalls)
|
||||
}
|
||||
}
|
||||
// Collect content for MCP conversation history
|
||||
if hasMCPToolsStream && ev.Choices[0].Delta != nil && ev.Choices[0].Delta.Content != nil {
|
||||
// Collect content for MCP conversation history and automatic tool parsing fallback
|
||||
if (hasMCPToolsStream || config.FunctionsConfig.AutomaticToolParsingFallback) && ev.Choices[0].Delta != nil && ev.Choices[0].Delta.Content != nil {
|
||||
if s, ok := ev.Choices[0].Delta.Content.(string); ok {
|
||||
collectedContent += s
|
||||
} else if sp, ok := ev.Choices[0].Delta.Content.(*string); ok && sp != nil {
|
||||
@@ -857,6 +857,43 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
}
|
||||
}
|
||||
|
||||
// Automatic tool parsing fallback for streaming: when no tools were
|
||||
// requested but the model emitted tool call markup, parse and emit them.
|
||||
if !shouldUseFn && config.FunctionsConfig.AutomaticToolParsingFallback && collectedContent != "" && !toolsCalled {
|
||||
parsed := functions.ParseFunctionCall(collectedContent, config.FunctionsConfig)
|
||||
for i, fc := range parsed {
|
||||
toolCallID := fc.ID
|
||||
if toolCallID == "" {
|
||||
toolCallID = id
|
||||
}
|
||||
toolCallMsg := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model,
|
||||
Choices: []schema.Choice{{
|
||||
Delta: &schema.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []schema.ToolCall{{
|
||||
Index: i,
|
||||
ID: toolCallID,
|
||||
Type: "function",
|
||||
FunctionCall: schema.FunctionCall{
|
||||
Name: fc.Name,
|
||||
Arguments: fc.Arguments,
|
||||
},
|
||||
}},
|
||||
},
|
||||
Index: 0,
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
respData, _ := json.Marshal(toolCallMsg)
|
||||
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
|
||||
c.Response().Flush()
|
||||
toolsCalled = true
|
||||
}
|
||||
}
|
||||
|
||||
// No MCP tools to execute, send final stop message
|
||||
finishReason := FinishReasonStop
|
||||
if toolsCalled && len(input.Tools) > 0 {
|
||||
@@ -995,6 +1032,16 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
funcResults = functions.ParseFunctionCall(cbRawResult, config.FunctionsConfig)
|
||||
}
|
||||
|
||||
// Content-based tool call fallback: if no tool calls were found,
|
||||
// try parsing the raw result — ParseFunctionCall handles detection internally.
|
||||
if len(funcResults) == 0 {
|
||||
contentFuncResults := functions.ParseFunctionCall(cbRawResult, config.FunctionsConfig)
|
||||
if len(contentFuncResults) > 0 {
|
||||
funcResults = contentFuncResults
|
||||
textContentToReturn = functions.StripToolCallMarkup(cbRawResult)
|
||||
}
|
||||
}
|
||||
|
||||
noActionsToRun := len(funcResults) > 0 && funcResults[0].Name == noActionName || len(funcResults) == 0
|
||||
|
||||
switch {
|
||||
@@ -1070,6 +1117,48 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
}
|
||||
}
|
||||
|
||||
// Automatic tool parsing fallback: when no tools/functions were in the
|
||||
// request but the model emitted tool call markup, parse and surface them.
|
||||
if !shouldUseFn && config.FunctionsConfig.AutomaticToolParsingFallback && len(result) > 0 {
|
||||
for i, choice := range result {
|
||||
if choice.Message == nil || choice.Message.Content == nil {
|
||||
continue
|
||||
}
|
||||
contentStr, ok := choice.Message.Content.(string)
|
||||
if !ok || contentStr == "" {
|
||||
continue
|
||||
}
|
||||
parsed := functions.ParseFunctionCall(contentStr, config.FunctionsConfig)
|
||||
if len(parsed) == 0 {
|
||||
continue
|
||||
}
|
||||
stripped := functions.StripToolCallMarkup(contentStr)
|
||||
toolCallsReason := FinishReasonToolCalls
|
||||
result[i].FinishReason = &toolCallsReason
|
||||
if stripped != "" {
|
||||
result[i].Message.Content = &stripped
|
||||
} else {
|
||||
result[i].Message.Content = nil
|
||||
}
|
||||
for _, fc := range parsed {
|
||||
toolCallID := fc.ID
|
||||
if toolCallID == "" {
|
||||
toolCallID = id
|
||||
}
|
||||
result[i].Message.ToolCalls = append(result[i].Message.ToolCalls,
|
||||
schema.ToolCall{
|
||||
ID: toolCallID,
|
||||
Type: "function",
|
||||
FunctionCall: schema.FunctionCall{
|
||||
Name: fc.Name,
|
||||
Arguments: fc.Arguments,
|
||||
},
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MCP server-side tool execution loop:
|
||||
// If we have MCP tools and the model returned tool_calls, execute MCP tools
|
||||
// and re-run inference with the results appended to the conversation.
|
||||
|
||||
@@ -1013,6 +1013,35 @@ func handleBackgroundNonStream(ctx context.Context, store *ResponseStore, respon
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(result, resultLogprobs)},
|
||||
})
|
||||
}
|
||||
} else if !shouldUseFn && cfg.FunctionsConfig.AutomaticToolParsingFallback && result != "" {
|
||||
// Automatic tool parsing fallback: no tools in request but model emitted tool call markup
|
||||
parsed := functions.ParseFunctionCall(result, cfg.FunctionsConfig)
|
||||
if len(parsed) > 0 {
|
||||
stripped := functions.StripToolCallMarkup(result)
|
||||
if stripped != "" {
|
||||
allOutputItems = append(allOutputItems, schema.ORItemField{
|
||||
Type: "message", ID: fmt.Sprintf("msg_%s", uuid.New().String()),
|
||||
Status: "completed", Role: "assistant",
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(stripped, resultLogprobs)},
|
||||
})
|
||||
}
|
||||
for _, fc := range parsed {
|
||||
toolCallID := fc.ID
|
||||
if toolCallID == "" {
|
||||
toolCallID = fmt.Sprintf("fc_%s", uuid.New().String())
|
||||
}
|
||||
allOutputItems = append(allOutputItems, schema.ORItemField{
|
||||
Type: "function_call", ID: fmt.Sprintf("fc_%s", uuid.New().String()),
|
||||
Status: "completed", CallID: toolCallID, Name: fc.Name, Arguments: fc.Arguments,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
allOutputItems = append(allOutputItems, schema.ORItemField{
|
||||
Type: "message", ID: fmt.Sprintf("msg_%s", uuid.New().String()),
|
||||
Status: "completed", Role: "assistant",
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(result, resultLogprobs)},
|
||||
})
|
||||
}
|
||||
} else {
|
||||
allOutputItems = append(allOutputItems, schema.ORItemField{
|
||||
Type: "message", ID: fmt.Sprintf("msg_%s", uuid.New().String()),
|
||||
@@ -1539,6 +1568,43 @@ func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt i
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(cleanedResult, resultLogprobs)},
|
||||
})
|
||||
}
|
||||
} else if !shouldUseFn && cfg.FunctionsConfig.AutomaticToolParsingFallback && cleanedResult != "" {
|
||||
// Automatic tool parsing fallback: no tools in request but model emitted tool call markup
|
||||
parsed := functions.ParseFunctionCall(cleanedResult, cfg.FunctionsConfig)
|
||||
if len(parsed) > 0 {
|
||||
stripped := functions.StripToolCallMarkup(cleanedResult)
|
||||
if stripped != "" {
|
||||
outputItems = append(outputItems, schema.ORItemField{
|
||||
Type: "message",
|
||||
ID: fmt.Sprintf("msg_%s", uuid.New().String()),
|
||||
Status: "completed",
|
||||
Role: "assistant",
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(stripped, resultLogprobs)},
|
||||
})
|
||||
}
|
||||
for _, fc := range parsed {
|
||||
toolCallID := fc.ID
|
||||
if toolCallID == "" {
|
||||
toolCallID = fmt.Sprintf("fc_%s", uuid.New().String())
|
||||
}
|
||||
outputItems = append(outputItems, schema.ORItemField{
|
||||
Type: "function_call",
|
||||
ID: fmt.Sprintf("fc_%s", uuid.New().String()),
|
||||
Status: "completed",
|
||||
CallID: toolCallID,
|
||||
Name: fc.Name,
|
||||
Arguments: fc.Arguments,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
outputItems = append(outputItems, schema.ORItemField{
|
||||
Type: "message",
|
||||
ID: fmt.Sprintf("msg_%s", uuid.New().String()),
|
||||
Status: "completed",
|
||||
Role: "assistant",
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(cleanedResult, resultLogprobs)},
|
||||
})
|
||||
}
|
||||
} else {
|
||||
// Simple text response (include logprobs if available)
|
||||
messageItem := schema.ORItemField{
|
||||
@@ -2514,6 +2580,15 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
|
||||
result = finalCleanedResult
|
||||
|
||||
// Automatic tool parsing fallback for streaming: parse tool calls from accumulated text
|
||||
var streamFallbackToolCalls []functions.FuncCallResults
|
||||
if cfg.FunctionsConfig.AutomaticToolParsingFallback && result != "" {
|
||||
streamFallbackToolCalls = functions.ParseFunctionCall(result, cfg.FunctionsConfig)
|
||||
if len(streamFallbackToolCalls) > 0 {
|
||||
result = functions.StripToolCallMarkup(result)
|
||||
}
|
||||
}
|
||||
|
||||
// Convert logprobs for streaming events
|
||||
mcpStreamLogprobs := convertLogprobsForStreaming(noToolLogprobs)
|
||||
|
||||
@@ -2552,10 +2627,42 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
})
|
||||
sequenceNumber++
|
||||
|
||||
// Emit function_call items from automatic tool parsing fallback
|
||||
for _, fc := range streamFallbackToolCalls {
|
||||
toolCallID := fc.ID
|
||||
if toolCallID == "" {
|
||||
toolCallID = fmt.Sprintf("fc_%s", uuid.New().String())
|
||||
}
|
||||
outputIndex++
|
||||
functionCallItem := &schema.ORItemField{
|
||||
Type: "function_call",
|
||||
ID: toolCallID,
|
||||
Status: "completed",
|
||||
CallID: toolCallID,
|
||||
Name: fc.Name,
|
||||
Arguments: fc.Arguments,
|
||||
}
|
||||
sendSSEEvent(c, &schema.ORStreamEvent{
|
||||
Type: "response.output_item.added",
|
||||
SequenceNumber: sequenceNumber,
|
||||
OutputIndex: &outputIndex,
|
||||
Item: functionCallItem,
|
||||
})
|
||||
sequenceNumber++
|
||||
sendSSEEvent(c, &schema.ORStreamEvent{
|
||||
Type: "response.output_item.done",
|
||||
SequenceNumber: sequenceNumber,
|
||||
OutputIndex: &outputIndex,
|
||||
Item: functionCallItem,
|
||||
})
|
||||
sequenceNumber++
|
||||
collectedOutputItems = append(collectedOutputItems, *functionCallItem)
|
||||
}
|
||||
|
||||
// Emit response.completed
|
||||
now := time.Now().Unix()
|
||||
|
||||
// Collect final output items (reasoning first, then message)
|
||||
// Collect final output items (reasoning first, then messages, then tool calls)
|
||||
var finalOutputItems []schema.ORItemField
|
||||
// Add reasoning item if it exists
|
||||
if currentReasoningID != "" && finalReasoning != "" {
|
||||
@@ -2577,6 +2684,12 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
} else {
|
||||
finalOutputItems = append(finalOutputItems, *messageItem)
|
||||
}
|
||||
// Add function_call items from fallback
|
||||
for _, item := range collectedOutputItems {
|
||||
if item.Type == "function_call" {
|
||||
finalOutputItems = append(finalOutputItems, item)
|
||||
}
|
||||
}
|
||||
responseCompleted := buildORResponse(responseID, createdAt, &now, "completed", input, finalOutputItems, &schema.ORUsage{
|
||||
InputTokens: noToolTokenUsage.Prompt,
|
||||
OutputTokens: noToolTokenUsage.Completion,
|
||||
|
||||
@@ -222,6 +222,9 @@ func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.
|
||||
if input.TopP != nil {
|
||||
config.TopP = input.TopP
|
||||
}
|
||||
if input.MinP != nil {
|
||||
config.MinP = input.MinP
|
||||
}
|
||||
|
||||
if input.Backend != "" {
|
||||
config.Backend = input.Backend
|
||||
|
||||
@@ -95,6 +95,7 @@ type PredictionOptions struct {
|
||||
// Common options between all the API calls, part of the OpenAI spec
|
||||
TopP *float64 `json:"top_p,omitempty" yaml:"top_p,omitempty"`
|
||||
TopK *int `json:"top_k,omitempty" yaml:"top_k,omitempty"`
|
||||
MinP *float64 `json:"min_p,omitempty" yaml:"min_p,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty" yaml:"temperature,omitempty"`
|
||||
Maxtokens *int `json:"max_tokens,omitempty" yaml:"max_tokens,omitempty"`
|
||||
Echo bool `json:"echo,omitempty" yaml:"echo,omitempty"`
|
||||
|
||||
@@ -621,25 +621,39 @@ func (p *ChatMsgParser) TryConsumeXMLToolCalls(format *XMLToolCallFormat) (bool,
|
||||
// Handle Functionary format (JSON parameters inside XML tags) - use regex parser
|
||||
if format.KeyStart == "" && format.ToolStart == "<function=" {
|
||||
// Fall back to regex-based parser for Functionary format
|
||||
results, err := parseFunctionaryFormat(p.input[p.pos:], format)
|
||||
sub := p.input[p.pos:]
|
||||
results, err := parseFunctionaryFormat(sub, format)
|
||||
if err != nil || len(results) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
for _, result := range results {
|
||||
p.AddToolCall(result.Name, "", result.Arguments)
|
||||
}
|
||||
// Advance position past the last tool call end tag
|
||||
if last := strings.LastIndex(sub, format.ToolEnd); last >= 0 {
|
||||
p.pos += last + len(format.ToolEnd)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Handle JSON-like formats (Apriel-1.5, Xiaomi-MiMo) - use regex parser
|
||||
if format.ToolStart != "" && strings.Contains(format.ToolStart, "{\"name\"") {
|
||||
results, err := parseJSONLikeXMLFormat(p.input[p.pos:], format)
|
||||
sub := p.input[p.pos:]
|
||||
results, err := parseJSONLikeXMLFormat(sub, format)
|
||||
if err != nil || len(results) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
for _, result := range results {
|
||||
p.AddToolCall(result.Name, "", result.Arguments)
|
||||
}
|
||||
// Advance position past the last scope/tool end tag
|
||||
endTag := format.ScopeEnd
|
||||
if endTag == "" {
|
||||
endTag = format.ToolEnd
|
||||
}
|
||||
if last := strings.LastIndex(sub, endTag); last >= 0 {
|
||||
p.pos += last + len(endTag)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -111,6 +111,11 @@ type FunctionsConfig struct {
|
||||
// If set, only this format will be tried (overrides XMLFormatPreset)
|
||||
XMLFormat *XMLToolCallFormat `yaml:"xml_format,omitempty" json:"xml_format,omitempty"`
|
||||
|
||||
// AutomaticToolParsingFallback enables automatic tool call parsing fallback:
|
||||
// - Wraps raw string arguments as {"query": raw_string} when JSON parsing fails
|
||||
// - Parses tool calls from response content even when no tools were in the request
|
||||
AutomaticToolParsingFallback bool `yaml:"automatic_tool_parsing_fallback,omitempty" json:"automatic_tool_parsing_fallback,omitempty"`
|
||||
|
||||
// DisablePEGParser disables the PEG parser and falls back to the legacy iterative parser
|
||||
DisablePEGParser bool `yaml:"disable_peg_parser,omitempty" json:"disable_peg_parser,omitempty"`
|
||||
|
||||
@@ -549,35 +554,64 @@ func getScopeOrToolStart(format *XMLToolCallFormat) string {
|
||||
return format.ToolStart
|
||||
}
|
||||
|
||||
// ParseResult holds tool calls and any non-tool-call content extracted by the parser.
|
||||
type ParseResult struct {
|
||||
ToolCalls []FuncCallResults
|
||||
Content string
|
||||
}
|
||||
|
||||
// tryParseXMLFromScopeStart finds the first occurrence of scopeStart (or format.ToolStart),
|
||||
// splits the input there, and parses only the suffix as XML tool calls. Returns (toolCalls, true)
|
||||
// if any tool calls were parsed, else (nil, false). This mimics llama.cpp's PEG order so that
|
||||
// splits the input there, and parses only the suffix as XML tool calls. Returns (result, true)
|
||||
// if any tool calls were parsed, else (empty, false). This mimics llama.cpp's PEG order so that
|
||||
// reasoning or content before the tool block does not cause "whitespace only before scope" to fail.
|
||||
func tryParseXMLFromScopeStart(s string, format *XMLToolCallFormat, isPartial bool) ([]FuncCallResults, bool) {
|
||||
func tryParseXMLFromScopeStart(s string, format *XMLToolCallFormat, isPartial bool) (ParseResult, bool) {
|
||||
if format == nil {
|
||||
return nil, false
|
||||
return ParseResult{}, false
|
||||
}
|
||||
scopeStart := getScopeOrToolStart(format)
|
||||
if scopeStart == "" {
|
||||
return nil, false
|
||||
return ParseResult{}, false
|
||||
}
|
||||
idx := strings.Index(s, scopeStart)
|
||||
if idx < 0 {
|
||||
return nil, false
|
||||
return ParseResult{}, false
|
||||
}
|
||||
toolCallsPart := s[idx:]
|
||||
parser := NewChatMsgParser(toolCallsPart, isPartial)
|
||||
success, err := parser.TryConsumeXMLToolCalls(format)
|
||||
if err != nil {
|
||||
if _, ok := err.(*ChatMsgPartialException); ok && isPartial {
|
||||
return parser.ToolCalls(), len(parser.ToolCalls()) > 0
|
||||
tc := parser.ToolCalls()
|
||||
if len(tc) > 0 {
|
||||
return ParseResult{ToolCalls: tc, Content: buildContent(s[:idx], parser)}, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
return ParseResult{}, false
|
||||
}
|
||||
if success && len(parser.ToolCalls()) > 0 {
|
||||
return parser.ToolCalls(), true
|
||||
return ParseResult{
|
||||
ToolCalls: parser.ToolCalls(),
|
||||
Content: buildContent(s[:idx], parser),
|
||||
}, true
|
||||
}
|
||||
return nil, false
|
||||
return ParseResult{}, false
|
||||
}
|
||||
|
||||
// buildContent assembles the non-tool-call content from the text before the tool
|
||||
// block, any content tracked by the parser, and any unconsumed trailing text.
|
||||
func buildContent(before string, parser *ChatMsgParser) string {
|
||||
var parts []string
|
||||
if b := strings.TrimSpace(before); b != "" {
|
||||
parts = append(parts, b)
|
||||
}
|
||||
if pc := strings.TrimSpace(parser.Content()); pc != "" {
|
||||
parts = append(parts, pc)
|
||||
}
|
||||
remaining := parser.Input()[parser.Pos():]
|
||||
if t := strings.TrimSpace(remaining); t != "" {
|
||||
parts = append(parts, t)
|
||||
}
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
// ParseXMLIterative parses XML tool calls using the iterative parser
|
||||
@@ -587,15 +621,15 @@ func tryParseXMLFromScopeStart(s string, format *XMLToolCallFormat, isPartial bo
|
||||
func ParseXMLIterative(s string, format *XMLToolCallFormat, isPartial bool) ([]FuncCallResults, error) {
|
||||
// Try split-on-scope first so reasoning/content before tool block is skipped
|
||||
if format != nil {
|
||||
if results, ok := tryParseXMLFromScopeStart(s, format, isPartial); ok {
|
||||
return results, nil
|
||||
if pr, ok := tryParseXMLFromScopeStart(s, format, isPartial); ok {
|
||||
return pr.ToolCalls, nil
|
||||
}
|
||||
} else {
|
||||
formats := getAllXMLFormats()
|
||||
for _, fmtPreset := range formats {
|
||||
if fmtPreset.format != nil {
|
||||
if results, ok := tryParseXMLFromScopeStart(s, fmtPreset.format, isPartial); ok {
|
||||
return results, nil
|
||||
if pr, ok := tryParseXMLFromScopeStart(s, fmtPreset.format, isPartial); ok {
|
||||
return pr.ToolCalls, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -885,8 +919,17 @@ func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncC
|
||||
// Marshal arguments to JSON string (handles both object and string cases)
|
||||
var d []byte
|
||||
if argsStr, ok := args.(string); ok {
|
||||
// Already a string, use it directly
|
||||
d = []byte(argsStr)
|
||||
// Check if the string is valid JSON; if not, auto-heal if enabled
|
||||
var testJSON map[string]any
|
||||
if json.Unmarshal([]byte(argsStr), &testJSON) == nil {
|
||||
d = []byte(argsStr)
|
||||
} else if functionConfig.AutomaticToolParsingFallback {
|
||||
healed := map[string]string{"query": argsStr}
|
||||
d, _ = json.Marshal(healed)
|
||||
xlog.Debug("Automatic tool parsing fallback: wrapped raw string arguments", "raw", argsStr)
|
||||
} else {
|
||||
d = []byte(argsStr)
|
||||
}
|
||||
} else {
|
||||
// Object, marshal to JSON
|
||||
d, _ = json.Marshal(args)
|
||||
|
||||
@@ -2536,4 +2536,61 @@ def hello():
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("Automatic tool parsing fallback", func() {
|
||||
It("wraps malformed string args as query when enabled", func() {
|
||||
input := `{"name": "web_search", "arguments": "search for cats"}`
|
||||
cfg := FunctionsConfig{AutomaticToolParsingFallback: true}
|
||||
results := ParseFunctionCall(input, cfg)
|
||||
Expect(results).To(HaveLen(1))
|
||||
Expect(results[0].Name).To(Equal("web_search"))
|
||||
|
||||
var args map[string]string
|
||||
err := json.Unmarshal([]byte(results[0].Arguments), &args)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(args["query"]).To(Equal("search for cats"))
|
||||
})
|
||||
|
||||
It("preserves malformed string args as-is when disabled", func() {
|
||||
input := `{"name": "web_search", "arguments": "search for cats"}`
|
||||
cfg := FunctionsConfig{AutomaticToolParsingFallback: false}
|
||||
results := ParseFunctionCall(input, cfg)
|
||||
Expect(results).To(HaveLen(1))
|
||||
Expect(results[0].Arguments).To(Equal("search for cats"))
|
||||
})
|
||||
|
||||
It("does not alter valid JSON string args", func() {
|
||||
input := `{"name": "web_search", "arguments": "{\"query\": \"cats\"}"}`
|
||||
cfg := FunctionsConfig{AutomaticToolParsingFallback: true}
|
||||
results := ParseFunctionCall(input, cfg)
|
||||
Expect(results).To(HaveLen(1))
|
||||
Expect(results[0].Arguments).To(Equal(`{"query": "cats"}`))
|
||||
})
|
||||
})
|
||||
|
||||
Context("StripToolCallMarkup", func() {
|
||||
It("removes functionary-style function blocks and keeps surrounding text", func() {
|
||||
input := `Text before <function=search>{"q":"cats"}</function> text after`
|
||||
result := StripToolCallMarkup(input)
|
||||
Expect(result).To(Equal("Text before text after"))
|
||||
})
|
||||
|
||||
It("removes qwen3-coder-style tool_call blocks and keeps preceding text", func() {
|
||||
input := `Here is my answer <tool_call><function=search><parameter=q>cats</parameter></function></tool_call>`
|
||||
result := StripToolCallMarkup(input)
|
||||
Expect(result).To(Equal("Here is my answer"))
|
||||
})
|
||||
|
||||
It("returns empty string when content is only tool calls", func() {
|
||||
input := `<function=search>{"q":"cats"}</function>`
|
||||
result := StripToolCallMarkup(input)
|
||||
Expect(result).To(Equal(""))
|
||||
})
|
||||
|
||||
It("preserves text with no tool call markup", func() {
|
||||
input := "Just a normal response with no tools"
|
||||
result := StripToolCallMarkup(input)
|
||||
Expect(result).To(Equal("Just a normal response with no tools"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
18
pkg/functions/strip.go
Normal file
18
pkg/functions/strip.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package functions
|
||||
|
||||
import "strings"
|
||||
|
||||
// StripToolCallMarkup extracts the non-tool-call content from a string
|
||||
// by reusing the iterative XML parser which already separates content
|
||||
// from tool calls. Returns the remaining text, trimmed.
|
||||
func StripToolCallMarkup(content string) string {
|
||||
for _, fmtPreset := range getAllXMLFormats() {
|
||||
if fmtPreset.format == nil {
|
||||
continue
|
||||
}
|
||||
if pr, ok := tryParseXMLFromScopeStart(content, fmtPreset.format, false); ok && len(pr.ToolCalls) > 0 {
|
||||
return strings.TrimSpace(pr.Content)
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(content)
|
||||
}
|
||||
Reference in New Issue
Block a user