diff --git a/.github/workflows/bump-inference-defaults.yml b/.github/workflows/bump-inference-defaults.yml new file mode 100644 index 000000000..302012a3c --- /dev/null +++ b/.github/workflows/bump-inference-defaults.yml @@ -0,0 +1,48 @@ +name: Bump inference defaults + +on: + schedule: + # Run daily at 06:00 UTC + - cron: '0 6 * * *' + workflow_dispatch: # Allow manual trigger + +permissions: + contents: write + pull-requests: write + +jobs: + bump: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-go@v5 + with: + go-version-file: go.mod + + - name: Re-fetch inference defaults + run: make generate-force + + - name: Check for changes + id: diff + run: | + if git diff --quiet core/config/inference_defaults.json; then + echo "changed=false" >> "$GITHUB_OUTPUT" + else + echo "changed=true" >> "$GITHUB_OUTPUT" + fi + + - name: Create Pull Request + if: steps.diff.outputs.changed == 'true' + uses: peter-evans/create-pull-request@v7 + with: + commit-message: "chore: bump inference defaults from unsloth" + title: "chore: bump inference defaults from unsloth" + body: | + Auto-generated update of `core/config/inference_defaults.json` from + [unsloth's inference_defaults.json](https://github.com/unslothai/unsloth/blob/main/studio/backend/assets/configs/inference_defaults.json). + + This PR was created automatically by the `bump-inference-defaults` workflow. + branch: chore/bump-inference-defaults + delete-branch: true + labels: automated diff --git a/Makefile b/Makefile index 459aa9208..61bceb7e3 100644 --- a/Makefile +++ b/Makefile @@ -107,7 +107,7 @@ core/http/react-ui/dist: react-ui ## Build: -build: protogen-go install-go-tools core/http/react-ui/dist ## Build the project +build: protogen-go generate install-go-tools core/http/react-ui/dist ## Build the project $(info ${GREEN}I local-ai build info:${RESET}) $(info ${GREEN}I BUILD_TYPE: ${YELLOW}$(BUILD_TYPE)${RESET}) $(info ${GREEN}I GO_TAGS: ${YELLOW}$(GO_TAGS)${RESET}) @@ -398,6 +398,16 @@ protogen-go: protoc install-go-tools ./protoc --experimental_allow_proto3_optional -Ibackend/ --go_out=pkg/grpc/proto/ --go_opt=paths=source_relative --go-grpc_out=pkg/grpc/proto/ --go-grpc_opt=paths=source_relative \ backend/backend.proto +core/config/inference_defaults.json: ## Fetch inference defaults from unsloth (only if missing) + $(GOCMD) generate ./core/config/... + +.PHONY: generate +generate: core/config/inference_defaults.json ## Ensure inference defaults exist + +.PHONY: generate-force +generate-force: ## Re-fetch inference defaults from unsloth (always) + $(GOCMD) generate ./core/config/... + .PHONY: protogen-go-clean protogen-go-clean: $(RM) pkg/grpc/proto/backend.pb.go pkg/grpc/proto/backend_grpc.pb.go diff --git a/backend/backend.proto b/backend/backend.proto index f89c18571..3f01efbe1 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -178,6 +178,7 @@ message PredictOptions { int32 Logprobs = 50; // Number of top logprobs to return (maps to OpenAI logprobs parameter) int32 TopLogprobs = 51; // Number of top logprobs to return per token (maps to OpenAI top_logprobs parameter) map Metadata = 52; // Generic per-request metadata (e.g., enable_thinking) + float MinP = 53; // Minimum probability sampling threshold (0.0 = disabled) } // ToolCallDelta represents an incremental tool call update from the C++ parser. diff --git a/backend/cpp/llama-cpp/grpc-server.cpp b/backend/cpp/llama-cpp/grpc-server.cpp index 67b5632af..89f03bf7d 100644 --- a/backend/cpp/llama-cpp/grpc-server.cpp +++ b/backend/cpp/llama-cpp/grpc-server.cpp @@ -136,6 +136,7 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, const data["mirostat_eta"] = predict->mirostateta(); data["n_keep"] = predict->nkeep(); data["seed"] = predict->seed(); + data["min_p"] = predict->minp(); std::string grammar_str = predict->grammar(); diff --git a/core/backend/options.go b/core/backend/options.go index 4275a6f07..71b9d682a 100644 --- a/core/backend/options.go +++ b/core/backend/options.go @@ -252,6 +252,7 @@ func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions TopP: float32(*c.TopP), NDraft: c.NDraft, TopK: int32(*c.TopK), + MinP: float32(*c.MinP), Tokens: int32(*c.Maxtokens), Threads: int32(*c.Threads), PromptCacheAll: c.PromptCacheAll, diff --git a/core/config/gen_inference_defaults/README.md b/core/config/gen_inference_defaults/README.md new file mode 100644 index 000000000..7060aaabd --- /dev/null +++ b/core/config/gen_inference_defaults/README.md @@ -0,0 +1,30 @@ +# gen_inference_defaults + +This tool fetches per-model-family inference parameter defaults from [unsloth's inference_defaults.json](https://github.com/unslothai/unsloth/blob/main/studio/backend/assets/configs/inference_defaults.json), validates the data, remaps field names to LocalAI conventions, and writes `core/config/inference_defaults.json`. + +## What it does + +1. Fetches the latest `inference_defaults.json` from unsloth's repo +2. Validates that every entry has required fields (`temperature`, `top_p`, `top_k`) +3. Validates that every pattern references an existing family +4. Warns if pattern ordering would cause shorter prefixes to shadow longer ones +5. Remaps `repetition_penalty` → `repeat_penalty` (LocalAI naming) +6. Filters to allowed fields only: `temperature`, `top_p`, `top_k`, `min_p`, `repeat_penalty`, `presence_penalty` +7. Writes the validated JSON to `core/config/inference_defaults.json` + +## Usage + +```bash +# Only regenerate if the file is missing (runs during make build) +make generate + +# Force re-fetch from unsloth +make generate-force + +# Or directly via go generate +go generate ./core/config/... +``` + +## Automation + +The GitHub Actions workflow `.github/workflows/bump-inference-defaults.yml` runs `make generate-force` daily and opens a PR if the upstream data changed. diff --git a/core/config/gen_inference_defaults/main.go b/core/config/gen_inference_defaults/main.go new file mode 100644 index 000000000..4f83b7de0 --- /dev/null +++ b/core/config/gen_inference_defaults/main.go @@ -0,0 +1,222 @@ +// gen_inference_defaults fetches unsloth's inference_defaults.json, +// validates its structure, remaps field names to LocalAI conventions, +// and writes the result to core/config/inference_defaults.json. +// +// Run via: go generate ./core/config/ +package main + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "sort" + "strings" +) + +const ( + unslothURL = "https://raw.githubusercontent.com/unslothai/unsloth/main/studio/backend/assets/configs/inference_defaults.json" + outputFile = "inference_defaults.json" +) + +// unslothDefaults mirrors the upstream JSON structure +type unslothDefaults struct { + Comment string `json:"_comment"` + Families map[string]map[string]float64 `json:"families"` + Patterns []string `json:"patterns"` +} + +// localAIDefaults is our output structure +type localAIDefaults struct { + Comment string `json:"_comment"` + Families map[string]map[string]float64 `json:"families"` + Patterns []string `json:"patterns"` +} + +// requiredFields are the fields every family entry must have +var requiredFields = []string{"temperature", "top_p", "top_k"} + +// fieldRemap maps unsloth field names to LocalAI field names +var fieldRemap = map[string]string{ + "repetition_penalty": "repeat_penalty", +} + +// allowedFields are the only fields we keep (after remapping) +var allowedFields = map[string]bool{ + "temperature": true, + "top_p": true, + "top_k": true, + "min_p": true, + "repeat_penalty": true, + "presence_penalty": true, +} + +func main() { + fmt.Fprintf(os.Stderr, "Fetching %s ...\n", unslothURL) + + resp, err := http.Get(unslothURL) + if err != nil { + fatal("fetch failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + fatal("fetch returned HTTP %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + fatal("read body: %v", err) + } + + var upstream unslothDefaults + if err := json.Unmarshal(body, &upstream); err != nil { + fatal("parse upstream JSON: %v", err) + } + + // Validate structure + if len(upstream.Families) == 0 { + fatal("upstream has no families") + } + if len(upstream.Patterns) == 0 { + fatal("upstream has no patterns") + } + + // Validate every pattern references a family + for _, p := range upstream.Patterns { + if _, ok := upstream.Families[p]; !ok { + fatal("pattern %q has no corresponding family entry", p) + } + } + + // Validate every family has required fields and remap field names + output := localAIDefaults{ + Comment: "Auto-generated from unsloth inference_defaults.json. DO NOT EDIT. Run go generate ./core/config/ to update.", + Families: make(map[string]map[string]float64, len(upstream.Families)), + Patterns: upstream.Patterns, + } + + // Sort family names for deterministic output + familyNames := make([]string, 0, len(upstream.Families)) + for name := range upstream.Families { + familyNames = append(familyNames, name) + } + sort.Strings(familyNames) + + for _, name := range familyNames { + params := upstream.Families[name] + + // Check required fields + for _, req := range requiredFields { + found := false + for k := range params { + mapped := k + if m, ok := fieldRemap[k]; ok { + mapped = m + } + if mapped == req || k == req { + found = true + break + } + } + if !found { + fatal("family %q missing required field %q", name, req) + } + } + + // Remap and filter fields + remapped := make(map[string]float64) + for k, v := range params { + if newName, ok := fieldRemap[k]; ok { + k = newName + } + if allowedFields[k] { + remapped[k] = v + } + } + output.Families[name] = remapped + } + + // Validate patterns are ordered longest-match-first within same prefix groups + validatePatternOrder(output.Patterns) + + // Marshal with ordered keys for readability + data, err := marshalOrdered(output) + if err != nil { + fatal("marshal output: %v", err) + } + + if err := os.WriteFile(outputFile, data, 0644); err != nil { + fatal("write %s: %v", outputFile, err) + } + + fmt.Fprintf(os.Stderr, "Written %s (%d families, %d patterns)\n", + outputFile, len(output.Families), len(output.Patterns)) +} + +// validatePatternOrder warns if a shorter pattern appears before a longer one +// that it's a prefix of (e.g., "qwen3" before "qwen3.5") +func validatePatternOrder(patterns []string) { + for i, p := range patterns { + for j := i + 1; j < len(patterns); j++ { + if strings.HasPrefix(patterns[j], p) { + fmt.Fprintf(os.Stderr, "WARNING: pattern %q at index %d is a prefix of %q at index %d — longer match should come first\n", + p, i, patterns[j], j) + } + } + } +} + +// marshalOrdered produces JSON with families in pattern order for readability +func marshalOrdered(d localAIDefaults) ([]byte, error) { + var sb strings.Builder + sb.WriteString("{\n") + sb.WriteString(fmt.Sprintf(" %q: %q,\n", "_comment", d.Comment)) + sb.WriteString(" \"families\": {\n") + + // Write families in pattern order, then any remaining not in patterns + written := make(map[string]bool) + allFamilies := make([]string, 0, len(d.Families)) + for _, p := range d.Patterns { + if _, ok := d.Families[p]; ok && !written[p] { + allFamilies = append(allFamilies, p) + written[p] = true + } + } + for name := range d.Families { + if !written[name] { + allFamilies = append(allFamilies, name) + } + } + + for i, name := range allFamilies { + params := d.Families[name] + paramJSON, err := json.Marshal(params) + if err != nil { + return nil, err + } + comma := "," + if i == len(allFamilies)-1 { + comma = "" + } + sb.WriteString(fmt.Sprintf(" %q: %s%s\n", name, paramJSON, comma)) + } + + sb.WriteString(" },\n") + + // Patterns array + patternsJSON, err := json.Marshal(d.Patterns) + if err != nil { + return nil, err + } + sb.WriteString(fmt.Sprintf(" \"patterns\": %s\n", patternsJSON)) + sb.WriteString("}\n") + + return []byte(sb.String()), nil +} + +func fatal(format string, args ...any) { + fmt.Fprintf(os.Stderr, "gen_inference_defaults: "+format+"\n", args...) + os.Exit(1) +} diff --git a/core/config/gguf.go b/core/config/gguf.go index ae95d4a13..0c8255478 100644 --- a/core/config/gguf.go +++ b/core/config/gguf.go @@ -76,6 +76,8 @@ func guessGGUFFromFile(cfg *ModelConfig, f *gguf.GGUFFile, defaultCtx int) { cfg.Options = append(cfg.Options, "use_jinja:true") cfg.KnownUsecaseStrings = append(cfg.KnownUsecaseStrings, "FLAG_CHAT") + // Apply per-model-family inference parameter defaults (temperature, top_p, etc.) + ApplyInferenceDefaults(cfg, f.Metadata().Name) } // DetectThinkingSupportFromBackend calls the ModelMetadata gRPC method to detect diff --git a/core/config/inference_defaults.go b/core/config/inference_defaults.go new file mode 100644 index 000000000..cecb2ad99 --- /dev/null +++ b/core/config/inference_defaults.go @@ -0,0 +1,128 @@ +package config + +//go:generate go run ./gen_inference_defaults/ + +import ( + _ "embed" + "encoding/json" + "strings" + + "github.com/mudler/xlog" +) + +//go:embed inference_defaults.json +var inferenceDefaultsJSON []byte + +// inferenceDefaults holds the parsed inference defaults data +type inferenceDefaults struct { + Families map[string]map[string]float64 `json:"families"` + Patterns []string `json:"patterns"` +} + +var defaultsData *inferenceDefaults + +func init() { + defaultsData = &inferenceDefaults{} + if err := json.Unmarshal(inferenceDefaultsJSON, defaultsData); err != nil { + xlog.Warn("failed to parse inference_defaults.json", "error", err) + } +} + +// normalizeModelID lowercases, strips org prefix (before /), and removes .gguf extension +func normalizeModelID(modelID string) string { + modelID = strings.ToLower(modelID) + + // Strip org prefix (e.g., "unsloth/Qwen3.5-9B-GGUF" -> "qwen3.5-9b-gguf") + if idx := strings.LastIndex(modelID, "/"); idx >= 0 { + modelID = modelID[idx+1:] + } + + // Strip .gguf extension + modelID = strings.TrimSuffix(modelID, ".gguf") + + // Replace underscores with hyphens for matching + modelID = strings.ReplaceAll(modelID, "_", "-") + + return modelID +} + +// MatchModelFamily returns the inference defaults for the best-matching model family. +// Patterns are checked in order (longest-match-first as defined in the JSON). +// Returns nil if no family matches. +func MatchModelFamily(modelID string) map[string]float64 { + if defaultsData == nil || len(defaultsData.Patterns) == 0 { + return nil + } + + normalized := normalizeModelID(modelID) + + for _, pattern := range defaultsData.Patterns { + if strings.Contains(normalized, pattern) { + if family, ok := defaultsData.Families[pattern]; ok { + return family + } + } + } + + return nil +} + +// ApplyInferenceDefaults sets recommended inference parameters on cfg based on modelIDs. +// Tries each modelID in order; the first match wins. +// Only fills in parameters that are not already set (nil pointers or zero values). +func ApplyInferenceDefaults(cfg *ModelConfig, modelIDs ...string) { + var family map[string]float64 + var matchedID string + for _, id := range modelIDs { + if id == "" { + continue + } + if f := MatchModelFamily(id); f != nil { + family = f + matchedID = id + break + } + } + if family == nil { + return + } + + xlog.Debug("[inference_defaults] applying defaults for model", "modelID", matchedID, "family", family) + + if cfg.Temperature == nil { + if v, ok := family["temperature"]; ok { + cfg.Temperature = &v + } + } + + if cfg.TopP == nil { + if v, ok := family["top_p"]; ok { + cfg.TopP = &v + } + } + + if cfg.TopK == nil { + if v, ok := family["top_k"]; ok { + intV := int(v) + cfg.TopK = &intV + } + } + + if cfg.MinP == nil { + if v, ok := family["min_p"]; ok { + cfg.MinP = &v + } + } + + if cfg.RepeatPenalty == 0 { + if v, ok := family["repeat_penalty"]; ok { + cfg.RepeatPenalty = v + } + } + + if cfg.PresencePenalty == 0 { + if v, ok := family["presence_penalty"]; ok { + cfg.PresencePenalty = v + } + } +} diff --git a/core/config/inference_defaults.json b/core/config/inference_defaults.json new file mode 100644 index 000000000..56109a040 --- /dev/null +++ b/core/config/inference_defaults.json @@ -0,0 +1,57 @@ +{ + "_comment": "Auto-generated from unsloth inference_defaults.json. DO NOT EDIT. Run go generate ./core/config/ to update.", + "families": { + "qwen3.5": {"min_p":0,"presence_penalty":1.5,"repeat_penalty":1,"temperature":0.7,"top_k":20,"top_p":0.8}, + "qwen3-coder": {"min_p":0,"repeat_penalty":1,"temperature":0.7,"top_k":20,"top_p":0.8}, + "qwen3-next": {"min_p":0,"repeat_penalty":1,"temperature":0.7,"top_k":20,"top_p":0.8}, + "qwen3-vl": {"min_p":0,"repeat_penalty":1,"temperature":0.7,"top_k":20,"top_p":0.8}, + "qwen3": {"min_p":0,"repeat_penalty":1,"temperature":0.6,"top_k":20,"top_p":0.95}, + "qwen2.5-coder": {"min_p":0.1,"repeat_penalty":1,"temperature":1.5,"top_k":-1,"top_p":0.95}, + "qwen2.5-vl": {"min_p":0.1,"repeat_penalty":1,"temperature":1.5,"top_k":-1,"top_p":0.95}, + "qwen2.5-omni": {"min_p":0,"repeat_penalty":1,"temperature":0.7,"top_k":20,"top_p":0.8}, + "qwen2.5-math": {"min_p":0,"repeat_penalty":1,"temperature":0.7,"top_k":20,"top_p":0.8}, + "qwen2.5": {"min_p":0,"repeat_penalty":1,"temperature":0.7,"top_k":20,"top_p":0.8}, + "qwen2-vl": {"min_p":0.1,"repeat_penalty":1,"temperature":1.5,"top_k":-1,"top_p":0.95}, + "qwen2": {"min_p":0,"repeat_penalty":1,"temperature":0.7,"top_k":20,"top_p":0.8}, + "qwq": {"min_p":0,"repeat_penalty":1,"temperature":0.6,"top_k":40,"top_p":0.95}, + "gemma-3n": {"min_p":0,"repeat_penalty":1,"temperature":1,"top_k":64,"top_p":0.95}, + "gemma-3": {"min_p":0,"repeat_penalty":1,"temperature":1,"top_k":64,"top_p":0.95}, + "medgemma": {"min_p":0,"repeat_penalty":1,"temperature":1,"top_k":64,"top_p":0.95}, + "gemma-2": {"min_p":0,"repeat_penalty":1,"temperature":1,"top_k":64,"top_p":0.95}, + "llama-4": {"min_p":0.01,"repeat_penalty":1,"temperature":1,"top_k":-1,"top_p":0.9}, + "llama-3.3": {"min_p":0.1,"repeat_penalty":1,"temperature":1.5,"top_k":-1,"top_p":0.95}, + "llama-3.2": {"min_p":0.1,"repeat_penalty":1,"temperature":1.5,"top_k":-1,"top_p":0.95}, + "llama-3.1": {"min_p":0.1,"repeat_penalty":1,"temperature":1.5,"top_k":-1,"top_p":0.95}, + "llama-3": {"min_p":0.1,"repeat_penalty":1,"temperature":1.5,"top_k":-1,"top_p":0.95}, + "phi-4": {"min_p":0,"repeat_penalty":1,"temperature":0.8,"top_k":-1,"top_p":0.95}, + "phi-3": {"min_p":0.01,"repeat_penalty":1,"temperature":0.7,"top_k":-1,"top_p":0.9}, + "mistral-nemo": {"min_p":0.01,"repeat_penalty":1,"temperature":0.7,"top_k":-1,"top_p":0.95}, + "mistral-small": {"min_p":0.01,"repeat_penalty":1,"temperature":0.15,"top_k":-1,"top_p":0.95}, + "mistral-large": {"min_p":0.01,"repeat_penalty":1,"temperature":0.7,"top_k":-1,"top_p":0.95}, + "magistral": {"min_p":0.01,"repeat_penalty":1,"temperature":0.7,"top_k":-1,"top_p":0.95}, + "ministral": {"min_p":0.01,"repeat_penalty":1,"temperature":0.15,"top_k":-1,"top_p":0.95}, + "devstral": {"min_p":0.01,"repeat_penalty":1,"temperature":0.7,"top_k":-1,"top_p":0.95}, + "pixtral": {"min_p":0.1,"repeat_penalty":1,"temperature":1.5,"top_k":-1,"top_p":0.95}, + "deepseek-r1": {"min_p":0.01,"repeat_penalty":1,"temperature":0.6,"top_k":-1,"top_p":0.95}, + "deepseek-v3": {"min_p":0.01,"repeat_penalty":1,"temperature":0.6,"top_k":-1,"top_p":0.95}, + "deepseek-ocr": {"min_p":0.01,"repeat_penalty":1,"temperature":0,"top_k":-1,"top_p":0.95}, + "glm-5": {"min_p":0.01,"repeat_penalty":1,"temperature":1,"top_k":-1,"top_p":0.95}, + "glm-4": {"min_p":0.01,"repeat_penalty":1,"temperature":1,"top_k":-1,"top_p":0.95}, + "nemotron": {"min_p":0.01,"repeat_penalty":1,"temperature":1,"top_k":-1,"top_p":1}, + "minimax-m2.5": {"min_p":0.01,"repeat_penalty":1,"temperature":1,"top_k":40,"top_p":0.95}, + "minimax": {"min_p":0.01,"repeat_penalty":1,"temperature":1,"top_k":40,"top_p":0.95}, + "gpt-oss": {"min_p":0.01,"repeat_penalty":1,"temperature":1,"top_k":0,"top_p":1}, + "granite-4": {"min_p":0.01,"repeat_penalty":1,"temperature":0,"top_k":0,"top_p":1}, + "kimi-k2": {"min_p":0.01,"repeat_penalty":1,"temperature":0.6,"top_k":-1,"top_p":0.95}, + "kimi": {"min_p":0.01,"repeat_penalty":1,"temperature":0.6,"top_k":-1,"top_p":0.95}, + "lfm2": {"min_p":0.15,"repeat_penalty":1.05,"temperature":0.1,"top_k":50,"top_p":0.1}, + "smollm": {"min_p":0.01,"repeat_penalty":1,"temperature":0.7,"top_k":-1,"top_p":0.95}, + "olmo": {"min_p":0.01,"repeat_penalty":1,"temperature":0.7,"top_k":-1,"top_p":0.95}, + "falcon": {"min_p":0.01,"repeat_penalty":1,"temperature":0.7,"top_k":-1,"top_p":0.95}, + "ernie": {"min_p":0.01,"repeat_penalty":1,"temperature":0.7,"top_k":-1,"top_p":0.95}, + "seed": {"min_p":0.01,"repeat_penalty":1,"temperature":0.7,"top_k":-1,"top_p":0.95}, + "grok": {"min_p":0.01,"repeat_penalty":1,"temperature":1,"top_k":-1,"top_p":0.95}, + "mimo": {"min_p":0.01,"repeat_penalty":1,"temperature":0.7,"top_k":-1,"top_p":0.95} + }, + "patterns": ["qwen3.5","qwen3-coder","qwen3-next","qwen3-vl","qwen3","qwen2.5-coder","qwen2.5-vl","qwen2.5-omni","qwen2.5-math","qwen2.5","qwen2-vl","qwen2","qwq","gemma-3n","gemma-3","medgemma","gemma-2","llama-4","llama-3.3","llama-3.2","llama-3.1","llama-3","phi-4","phi-3","mistral-nemo","mistral-small","mistral-large","magistral","ministral","devstral","pixtral","deepseek-r1","deepseek-v3","deepseek-ocr","glm-5","glm-4","nemotron","minimax-m2.5","minimax","gpt-oss","granite-4","kimi-k2","kimi","lfm2","smollm","olmo","falcon","ernie","seed","grok","mimo"] +} diff --git a/core/config/inference_defaults_test.go b/core/config/inference_defaults_test.go new file mode 100644 index 000000000..13db1765a --- /dev/null +++ b/core/config/inference_defaults_test.go @@ -0,0 +1,154 @@ +package config_test + +import ( + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/schema" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("InferenceDefaults", func() { + Describe("MatchModelFamily", func() { + It("matches qwen3.5 and not qwen3 for Qwen3.5 model", func() { + family := config.MatchModelFamily("unsloth/Qwen3.5-9B-GGUF") + Expect(family).ToNot(BeNil()) + Expect(family["temperature"]).To(Equal(0.7)) + Expect(family["top_p"]).To(Equal(0.8)) + Expect(family["top_k"]).To(Equal(float64(20))) + Expect(family["presence_penalty"]).To(Equal(1.5)) + }) + + It("matches llama-3.3 and not llama-3 for Llama-3.3 model", func() { + family := config.MatchModelFamily("meta-llama/Llama-3.3-70B") + Expect(family).ToNot(BeNil()) + Expect(family["temperature"]).To(Equal(1.5)) + }) + + It("is case insensitive", func() { + family := config.MatchModelFamily("QWEN3.5-9B") + Expect(family).ToNot(BeNil()) + Expect(family["temperature"]).To(Equal(0.7)) + }) + + It("strips org prefix", func() { + family := config.MatchModelFamily("someorg/deepseek-r1-7b.gguf") + Expect(family).ToNot(BeNil()) + Expect(family["temperature"]).To(Equal(0.6)) + }) + + It("strips .gguf extension", func() { + family := config.MatchModelFamily("gemma-3-4b-q4_k_m.gguf") + Expect(family).ToNot(BeNil()) + Expect(family["temperature"]).To(Equal(1.0)) + Expect(family["top_k"]).To(Equal(float64(64))) + }) + + It("returns nil for unknown model", func() { + family := config.MatchModelFamily("my-custom-model-v1") + Expect(family).To(BeNil()) + }) + + It("matches qwen3-coder before qwen3", func() { + family := config.MatchModelFamily("Qwen3-Coder-8B") + Expect(family).ToNot(BeNil()) + Expect(family["temperature"]).To(Equal(0.7)) + Expect(family["top_p"]).To(Equal(0.8)) + }) + + It("matches deepseek-v3", func() { + family := config.MatchModelFamily("deepseek-v3-base") + Expect(family).ToNot(BeNil()) + Expect(family["temperature"]).To(Equal(0.6)) + }) + + It("matches lfm2 with non-standard params", func() { + family := config.MatchModelFamily("lfm2-7b") + Expect(family).ToNot(BeNil()) + Expect(family["temperature"]).To(Equal(0.1)) + Expect(family["top_p"]).To(Equal(0.1)) + Expect(family["min_p"]).To(Equal(0.15)) + Expect(family["repeat_penalty"]).To(Equal(1.05)) + }) + + It("includes min_p for llama-3.3", func() { + family := config.MatchModelFamily("llama-3.3-70b") + Expect(family).ToNot(BeNil()) + Expect(family["min_p"]).To(Equal(0.1)) + }) + }) + + Describe("ApplyInferenceDefaults", func() { + It("fills nil fields from defaults", func() { + cfg := &config.ModelConfig{} + config.ApplyInferenceDefaults(cfg, "gemma-3-8b") + + Expect(cfg.Temperature).ToNot(BeNil()) + Expect(*cfg.Temperature).To(Equal(1.0)) + Expect(cfg.TopP).ToNot(BeNil()) + Expect(*cfg.TopP).To(Equal(0.95)) + Expect(cfg.TopK).ToNot(BeNil()) + Expect(*cfg.TopK).To(Equal(64)) + Expect(cfg.MinP).ToNot(BeNil()) + Expect(*cfg.MinP).To(Equal(0.0)) + Expect(cfg.RepeatPenalty).To(Equal(1.0)) + }) + + It("fills min_p with non-zero value", func() { + cfg := &config.ModelConfig{} + config.ApplyInferenceDefaults(cfg, "llama-3.3-8b") + + Expect(cfg.MinP).ToNot(BeNil()) + Expect(*cfg.MinP).To(Equal(0.1)) + }) + + It("preserves non-nil fields", func() { + temp := 0.5 + topK := 10 + cfg := &config.ModelConfig{ + PredictionOptions: schema.PredictionOptions{ + Temperature: &temp, + TopK: &topK, + }, + } + config.ApplyInferenceDefaults(cfg, "gemma-3-8b") + + Expect(*cfg.Temperature).To(Equal(0.5)) + Expect(*cfg.TopK).To(Equal(10)) + // TopP should be filled since it was nil + Expect(cfg.TopP).ToNot(BeNil()) + Expect(*cfg.TopP).To(Equal(0.95)) + }) + + It("preserves non-zero repeat penalty", func() { + cfg := &config.ModelConfig{ + PredictionOptions: schema.PredictionOptions{ + RepeatPenalty: 1.2, + }, + } + config.ApplyInferenceDefaults(cfg, "gemma-3-8b") + Expect(cfg.RepeatPenalty).To(Equal(1.2)) + }) + + It("preserves non-nil min_p", func() { + minP := 0.05 + cfg := &config.ModelConfig{ + PredictionOptions: schema.PredictionOptions{ + MinP: &minP, + }, + } + config.ApplyInferenceDefaults(cfg, "llama-3.3-8b") + Expect(*cfg.MinP).To(Equal(0.05)) + }) + + It("does nothing for unknown model", func() { + cfg := &config.ModelConfig{} + config.ApplyInferenceDefaults(cfg, "my-custom-model") + + Expect(cfg.Temperature).To(BeNil()) + Expect(cfg.TopP).To(BeNil()) + Expect(cfg.TopK).To(BeNil()) + Expect(cfg.MinP).To(BeNil()) + }) + }) +}) diff --git a/core/config/model_config.go b/core/config/model_config.go index defbbd525..0d148eac1 100644 --- a/core/config/model_config.go +++ b/core/config/model_config.go @@ -375,9 +375,15 @@ func (cfg *ModelConfig) SetDefaults(opts ...ConfigLoaderOption) { threads := lo.threads f16 := lo.f16 debug := lo.debug + + // Apply model-family-specific inference defaults before generic fallbacks. + // This ensures gallery-installed and runtime-loaded models get optimal parameters. + ApplyInferenceDefaults(cfg, cfg.Name, cfg.Model) + // https://github.com/ggerganov/llama.cpp/blob/75cd4c77292034ecec587ecb401366f57338f7c0/common/sampling.h#L22 defaultTopP := 0.95 defaultTopK := 40 + defaultMinP := 0.0 defaultTemp := 0.9 // https://github.com/mudler/LocalAI/issues/2780 defaultMirostat := 0 @@ -400,6 +406,10 @@ func (cfg *ModelConfig) SetDefaults(opts ...ConfigLoaderOption) { cfg.TopK = &defaultTopK } + if cfg.MinP == nil { + cfg.MinP = &defaultMinP + } + if cfg.TypicalP == nil { cfg.TypicalP = &defaultTypicalP } diff --git a/core/gallery/importers/llama-cpp.go b/core/gallery/importers/llama-cpp.go index f0d5915c5..edd938791 100644 --- a/core/gallery/importers/llama-cpp.go +++ b/core/gallery/importers/llama-cpp.go @@ -114,6 +114,7 @@ func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error) GrammarConfig: functions.GrammarConfig{ NoGrammar: true, }, + AutomaticToolParsingFallback: true, }, } @@ -249,6 +250,9 @@ func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error) } } + // Apply per-model-family inference parameter defaults + config.ApplyInferenceDefaults(&modelConfig, details.URI) + data, err := yaml.Marshal(modelConfig) if err != nil { return gallery.ModelConfig{}, err diff --git a/core/gallery/importers/mlx.go b/core/gallery/importers/mlx.go index faa28846f..7ab513f6d 100644 --- a/core/gallery/importers/mlx.go +++ b/core/gallery/importers/mlx.go @@ -81,6 +81,9 @@ func (i *MLXImporter) Import(details Details) (gallery.ModelConfig, error) { }, } + // Apply per-model-family inference parameter defaults + config.ApplyInferenceDefaults(&modelConfig, details.URI) + data, err := yaml.Marshal(modelConfig) if err != nil { return gallery.ModelConfig{}, err diff --git a/core/gallery/importers/transformers.go b/core/gallery/importers/transformers.go index cd09c366d..5a4732ca8 100644 --- a/core/gallery/importers/transformers.go +++ b/core/gallery/importers/transformers.go @@ -97,6 +97,9 @@ func (i *TransformersImporter) Import(details Details) (gallery.ModelConfig, err modelConfig.ModelType = modelType modelConfig.Quantization = quantization + // Apply per-model-family inference parameter defaults + config.ApplyInferenceDefaults(&modelConfig, details.URI) + data, err := yaml.Marshal(modelConfig) if err != nil { return gallery.ModelConfig{}, err diff --git a/core/gallery/importers/vllm.go b/core/gallery/importers/vllm.go index be544662a..88baef1fe 100644 --- a/core/gallery/importers/vllm.go +++ b/core/gallery/importers/vllm.go @@ -85,6 +85,9 @@ func (i *VLLMImporter) Import(details Details) (gallery.ModelConfig, error) { }, } + // Apply per-model-family inference parameter defaults + config.ApplyInferenceDefaults(&modelConfig, details.URI) + data, err := yaml.Marshal(modelConfig) if err != nil { return gallery.ModelConfig{}, err diff --git a/core/gallery/models.go b/core/gallery/models.go index a2da4442f..3aa5e4db8 100644 --- a/core/gallery/models.go +++ b/core/gallery/models.go @@ -264,6 +264,49 @@ func InstallModel(ctx context.Context, systemState *system.SystemState, nameOver return nil, fmt.Errorf("failed to unmarshal updated config YAML: %v", err) } + // Apply model-family-specific inference defaults so they are persisted in the config YAML. + // Apply to the typed struct for validation, and merge into configMap for serialization + // (configMap preserves unknown fields that ModelConfig would drop). + lconfig.ApplyInferenceDefaults(&modelConfig, name, modelConfig.Model) + + // Merge inference defaults into configMap so they are persisted without losing unknown fields. + if modelConfig.Temperature != nil { + if _, exists := configMap["temperature"]; !exists { + configMap["temperature"] = *modelConfig.Temperature + } + } + if modelConfig.TopP != nil { + if _, exists := configMap["top_p"]; !exists { + configMap["top_p"] = *modelConfig.TopP + } + } + if modelConfig.TopK != nil { + if _, exists := configMap["top_k"]; !exists { + configMap["top_k"] = *modelConfig.TopK + } + } + if modelConfig.MinP != nil { + if _, exists := configMap["min_p"]; !exists { + configMap["min_p"] = *modelConfig.MinP + } + } + if modelConfig.RepeatPenalty != 0 { + if _, exists := configMap["repeat_penalty"]; !exists { + configMap["repeat_penalty"] = modelConfig.RepeatPenalty + } + } + if modelConfig.PresencePenalty != 0 { + if _, exists := configMap["presence_penalty"]; !exists { + configMap["presence_penalty"] = modelConfig.PresencePenalty + } + } + + // Re-marshal from configMap to preserve unknown fields + updatedConfigYAML, err = yaml.Marshal(configMap) + if err != nil { + return nil, fmt.Errorf("failed to marshal config with inference defaults: %v", err) + } + if valid, err := modelConfig.Validate(); !valid { return nil, fmt.Errorf("failed to validate updated config YAML: %v", err) } diff --git a/core/http/endpoints/anthropic/messages.go b/core/http/endpoints/anthropic/messages.go index a08230510..adb4b989f 100644 --- a/core/http/endpoints/anthropic/messages.go +++ b/core/http/endpoints/anthropic/messages.go @@ -306,6 +306,35 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic if textContent != "" { contentBlocks = append([]schema.AnthropicContentBlock{{Type: "text", Text: textContent}}, contentBlocks...) } + } else if !shouldUseFn && cfg.FunctionsConfig.AutomaticToolParsingFallback && result != "" { + // Automatic tool parsing fallback: no tools in request but model emitted tool call markup + parsed := functions.ParseFunctionCall(result, cfg.FunctionsConfig) + if len(parsed) > 0 { + stopReason = "tool_use" + stripped := functions.StripToolCallMarkup(result) + if stripped != "" { + contentBlocks = append(contentBlocks, schema.AnthropicContentBlock{Type: "text", Text: stripped}) + } + for i, fc := range parsed { + var inputArgs map[string]interface{} + if err := json.Unmarshal([]byte(fc.Arguments), &inputArgs); err != nil { + inputArgs = map[string]interface{}{"raw": fc.Arguments} + } + toolCallID := fc.ID + if toolCallID == "" { + toolCallID = fmt.Sprintf("toolu_%s_%d", id, i) + } + contentBlocks = append(contentBlocks, schema.AnthropicContentBlock{ + Type: "tool_use", + ID: toolCallID, + Name: fc.Name, + Input: inputArgs, + }) + } + } else { + stopReason = "end_turn" + contentBlocks = []schema.AnthropicContentBlock{{Type: "text", Text: result}} + } } else { stopReason = "end_turn" contentBlocks = []schema.AnthropicContentBlock{ @@ -522,6 +551,51 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq } } + // Automatic tool parsing fallback for streaming: when no tools were requested + // but the model emitted tool call markup, parse and emit as tool_use blocks. + if !shouldUseFn && cfg.FunctionsConfig.AutomaticToolParsingFallback && accumulatedContent != "" && toolCallsEmitted == 0 { + parsed := functions.ParseFunctionCall(accumulatedContent, cfg.FunctionsConfig) + if len(parsed) > 0 { + // Close the text content block + sendAnthropicSSE(c, schema.AnthropicStreamEvent{ + Type: "content_block_stop", + Index: currentBlockIndex, + }) + currentBlockIndex++ + inToolCall = true + + for i, fc := range parsed { + toolCallID := fc.ID + if toolCallID == "" { + toolCallID = fmt.Sprintf("toolu_%s_%d", id, i) + } + sendAnthropicSSE(c, schema.AnthropicStreamEvent{ + Type: "content_block_start", + Index: currentBlockIndex, + ContentBlock: &schema.AnthropicContentBlock{ + Type: "tool_use", + ID: toolCallID, + Name: fc.Name, + }, + }) + sendAnthropicSSE(c, schema.AnthropicStreamEvent{ + Type: "content_block_delta", + Index: currentBlockIndex, + Delta: &schema.AnthropicStreamDelta{ + Type: "input_json_delta", + PartialJSON: fc.Arguments, + }, + }) + sendAnthropicSSE(c, schema.AnthropicStreamEvent{ + Type: "content_block_stop", + Index: currentBlockIndex, + }) + currentBlockIndex++ + toolCallsEmitted++ + } + } + } + // No MCP tools to execute, close stream if !inToolCall { sendAnthropicSSE(c, schema.AnthropicStreamEvent{ diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index ac8921d67..871084054 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -751,8 +751,8 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator collectedToolCalls = mergeToolCallDeltas(collectedToolCalls, ev.Choices[0].Delta.ToolCalls) } } - // Collect content for MCP conversation history - if hasMCPToolsStream && ev.Choices[0].Delta != nil && ev.Choices[0].Delta.Content != nil { + // Collect content for MCP conversation history and automatic tool parsing fallback + if (hasMCPToolsStream || config.FunctionsConfig.AutomaticToolParsingFallback) && ev.Choices[0].Delta != nil && ev.Choices[0].Delta.Content != nil { if s, ok := ev.Choices[0].Delta.Content.(string); ok { collectedContent += s } else if sp, ok := ev.Choices[0].Delta.Content.(*string); ok && sp != nil { @@ -857,6 +857,43 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator } } + // Automatic tool parsing fallback for streaming: when no tools were + // requested but the model emitted tool call markup, parse and emit them. + if !shouldUseFn && config.FunctionsConfig.AutomaticToolParsingFallback && collectedContent != "" && !toolsCalled { + parsed := functions.ParseFunctionCall(collectedContent, config.FunctionsConfig) + for i, fc := range parsed { + toolCallID := fc.ID + if toolCallID == "" { + toolCallID = id + } + toolCallMsg := schema.OpenAIResponse{ + ID: id, + Created: created, + Model: input.Model, + Choices: []schema.Choice{{ + Delta: &schema.Message{ + Role: "assistant", + ToolCalls: []schema.ToolCall{{ + Index: i, + ID: toolCallID, + Type: "function", + FunctionCall: schema.FunctionCall{ + Name: fc.Name, + Arguments: fc.Arguments, + }, + }}, + }, + Index: 0, + }}, + Object: "chat.completion.chunk", + } + respData, _ := json.Marshal(toolCallMsg) + fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData) + c.Response().Flush() + toolsCalled = true + } + } + // No MCP tools to execute, send final stop message finishReason := FinishReasonStop if toolsCalled && len(input.Tools) > 0 { @@ -995,6 +1032,16 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator funcResults = functions.ParseFunctionCall(cbRawResult, config.FunctionsConfig) } + // Content-based tool call fallback: if no tool calls were found, + // try parsing the raw result — ParseFunctionCall handles detection internally. + if len(funcResults) == 0 { + contentFuncResults := functions.ParseFunctionCall(cbRawResult, config.FunctionsConfig) + if len(contentFuncResults) > 0 { + funcResults = contentFuncResults + textContentToReturn = functions.StripToolCallMarkup(cbRawResult) + } + } + noActionsToRun := len(funcResults) > 0 && funcResults[0].Name == noActionName || len(funcResults) == 0 switch { @@ -1070,6 +1117,48 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator } } + // Automatic tool parsing fallback: when no tools/functions were in the + // request but the model emitted tool call markup, parse and surface them. + if !shouldUseFn && config.FunctionsConfig.AutomaticToolParsingFallback && len(result) > 0 { + for i, choice := range result { + if choice.Message == nil || choice.Message.Content == nil { + continue + } + contentStr, ok := choice.Message.Content.(string) + if !ok || contentStr == "" { + continue + } + parsed := functions.ParseFunctionCall(contentStr, config.FunctionsConfig) + if len(parsed) == 0 { + continue + } + stripped := functions.StripToolCallMarkup(contentStr) + toolCallsReason := FinishReasonToolCalls + result[i].FinishReason = &toolCallsReason + if stripped != "" { + result[i].Message.Content = &stripped + } else { + result[i].Message.Content = nil + } + for _, fc := range parsed { + toolCallID := fc.ID + if toolCallID == "" { + toolCallID = id + } + result[i].Message.ToolCalls = append(result[i].Message.ToolCalls, + schema.ToolCall{ + ID: toolCallID, + Type: "function", + FunctionCall: schema.FunctionCall{ + Name: fc.Name, + Arguments: fc.Arguments, + }, + }, + ) + } + } + } + // MCP server-side tool execution loop: // If we have MCP tools and the model returned tool_calls, execute MCP tools // and re-run inference with the results appended to the conversation. diff --git a/core/http/endpoints/openresponses/responses.go b/core/http/endpoints/openresponses/responses.go index 28eb6433a..37c59b568 100644 --- a/core/http/endpoints/openresponses/responses.go +++ b/core/http/endpoints/openresponses/responses.go @@ -1013,6 +1013,35 @@ func handleBackgroundNonStream(ctx context.Context, store *ResponseStore, respon Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(result, resultLogprobs)}, }) } + } else if !shouldUseFn && cfg.FunctionsConfig.AutomaticToolParsingFallback && result != "" { + // Automatic tool parsing fallback: no tools in request but model emitted tool call markup + parsed := functions.ParseFunctionCall(result, cfg.FunctionsConfig) + if len(parsed) > 0 { + stripped := functions.StripToolCallMarkup(result) + if stripped != "" { + allOutputItems = append(allOutputItems, schema.ORItemField{ + Type: "message", ID: fmt.Sprintf("msg_%s", uuid.New().String()), + Status: "completed", Role: "assistant", + Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(stripped, resultLogprobs)}, + }) + } + for _, fc := range parsed { + toolCallID := fc.ID + if toolCallID == "" { + toolCallID = fmt.Sprintf("fc_%s", uuid.New().String()) + } + allOutputItems = append(allOutputItems, schema.ORItemField{ + Type: "function_call", ID: fmt.Sprintf("fc_%s", uuid.New().String()), + Status: "completed", CallID: toolCallID, Name: fc.Name, Arguments: fc.Arguments, + }) + } + } else { + allOutputItems = append(allOutputItems, schema.ORItemField{ + Type: "message", ID: fmt.Sprintf("msg_%s", uuid.New().String()), + Status: "completed", Role: "assistant", + Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(result, resultLogprobs)}, + }) + } } else { allOutputItems = append(allOutputItems, schema.ORItemField{ Type: "message", ID: fmt.Sprintf("msg_%s", uuid.New().String()), @@ -1539,6 +1568,43 @@ func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt i Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(cleanedResult, resultLogprobs)}, }) } + } else if !shouldUseFn && cfg.FunctionsConfig.AutomaticToolParsingFallback && cleanedResult != "" { + // Automatic tool parsing fallback: no tools in request but model emitted tool call markup + parsed := functions.ParseFunctionCall(cleanedResult, cfg.FunctionsConfig) + if len(parsed) > 0 { + stripped := functions.StripToolCallMarkup(cleanedResult) + if stripped != "" { + outputItems = append(outputItems, schema.ORItemField{ + Type: "message", + ID: fmt.Sprintf("msg_%s", uuid.New().String()), + Status: "completed", + Role: "assistant", + Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(stripped, resultLogprobs)}, + }) + } + for _, fc := range parsed { + toolCallID := fc.ID + if toolCallID == "" { + toolCallID = fmt.Sprintf("fc_%s", uuid.New().String()) + } + outputItems = append(outputItems, schema.ORItemField{ + Type: "function_call", + ID: fmt.Sprintf("fc_%s", uuid.New().String()), + Status: "completed", + CallID: toolCallID, + Name: fc.Name, + Arguments: fc.Arguments, + }) + } + } else { + outputItems = append(outputItems, schema.ORItemField{ + Type: "message", + ID: fmt.Sprintf("msg_%s", uuid.New().String()), + Status: "completed", + Role: "assistant", + Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(cleanedResult, resultLogprobs)}, + }) + } } else { // Simple text response (include logprobs if available) messageItem := schema.ORItemField{ @@ -2514,6 +2580,15 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6 result = finalCleanedResult + // Automatic tool parsing fallback for streaming: parse tool calls from accumulated text + var streamFallbackToolCalls []functions.FuncCallResults + if cfg.FunctionsConfig.AutomaticToolParsingFallback && result != "" { + streamFallbackToolCalls = functions.ParseFunctionCall(result, cfg.FunctionsConfig) + if len(streamFallbackToolCalls) > 0 { + result = functions.StripToolCallMarkup(result) + } + } + // Convert logprobs for streaming events mcpStreamLogprobs := convertLogprobsForStreaming(noToolLogprobs) @@ -2552,10 +2627,42 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6 }) sequenceNumber++ + // Emit function_call items from automatic tool parsing fallback + for _, fc := range streamFallbackToolCalls { + toolCallID := fc.ID + if toolCallID == "" { + toolCallID = fmt.Sprintf("fc_%s", uuid.New().String()) + } + outputIndex++ + functionCallItem := &schema.ORItemField{ + Type: "function_call", + ID: toolCallID, + Status: "completed", + CallID: toolCallID, + Name: fc.Name, + Arguments: fc.Arguments, + } + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_item.added", + SequenceNumber: sequenceNumber, + OutputIndex: &outputIndex, + Item: functionCallItem, + }) + sequenceNumber++ + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_item.done", + SequenceNumber: sequenceNumber, + OutputIndex: &outputIndex, + Item: functionCallItem, + }) + sequenceNumber++ + collectedOutputItems = append(collectedOutputItems, *functionCallItem) + } + // Emit response.completed now := time.Now().Unix() - // Collect final output items (reasoning first, then message) + // Collect final output items (reasoning first, then messages, then tool calls) var finalOutputItems []schema.ORItemField // Add reasoning item if it exists if currentReasoningID != "" && finalReasoning != "" { @@ -2577,6 +2684,12 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6 } else { finalOutputItems = append(finalOutputItems, *messageItem) } + // Add function_call items from fallback + for _, item := range collectedOutputItems { + if item.Type == "function_call" { + finalOutputItems = append(finalOutputItems, item) + } + } responseCompleted := buildORResponse(responseID, createdAt, &now, "completed", input, finalOutputItems, &schema.ORUsage{ InputTokens: noToolTokenUsage.Prompt, OutputTokens: noToolTokenUsage.Completion, diff --git a/core/http/middleware/request.go b/core/http/middleware/request.go index f9e34940b..a853eb3d6 100644 --- a/core/http/middleware/request.go +++ b/core/http/middleware/request.go @@ -222,6 +222,9 @@ func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema. if input.TopP != nil { config.TopP = input.TopP } + if input.MinP != nil { + config.MinP = input.MinP + } if input.Backend != "" { config.Backend = input.Backend diff --git a/core/schema/prediction.go b/core/schema/prediction.go index cf1eda841..913e13c1a 100644 --- a/core/schema/prediction.go +++ b/core/schema/prediction.go @@ -95,6 +95,7 @@ type PredictionOptions struct { // Common options between all the API calls, part of the OpenAI spec TopP *float64 `json:"top_p,omitempty" yaml:"top_p,omitempty"` TopK *int `json:"top_k,omitempty" yaml:"top_k,omitempty"` + MinP *float64 `json:"min_p,omitempty" yaml:"min_p,omitempty"` Temperature *float64 `json:"temperature,omitempty" yaml:"temperature,omitempty"` Maxtokens *int `json:"max_tokens,omitempty" yaml:"max_tokens,omitempty"` Echo bool `json:"echo,omitempty" yaml:"echo,omitempty"` diff --git a/pkg/functions/iterative_parser.go b/pkg/functions/iterative_parser.go index 724603993..2d613789f 100644 --- a/pkg/functions/iterative_parser.go +++ b/pkg/functions/iterative_parser.go @@ -621,25 +621,39 @@ func (p *ChatMsgParser) TryConsumeXMLToolCalls(format *XMLToolCallFormat) (bool, // Handle Functionary format (JSON parameters inside XML tags) - use regex parser if format.KeyStart == "" && format.ToolStart == "= 0 { + p.pos += last + len(format.ToolEnd) + } return true, nil } // Handle JSON-like formats (Apriel-1.5, Xiaomi-MiMo) - use regex parser if format.ToolStart != "" && strings.Contains(format.ToolStart, "{\"name\"") { - results, err := parseJSONLikeXMLFormat(p.input[p.pos:], format) + sub := p.input[p.pos:] + results, err := parseJSONLikeXMLFormat(sub, format) if err != nil || len(results) == 0 { return false, nil } for _, result := range results { p.AddToolCall(result.Name, "", result.Arguments) } + // Advance position past the last scope/tool end tag + endTag := format.ScopeEnd + if endTag == "" { + endTag = format.ToolEnd + } + if last := strings.LastIndex(sub, endTag); last >= 0 { + p.pos += last + len(endTag) + } return true, nil } diff --git a/pkg/functions/parse.go b/pkg/functions/parse.go index 2076be3cd..58becb2e9 100644 --- a/pkg/functions/parse.go +++ b/pkg/functions/parse.go @@ -111,6 +111,11 @@ type FunctionsConfig struct { // If set, only this format will be tried (overrides XMLFormatPreset) XMLFormat *XMLToolCallFormat `yaml:"xml_format,omitempty" json:"xml_format,omitempty"` + // AutomaticToolParsingFallback enables automatic tool call parsing fallback: + // - Wraps raw string arguments as {"query": raw_string} when JSON parsing fails + // - Parses tool calls from response content even when no tools were in the request + AutomaticToolParsingFallback bool `yaml:"automatic_tool_parsing_fallback,omitempty" json:"automatic_tool_parsing_fallback,omitempty"` + // DisablePEGParser disables the PEG parser and falls back to the legacy iterative parser DisablePEGParser bool `yaml:"disable_peg_parser,omitempty" json:"disable_peg_parser,omitempty"` @@ -549,35 +554,64 @@ func getScopeOrToolStart(format *XMLToolCallFormat) string { return format.ToolStart } +// ParseResult holds tool calls and any non-tool-call content extracted by the parser. +type ParseResult struct { + ToolCalls []FuncCallResults + Content string +} + // tryParseXMLFromScopeStart finds the first occurrence of scopeStart (or format.ToolStart), -// splits the input there, and parses only the suffix as XML tool calls. Returns (toolCalls, true) -// if any tool calls were parsed, else (nil, false). This mimics llama.cpp's PEG order so that +// splits the input there, and parses only the suffix as XML tool calls. Returns (result, true) +// if any tool calls were parsed, else (empty, false). This mimics llama.cpp's PEG order so that // reasoning or content before the tool block does not cause "whitespace only before scope" to fail. -func tryParseXMLFromScopeStart(s string, format *XMLToolCallFormat, isPartial bool) ([]FuncCallResults, bool) { +func tryParseXMLFromScopeStart(s string, format *XMLToolCallFormat, isPartial bool) (ParseResult, bool) { if format == nil { - return nil, false + return ParseResult{}, false } scopeStart := getScopeOrToolStart(format) if scopeStart == "" { - return nil, false + return ParseResult{}, false } idx := strings.Index(s, scopeStart) if idx < 0 { - return nil, false + return ParseResult{}, false } toolCallsPart := s[idx:] parser := NewChatMsgParser(toolCallsPart, isPartial) success, err := parser.TryConsumeXMLToolCalls(format) if err != nil { if _, ok := err.(*ChatMsgPartialException); ok && isPartial { - return parser.ToolCalls(), len(parser.ToolCalls()) > 0 + tc := parser.ToolCalls() + if len(tc) > 0 { + return ParseResult{ToolCalls: tc, Content: buildContent(s[:idx], parser)}, true + } } - return nil, false + return ParseResult{}, false } if success && len(parser.ToolCalls()) > 0 { - return parser.ToolCalls(), true + return ParseResult{ + ToolCalls: parser.ToolCalls(), + Content: buildContent(s[:idx], parser), + }, true } - return nil, false + return ParseResult{}, false +} + +// buildContent assembles the non-tool-call content from the text before the tool +// block, any content tracked by the parser, and any unconsumed trailing text. +func buildContent(before string, parser *ChatMsgParser) string { + var parts []string + if b := strings.TrimSpace(before); b != "" { + parts = append(parts, b) + } + if pc := strings.TrimSpace(parser.Content()); pc != "" { + parts = append(parts, pc) + } + remaining := parser.Input()[parser.Pos():] + if t := strings.TrimSpace(remaining); t != "" { + parts = append(parts, t) + } + return strings.Join(parts, " ") } // ParseXMLIterative parses XML tool calls using the iterative parser @@ -587,15 +621,15 @@ func tryParseXMLFromScopeStart(s string, format *XMLToolCallFormat, isPartial bo func ParseXMLIterative(s string, format *XMLToolCallFormat, isPartial bool) ([]FuncCallResults, error) { // Try split-on-scope first so reasoning/content before tool block is skipped if format != nil { - if results, ok := tryParseXMLFromScopeStart(s, format, isPartial); ok { - return results, nil + if pr, ok := tryParseXMLFromScopeStart(s, format, isPartial); ok { + return pr.ToolCalls, nil } } else { formats := getAllXMLFormats() for _, fmtPreset := range formats { if fmtPreset.format != nil { - if results, ok := tryParseXMLFromScopeStart(s, fmtPreset.format, isPartial); ok { - return results, nil + if pr, ok := tryParseXMLFromScopeStart(s, fmtPreset.format, isPartial); ok { + return pr.ToolCalls, nil } } } @@ -885,8 +919,17 @@ func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncC // Marshal arguments to JSON string (handles both object and string cases) var d []byte if argsStr, ok := args.(string); ok { - // Already a string, use it directly - d = []byte(argsStr) + // Check if the string is valid JSON; if not, auto-heal if enabled + var testJSON map[string]any + if json.Unmarshal([]byte(argsStr), &testJSON) == nil { + d = []byte(argsStr) + } else if functionConfig.AutomaticToolParsingFallback { + healed := map[string]string{"query": argsStr} + d, _ = json.Marshal(healed) + xlog.Debug("Automatic tool parsing fallback: wrapped raw string arguments", "raw", argsStr) + } else { + d = []byte(argsStr) + } } else { // Object, marshal to JSON d, _ = json.Marshal(args) diff --git a/pkg/functions/parse_test.go b/pkg/functions/parse_test.go index ca7955c44..89c883f78 100644 --- a/pkg/functions/parse_test.go +++ b/pkg/functions/parse_test.go @@ -2536,4 +2536,61 @@ def hello(): }) }) }) + + Context("Automatic tool parsing fallback", func() { + It("wraps malformed string args as query when enabled", func() { + input := `{"name": "web_search", "arguments": "search for cats"}` + cfg := FunctionsConfig{AutomaticToolParsingFallback: true} + results := ParseFunctionCall(input, cfg) + Expect(results).To(HaveLen(1)) + Expect(results[0].Name).To(Equal("web_search")) + + var args map[string]string + err := json.Unmarshal([]byte(results[0].Arguments), &args) + Expect(err).NotTo(HaveOccurred()) + Expect(args["query"]).To(Equal("search for cats")) + }) + + It("preserves malformed string args as-is when disabled", func() { + input := `{"name": "web_search", "arguments": "search for cats"}` + cfg := FunctionsConfig{AutomaticToolParsingFallback: false} + results := ParseFunctionCall(input, cfg) + Expect(results).To(HaveLen(1)) + Expect(results[0].Arguments).To(Equal("search for cats")) + }) + + It("does not alter valid JSON string args", func() { + input := `{"name": "web_search", "arguments": "{\"query\": \"cats\"}"}` + cfg := FunctionsConfig{AutomaticToolParsingFallback: true} + results := ParseFunctionCall(input, cfg) + Expect(results).To(HaveLen(1)) + Expect(results[0].Arguments).To(Equal(`{"query": "cats"}`)) + }) + }) + + Context("StripToolCallMarkup", func() { + It("removes functionary-style function blocks and keeps surrounding text", func() { + input := `Text before {"q":"cats"} text after` + result := StripToolCallMarkup(input) + Expect(result).To(Equal("Text before text after")) + }) + + It("removes qwen3-coder-style tool_call blocks and keeps preceding text", func() { + input := `Here is my answer cats` + result := StripToolCallMarkup(input) + Expect(result).To(Equal("Here is my answer")) + }) + + It("returns empty string when content is only tool calls", func() { + input := `{"q":"cats"}` + result := StripToolCallMarkup(input) + Expect(result).To(Equal("")) + }) + + It("preserves text with no tool call markup", func() { + input := "Just a normal response with no tools" + result := StripToolCallMarkup(input) + Expect(result).To(Equal("Just a normal response with no tools")) + }) + }) }) diff --git a/pkg/functions/strip.go b/pkg/functions/strip.go new file mode 100644 index 000000000..c705b6032 --- /dev/null +++ b/pkg/functions/strip.go @@ -0,0 +1,18 @@ +package functions + +import "strings" + +// StripToolCallMarkup extracts the non-tool-call content from a string +// by reusing the iterative XML parser which already separates content +// from tool calls. Returns the remaining text, trimmed. +func StripToolCallMarkup(content string) string { + for _, fmtPreset := range getAllXMLFormats() { + if fmtPreset.format == nil { + continue + } + if pr, ok := tryParseXMLFromScopeStart(content, fmtPreset.format, false); ok && len(pr.ToolCalls) > 0 { + return strings.TrimSpace(pr.Content) + } + } + return strings.TrimSpace(content) +}