Files
LocalAI/backend/go/sam3-cpp/gosam3.go
Ettore Di Giacinto 706cf5d43c feat(sam.cpp): add sam.cpp detection backend (#9288)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-04-09 21:49:11 +02:00

144 lines
3.5 KiB
Go

package main
import (
"encoding/base64"
"fmt"
"os"
"path/filepath"
"unsafe"
"github.com/mudler/LocalAI/pkg/grpc/base"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
)
type SAM3 struct {
base.SingleThread
}
var (
CppLoadModel func(modelPath string, threads int) int
CppEncodeImage func(imagePath string) int
CppSegmentPVS func(points uintptr, nPointTriples int, boxes uintptr, nBoxQuads int, threshold float32) int
CppSegmentPCS func(textPrompt string, threshold float32) int
CppGetNDetections func() int
CppGetDetectionX func(i int) float32
CppGetDetectionY func(i int) float32
CppGetDetectionW func(i int) float32
CppGetDetectionH func(i int) float32
CppGetDetectionScore func(i int) float32
CppGetDetectionMaskPNG func(i int, buf uintptr, bufSize int) int
CppFreeResults func()
)
func (s *SAM3) Load(opts *pb.ModelOptions) error {
modelFile := opts.ModelFile
if modelFile == "" {
modelFile = opts.Model
}
var modelPath string
if filepath.IsAbs(modelFile) {
modelPath = modelFile
} else {
modelPath = filepath.Join(opts.ModelPath, modelFile)
}
threads := int(opts.Threads)
if threads <= 0 {
threads = 4
}
ret := CppLoadModel(modelPath, threads)
if ret != 0 {
return fmt.Errorf("failed to load SAM3 model (error %d): %s", ret, modelPath)
}
return nil
}
func (s *SAM3) Detect(opts *pb.DetectOptions) (pb.DetectResponse, error) {
// Decode base64 image and write to temp file
imgData, err := base64.StdEncoding.DecodeString(opts.Src)
if err != nil {
return pb.DetectResponse{}, fmt.Errorf("failed to decode image: %w", err)
}
tmpFile, err := os.CreateTemp("", "sam3-*.png")
if err != nil {
return pb.DetectResponse{}, fmt.Errorf("failed to create temp file: %w", err)
}
defer os.Remove(tmpFile.Name())
if _, err := tmpFile.Write(imgData); err != nil {
tmpFile.Close()
return pb.DetectResponse{}, fmt.Errorf("failed to write temp file: %w", err)
}
tmpFile.Close()
// Encode image
ret := CppEncodeImage(tmpFile.Name())
if ret != 0 {
return pb.DetectResponse{}, fmt.Errorf("failed to encode image (error %d)", ret)
}
threshold := opts.Threshold
if threshold <= 0 {
threshold = 0.5
}
// Determine segmentation mode
var nDetections int
if opts.Prompt != "" {
// Text-prompted segmentation (PCS mode, SAM 3 only)
nDetections = CppSegmentPCS(opts.Prompt, threshold)
} else {
// Point/box-prompted segmentation (PVS mode)
var pointsPtr uintptr
var boxesPtr uintptr
nPointTriples := len(opts.Points) / 3
nBoxQuads := len(opts.Boxes) / 4
if nPointTriples > 0 {
pointsPtr = uintptr(unsafe.Pointer(&opts.Points[0]))
}
if nBoxQuads > 0 {
boxesPtr = uintptr(unsafe.Pointer(&opts.Boxes[0]))
}
nDetections = CppSegmentPVS(pointsPtr, nPointTriples, boxesPtr, nBoxQuads, threshold)
}
if nDetections < 0 {
return pb.DetectResponse{}, fmt.Errorf("segmentation failed")
}
defer CppFreeResults()
// Build response
detections := make([]*pb.Detection, nDetections)
for i := 0; i < nDetections; i++ {
det := &pb.Detection{
X: CppGetDetectionX(i),
Y: CppGetDetectionY(i),
Width: CppGetDetectionW(i),
Height: CppGetDetectionH(i),
Confidence: CppGetDetectionScore(i),
ClassName: "segment",
}
// Get mask PNG
maskSize := CppGetDetectionMaskPNG(i, 0, 0)
if maskSize > 0 {
maskBuf := make([]byte, maskSize)
CppGetDetectionMaskPNG(i, uintptr(unsafe.Pointer(&maskBuf[0])), maskSize)
det.Mask = maskBuf
}
detections[i] = det
}
return pb.DetectResponse{
Detections: detections,
}, nil
}