mirror of
https://github.com/mudler/LocalAI.git
synced 2026-04-04 23:14:41 -04:00
feat(transformers): add support to Dia (#5991)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
09457b9221
commit
003b9292fe
@@ -22,7 +22,7 @@ import torch.cuda
|
||||
|
||||
XPU=os.environ.get("XPU", "0") == "1"
|
||||
from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria, MambaConfig, MambaForCausalLM
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration, DiaForConditionalGeneration
|
||||
from scipy.io import wavfile
|
||||
import outetts
|
||||
from sentence_transformers import SentenceTransformer
|
||||
@@ -90,6 +90,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
self.CUDA = torch.cuda.is_available()
|
||||
self.OV=False
|
||||
self.OuteTTS=False
|
||||
self.DiaTTS=False
|
||||
self.SentenceTransformer = False
|
||||
|
||||
device_map="cpu"
|
||||
@@ -97,6 +98,30 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
quantization = None
|
||||
autoTokenizer = True
|
||||
|
||||
# Parse options from request.Options
|
||||
self.options = {}
|
||||
options = request.Options
|
||||
|
||||
# The options are a list of strings in this form optname:optvalue
|
||||
# We are storing all the options in a dict so we can use it later when generating
|
||||
# Example options: ["max_new_tokens:3072", "guidance_scale:3.0", "temperature:1.8", "top_p:0.90", "top_k:45"]
|
||||
for opt in options:
|
||||
if ":" not in opt:
|
||||
continue
|
||||
key, value = opt.split(":", 1)
|
||||
# if value is a number, convert it to the appropriate type
|
||||
try:
|
||||
if "." in value:
|
||||
value = float(value)
|
||||
else:
|
||||
value = int(value)
|
||||
except ValueError:
|
||||
# Keep as string if conversion fails
|
||||
pass
|
||||
self.options[key] = value
|
||||
|
||||
print(f"Parsed options: {self.options}", file=sys.stderr)
|
||||
|
||||
if self.CUDA:
|
||||
from transformers import BitsAndBytesConfig, AutoModelForCausalLM
|
||||
if request.MainGPU:
|
||||
@@ -202,6 +227,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
autoTokenizer = False
|
||||
self.processor = AutoProcessor.from_pretrained(model_name)
|
||||
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
|
||||
elif request.Type == "DiaForConditionalGeneration":
|
||||
autoTokenizer = False
|
||||
self.processor = AutoProcessor.from_pretrained(model_name)
|
||||
self.model = DiaForConditionalGeneration.from_pretrained(model_name)
|
||||
self.DiaTTS = True
|
||||
elif request.Type == "OuteTTS":
|
||||
autoTokenizer = False
|
||||
options = request.Options
|
||||
@@ -262,7 +292,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
elif hasattr(self.model, 'config') and hasattr(self.model.config, 'max_position_embeddings'):
|
||||
self.max_tokens = self.model.config.max_position_embeddings
|
||||
else:
|
||||
self.max_tokens = 512
|
||||
self.max_tokens = self.options.get("max_new_tokens", 512)
|
||||
|
||||
if autoTokenizer:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_safetensors=True)
|
||||
@@ -485,16 +515,15 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
tokens = 256
|
||||
if request.HasField('duration'):
|
||||
tokens = int(request.duration * 51.2) # 256 tokens = 5 seconds, therefore 51.2 tokens is one second
|
||||
guidance = 3.0
|
||||
guidance = self.options.get("guidance_scale", 3.0)
|
||||
if request.HasField('temperature'):
|
||||
guidance = request.temperature
|
||||
dosample = True
|
||||
dosample = self.options.get("do_sample", True)
|
||||
if request.HasField('sample'):
|
||||
dosample = request.sample
|
||||
audio_values = self.model.generate(**inputs, do_sample=dosample, guidance_scale=guidance, max_new_tokens=tokens)
|
||||
audio_values = self.model.generate(**inputs, do_sample=dosample, guidance_scale=guidance, max_new_tokens=self.max_tokens)
|
||||
print("[transformers-musicgen] SoundGeneration generated!", file=sys.stderr)
|
||||
sampling_rate = self.model.config.audio_encoder.sampling_rate
|
||||
wavfile.write(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy())
|
||||
@@ -506,13 +535,59 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
return backend_pb2.Result(success=True)
|
||||
|
||||
|
||||
def DiaTTS(self, request, context):
|
||||
"""
|
||||
Generates dialogue audio using the Dia model.
|
||||
|
||||
Args:
|
||||
request: A TTSRequest containing text dialogue and generation parameters
|
||||
context: The gRPC context
|
||||
|
||||
Returns:
|
||||
A Result object indicating success or failure
|
||||
"""
|
||||
try:
|
||||
print("[DiaTTS] generating dialogue audio", file=sys.stderr)
|
||||
|
||||
# Prepare text input - expect dialogue format like [S1] ... [S2] ...
|
||||
text = [request.text]
|
||||
|
||||
# Process the input
|
||||
inputs = self.processor(text=text, padding=True, return_tensors="pt")
|
||||
|
||||
# Generate audio with parameters from options or defaults
|
||||
generation_params = {
|
||||
**inputs,
|
||||
"max_new_tokens": self.max_tokens,
|
||||
"guidance_scale": self.options.get("guidance_scale", 3.0),
|
||||
"temperature": self.options.get("temperature", 1.8),
|
||||
"top_p": self.options.get("top_p", 0.90),
|
||||
"top_k": self.options.get("top_k", 45)
|
||||
}
|
||||
|
||||
outputs = self.model.generate(**generation_params)
|
||||
|
||||
# Decode and save audio
|
||||
outputs = self.processor.batch_decode(outputs)
|
||||
self.processor.save_audio(outputs, request.dst)
|
||||
|
||||
print("[DiaTTS] Generated dialogue audio", file=sys.stderr)
|
||||
print("[DiaTTS] Audio saved to", request.dst, file=sys.stderr)
|
||||
print("[DiaTTS] Dialogue generation done", file=sys.stderr)
|
||||
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
return backend_pb2.Result(success=True)
|
||||
|
||||
|
||||
def OuteTTS(self, request, context):
|
||||
try:
|
||||
print("[OuteTTS] generating TTS", file=sys.stderr)
|
||||
gen_cfg = outetts.GenerationConfig(
|
||||
text="Speech synthesis is the artificial production of human speech.",
|
||||
temperature=0.1,
|
||||
repetition_penalty=1.1,
|
||||
temperature=self.options.get("temperature", 0.1),
|
||||
repetition_penalty=self.options.get("repetition_penalty", 1.1),
|
||||
max_length=self.max_tokens,
|
||||
speaker=self.speaker,
|
||||
# voice_characteristics="upbeat enthusiasm, friendliness, clarity, professionalism, and trustworthiness"
|
||||
@@ -529,6 +604,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
def TTS(self, request, context):
|
||||
if self.OuteTTS:
|
||||
return self.OuteTTS(request, context)
|
||||
|
||||
if self.DiaTTS:
|
||||
return self.DiaTTS(request, context)
|
||||
|
||||
model_name = request.model
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user