mirror of
https://github.com/mudler/LocalAI.git
synced 2026-04-04 15:04:27 -04:00
chore: add Dia to the model gallery, fix backend (#5998)
* fix: correctly call OuteTTS and DiaTTS Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * chore(model gallery): add dia Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
326fda3223
commit
4733adb983
@@ -229,8 +229,13 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
|
||||
elif request.Type == "DiaForConditionalGeneration":
|
||||
autoTokenizer = False
|
||||
print("DiaForConditionalGeneration", file=sys.stderr)
|
||||
self.processor = AutoProcessor.from_pretrained(model_name)
|
||||
self.model = DiaForConditionalGeneration.from_pretrained(model_name)
|
||||
if self.CUDA:
|
||||
self.model = self.model.to("cuda")
|
||||
self.processor = self.processor.to("cuda")
|
||||
print("DiaForConditionalGeneration loaded", file=sys.stderr)
|
||||
self.DiaTTS = True
|
||||
elif request.Type == "OuteTTS":
|
||||
autoTokenizer = False
|
||||
@@ -536,7 +541,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
return backend_pb2.Result(success=True)
|
||||
|
||||
|
||||
def DiaTTS(self, request, context):
|
||||
def CallDiaTTS(self, request, context):
|
||||
"""
|
||||
Generates dialogue audio using the Dia model.
|
||||
|
||||
@@ -581,7 +586,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
return backend_pb2.Result(success=True)
|
||||
|
||||
|
||||
def OuteTTS(self, request, context):
|
||||
def CallOuteTTS(self, request, context):
|
||||
try:
|
||||
print("[OuteTTS] generating TTS", file=sys.stderr)
|
||||
gen_cfg = outetts.GenerationConfig(
|
||||
@@ -603,10 +608,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
# The TTS endpoint is older, and provides fewer features, but exists for compatibility reasons
|
||||
def TTS(self, request, context):
|
||||
if self.OuteTTS:
|
||||
return self.OuteTTS(request, context)
|
||||
return self.CallOuteTTS(request, context)
|
||||
|
||||
if self.DiaTTS:
|
||||
return self.DiaTTS(request, context)
|
||||
print("DiaTTS", file=sys.stderr)
|
||||
return self.CallDiaTTS(request, context)
|
||||
|
||||
model_name = request.model
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user