mirror of
https://github.com/mudler/LocalAI.git
synced 2026-03-31 21:25:59 -04:00
fix: use exact tag matching for model gallery tag filtering (#9041)
The Search() method uses strings.Contains() on comma-joined tags, causing substring false positives (e.g., "asr" matching "image-diffusers"). Add FilterByTag() method that checks each tag with strings.EqualFold() for exact, case-insensitive matching. Add 'tag' query parameter to /api/models and /api/backends endpoints. Update the React frontend to send filter selections as 'tag' instead of 'term'. Closes #8775 Signed-off-by: majiayu000 <1835304752@qq.com>
This commit is contained in:
@@ -92,6 +92,19 @@ func (gm GalleryElements[T]) Search(term string) GalleryElements[T] {
|
||||
return filteredModels
|
||||
}
|
||||
|
||||
func (gm GalleryElements[T]) FilterByTag(tag string) GalleryElements[T] {
|
||||
var filtered GalleryElements[T]
|
||||
for _, m := range gm {
|
||||
for _, t := range m.GetTags() {
|
||||
if strings.EqualFold(t, tag) {
|
||||
filtered = append(filtered, m)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func (gm GalleryElements[T]) SortByName(sortOrder string) GalleryElements[T] {
|
||||
sort.Slice(gm, func(i, j int) bool {
|
||||
if sortOrder == "asc" {
|
||||
|
||||
@@ -159,6 +159,68 @@ var _ = Describe("Gallery", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Describe("GalleryElements FilterByTag", func() {
|
||||
var elements GalleryElements[*GalleryModel]
|
||||
|
||||
BeforeEach(func() {
|
||||
elements = GalleryElements[*GalleryModel]{
|
||||
{
|
||||
Metadata: Metadata{
|
||||
Name: "whisper-asr",
|
||||
Tags: []string{"asr", "stt"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Metadata: Metadata{
|
||||
Name: "image-diffusers",
|
||||
Tags: []string{"sd", "image"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Metadata: Metadata{
|
||||
Name: "another-stt-model",
|
||||
Tags: []string{"stt", "audio"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Metadata: Metadata{
|
||||
Name: "no-tags-model",
|
||||
Tags: []string{},
|
||||
},
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
It("should return exact tag matches only", func() {
|
||||
results := elements.FilterByTag("asr")
|
||||
Expect(results).To(HaveLen(1))
|
||||
Expect(results[0].GetName()).To(Equal("whisper-asr"))
|
||||
})
|
||||
|
||||
It("should not match substrings (image-diffusers must NOT match 'asr')", func() {
|
||||
results := elements.FilterByTag("asr")
|
||||
for _, r := range results {
|
||||
Expect(r.GetName()).NotTo(Equal("image-diffusers"))
|
||||
}
|
||||
})
|
||||
|
||||
It("should be case insensitive", func() {
|
||||
results := elements.FilterByTag("ASR")
|
||||
Expect(results).To(HaveLen(1))
|
||||
Expect(results[0].GetName()).To(Equal("whisper-asr"))
|
||||
})
|
||||
|
||||
It("should return multiple models with the same tag", func() {
|
||||
results := elements.FilterByTag("stt")
|
||||
Expect(results).To(HaveLen(2))
|
||||
})
|
||||
|
||||
It("should return empty when no models have the tag", func() {
|
||||
results := elements.FilterByTag("nonexistent")
|
||||
Expect(results).To(HaveLen(0))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("GalleryElements SortByName", func() {
|
||||
var elements GalleryElements[*GalleryModel]
|
||||
|
||||
|
||||
@@ -130,13 +130,12 @@ export default function Models() {
|
||||
const filterVal = params.filter !== undefined ? params.filter : filter
|
||||
const sortVal = params.sort !== undefined ? params.sort : sort
|
||||
const backendVal = params.backendFilter !== undefined ? params.backendFilter : backendFilter
|
||||
// Combine search text and filter into 'term' param
|
||||
const term = searchVal || filterVal || ''
|
||||
const queryParams = {
|
||||
page: params.page || page,
|
||||
items: 9,
|
||||
}
|
||||
if (term) queryParams.term = term
|
||||
if (filterVal) queryParams.tag = filterVal
|
||||
if (searchVal) queryParams.term = searchVal
|
||||
if (backendVal) queryParams.backend = backendVal
|
||||
if (sortVal) {
|
||||
queryParams.sort = sortVal
|
||||
|
||||
@@ -210,6 +210,7 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model
|
||||
// Model Gallery APIs (admin only)
|
||||
app.GET("/api/models", func(c echo.Context) error {
|
||||
term := c.QueryParam("term")
|
||||
tag := c.QueryParam("tag")
|
||||
page := c.QueryParam("page")
|
||||
if page == "" {
|
||||
page = "1"
|
||||
@@ -253,6 +254,9 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model
|
||||
}
|
||||
sort.Strings(backendNames)
|
||||
|
||||
if tag != "" {
|
||||
models = gallery.GalleryElements[*gallery.GalleryModel](models).FilterByTag(tag)
|
||||
}
|
||||
if term != "" {
|
||||
models = gallery.GalleryElements[*gallery.GalleryModel](models).Search(term)
|
||||
}
|
||||
@@ -776,6 +780,7 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model
|
||||
// Backend Gallery APIs
|
||||
app.GET("/api/backends", func(c echo.Context) error {
|
||||
term := c.QueryParam("term")
|
||||
tag := c.QueryParam("tag")
|
||||
page := c.QueryParam("page")
|
||||
if page == "" {
|
||||
page = "1"
|
||||
@@ -806,6 +811,9 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model
|
||||
}
|
||||
sort.Strings(tags)
|
||||
|
||||
if tag != "" {
|
||||
backends = gallery.GalleryElements[*gallery.GalleryBackend](backends).FilterByTag(tag)
|
||||
}
|
||||
if term != "" {
|
||||
backends = gallery.GalleryElements[*gallery.GalleryBackend](backends).Search(term)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user