mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-24 00:28:55 -04:00
Compare commits
1 Commits
master
...
fix/collec
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c0ec84b18c |
@@ -70,7 +70,7 @@ func UploadToCollectionEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
name := c.Param("name")
|
||||
name := decodedParam(c, "name")
|
||||
file, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "file required"})
|
||||
@@ -116,7 +116,7 @@ func ListCollectionEntriesEndpoint(app *application.Application) echo.HandlerFun
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
entries, err := svc.ListCollectionEntriesForUser(userID, c.Param("name"))
|
||||
entries, err := svc.ListCollectionEntriesForUser(userID, decodedParam(c, "name"))
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
@@ -139,7 +139,7 @@ func GetCollectionEntryContentEndpoint(app *application.Application) echo.Handle
|
||||
if err != nil {
|
||||
entry = entryParam
|
||||
}
|
||||
content, chunkCount, err := svc.GetCollectionEntryContentForUser(userID, c.Param("name"), entry)
|
||||
content, chunkCount, err := svc.GetCollectionEntryContentForUser(userID, decodedParam(c, "name"), entry)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
@@ -164,7 +164,7 @@ func SearchCollectionEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
if err := c.Bind(&payload); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
results, err := svc.SearchCollectionForUser(userID, c.Param("name"), payload.Query, payload.MaxResults)
|
||||
results, err := svc.SearchCollectionForUser(userID, decodedParam(c, "name"), payload.Query, payload.MaxResults)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
@@ -182,7 +182,7 @@ func ResetCollectionEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
if err := svc.ResetCollectionForUser(userID, c.Param("name")); err != nil {
|
||||
if err := svc.ResetCollectionForUser(userID, decodedParam(c, "name")); err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
}
|
||||
@@ -202,7 +202,7 @@ func DeleteCollectionEntryEndpoint(app *application.Application) echo.HandlerFun
|
||||
if err := c.Bind(&payload); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
remaining, err := svc.DeleteCollectionEntryForUser(userID, c.Param("name"), payload.Entry)
|
||||
remaining, err := svc.DeleteCollectionEntryForUser(userID, decodedParam(c, "name"), payload.Entry)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
@@ -230,7 +230,7 @@ func AddCollectionSourceEndpoint(app *application.Application) echo.HandlerFunc
|
||||
if payload.UpdateInterval < 1 {
|
||||
payload.UpdateInterval = 60
|
||||
}
|
||||
if err := svc.AddCollectionSourceForUser(userID, c.Param("name"), payload.URL, payload.UpdateInterval); err != nil {
|
||||
if err := svc.AddCollectionSourceForUser(userID, decodedParam(c, "name"), payload.URL, payload.UpdateInterval); err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
}
|
||||
@@ -250,7 +250,7 @@ func RemoveCollectionSourceEndpoint(app *application.Application) echo.HandlerFu
|
||||
if err := c.Bind(&payload); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
if err := svc.RemoveCollectionSourceForUser(userID, c.Param("name"), payload.URL); err != nil {
|
||||
if err := svc.RemoveCollectionSourceForUser(userID, decodedParam(c, "name"), payload.URL); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
|
||||
@@ -267,7 +267,7 @@ func GetCollectionEntryRawFileEndpoint(app *application.Application) echo.Handle
|
||||
if err != nil {
|
||||
entry = entryParam
|
||||
}
|
||||
fpath, err := svc.GetCollectionEntryFilePathForUser(userID, c.Param("name"), entry)
|
||||
fpath, err := svc.GetCollectionEntryFilePathForUser(userID, decodedParam(c, "name"), entry)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
@@ -282,7 +282,7 @@ func ListCollectionSourcesEndpoint(app *application.Application) echo.HandlerFun
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
sources, err := svc.ListCollectionSourcesForUser(userID, c.Param("name"))
|
||||
sources, err := svc.ListCollectionSourcesForUser(userID, decodedParam(c, "name"))
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
|
||||
49
core/http/endpoints/localai/agent_collections_param_test.go
Normal file
49
core/http/endpoints/localai/agent_collections_param_test.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// Regression for #10443: agent/collection names carry a "legacy-api-key:"
|
||||
// prefix, so the ':' is percent-encoded as %3A in the request path. Echo routes
|
||||
// such paths via URL.RawPath and stores the path-param value still escaped, so
|
||||
// handlers must URL-decode it before looking the collection up in the store -
|
||||
// otherwise the lookup sees "legacy-api-key%3ALiteraryResearch" and 404s.
|
||||
var _ = Describe("decodedParam", func() {
|
||||
var e *echo.Echo
|
||||
|
||||
BeforeEach(func() {
|
||||
e = echo.New()
|
||||
})
|
||||
|
||||
// route runs a request through Echo's real router so the path param is
|
||||
// populated exactly as it would be in production, then returns the decoded
|
||||
// value the handler would observe.
|
||||
route := func(rawPath string) string {
|
||||
var got string
|
||||
e.GET("/api/agents/collections/:name/upload", func(c echo.Context) error {
|
||||
got = decodedParam(c, "name")
|
||||
return c.NoContent(http.StatusOK)
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodGet, rawPath, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
return got
|
||||
}
|
||||
|
||||
It("decodes a percent-encoded colon in the collection name", func() {
|
||||
got := route("/api/agents/collections/legacy-api-key%3ALiteraryResearch/upload")
|
||||
Expect(got).To(Equal("legacy-api-key:LiteraryResearch"))
|
||||
})
|
||||
|
||||
It("leaves an unencoded name untouched", func() {
|
||||
got := route("/api/agents/collections/PlainCollection/upload")
|
||||
Expect(got).To(Equal("PlainCollection"))
|
||||
})
|
||||
})
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"io"
|
||||
"maps"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
@@ -33,6 +34,22 @@ func getUserID(c echo.Context) string {
|
||||
return user.ID
|
||||
}
|
||||
|
||||
// decodedParam returns the named path parameter, URL-decoding it.
|
||||
//
|
||||
// Echo routes a request via URL.RawPath whenever the path contains
|
||||
// percent-encoded characters (e.g. %3A for ':'), and in that case stores the
|
||||
// matched path-param value raw/escaped. Agent and collection names carry a
|
||||
// "legacy-api-key:" prefix, so the ':' arrives as %3A and the raw param no
|
||||
// longer matches the stored name. Callers must unescape before lookups.
|
||||
// Falls back to the raw value if it isn't valid percent-encoding.
|
||||
func decodedParam(c echo.Context, name string) string {
|
||||
raw := c.Param(name)
|
||||
if decoded, err := url.PathUnescape(raw); err == nil {
|
||||
return decoded
|
||||
}
|
||||
return raw
|
||||
}
|
||||
|
||||
// isAdminUser returns true if the authenticated user has admin role.
|
||||
func isAdminUser(c echo.Context) bool {
|
||||
user := auth.GetUser(c)
|
||||
@@ -127,7 +144,7 @@ func GetAgentEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
name := c.Param("name")
|
||||
name := decodedParam(c, "name")
|
||||
|
||||
statuses := svc.ListAgentsForUser(userID)
|
||||
active, exists := statuses[name]
|
||||
@@ -142,7 +159,7 @@ func UpdateAgentEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
name := c.Param("name")
|
||||
name := decodedParam(c, "name")
|
||||
var cfg state.AgentConfig
|
||||
if err := c.Bind(&cfg); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
@@ -161,7 +178,7 @@ func DeleteAgentEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
name := c.Param("name")
|
||||
name := decodedParam(c, "name")
|
||||
if err := svc.DeleteAgentForUser(userID, name); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
@@ -173,7 +190,7 @@ func GetAgentConfigEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
name := c.Param("name")
|
||||
name := decodedParam(c, "name")
|
||||
cfg := svc.GetAgentConfigForUser(userID, name)
|
||||
if cfg == nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": "Agent not found"})
|
||||
@@ -186,7 +203,7 @@ func PauseAgentEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
if err := svc.PauseAgentForUser(userID, c.Param("name")); err != nil {
|
||||
if err := svc.PauseAgentForUser(userID, decodedParam(c, "name")); err != nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
|
||||
@@ -197,7 +214,7 @@ func ResumeAgentEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
if err := svc.ResumeAgentForUser(userID, c.Param("name")); err != nil {
|
||||
if err := svc.ResumeAgentForUser(userID, decodedParam(c, "name")); err != nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
|
||||
@@ -208,7 +225,7 @@ func GetAgentStatusEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
name := c.Param("name")
|
||||
name := decodedParam(c, "name")
|
||||
|
||||
history := svc.GetAgentStatusForUser(userID, name)
|
||||
if history == nil {
|
||||
@@ -241,7 +258,7 @@ func GetAgentObservablesEndpoint(app *application.Application) echo.HandlerFunc
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
name := c.Param("name")
|
||||
name := decodedParam(c, "name")
|
||||
|
||||
history, err := svc.GetAgentObservablesForUser(userID, name)
|
||||
if err != nil {
|
||||
@@ -261,7 +278,7 @@ func ClearAgentObservablesEndpoint(app *application.Application) echo.HandlerFun
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
name := c.Param("name")
|
||||
name := decodedParam(c, "name")
|
||||
if err := svc.ClearAgentObservablesForUser(userID, name); err != nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
}
|
||||
@@ -273,7 +290,7 @@ func ChatWithAgentEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
name := c.Param("name")
|
||||
name := decodedParam(c, "name")
|
||||
var payload struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
@@ -302,7 +319,7 @@ func AgentSSEEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
name := c.Param("name")
|
||||
name := decodedParam(c, "name")
|
||||
|
||||
// Try local SSE manager first
|
||||
manager := svc.GetSSEManagerForUser(userID, name)
|
||||
@@ -334,7 +351,7 @@ func ExportAgentEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
name := c.Param("name")
|
||||
name := decodedParam(c, "name")
|
||||
data, err := svc.ExportAgentForUser(userID, name)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
|
||||
Reference in New Issue
Block a user