mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-16 20:52:08 -04:00
fix(ollama): accept float-encoded integer options (num_ctx, top_k, ...) Home Assistant's Ollama integration encodes integer options as JSON floats (e.g. `"num_ctx": 8192.0`). Stdlib `json.Unmarshal` refuses to decode a number with fractional notation into an `int` field, so the entire request was rejected with HTTP 400 before reaching the backend: Unmarshal type error: expected=int, got=number 8192.0, field=options.num_ctx Add a custom `UnmarshalJSON` on `OllamaOptions` that routes the int fields (`top_k`, `num_predict`, `seed`, `repeat_last_n`, `num_ctx`) through `*json.Number`, then converts via `Int64()` with a `Float64()` fallback. Public field types are unchanged, so endpoint code is untouched. Float fields and `stop` continue to parse via the default path. Fixes #9837 Assisted-by: Claude Code:claude-opus-4-7 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
@@ -2,6 +2,8 @@ package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -18,6 +20,79 @@ type OllamaOptions struct {
|
||||
NumCtx int `json:"num_ctx,omitempty"`
|
||||
}
|
||||
|
||||
// UnmarshalJSON accepts integer parameters encoded as either JSON ints
|
||||
// (`8192`) or JSON floats (`8192.0`). Some clients - notably Home Assistant's
|
||||
// Ollama integration - serialize ints as floats, which stdlib json refuses
|
||||
// to decode into int fields. See https://github.com/mudler/LocalAI/issues/9837.
|
||||
func (o *OllamaOptions) UnmarshalJSON(data []byte) error {
|
||||
type aux struct {
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK *json.Number `json:"top_k,omitempty"`
|
||||
NumPredict *json.Number `json:"num_predict,omitempty"`
|
||||
RepeatPenalty float64 `json:"repeat_penalty,omitempty"`
|
||||
RepeatLastN *json.Number `json:"repeat_last_n,omitempty"`
|
||||
Seed *json.Number `json:"seed,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
NumCtx *json.Number `json:"num_ctx,omitempty"`
|
||||
}
|
||||
var a aux
|
||||
if err := json.Unmarshal(data, &a); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
o.Temperature = a.Temperature
|
||||
o.TopP = a.TopP
|
||||
o.RepeatPenalty = a.RepeatPenalty
|
||||
o.Stop = a.Stop
|
||||
|
||||
var err error
|
||||
if o.TopK, err = jsonNumberToIntPtr(a.TopK); err != nil {
|
||||
return fmt.Errorf("options.top_k: %w", err)
|
||||
}
|
||||
if o.NumPredict, err = jsonNumberToIntPtr(a.NumPredict); err != nil {
|
||||
return fmt.Errorf("options.num_predict: %w", err)
|
||||
}
|
||||
if o.Seed, err = jsonNumberToIntPtr(a.Seed); err != nil {
|
||||
return fmt.Errorf("options.seed: %w", err)
|
||||
}
|
||||
if o.RepeatLastN, err = jsonNumberToInt(a.RepeatLastN); err != nil {
|
||||
return fmt.Errorf("options.repeat_last_n: %w", err)
|
||||
}
|
||||
if o.NumCtx, err = jsonNumberToInt(a.NumCtx); err != nil {
|
||||
return fmt.Errorf("options.num_ctx: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// jsonNumberToInt parses a json.Number literal as an int, tolerating both
|
||||
// integer (`8192`) and float (`8192.0`) encodings. A nil pointer or empty
|
||||
// string yields 0, matching the zero-value semantics of the int fields.
|
||||
func jsonNumberToInt(n *json.Number) (int, error) {
|
||||
if n == nil || *n == "" {
|
||||
return 0, nil
|
||||
}
|
||||
if i, err := n.Int64(); err == nil {
|
||||
return int(i), nil
|
||||
}
|
||||
f, err := n.Float64()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(f), nil
|
||||
}
|
||||
|
||||
func jsonNumberToIntPtr(n *json.Number) (*int, error) {
|
||||
if n == nil {
|
||||
return nil, nil
|
||||
}
|
||||
i, err := jsonNumberToInt(n)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &i, nil
|
||||
}
|
||||
|
||||
// OllamaMessage represents a message in Ollama chat format
|
||||
type OllamaMessage struct {
|
||||
Role string `json:"role"`
|
||||
|
||||
@@ -84,3 +84,94 @@ var _ = Describe("OllamaEmbedRequest", func() {
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// Several Ollama clients (notably Home Assistant's Python client) encode
|
||||
// integer parameters as JSON floats (`8192.0`). Stdlib json refuses to
|
||||
// unmarshal those into `int` fields, so OllamaOptions has a custom
|
||||
// UnmarshalJSON that accepts both forms. See
|
||||
// https://github.com/mudler/LocalAI/issues/9837.
|
||||
var _ = Describe("OllamaOptions JSON unmarshaling", func() {
|
||||
It("accepts integer literals for int fields", func() {
|
||||
body := []byte(`{"num_ctx": 8192, "num_predict": 256, "top_k": 40, "seed": 7, "repeat_last_n": 64}`)
|
||||
|
||||
var opts OllamaOptions
|
||||
Expect(json.Unmarshal(body, &opts)).To(Succeed())
|
||||
|
||||
Expect(opts.NumCtx).To(Equal(8192))
|
||||
Expect(opts.NumPredict).NotTo(BeNil())
|
||||
Expect(*opts.NumPredict).To(Equal(256))
|
||||
Expect(opts.TopK).NotTo(BeNil())
|
||||
Expect(*opts.TopK).To(Equal(40))
|
||||
Expect(opts.Seed).NotTo(BeNil())
|
||||
Expect(*opts.Seed).To(Equal(7))
|
||||
Expect(opts.RepeatLastN).To(Equal(64))
|
||||
})
|
||||
|
||||
It("accepts float literals for int fields (Home Assistant Ollama client)", func() {
|
||||
body := []byte(`{"num_ctx": 8192.0, "num_predict": 256.0, "top_k": 40.0, "seed": 7.0, "repeat_last_n": 64.0}`)
|
||||
|
||||
var opts OllamaOptions
|
||||
Expect(json.Unmarshal(body, &opts)).To(Succeed())
|
||||
|
||||
Expect(opts.NumCtx).To(Equal(8192))
|
||||
Expect(opts.NumPredict).NotTo(BeNil())
|
||||
Expect(*opts.NumPredict).To(Equal(256))
|
||||
Expect(opts.TopK).NotTo(BeNil())
|
||||
Expect(*opts.TopK).To(Equal(40))
|
||||
Expect(opts.Seed).NotTo(BeNil())
|
||||
Expect(*opts.Seed).To(Equal(7))
|
||||
Expect(opts.RepeatLastN).To(Equal(64))
|
||||
})
|
||||
|
||||
It("preserves float fields and stop list", func() {
|
||||
body := []byte(`{"temperature": 0.7, "top_p": 0.9, "repeat_penalty": 1.1, "stop": ["<|end|>", "</s>"]}`)
|
||||
|
||||
var opts OllamaOptions
|
||||
Expect(json.Unmarshal(body, &opts)).To(Succeed())
|
||||
|
||||
Expect(opts.Temperature).NotTo(BeNil())
|
||||
Expect(*opts.Temperature).To(Equal(0.7))
|
||||
Expect(opts.TopP).NotTo(BeNil())
|
||||
Expect(*opts.TopP).To(Equal(0.9))
|
||||
Expect(opts.RepeatPenalty).To(Equal(1.1))
|
||||
Expect(opts.Stop).To(Equal([]string{"<|end|>", "</s>"}))
|
||||
})
|
||||
|
||||
It("leaves optional int fields nil when absent", func() {
|
||||
body := []byte(`{}`)
|
||||
|
||||
var opts OllamaOptions
|
||||
Expect(json.Unmarshal(body, &opts)).To(Succeed())
|
||||
|
||||
Expect(opts.NumPredict).To(BeNil())
|
||||
Expect(opts.TopK).To(BeNil())
|
||||
Expect(opts.Seed).To(BeNil())
|
||||
Expect(opts.NumCtx).To(Equal(0))
|
||||
Expect(opts.RepeatLastN).To(Equal(0))
|
||||
})
|
||||
|
||||
It("accepts nested options on a chat request with float num_ctx", func() {
|
||||
// Mirrors the payload Home Assistant sends; reproduces issue #9837.
|
||||
body := []byte(`{
|
||||
"model": "qwen2",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"options": {"num_ctx": 8192.0, "top_k": 40.0}
|
||||
}`)
|
||||
|
||||
var req OllamaChatRequest
|
||||
Expect(json.Unmarshal(body, &req)).To(Succeed())
|
||||
|
||||
Expect(req.Options).NotTo(BeNil())
|
||||
Expect(req.Options.NumCtx).To(Equal(8192))
|
||||
Expect(req.Options.TopK).NotTo(BeNil())
|
||||
Expect(*req.Options.TopK).To(Equal(40))
|
||||
})
|
||||
|
||||
It("rejects non-numeric values with a clear error", func() {
|
||||
body := []byte(`{"num_ctx": "not-a-number"}`)
|
||||
|
||||
var opts OllamaOptions
|
||||
err := json.Unmarshal(body, &opts)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user