diff --git a/backend/python/transformers/backend.py b/backend/python/transformers/backend.py index 88b410e54..3d34132fd 100644 --- a/backend/python/transformers/backend.py +++ b/backend/python/transformers/backend.py @@ -22,7 +22,7 @@ import torch.cuda XPU=os.environ.get("XPU", "0") == "1" from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria, MambaConfig, MambaForCausalLM -from transformers import AutoProcessor, MusicgenForConditionalGeneration +from transformers import AutoProcessor, MusicgenForConditionalGeneration, DiaForConditionalGeneration from scipy.io import wavfile import outetts from sentence_transformers import SentenceTransformer @@ -90,6 +90,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): self.CUDA = torch.cuda.is_available() self.OV=False self.OuteTTS=False + self.DiaTTS=False self.SentenceTransformer = False device_map="cpu" @@ -97,6 +98,30 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): quantization = None autoTokenizer = True + # Parse options from request.Options + self.options = {} + options = request.Options + + # The options are a list of strings in this form optname:optvalue + # We are storing all the options in a dict so we can use it later when generating + # Example options: ["max_new_tokens:3072", "guidance_scale:3.0", "temperature:1.8", "top_p:0.90", "top_k:45"] + for opt in options: + if ":" not in opt: + continue + key, value = opt.split(":", 1) + # if value is a number, convert it to the appropriate type + try: + if "." in value: + value = float(value) + else: + value = int(value) + except ValueError: + # Keep as string if conversion fails + pass + self.options[key] = value + + print(f"Parsed options: {self.options}", file=sys.stderr) + if self.CUDA: from transformers import BitsAndBytesConfig, AutoModelForCausalLM if request.MainGPU: @@ -202,6 +227,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): autoTokenizer = False self.processor = AutoProcessor.from_pretrained(model_name) self.model = MusicgenForConditionalGeneration.from_pretrained(model_name) + elif request.Type == "DiaForConditionalGeneration": + autoTokenizer = False + self.processor = AutoProcessor.from_pretrained(model_name) + self.model = DiaForConditionalGeneration.from_pretrained(model_name) + self.DiaTTS = True elif request.Type == "OuteTTS": autoTokenizer = False options = request.Options @@ -262,7 +292,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): elif hasattr(self.model, 'config') and hasattr(self.model.config, 'max_position_embeddings'): self.max_tokens = self.model.config.max_position_embeddings else: - self.max_tokens = 512 + self.max_tokens = self.options.get("max_new_tokens", 512) if autoTokenizer: self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_safetensors=True) @@ -485,16 +515,15 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): return_tensors="pt", ) - tokens = 256 if request.HasField('duration'): tokens = int(request.duration * 51.2) # 256 tokens = 5 seconds, therefore 51.2 tokens is one second - guidance = 3.0 + guidance = self.options.get("guidance_scale", 3.0) if request.HasField('temperature'): guidance = request.temperature - dosample = True + dosample = self.options.get("do_sample", True) if request.HasField('sample'): dosample = request.sample - audio_values = self.model.generate(**inputs, do_sample=dosample, guidance_scale=guidance, max_new_tokens=tokens) + audio_values = self.model.generate(**inputs, do_sample=dosample, guidance_scale=guidance, max_new_tokens=self.max_tokens) print("[transformers-musicgen] SoundGeneration generated!", file=sys.stderr) sampling_rate = self.model.config.audio_encoder.sampling_rate wavfile.write(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy()) @@ -506,13 +535,59 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") return backend_pb2.Result(success=True) + + def DiaTTS(self, request, context): + """ + Generates dialogue audio using the Dia model. + + Args: + request: A TTSRequest containing text dialogue and generation parameters + context: The gRPC context + + Returns: + A Result object indicating success or failure + """ + try: + print("[DiaTTS] generating dialogue audio", file=sys.stderr) + + # Prepare text input - expect dialogue format like [S1] ... [S2] ... + text = [request.text] + + # Process the input + inputs = self.processor(text=text, padding=True, return_tensors="pt") + + # Generate audio with parameters from options or defaults + generation_params = { + **inputs, + "max_new_tokens": self.max_tokens, + "guidance_scale": self.options.get("guidance_scale", 3.0), + "temperature": self.options.get("temperature", 1.8), + "top_p": self.options.get("top_p", 0.90), + "top_k": self.options.get("top_k", 45) + } + + outputs = self.model.generate(**generation_params) + + # Decode and save audio + outputs = self.processor.batch_decode(outputs) + self.processor.save_audio(outputs, request.dst) + + print("[DiaTTS] Generated dialogue audio", file=sys.stderr) + print("[DiaTTS] Audio saved to", request.dst, file=sys.stderr) + print("[DiaTTS] Dialogue generation done", file=sys.stderr) + + except Exception as err: + return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") + return backend_pb2.Result(success=True) + + def OuteTTS(self, request, context): try: print("[OuteTTS] generating TTS", file=sys.stderr) gen_cfg = outetts.GenerationConfig( text="Speech synthesis is the artificial production of human speech.", - temperature=0.1, - repetition_penalty=1.1, + temperature=self.options.get("temperature", 0.1), + repetition_penalty=self.options.get("repetition_penalty", 1.1), max_length=self.max_tokens, speaker=self.speaker, # voice_characteristics="upbeat enthusiasm, friendliness, clarity, professionalism, and trustworthiness" @@ -529,6 +604,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): def TTS(self, request, context): if self.OuteTTS: return self.OuteTTS(request, context) + + if self.DiaTTS: + return self.DiaTTS(request, context) model_name = request.model try: