mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-24 16:51:44 -04:00
feat(rfdetr): add object detection API (#5923)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
73ecb7f90b
commit
949e5b9be8
34
core/backend/detection.go
Normal file
34
core/backend/detection.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func Detection(
|
||||
sourceFile string,
|
||||
loader *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
backendConfig config.BackendConfig,
|
||||
) (*proto.DetectResponse, error) {
|
||||
opts := ModelOptions(backendConfig, appConfig)
|
||||
detectionModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer loader.Close()
|
||||
|
||||
if detectionModel == nil {
|
||||
return nil, fmt.Errorf("could not load detection model")
|
||||
}
|
||||
|
||||
res, err := detectionModel.Detect(context.Background(), &proto.DetectOptions{
|
||||
Src: sourceFile,
|
||||
})
|
||||
|
||||
return res, err
|
||||
}
|
||||
@@ -458,6 +458,7 @@ const (
|
||||
FLAG_TOKENIZE BackendConfigUsecases = 0b001000000000
|
||||
FLAG_VAD BackendConfigUsecases = 0b010000000000
|
||||
FLAG_VIDEO BackendConfigUsecases = 0b100000000000
|
||||
FLAG_DETECTION BackendConfigUsecases = 0b1000000000000
|
||||
|
||||
// Common Subsets
|
||||
FLAG_LLM BackendConfigUsecases = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT
|
||||
@@ -479,6 +480,7 @@ func GetAllBackendConfigUsecases() map[string]BackendConfigUsecases {
|
||||
"FLAG_VAD": FLAG_VAD,
|
||||
"FLAG_LLM": FLAG_LLM,
|
||||
"FLAG_VIDEO": FLAG_VIDEO,
|
||||
"FLAG_DETECTION": FLAG_DETECTION,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -572,6 +574,12 @@ func (c *BackendConfig) GuessUsecases(u BackendConfigUsecases) bool {
|
||||
}
|
||||
}
|
||||
|
||||
if (u & FLAG_DETECTION) == FLAG_DETECTION {
|
||||
if c.Backend != "rfdetr" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if (u & FLAG_SOUND_GENERATION) == FLAG_SOUND_GENERATION {
|
||||
if c.Backend != "transformers-musicgen" {
|
||||
return false
|
||||
|
||||
59
core/http/endpoints/localai/detection.go
Normal file
59
core/http/endpoints/localai/detection.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// DetectionEndpoint is the LocalAI Detection endpoint https://localai.io/docs/api-reference/detection
|
||||
// @Summary Detects objects in the input image.
|
||||
// @Param request body schema.DetectionRequest true "query params"
|
||||
// @Success 200 {object} schema.DetectionResponse "Response"
|
||||
// @Router /v1/detection [post]
|
||||
func DetectionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.DetectionRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
log.Debug().Str("image", input.Image).Str("modelFile", "modelFile").Str("backend", cfg.Backend).Msg("Detection")
|
||||
|
||||
image, err := utils.GetContentURIAsBase64(input.Image)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
res, err := backend.Detection(image, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response := schema.DetectionResponse{
|
||||
Detections: make([]schema.Detection, len(res.Detections)),
|
||||
}
|
||||
for i, detection := range res.Detections {
|
||||
response.Detections[i] = schema.Detection{
|
||||
X: detection.X,
|
||||
Y: detection.Y,
|
||||
Width: detection.Width,
|
||||
Height: detection.Height,
|
||||
ClassName: detection.ClassName,
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(response)
|
||||
}
|
||||
}
|
||||
@@ -41,6 +41,11 @@ func RegisterLocalAIRoutes(router *fiber.App,
|
||||
router.Get("/backends/jobs/:uuid", backendGalleryEndpointService.GetOpStatusEndpoint())
|
||||
}
|
||||
|
||||
router.Post("/v1/detection",
|
||||
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_DETECTION)),
|
||||
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.DetectionRequest) }),
|
||||
localai.DetectionEndpoint(cl, ml, appConfig))
|
||||
|
||||
router.Post("/tts",
|
||||
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)),
|
||||
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TTSRequest) }),
|
||||
|
||||
@@ -90,6 +90,14 @@
|
||||
hx-indicator=".htmx-indicator">
|
||||
<i class="fas fa-headphones mr-2"></i>Whisper
|
||||
</button>
|
||||
<button hx-post="browse/search/backends"
|
||||
class="inline-flex items-center rounded-full px-4 py-2 text-sm font-medium bg-red-900/60 text-red-200 border border-red-700/50 hover:bg-red-800 transition duration-200 ease-in-out"
|
||||
hx-target="#search-results"
|
||||
hx-vals='{"search": "object-detection"}'
|
||||
onclick="hidePagination()"
|
||||
hx-indicator=".htmx-indicator">
|
||||
<i class="fas fa-eye mr-2"></i>Object detection
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -115,6 +115,14 @@
|
||||
hx-indicator=".htmx-indicator">
|
||||
<i class="fas fa-headphones mr-2"></i>Audio transcription
|
||||
</button>
|
||||
<button hx-post="browse/search/models"
|
||||
class="inline-flex items-center rounded-full px-4 py-2 text-sm font-medium bg-red-900/60 text-red-200 border border-red-700/50 hover:bg-red-800 transition duration-200 ease-in-out"
|
||||
hx-target="#search-results"
|
||||
hx-vals='{"search": "object-detection"}'
|
||||
onclick="hidePagination()"
|
||||
hx-indicator=".htmx-indicator">
|
||||
<i class="fas fa-eye mr-2"></i>Object detection
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
@@ -120,3 +120,20 @@ type SystemInformationResponse struct {
|
||||
Backends []string `json:"backends"`
|
||||
Models []SysInfoModel `json:"loaded_models"`
|
||||
}
|
||||
|
||||
type DetectionRequest struct {
|
||||
BasicModelRequest
|
||||
Image string `json:"image"`
|
||||
}
|
||||
|
||||
type DetectionResponse struct {
|
||||
Detections []Detection `json:"detections"`
|
||||
}
|
||||
|
||||
type Detection struct {
|
||||
X float32 `json:"x"`
|
||||
Y float32 `json:"y"`
|
||||
Width float32 `json:"width"`
|
||||
Height float32 `json:"height"`
|
||||
ClassName string `json:"class_name"`
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user