feat(transformers): add support to Dia (#5991)

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2025-08-07 21:51:52 +02:00
committed by GitHub
parent 09457b9221
commit 003b9292fe

View File

@@ -22,7 +22,7 @@ import torch.cuda
XPU=os.environ.get("XPU", "0") == "1"
from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria, MambaConfig, MambaForCausalLM
from transformers import AutoProcessor, MusicgenForConditionalGeneration
from transformers import AutoProcessor, MusicgenForConditionalGeneration, DiaForConditionalGeneration
from scipy.io import wavfile
import outetts
from sentence_transformers import SentenceTransformer
@@ -90,6 +90,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
self.CUDA = torch.cuda.is_available()
self.OV=False
self.OuteTTS=False
self.DiaTTS=False
self.SentenceTransformer = False
device_map="cpu"
@@ -97,6 +98,30 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
quantization = None
autoTokenizer = True
# Parse options from request.Options
self.options = {}
options = request.Options
# The options are a list of strings in this form optname:optvalue
# We are storing all the options in a dict so we can use it later when generating
# Example options: ["max_new_tokens:3072", "guidance_scale:3.0", "temperature:1.8", "top_p:0.90", "top_k:45"]
for opt in options:
if ":" not in opt:
continue
key, value = opt.split(":", 1)
# if value is a number, convert it to the appropriate type
try:
if "." in value:
value = float(value)
else:
value = int(value)
except ValueError:
# Keep as string if conversion fails
pass
self.options[key] = value
print(f"Parsed options: {self.options}", file=sys.stderr)
if self.CUDA:
from transformers import BitsAndBytesConfig, AutoModelForCausalLM
if request.MainGPU:
@@ -202,6 +227,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
autoTokenizer = False
self.processor = AutoProcessor.from_pretrained(model_name)
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
elif request.Type == "DiaForConditionalGeneration":
autoTokenizer = False
self.processor = AutoProcessor.from_pretrained(model_name)
self.model = DiaForConditionalGeneration.from_pretrained(model_name)
self.DiaTTS = True
elif request.Type == "OuteTTS":
autoTokenizer = False
options = request.Options
@@ -262,7 +292,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
elif hasattr(self.model, 'config') and hasattr(self.model.config, 'max_position_embeddings'):
self.max_tokens = self.model.config.max_position_embeddings
else:
self.max_tokens = 512
self.max_tokens = self.options.get("max_new_tokens", 512)
if autoTokenizer:
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_safetensors=True)
@@ -485,16 +515,15 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
return_tensors="pt",
)
tokens = 256
if request.HasField('duration'):
tokens = int(request.duration * 51.2) # 256 tokens = 5 seconds, therefore 51.2 tokens is one second
guidance = 3.0
guidance = self.options.get("guidance_scale", 3.0)
if request.HasField('temperature'):
guidance = request.temperature
dosample = True
dosample = self.options.get("do_sample", True)
if request.HasField('sample'):
dosample = request.sample
audio_values = self.model.generate(**inputs, do_sample=dosample, guidance_scale=guidance, max_new_tokens=tokens)
audio_values = self.model.generate(**inputs, do_sample=dosample, guidance_scale=guidance, max_new_tokens=self.max_tokens)
print("[transformers-musicgen] SoundGeneration generated!", file=sys.stderr)
sampling_rate = self.model.config.audio_encoder.sampling_rate
wavfile.write(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy())
@@ -506,13 +535,59 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
return backend_pb2.Result(success=True)
def DiaTTS(self, request, context):
"""
Generates dialogue audio using the Dia model.
Args:
request: A TTSRequest containing text dialogue and generation parameters
context: The gRPC context
Returns:
A Result object indicating success or failure
"""
try:
print("[DiaTTS] generating dialogue audio", file=sys.stderr)
# Prepare text input - expect dialogue format like [S1] ... [S2] ...
text = [request.text]
# Process the input
inputs = self.processor(text=text, padding=True, return_tensors="pt")
# Generate audio with parameters from options or defaults
generation_params = {
**inputs,
"max_new_tokens": self.max_tokens,
"guidance_scale": self.options.get("guidance_scale", 3.0),
"temperature": self.options.get("temperature", 1.8),
"top_p": self.options.get("top_p", 0.90),
"top_k": self.options.get("top_k", 45)
}
outputs = self.model.generate(**generation_params)
# Decode and save audio
outputs = self.processor.batch_decode(outputs)
self.processor.save_audio(outputs, request.dst)
print("[DiaTTS] Generated dialogue audio", file=sys.stderr)
print("[DiaTTS] Audio saved to", request.dst, file=sys.stderr)
print("[DiaTTS] Dialogue generation done", file=sys.stderr)
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
return backend_pb2.Result(success=True)
def OuteTTS(self, request, context):
try:
print("[OuteTTS] generating TTS", file=sys.stderr)
gen_cfg = outetts.GenerationConfig(
text="Speech synthesis is the artificial production of human speech.",
temperature=0.1,
repetition_penalty=1.1,
temperature=self.options.get("temperature", 0.1),
repetition_penalty=self.options.get("repetition_penalty", 1.1),
max_length=self.max_tokens,
speaker=self.speaker,
# voice_characteristics="upbeat enthusiasm, friendliness, clarity, professionalism, and trustworthiness"
@@ -529,6 +604,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
def TTS(self, request, context):
if self.OuteTTS:
return self.OuteTTS(request, context)
if self.DiaTTS:
return self.DiaTTS(request, context)
model_name = request.model
try: