mirror of
https://github.com/mudler/LocalAI.git
synced 2026-02-05 04:02:45 -05:00
feat(metal): try to extend support to remaining backends (#8374)
* feat(metal): try to extend support to remaining backends Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * neutts doesn't work Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * split outetts out of transformers Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Remove torch pin to whisperx Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
5195062e12
commit
e7fc604dbc
@@ -24,7 +24,6 @@ XPU=os.environ.get("XPU", "0") == "1"
|
||||
from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria, MambaConfig, MambaForCausalLM
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration, DiaForConditionalGeneration
|
||||
from scipy.io import wavfile
|
||||
import outetts
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
|
||||
@@ -89,7 +88,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
self.CUDA = torch.cuda.is_available()
|
||||
self.OV=False
|
||||
self.OuteTTS=False
|
||||
self.DiaTTS=False
|
||||
self.SentenceTransformer = False
|
||||
|
||||
@@ -239,45 +237,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
self.processor = self.processor.to("cuda")
|
||||
print("DiaForConditionalGeneration loaded", file=sys.stderr)
|
||||
self.DiaTTS = True
|
||||
elif request.Type == "OuteTTS":
|
||||
autoTokenizer = False
|
||||
options = request.Options
|
||||
MODELNAME = "OuteAI/OuteTTS-0.3-1B"
|
||||
TOKENIZER = "OuteAI/OuteTTS-0.3-1B"
|
||||
VERSION = "0.3"
|
||||
SPEAKER = "en_male_1"
|
||||
for opt in options:
|
||||
if opt.startswith("tokenizer:"):
|
||||
TOKENIZER = opt.split(":")[1]
|
||||
break
|
||||
if opt.startswith("version:"):
|
||||
VERSION = opt.split(":")[1]
|
||||
break
|
||||
if opt.startswith("speaker:"):
|
||||
SPEAKER = opt.split(":")[1]
|
||||
break
|
||||
|
||||
if model_name != "":
|
||||
MODELNAME = model_name
|
||||
|
||||
# Configure the model
|
||||
model_config = outetts.HFModelConfig_v2(
|
||||
model_path=MODELNAME,
|
||||
tokenizer_path=TOKENIZER
|
||||
)
|
||||
# Initialize the interface
|
||||
self.interface = outetts.InterfaceHF(model_version=VERSION, cfg=model_config)
|
||||
self.OuteTTS = True
|
||||
|
||||
self.interface.print_default_speakers()
|
||||
if request.AudioPath:
|
||||
if os.path.isabs(request.AudioPath):
|
||||
self.AudioPath = request.AudioPath
|
||||
else:
|
||||
self.AudioPath = os.path.join(request.ModelPath, request.AudioPath)
|
||||
self.speaker = self.interface.create_speaker(audio_path=self.AudioPath)
|
||||
else:
|
||||
self.speaker = self.interface.load_default_speaker(name=SPEAKER)
|
||||
elif request.Type == "SentenceTransformer":
|
||||
autoTokenizer = False
|
||||
self.model = SentenceTransformer(model_name, trust_remote_code=request.TrustRemoteCode)
|
||||
@@ -588,30 +547,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
return backend_pb2.Result(success=True)
|
||||
|
||||
|
||||
def CallOuteTTS(self, request, context):
|
||||
try:
|
||||
print("[OuteTTS] generating TTS", file=sys.stderr)
|
||||
gen_cfg = outetts.GenerationConfig(
|
||||
text="Speech synthesis is the artificial production of human speech.",
|
||||
temperature=self.options.get("temperature", 0.1),
|
||||
repetition_penalty=self.options.get("repetition_penalty", 1.1),
|
||||
max_length=self.max_tokens,
|
||||
speaker=self.speaker,
|
||||
# voice_characteristics="upbeat enthusiasm, friendliness, clarity, professionalism, and trustworthiness"
|
||||
)
|
||||
output = self.interface.generate(config=gen_cfg)
|
||||
print("[OuteTTS] Generated TTS", file=sys.stderr)
|
||||
output.save(request.dst)
|
||||
print("[OuteTTS] TTS done", file=sys.stderr)
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
return backend_pb2.Result(success=True)
|
||||
|
||||
# The TTS endpoint is older, and provides fewer features, but exists for compatibility reasons
|
||||
def TTS(self, request, context):
|
||||
if self.OuteTTS:
|
||||
return self.CallOuteTTS(request, context)
|
||||
|
||||
if self.DiaTTS:
|
||||
print("DiaTTS", file=sys.stderr)
|
||||
return self.CallDiaTTS(request, context)
|
||||
|
||||
@@ -4,6 +4,5 @@ numba==0.60.0
|
||||
accelerate
|
||||
transformers
|
||||
bitsandbytes
|
||||
outetts
|
||||
sentence-transformers==5.2.2
|
||||
protobuf==6.33.5
|
||||
@@ -4,6 +4,5 @@ llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
transformers
|
||||
bitsandbytes
|
||||
outetts
|
||||
sentence-transformers==5.2.2
|
||||
protobuf==6.33.5
|
||||
@@ -4,6 +4,5 @@ llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
transformers
|
||||
bitsandbytes
|
||||
outetts
|
||||
sentence-transformers==5.2.2
|
||||
protobuf==6.33.5
|
||||
@@ -5,7 +5,5 @@ transformers
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
bitsandbytes
|
||||
outetts
|
||||
bitsandbytes
|
||||
sentence-transformers==5.2.2
|
||||
protobuf==6.33.5
|
||||
@@ -5,6 +5,5 @@ llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
transformers
|
||||
bitsandbytes
|
||||
outetts
|
||||
sentence-transformers==5.2.2
|
||||
protobuf==6.33.5
|
||||
8
backend/python/transformers/requirements-mps.txt
Normal file
8
backend/python/transformers/requirements-mps.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
torch==2.7.1
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
accelerate
|
||||
transformers
|
||||
bitsandbytes
|
||||
sentence-transformers==5.2.2
|
||||
protobuf==6.33.5
|
||||
Reference in New Issue
Block a user