Files
LocalAI/backend/go/acestep-cpp/acestepcpp_test.go
Ettore Di Giacinto a738f8b0e4 feat(backends): add ace-step.cpp (#8965)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-12 18:56:26 +01:00

196 lines
5.0 KiB
Go

package main
import (
"context"
"os"
"os/exec"
"path/filepath"
"testing"
"time"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
const (
testAddr = "localhost:50051"
startupWait = 5 * time.Second
)
func skipIfNoModel(t *testing.T) string {
t.Helper()
modelDir := os.Getenv("ACESTEP_MODEL_DIR")
if modelDir == "" {
t.Skip("ACESTEP_MODEL_DIR not set, skipping test (set to directory with GGUF models)")
}
if _, err := os.Stat(filepath.Join(modelDir, "acestep-5Hz-lm-0.6B-Q8_0.gguf")); os.IsNotExist(err) {
t.Skipf("LM model file not found in %s, skipping", modelDir)
}
if _, err := os.Stat(filepath.Join(modelDir, "Qwen3-Embedding-0.6B-Q8_0.gguf")); os.IsNotExist(err) {
t.Skipf("Text encoder model file not found in %s, skipping", modelDir)
}
if _, err := os.Stat(filepath.Join(modelDir, "acestep-v15-turbo-Q8_0.gguf")); os.IsNotExist(err) {
t.Skipf("DiT model file not found in %s, skipping", modelDir)
}
if _, err := os.Stat(filepath.Join(modelDir, "vae-BF16.gguf")); os.IsNotExist(err) {
t.Skipf("VAE model file not found in %s, skipping", modelDir)
}
return modelDir
}
func startServer(t *testing.T) *exec.Cmd {
t.Helper()
binary := os.Getenv("ACESTEP_BINARY")
if binary == "" {
binary = "./acestep-cpp"
}
if _, err := os.Stat(binary); os.IsNotExist(err) {
t.Skipf("Backend binary not found at %s, skipping", binary)
}
cmd := exec.Command(binary, "--addr", testAddr)
cmd.Stdout = os.Stderr
cmd.Stderr = os.Stderr
if err := cmd.Start(); err != nil {
t.Fatalf("Failed to start server: %v", err)
}
time.Sleep(startupWait)
return cmd
}
func stopServer(cmd *exec.Cmd) {
if cmd != nil && cmd.Process != nil {
cmd.Process.Kill()
cmd.Wait()
}
}
func dialGRPC(t *testing.T) *grpc.ClientConn {
t.Helper()
conn, err := grpc.Dial(testAddr,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultCallOptions(
grpc.MaxCallRecvMsgSize(50*1024*1024),
grpc.MaxCallSendMsgSize(50*1024*1024),
),
)
if err != nil {
t.Fatalf("Failed to dial gRPC: %v", err)
}
return conn
}
func TestServerHealth(t *testing.T) {
cmd := startServer(t)
defer stopServer(cmd)
conn := dialGRPC(t)
defer conn.Close()
client := pb.NewBackendClient(conn)
resp, err := client.Health(context.Background(), &pb.HealthMessage{})
if err != nil {
t.Fatalf("Health check failed: %v", err)
}
if string(resp.Message) != "OK" {
t.Fatalf("Expected OK, got %s", string(resp.Message))
}
}
func TestLoadModel(t *testing.T) {
modelDir := skipIfNoModel(t)
cmd := startServer(t)
defer stopServer(cmd)
conn := dialGRPC(t)
defer conn.Close()
client := pb.NewBackendClient(conn)
resp, err := client.LoadModel(context.Background(), &pb.ModelOptions{
ModelFile: filepath.Join(modelDir, "acestep-5Hz-lm-0.6B-Q8_0.gguf"),
Options: []string{
"text_encoder_model:" + filepath.Join(modelDir, "Qwen3-Embedding-0.6B-Q8_0.gguf"),
"dit_model:" + filepath.Join(modelDir, "acestep-v15-turbo-Q8_0.gguf"),
"vae_model:" + filepath.Join(modelDir, "vae-BF16.gguf"),
},
})
if err != nil {
t.Fatalf("LoadModel failed: %v", err)
}
if !resp.Success {
t.Fatalf("LoadModel returned failure: %s", resp.Message)
}
}
func TestSoundGeneration(t *testing.T) {
modelDir := skipIfNoModel(t)
tmpDir, err := os.MkdirTemp("", "acestep-test")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpDir)
outputFile := filepath.Join(tmpDir, "output.wav")
cmd := startServer(t)
defer stopServer(cmd)
conn := dialGRPC(t)
defer conn.Close()
client := pb.NewBackendClient(conn)
// Load models
loadResp, err := client.LoadModel(context.Background(), &pb.ModelOptions{
ModelFile: filepath.Join(modelDir, "acestep-5Hz-lm-0.6B-Q8_0.gguf"),
Options: []string{
"text_encoder_model:" + filepath.Join(modelDir, "Qwen3-Embedding-0.6B-Q8_0.gguf"),
"dit_model:" + filepath.Join(modelDir, "acestep-v15-turbo-Q8_0.gguf"),
"vae_model:" + filepath.Join(modelDir, "vae-BF16.gguf"),
},
})
if err != nil {
t.Fatalf("LoadModel failed: %v", err)
}
if !loadResp.Success {
t.Fatalf("LoadModel returned failure: %s", loadResp.Message)
}
// Generate music
duration := float32(10.0)
temperature := float32(0.85)
bpm := int32(120)
caption := "A cheerful electronic dance track"
timesig := "4/4"
_, err = client.SoundGeneration(context.Background(), &pb.SoundGenerationRequest{
Text: caption,
Caption: &caption,
Dst: outputFile,
Duration: &duration,
Temperature: &temperature,
Bpm: &bpm,
Timesignature: &timesig,
})
if err != nil {
t.Fatalf("SoundGeneration failed: %v", err)
}
// Verify output file exists and has content
info, err := os.Stat(outputFile)
if os.IsNotExist(err) {
t.Fatal("Output audio file was not created")
}
if err != nil {
t.Fatalf("Failed to stat output file: %v", err)
}
t.Logf("Output file size: %d bytes", info.Size())
// WAV header is 44 bytes minimum; any real audio should be much larger
if info.Size() < 1000 {
t.Errorf("Output file too small (%d bytes), expected real audio data", info.Size())
}
}