mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-18 21:45:01 -04:00
reranker tests and top_n check fix #7212 Signed-off-by: Mikhail Khludnev <mkhl@apache.org>
This commit is contained in:
@@ -32,10 +32,22 @@ func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
||||
}
|
||||
|
||||
log.Debug().Str("model", input.Model).Msg("JINA Rerank Request received")
|
||||
|
||||
var requestTopN int32
|
||||
docs := int32(len(input.Documents))
|
||||
if input.TopN == nil { // omit top_n to get all
|
||||
requestTopN = docs
|
||||
} else {
|
||||
requestTopN = int32(*input.TopN)
|
||||
if requestTopN < 1 {
|
||||
return c.JSON(http.StatusUnprocessableEntity, "top_n - should be greater than or equal to 1")
|
||||
}
|
||||
if requestTopN > docs { // make it more obvious for backends
|
||||
requestTopN = docs
|
||||
}
|
||||
}
|
||||
request := &proto.RerankRequest{
|
||||
Query: input.Query,
|
||||
TopN: int32(input.TopN),
|
||||
TopN: requestTopN,
|
||||
Documents: input.Documents,
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ type JINARerankRequest struct {
|
||||
BasicModelRequest
|
||||
Query string `json:"query"`
|
||||
Documents []string `json:"documents"`
|
||||
TopN int `json:"top_n"`
|
||||
TopN *int `json:"top_n,omitempty"`
|
||||
Backend string `json:"backend"`
|
||||
}
|
||||
|
||||
|
||||
@@ -286,45 +286,64 @@ var _ = Describe("E2E test", func() {
|
||||
Context("reranker", func() {
|
||||
It("correctly", func() {
|
||||
modelName := "jina-reranker-v1-base-en"
|
||||
|
||||
req := schema.JINARerankRequest{
|
||||
BasicModelRequest: schema.BasicModelRequest{
|
||||
Model: modelName,
|
||||
},
|
||||
Query: "Organic skincare products for sensitive skin",
|
||||
Documents: []string{
|
||||
"Eco-friendly kitchenware for modern homes",
|
||||
"Biodegradable cleaning supplies for eco-conscious consumers",
|
||||
"Organic cotton baby clothes for sensitive skin",
|
||||
"Natural organic skincare range for sensitive skin",
|
||||
"Tech gadgets for smart homes: 2024 edition",
|
||||
"Sustainable gardening tools and compost solutions",
|
||||
"Sensitive skin-friendly facial cleansers and toners",
|
||||
"Organic food wraps and storage solutions",
|
||||
"All-natural pet food for dogs with allergies",
|
||||
"Yoga mats made from recycled materials",
|
||||
},
|
||||
TopN: 3,
|
||||
const query = "Organic skincare products for sensitive skin"
|
||||
var documents = []string{
|
||||
"Eco-friendly kitchenware for modern homes",
|
||||
"Biodegradable cleaning supplies for eco-conscious consumers",
|
||||
"Organic cotton baby clothes for sensitive skin",
|
||||
"Natural organic skincare range for sensitive skin",
|
||||
"Tech gadgets for smart homes: 2024 edition",
|
||||
"Sustainable gardening tools and compost solutions",
|
||||
"Sensitive skin-friendly facial cleansers and toners",
|
||||
"Organic food wraps and storage solutions",
|
||||
"All-natural pet food for dogs with allergies",
|
||||
"Yoga mats made from recycled materials",
|
||||
}
|
||||
// Exceed len or requested results
|
||||
randomValue := int(GinkgoRandomSeed()) % (len(documents) + 1)
|
||||
requestResults := randomValue + 1 // at least 1 results
|
||||
// Cap expectResults by the length of documents
|
||||
expectResults := min(requestResults, len(documents))
|
||||
var maybeSkipTopN = &requestResults
|
||||
if requestResults >= len(documents) && int(GinkgoRandomSeed())%2 == 0 {
|
||||
maybeSkipTopN = nil
|
||||
}
|
||||
|
||||
serialized, err := json.Marshal(req)
|
||||
Expect(err).To(BeNil())
|
||||
Expect(serialized).ToNot(BeNil())
|
||||
|
||||
rerankerEndpoint := apiEndpoint + "/rerank"
|
||||
resp, err := http.Post(rerankerEndpoint, "application/json", bytes.NewReader(serialized))
|
||||
Expect(err).To(BeNil())
|
||||
Expect(resp).ToNot(BeNil())
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
resp, body := requestRerank(modelName, query, documents, maybeSkipTopN, apiEndpoint)
|
||||
Expect(resp.StatusCode).To(Equal(200), fmt.Sprintf("body: %s, response: %+v", body, resp))
|
||||
|
||||
deserializedResponse := schema.JINARerankResponse{}
|
||||
err = json.Unmarshal(body, &deserializedResponse)
|
||||
err := json.Unmarshal(body, &deserializedResponse)
|
||||
Expect(err).To(BeNil())
|
||||
Expect(deserializedResponse).ToNot(BeZero())
|
||||
Expect(deserializedResponse.Model).To(Equal(modelName))
|
||||
Expect(len(deserializedResponse.Results)).To(BeNumerically(">", 0))
|
||||
//Expect(len(deserializedResponse.Results)).To(BeNumerically(">", 0))
|
||||
Expect(len(deserializedResponse.Results)).To(Equal(expectResults))
|
||||
// Assert that relevance scores are in decreasing order
|
||||
for i := 1; i < len(deserializedResponse.Results); i++ {
|
||||
Expect(deserializedResponse.Results[i].RelevanceScore).To(
|
||||
BeNumerically("<=", deserializedResponse.Results[i-1].RelevanceScore),
|
||||
fmt.Sprintf("Result at index %d should have lower relevance score than previous result.", i),
|
||||
)
|
||||
}
|
||||
// Assert that each result's index points to the correct document
|
||||
for i, result := range deserializedResponse.Results {
|
||||
Expect(result.Index).To(
|
||||
And(
|
||||
BeNumerically(">=", 0),
|
||||
BeNumerically("<", len(documents)),
|
||||
),
|
||||
fmt.Sprintf("Result at position %d has index %d which should be within bounds [0, %d)", i, result.Index, len(documents)),
|
||||
)
|
||||
Expect(result.Document.Text).To(
|
||||
Equal(documents[result.Index]),
|
||||
fmt.Sprintf("Result at position %d (index %d) should have document text '%s', but got '%s'",
|
||||
i, result.Index, documents[result.Index], result.Document.Text),
|
||||
)
|
||||
}
|
||||
zeroOrNeg := int(GinkgoRandomSeed())%2 - 1 // Results in either -1 or 0
|
||||
resp, body = requestRerank(modelName, query, documents, &zeroOrNeg, apiEndpoint)
|
||||
Expect(resp.StatusCode).To(Equal(422), fmt.Sprintf("body: %s, response: %+v", body, resp))
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -350,3 +369,26 @@ func downloadHttpFile(url string) (string, error) {
|
||||
|
||||
return tmpfile.Name(), nil
|
||||
}
|
||||
|
||||
func requestRerank(modelName, query string, documents []string, topN *int, apiEndpoint string) (*http.Response, []byte) {
|
||||
req := schema.JINARerankRequest{
|
||||
BasicModelRequest: schema.BasicModelRequest{
|
||||
Model: modelName,
|
||||
},
|
||||
Query: query,
|
||||
Documents: documents,
|
||||
TopN: topN,
|
||||
}
|
||||
|
||||
serialized, err := json.Marshal(req)
|
||||
Expect(err).To(BeNil())
|
||||
Expect(serialized).ToNot(BeNil())
|
||||
rerankerEndpoint := apiEndpoint + "/rerank"
|
||||
resp, err := http.Post(rerankerEndpoint, "application/json", bytes.NewReader(serialized))
|
||||
Expect(err).To(BeNil())
|
||||
Expect(resp).ToNot(BeNil())
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
return resp, body
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user