mirror of
https://github.com/mudler/LocalAI.git
synced 2026-04-18 22:10:15 -04:00
144 lines
3.5 KiB
Go
144 lines
3.5 KiB
Go
package main
|
|
|
|
import (
|
|
"encoding/base64"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"unsafe"
|
|
|
|
"github.com/mudler/LocalAI/pkg/grpc/base"
|
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
|
)
|
|
|
|
type SAM3 struct {
|
|
base.SingleThread
|
|
}
|
|
|
|
var (
|
|
CppLoadModel func(modelPath string, threads int) int
|
|
CppEncodeImage func(imagePath string) int
|
|
CppSegmentPVS func(points uintptr, nPointTriples int, boxes uintptr, nBoxQuads int, threshold float32) int
|
|
CppSegmentPCS func(textPrompt string, threshold float32) int
|
|
CppGetNDetections func() int
|
|
CppGetDetectionX func(i int) float32
|
|
CppGetDetectionY func(i int) float32
|
|
CppGetDetectionW func(i int) float32
|
|
CppGetDetectionH func(i int) float32
|
|
CppGetDetectionScore func(i int) float32
|
|
CppGetDetectionMaskPNG func(i int, buf uintptr, bufSize int) int
|
|
CppFreeResults func()
|
|
)
|
|
|
|
func (s *SAM3) Load(opts *pb.ModelOptions) error {
|
|
modelFile := opts.ModelFile
|
|
if modelFile == "" {
|
|
modelFile = opts.Model
|
|
}
|
|
|
|
var modelPath string
|
|
if filepath.IsAbs(modelFile) {
|
|
modelPath = modelFile
|
|
} else {
|
|
modelPath = filepath.Join(opts.ModelPath, modelFile)
|
|
}
|
|
|
|
threads := int(opts.Threads)
|
|
if threads <= 0 {
|
|
threads = 4
|
|
}
|
|
|
|
ret := CppLoadModel(modelPath, threads)
|
|
if ret != 0 {
|
|
return fmt.Errorf("failed to load SAM3 model (error %d): %s", ret, modelPath)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *SAM3) Detect(opts *pb.DetectOptions) (pb.DetectResponse, error) {
|
|
// Decode base64 image and write to temp file
|
|
imgData, err := base64.StdEncoding.DecodeString(opts.Src)
|
|
if err != nil {
|
|
return pb.DetectResponse{}, fmt.Errorf("failed to decode image: %w", err)
|
|
}
|
|
|
|
tmpFile, err := os.CreateTemp("", "sam3-*.png")
|
|
if err != nil {
|
|
return pb.DetectResponse{}, fmt.Errorf("failed to create temp file: %w", err)
|
|
}
|
|
defer os.Remove(tmpFile.Name())
|
|
|
|
if _, err := tmpFile.Write(imgData); err != nil {
|
|
tmpFile.Close()
|
|
return pb.DetectResponse{}, fmt.Errorf("failed to write temp file: %w", err)
|
|
}
|
|
tmpFile.Close()
|
|
|
|
// Encode image
|
|
ret := CppEncodeImage(tmpFile.Name())
|
|
if ret != 0 {
|
|
return pb.DetectResponse{}, fmt.Errorf("failed to encode image (error %d)", ret)
|
|
}
|
|
|
|
threshold := opts.Threshold
|
|
if threshold <= 0 {
|
|
threshold = 0.5
|
|
}
|
|
|
|
// Determine segmentation mode
|
|
var nDetections int
|
|
if opts.Prompt != "" {
|
|
// Text-prompted segmentation (PCS mode, SAM 3 only)
|
|
nDetections = CppSegmentPCS(opts.Prompt, threshold)
|
|
} else {
|
|
// Point/box-prompted segmentation (PVS mode)
|
|
var pointsPtr uintptr
|
|
var boxesPtr uintptr
|
|
nPointTriples := len(opts.Points) / 3
|
|
nBoxQuads := len(opts.Boxes) / 4
|
|
|
|
if nPointTriples > 0 {
|
|
pointsPtr = uintptr(unsafe.Pointer(&opts.Points[0]))
|
|
}
|
|
if nBoxQuads > 0 {
|
|
boxesPtr = uintptr(unsafe.Pointer(&opts.Boxes[0]))
|
|
}
|
|
|
|
nDetections = CppSegmentPVS(pointsPtr, nPointTriples, boxesPtr, nBoxQuads, threshold)
|
|
}
|
|
|
|
if nDetections < 0 {
|
|
return pb.DetectResponse{}, fmt.Errorf("segmentation failed")
|
|
}
|
|
|
|
defer CppFreeResults()
|
|
|
|
// Build response
|
|
detections := make([]*pb.Detection, nDetections)
|
|
for i := 0; i < nDetections; i++ {
|
|
det := &pb.Detection{
|
|
X: CppGetDetectionX(i),
|
|
Y: CppGetDetectionY(i),
|
|
Width: CppGetDetectionW(i),
|
|
Height: CppGetDetectionH(i),
|
|
Confidence: CppGetDetectionScore(i),
|
|
ClassName: "segment",
|
|
}
|
|
|
|
// Get mask PNG
|
|
maskSize := CppGetDetectionMaskPNG(i, 0, 0)
|
|
if maskSize > 0 {
|
|
maskBuf := make([]byte, maskSize)
|
|
CppGetDetectionMaskPNG(i, uintptr(unsafe.Pointer(&maskBuf[0])), maskSize)
|
|
det.Mask = maskBuf
|
|
}
|
|
|
|
detections[i] = det
|
|
}
|
|
|
|
return pb.DetectResponse{
|
|
Detections: detections,
|
|
}, nil
|
|
}
|