diff --git a/core/schema/ollama.go b/core/schema/ollama.go index 68deaf416..964b922db 100644 --- a/core/schema/ollama.go +++ b/core/schema/ollama.go @@ -120,10 +120,14 @@ type OllamaGenerateResponse struct { EvalDuration int64 `json:"eval_duration,omitempty"` } -// OllamaEmbedRequest represents a request to the Ollama Embed API +// OllamaEmbedRequest represents a request to the Ollama Embed API. +// Ollama's /api/embed endpoint accepts both `input` and `prompt` as the +// input string value (see https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings), +// so both keys are deserialized here for client compatibility. type OllamaEmbedRequest struct { - Model string `json:"model"` - Input any `json:"input"` // string or []string + Model string `json:"model"` + Input any `json:"input,omitempty"` // string or []string + Prompt any `json:"prompt,omitempty"` // string or []string (Ollama alias for Input) Options *OllamaOptions `json:"options,omitempty"` } @@ -135,10 +139,21 @@ func (r *OllamaEmbedRequest) ModelName(s *string) string { return r.Model } -// GetInputStrings normalizes the Input field to a string slice +// GetInputStrings normalizes the Input/Prompt field to a string slice. +// Input takes precedence over Prompt when both are provided. func (r *OllamaEmbedRequest) GetInputStrings() []string { - switch v := r.Input.(type) { + if v := normalizeOllamaEmbedInput(r.Input); v != nil { + return v + } + return normalizeOllamaEmbedInput(r.Prompt) +} + +func normalizeOllamaEmbedInput(v any) []string { + switch v := v.(type) { case string: + if v == "" { + return nil + } return []string{v} case []any: var result []string diff --git a/core/schema/ollama_test.go b/core/schema/ollama_test.go new file mode 100644 index 000000000..900e6954f --- /dev/null +++ b/core/schema/ollama_test.go @@ -0,0 +1,86 @@ +package schema_test + +import ( + "encoding/json" + + . "github.com/mudler/LocalAI/core/schema" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("OllamaEmbedRequest", func() { + + Context("GetInputStrings", func() { + It("returns a single string when Input is a string", func() { + req := OllamaEmbedRequest{Input: "hello world"} + + Expect(req.GetInputStrings()).To(Equal([]string{"hello world"})) + }) + + It("returns a list of strings when Input is a []string", func() { + req := OllamaEmbedRequest{Input: []string{"hello", "world"}} + + Expect(req.GetInputStrings()).To(Equal([]string{"hello", "world"})) + }) + + It("returns a list of strings when Input is a []any (post JSON unmarshal)", func() { + req := OllamaEmbedRequest{Input: []any{"hello", "world"}} + + Expect(req.GetInputStrings()).To(Equal([]string{"hello", "world"})) + }) + }) + + Context("JSON unmarshaling (Ollama API compatibility)", func() { + It("accepts the 'input' field as a single string", func() { + body := []byte(`{"model": "m", "input": "why is the sky blue?"}`) + + var req OllamaEmbedRequest + Expect(json.Unmarshal(body, &req)).To(Succeed()) + + Expect(req.Model).To(Equal("m")) + Expect(req.GetInputStrings()).To(Equal([]string{"why is the sky blue?"})) + }) + + It("accepts the 'input' field as an array of strings", func() { + body := []byte(`{"model": "m", "input": ["why is the sky blue?", "why is the grass green?"]}`) + + var req OllamaEmbedRequest + Expect(json.Unmarshal(body, &req)).To(Succeed()) + + Expect(req.GetInputStrings()).To(Equal([]string{"why is the sky blue?", "why is the grass green?"})) + }) + + // Ollama's embedding endpoint accepts both `input` and `prompt` keys: + // https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings + // LocalAI must accept `prompt` so client libraries using that key are not broken. + // See https://github.com/mudler/LocalAI/issues/9767. + It("accepts the 'prompt' field as a single string (Ollama compatibility)", func() { + body := []byte(`{"model": "m", "prompt": "why is the sky blue?"}`) + + var req OllamaEmbedRequest + Expect(json.Unmarshal(body, &req)).To(Succeed()) + + Expect(req.Model).To(Equal("m")) + Expect(req.GetInputStrings()).To(Equal([]string{"why is the sky blue?"})) + }) + + It("accepts the 'prompt' field as an array of strings (Ollama compatibility)", func() { + body := []byte(`{"model": "m", "prompt": ["why is the sky blue?", "why is the grass green?"]}`) + + var req OllamaEmbedRequest + Expect(json.Unmarshal(body, &req)).To(Succeed()) + + Expect(req.GetInputStrings()).To(Equal([]string{"why is the sky blue?", "why is the grass green?"})) + }) + + It("prefers 'input' when both 'input' and 'prompt' are provided", func() { + body := []byte(`{"model": "m", "input": "from input", "prompt": "from prompt"}`) + + var req OllamaEmbedRequest + Expect(json.Unmarshal(body, &req)).To(Succeed()) + + Expect(req.GetInputStrings()).To(Equal([]string{"from input"})) + }) + }) +})