mirror of
https://github.com/mudler/LocalAI.git
synced 2026-03-31 13:15:51 -04:00
* feat: add distributed mode (experimental) Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix data races, mutexes, transactions Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactorings Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fixups Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix events and tool stream in agent chat Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * use ginkgo Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(cron): compute correctly time boundaries avoiding re-triggering Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * enhancements, refactorings Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * do not flood of healthy checks Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * do not list obvious backends as text backends Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * tests fixups Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Drop redundant healthcheck Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * enhancements, refactorings Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
110 lines
2.9 KiB
Go
110 lines
2.9 KiB
Go
package main
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"github.com/mudler/LocalAI/pkg/grpc/base"
|
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
|
)
|
|
|
|
var (
|
|
CppLoadModel func(lmModelPath, textEncoderPath, ditModelPath, vaeModelPath string) int
|
|
CppGenerateMusic func(caption, lyrics string, bpm int, keyscale, timesignature string, duration, temperature float32, instrumental bool, seed int, dst string, threads int) int
|
|
)
|
|
|
|
type AceStepCpp struct {
|
|
base.SingleThread
|
|
}
|
|
|
|
func (a *AceStepCpp) Load(opts *pb.ModelOptions) error {
|
|
// ModelFile is the LM model path
|
|
lmModel := opts.ModelFile
|
|
|
|
// Get the base directory from ModelFile for resolving relative paths
|
|
baseDir := opts.ModelPath
|
|
|
|
var textEncoderModel, ditModel, vaeModel string
|
|
|
|
for _, oo := range opts.Options {
|
|
key, value, found := strings.Cut(oo, ":")
|
|
if !found {
|
|
fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo)
|
|
continue
|
|
}
|
|
switch key {
|
|
case "text_encoder_model":
|
|
textEncoderModel = value
|
|
case "dit_model":
|
|
ditModel = value
|
|
case "vae_model":
|
|
vaeModel = value
|
|
default:
|
|
fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo)
|
|
}
|
|
}
|
|
|
|
if textEncoderModel == "" {
|
|
return fmt.Errorf("text_encoder_model option is required")
|
|
}
|
|
if ditModel == "" {
|
|
return fmt.Errorf("dit_model option is required")
|
|
}
|
|
if vaeModel == "" {
|
|
return fmt.Errorf("vae_model option is required")
|
|
}
|
|
|
|
// Resolve relative paths to the base directory
|
|
// If the path doesn't start with "/" it's relative
|
|
if !filepath.IsAbs(textEncoderModel) {
|
|
textEncoderModel = filepath.Join(baseDir, textEncoderModel)
|
|
}
|
|
if !filepath.IsAbs(ditModel) {
|
|
ditModel = filepath.Join(baseDir, ditModel)
|
|
}
|
|
if !filepath.IsAbs(vaeModel) {
|
|
vaeModel = filepath.Join(baseDir, vaeModel)
|
|
}
|
|
|
|
// Also resolve the lmModel if it's relative
|
|
if !filepath.IsAbs(lmModel) {
|
|
lmModel = filepath.Join(baseDir, lmModel)
|
|
}
|
|
|
|
fmt.Fprintf(os.Stderr, "[acestep-cpp] Resolved paths:\n")
|
|
fmt.Fprintf(os.Stderr, " LM Model: %s\n", lmModel)
|
|
fmt.Fprintf(os.Stderr, " Text Encoder: %s\n", textEncoderModel)
|
|
fmt.Fprintf(os.Stderr, " DiT Model: %s\n", ditModel)
|
|
fmt.Fprintf(os.Stderr, " VAE Model: %s\n", vaeModel)
|
|
|
|
if ret := CppLoadModel(lmModel, textEncoderModel, ditModel, vaeModel); ret != 0 {
|
|
return fmt.Errorf("failed to load acestep models (error code: %d)", ret)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *AceStepCpp) SoundGeneration(req *pb.SoundGenerationRequest) error {
|
|
caption := req.GetCaption()
|
|
if caption == "" {
|
|
caption = req.GetText()
|
|
}
|
|
lyrics := req.GetLyrics()
|
|
bpm := int(req.GetBpm())
|
|
keyscale := req.GetKeyscale()
|
|
timesignature := req.GetTimesignature()
|
|
duration := req.GetDuration()
|
|
temperature := req.GetTemperature()
|
|
instrumental := req.GetInstrumental()
|
|
seed := 42
|
|
threads := 4
|
|
|
|
if ret := CppGenerateMusic(caption, lyrics, bpm, keyscale, timesignature, duration, temperature, instrumental, seed, req.GetDst(), threads); ret != 0 {
|
|
return fmt.Errorf("failed to generate music (error code: %d)", ret)
|
|
}
|
|
|
|
return nil
|
|
}
|