diff --git a/core/schema/ollama.go b/core/schema/ollama.go index 964b922db..8ea414dde 100644 --- a/core/schema/ollama.go +++ b/core/schema/ollama.go @@ -2,6 +2,8 @@ package schema import ( "context" + "encoding/json" + "fmt" "time" ) @@ -18,6 +20,79 @@ type OllamaOptions struct { NumCtx int `json:"num_ctx,omitempty"` } +// UnmarshalJSON accepts integer parameters encoded as either JSON ints +// (`8192`) or JSON floats (`8192.0`). Some clients - notably Home Assistant's +// Ollama integration - serialize ints as floats, which stdlib json refuses +// to decode into int fields. See https://github.com/mudler/LocalAI/issues/9837. +func (o *OllamaOptions) UnmarshalJSON(data []byte) error { + type aux struct { + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + TopK *json.Number `json:"top_k,omitempty"` + NumPredict *json.Number `json:"num_predict,omitempty"` + RepeatPenalty float64 `json:"repeat_penalty,omitempty"` + RepeatLastN *json.Number `json:"repeat_last_n,omitempty"` + Seed *json.Number `json:"seed,omitempty"` + Stop []string `json:"stop,omitempty"` + NumCtx *json.Number `json:"num_ctx,omitempty"` + } + var a aux + if err := json.Unmarshal(data, &a); err != nil { + return err + } + + o.Temperature = a.Temperature + o.TopP = a.TopP + o.RepeatPenalty = a.RepeatPenalty + o.Stop = a.Stop + + var err error + if o.TopK, err = jsonNumberToIntPtr(a.TopK); err != nil { + return fmt.Errorf("options.top_k: %w", err) + } + if o.NumPredict, err = jsonNumberToIntPtr(a.NumPredict); err != nil { + return fmt.Errorf("options.num_predict: %w", err) + } + if o.Seed, err = jsonNumberToIntPtr(a.Seed); err != nil { + return fmt.Errorf("options.seed: %w", err) + } + if o.RepeatLastN, err = jsonNumberToInt(a.RepeatLastN); err != nil { + return fmt.Errorf("options.repeat_last_n: %w", err) + } + if o.NumCtx, err = jsonNumberToInt(a.NumCtx); err != nil { + return fmt.Errorf("options.num_ctx: %w", err) + } + return nil +} + +// jsonNumberToInt parses a json.Number literal as an int, tolerating both +// integer (`8192`) and float (`8192.0`) encodings. A nil pointer or empty +// string yields 0, matching the zero-value semantics of the int fields. +func jsonNumberToInt(n *json.Number) (int, error) { + if n == nil || *n == "" { + return 0, nil + } + if i, err := n.Int64(); err == nil { + return int(i), nil + } + f, err := n.Float64() + if err != nil { + return 0, err + } + return int(f), nil +} + +func jsonNumberToIntPtr(n *json.Number) (*int, error) { + if n == nil { + return nil, nil + } + i, err := jsonNumberToInt(n) + if err != nil { + return nil, err + } + return &i, nil +} + // OllamaMessage represents a message in Ollama chat format type OllamaMessage struct { Role string `json:"role"` diff --git a/core/schema/ollama_test.go b/core/schema/ollama_test.go index 900e6954f..eb50e214f 100644 --- a/core/schema/ollama_test.go +++ b/core/schema/ollama_test.go @@ -84,3 +84,94 @@ var _ = Describe("OllamaEmbedRequest", func() { }) }) }) + +// Several Ollama clients (notably Home Assistant's Python client) encode +// integer parameters as JSON floats (`8192.0`). Stdlib json refuses to +// unmarshal those into `int` fields, so OllamaOptions has a custom +// UnmarshalJSON that accepts both forms. See +// https://github.com/mudler/LocalAI/issues/9837. +var _ = Describe("OllamaOptions JSON unmarshaling", func() { + It("accepts integer literals for int fields", func() { + body := []byte(`{"num_ctx": 8192, "num_predict": 256, "top_k": 40, "seed": 7, "repeat_last_n": 64}`) + + var opts OllamaOptions + Expect(json.Unmarshal(body, &opts)).To(Succeed()) + + Expect(opts.NumCtx).To(Equal(8192)) + Expect(opts.NumPredict).NotTo(BeNil()) + Expect(*opts.NumPredict).To(Equal(256)) + Expect(opts.TopK).NotTo(BeNil()) + Expect(*opts.TopK).To(Equal(40)) + Expect(opts.Seed).NotTo(BeNil()) + Expect(*opts.Seed).To(Equal(7)) + Expect(opts.RepeatLastN).To(Equal(64)) + }) + + It("accepts float literals for int fields (Home Assistant Ollama client)", func() { + body := []byte(`{"num_ctx": 8192.0, "num_predict": 256.0, "top_k": 40.0, "seed": 7.0, "repeat_last_n": 64.0}`) + + var opts OllamaOptions + Expect(json.Unmarshal(body, &opts)).To(Succeed()) + + Expect(opts.NumCtx).To(Equal(8192)) + Expect(opts.NumPredict).NotTo(BeNil()) + Expect(*opts.NumPredict).To(Equal(256)) + Expect(opts.TopK).NotTo(BeNil()) + Expect(*opts.TopK).To(Equal(40)) + Expect(opts.Seed).NotTo(BeNil()) + Expect(*opts.Seed).To(Equal(7)) + Expect(opts.RepeatLastN).To(Equal(64)) + }) + + It("preserves float fields and stop list", func() { + body := []byte(`{"temperature": 0.7, "top_p": 0.9, "repeat_penalty": 1.1, "stop": ["<|end|>", ""]}`) + + var opts OllamaOptions + Expect(json.Unmarshal(body, &opts)).To(Succeed()) + + Expect(opts.Temperature).NotTo(BeNil()) + Expect(*opts.Temperature).To(Equal(0.7)) + Expect(opts.TopP).NotTo(BeNil()) + Expect(*opts.TopP).To(Equal(0.9)) + Expect(opts.RepeatPenalty).To(Equal(1.1)) + Expect(opts.Stop).To(Equal([]string{"<|end|>", ""})) + }) + + It("leaves optional int fields nil when absent", func() { + body := []byte(`{}`) + + var opts OllamaOptions + Expect(json.Unmarshal(body, &opts)).To(Succeed()) + + Expect(opts.NumPredict).To(BeNil()) + Expect(opts.TopK).To(BeNil()) + Expect(opts.Seed).To(BeNil()) + Expect(opts.NumCtx).To(Equal(0)) + Expect(opts.RepeatLastN).To(Equal(0)) + }) + + It("accepts nested options on a chat request with float num_ctx", func() { + // Mirrors the payload Home Assistant sends; reproduces issue #9837. + body := []byte(`{ + "model": "qwen2", + "messages": [{"role": "user", "content": "hi"}], + "options": {"num_ctx": 8192.0, "top_k": 40.0} + }`) + + var req OllamaChatRequest + Expect(json.Unmarshal(body, &req)).To(Succeed()) + + Expect(req.Options).NotTo(BeNil()) + Expect(req.Options.NumCtx).To(Equal(8192)) + Expect(req.Options.TopK).NotTo(BeNil()) + Expect(*req.Options.TopK).To(Equal(40)) + }) + + It("rejects non-numeric values with a clear error", func() { + body := []byte(`{"num_ctx": "not-a-number"}`) + + var opts OllamaOptions + err := json.Unmarshal(body, &opts) + Expect(err).To(HaveOccurred()) + }) +})