feat(rfdetr): add object detection API (#5923)

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2025-07-27 22:02:51 +02:00
committed by GitHub
parent 73ecb7f90b
commit 949e5b9be8
34 changed files with 884 additions and 7 deletions

34
core/backend/detection.go Normal file
View File

@@ -0,0 +1,34 @@
package backend
import (
"context"
"fmt"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/model"
)
func Detection(
sourceFile string,
loader *model.ModelLoader,
appConfig *config.ApplicationConfig,
backendConfig config.BackendConfig,
) (*proto.DetectResponse, error) {
opts := ModelOptions(backendConfig, appConfig)
detectionModel, err := loader.Load(opts...)
if err != nil {
return nil, err
}
defer loader.Close()
if detectionModel == nil {
return nil, fmt.Errorf("could not load detection model")
}
res, err := detectionModel.Detect(context.Background(), &proto.DetectOptions{
Src: sourceFile,
})
return res, err
}

View File

@@ -458,6 +458,7 @@ const (
FLAG_TOKENIZE BackendConfigUsecases = 0b001000000000
FLAG_VAD BackendConfigUsecases = 0b010000000000
FLAG_VIDEO BackendConfigUsecases = 0b100000000000
FLAG_DETECTION BackendConfigUsecases = 0b1000000000000
// Common Subsets
FLAG_LLM BackendConfigUsecases = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT
@@ -479,6 +480,7 @@ func GetAllBackendConfigUsecases() map[string]BackendConfigUsecases {
"FLAG_VAD": FLAG_VAD,
"FLAG_LLM": FLAG_LLM,
"FLAG_VIDEO": FLAG_VIDEO,
"FLAG_DETECTION": FLAG_DETECTION,
}
}
@@ -572,6 +574,12 @@ func (c *BackendConfig) GuessUsecases(u BackendConfigUsecases) bool {
}
}
if (u & FLAG_DETECTION) == FLAG_DETECTION {
if c.Backend != "rfdetr" {
return false
}
}
if (u & FLAG_SOUND_GENERATION) == FLAG_SOUND_GENERATION {
if c.Backend != "transformers-musicgen" {
return false

View File

@@ -0,0 +1,59 @@
package localai
import (
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/utils"
"github.com/rs/zerolog/log"
)
// DetectionEndpoint is the LocalAI Detection endpoint https://localai.io/docs/api-reference/detection
// @Summary Detects objects in the input image.
// @Param request body schema.DetectionRequest true "query params"
// @Success 200 {object} schema.DetectionResponse "Response"
// @Router /v1/detection [post]
func DetectionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.DetectionRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
}
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
}
log.Debug().Str("image", input.Image).Str("modelFile", "modelFile").Str("backend", cfg.Backend).Msg("Detection")
image, err := utils.GetContentURIAsBase64(input.Image)
if err != nil {
return err
}
res, err := backend.Detection(image, ml, appConfig, *cfg)
if err != nil {
return err
}
response := schema.DetectionResponse{
Detections: make([]schema.Detection, len(res.Detections)),
}
for i, detection := range res.Detections {
response.Detections[i] = schema.Detection{
X: detection.X,
Y: detection.Y,
Width: detection.Width,
Height: detection.Height,
ClassName: detection.ClassName,
}
}
return c.JSON(response)
}
}

View File

@@ -41,6 +41,11 @@ func RegisterLocalAIRoutes(router *fiber.App,
router.Get("/backends/jobs/:uuid", backendGalleryEndpointService.GetOpStatusEndpoint())
}
router.Post("/v1/detection",
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_DETECTION)),
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.DetectionRequest) }),
localai.DetectionEndpoint(cl, ml, appConfig))
router.Post("/tts",
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)),
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TTSRequest) }),

View File

@@ -90,6 +90,14 @@
hx-indicator=".htmx-indicator">
<i class="fas fa-headphones mr-2"></i>Whisper
</button>
<button hx-post="browse/search/backends"
class="inline-flex items-center rounded-full px-4 py-2 text-sm font-medium bg-red-900/60 text-red-200 border border-red-700/50 hover:bg-red-800 transition duration-200 ease-in-out"
hx-target="#search-results"
hx-vals='{"search": "object-detection"}'
onclick="hidePagination()"
hx-indicator=".htmx-indicator">
<i class="fas fa-eye mr-2"></i>Object detection
</button>
</div>
</div>
</div>

View File

@@ -115,6 +115,14 @@
hx-indicator=".htmx-indicator">
<i class="fas fa-headphones mr-2"></i>Audio transcription
</button>
<button hx-post="browse/search/models"
class="inline-flex items-center rounded-full px-4 py-2 text-sm font-medium bg-red-900/60 text-red-200 border border-red-700/50 hover:bg-red-800 transition duration-200 ease-in-out"
hx-target="#search-results"
hx-vals='{"search": "object-detection"}'
onclick="hidePagination()"
hx-indicator=".htmx-indicator">
<i class="fas fa-eye mr-2"></i>Object detection
</button>
</div>
</div>

View File

@@ -120,3 +120,20 @@ type SystemInformationResponse struct {
Backends []string `json:"backends"`
Models []SysInfoModel `json:"loaded_models"`
}
type DetectionRequest struct {
BasicModelRequest
Image string `json:"image"`
}
type DetectionResponse struct {
Detections []Detection `json:"detections"`
}
type Detection struct {
X float32 `json:"x"`
Y float32 `json:"y"`
Width float32 `json:"width"`
Height float32 `json:"height"`
ClassName string `json:"class_name"`
}