mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-16 20:52:08 -04:00
fix(ollama): accept prompt alias on /api/embed for Ollama parity (#9780)
Ollama's embedding endpoint accepts both `input` and `prompt` as the input string value (see ollama/ollama docs/api.md#generate-embeddings). LocalAI only accepted `input`, which broke client libraries that send the `prompt` form. Add `Prompt` to OllamaEmbedRequest and have GetInputStrings fall back to it when Input is unset. Input still wins when both are provided. Fixes #9767. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
@@ -120,10 +120,14 @@ type OllamaGenerateResponse struct {
|
||||
EvalDuration int64 `json:"eval_duration,omitempty"`
|
||||
}
|
||||
|
||||
// OllamaEmbedRequest represents a request to the Ollama Embed API
|
||||
// OllamaEmbedRequest represents a request to the Ollama Embed API.
|
||||
// Ollama's /api/embed endpoint accepts both `input` and `prompt` as the
|
||||
// input string value (see https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings),
|
||||
// so both keys are deserialized here for client compatibility.
|
||||
type OllamaEmbedRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input any `json:"input"` // string or []string
|
||||
Model string `json:"model"`
|
||||
Input any `json:"input,omitempty"` // string or []string
|
||||
Prompt any `json:"prompt,omitempty"` // string or []string (Ollama alias for Input)
|
||||
Options *OllamaOptions `json:"options,omitempty"`
|
||||
}
|
||||
|
||||
@@ -135,10 +139,21 @@ func (r *OllamaEmbedRequest) ModelName(s *string) string {
|
||||
return r.Model
|
||||
}
|
||||
|
||||
// GetInputStrings normalizes the Input field to a string slice
|
||||
// GetInputStrings normalizes the Input/Prompt field to a string slice.
|
||||
// Input takes precedence over Prompt when both are provided.
|
||||
func (r *OllamaEmbedRequest) GetInputStrings() []string {
|
||||
switch v := r.Input.(type) {
|
||||
if v := normalizeOllamaEmbedInput(r.Input); v != nil {
|
||||
return v
|
||||
}
|
||||
return normalizeOllamaEmbedInput(r.Prompt)
|
||||
}
|
||||
|
||||
func normalizeOllamaEmbedInput(v any) []string {
|
||||
switch v := v.(type) {
|
||||
case string:
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
return []string{v}
|
||||
case []any:
|
||||
var result []string
|
||||
|
||||
86
core/schema/ollama_test.go
Normal file
86
core/schema/ollama_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package schema_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
. "github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("OllamaEmbedRequest", func() {
|
||||
|
||||
Context("GetInputStrings", func() {
|
||||
It("returns a single string when Input is a string", func() {
|
||||
req := OllamaEmbedRequest{Input: "hello world"}
|
||||
|
||||
Expect(req.GetInputStrings()).To(Equal([]string{"hello world"}))
|
||||
})
|
||||
|
||||
It("returns a list of strings when Input is a []string", func() {
|
||||
req := OllamaEmbedRequest{Input: []string{"hello", "world"}}
|
||||
|
||||
Expect(req.GetInputStrings()).To(Equal([]string{"hello", "world"}))
|
||||
})
|
||||
|
||||
It("returns a list of strings when Input is a []any (post JSON unmarshal)", func() {
|
||||
req := OllamaEmbedRequest{Input: []any{"hello", "world"}}
|
||||
|
||||
Expect(req.GetInputStrings()).To(Equal([]string{"hello", "world"}))
|
||||
})
|
||||
})
|
||||
|
||||
Context("JSON unmarshaling (Ollama API compatibility)", func() {
|
||||
It("accepts the 'input' field as a single string", func() {
|
||||
body := []byte(`{"model": "m", "input": "why is the sky blue?"}`)
|
||||
|
||||
var req OllamaEmbedRequest
|
||||
Expect(json.Unmarshal(body, &req)).To(Succeed())
|
||||
|
||||
Expect(req.Model).To(Equal("m"))
|
||||
Expect(req.GetInputStrings()).To(Equal([]string{"why is the sky blue?"}))
|
||||
})
|
||||
|
||||
It("accepts the 'input' field as an array of strings", func() {
|
||||
body := []byte(`{"model": "m", "input": ["why is the sky blue?", "why is the grass green?"]}`)
|
||||
|
||||
var req OllamaEmbedRequest
|
||||
Expect(json.Unmarshal(body, &req)).To(Succeed())
|
||||
|
||||
Expect(req.GetInputStrings()).To(Equal([]string{"why is the sky blue?", "why is the grass green?"}))
|
||||
})
|
||||
|
||||
// Ollama's embedding endpoint accepts both `input` and `prompt` keys:
|
||||
// https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings
|
||||
// LocalAI must accept `prompt` so client libraries using that key are not broken.
|
||||
// See https://github.com/mudler/LocalAI/issues/9767.
|
||||
It("accepts the 'prompt' field as a single string (Ollama compatibility)", func() {
|
||||
body := []byte(`{"model": "m", "prompt": "why is the sky blue?"}`)
|
||||
|
||||
var req OllamaEmbedRequest
|
||||
Expect(json.Unmarshal(body, &req)).To(Succeed())
|
||||
|
||||
Expect(req.Model).To(Equal("m"))
|
||||
Expect(req.GetInputStrings()).To(Equal([]string{"why is the sky blue?"}))
|
||||
})
|
||||
|
||||
It("accepts the 'prompt' field as an array of strings (Ollama compatibility)", func() {
|
||||
body := []byte(`{"model": "m", "prompt": ["why is the sky blue?", "why is the grass green?"]}`)
|
||||
|
||||
var req OllamaEmbedRequest
|
||||
Expect(json.Unmarshal(body, &req)).To(Succeed())
|
||||
|
||||
Expect(req.GetInputStrings()).To(Equal([]string{"why is the sky blue?", "why is the grass green?"}))
|
||||
})
|
||||
|
||||
It("prefers 'input' when both 'input' and 'prompt' are provided", func() {
|
||||
body := []byte(`{"model": "m", "input": "from input", "prompt": "from prompt"}`)
|
||||
|
||||
var req OllamaEmbedRequest
|
||||
Expect(json.Unmarshal(body, &req)).To(Succeed())
|
||||
|
||||
Expect(req.GetInputStrings()).To(Equal([]string{"from input"}))
|
||||
})
|
||||
})
|
||||
})
|
||||
Reference in New Issue
Block a user