mirror of
https://github.com/mudler/LocalAI.git
synced 2026-04-01 13:42:20 -04:00
chore(transformers): bump to >5.0 and generically load models (#9097)
* chore(transformers): bump to >5.0 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * chore: refactor to use generic model loading Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
b74111feed
commit
be25217955
@@ -21,11 +21,13 @@ import torch.cuda
|
||||
|
||||
|
||||
XPU=os.environ.get("XPU", "0") == "1"
|
||||
from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria, MambaConfig, MambaForCausalLM
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration, DiaForConditionalGeneration
|
||||
import transformers as transformers_module
|
||||
from transformers import AutoTokenizer, AutoModel, AutoProcessor, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria
|
||||
from scipy.io import wavfile
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
# Backward-compat aliases for model types
|
||||
TYPE_ALIASES = {"Mamba": "MambaForCausalLM"}
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
@@ -52,32 +54,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
This class implements the gRPC methods for the backend service, including Health, LoadModel, and Embedding.
|
||||
"""
|
||||
def Health(self, request, context):
|
||||
"""
|
||||
A gRPC method that returns the health status of the backend service.
|
||||
|
||||
Args:
|
||||
request: A HealthRequest object that contains the request parameters.
|
||||
context: A grpc.ServicerContext object that provides information about the RPC.
|
||||
|
||||
Returns:
|
||||
A Reply object that contains the health status of the backend service.
|
||||
"""
|
||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||
|
||||
def LoadModel(self, request, context):
|
||||
"""
|
||||
A gRPC method that loads a model into memory.
|
||||
|
||||
Args:
|
||||
request: A LoadModelRequest object that contains the request parameters.
|
||||
context: A grpc.ServicerContext object that provides information about the RPC.
|
||||
|
||||
Returns:
|
||||
A Result object that contains the result of the LoadModel operation.
|
||||
"""
|
||||
|
||||
model_name = request.Model
|
||||
|
||||
|
||||
# Check to see if the Model exists in the filesystem already.
|
||||
if os.path.exists(request.ModelFile):
|
||||
model_name = request.ModelFile
|
||||
@@ -88,8 +69,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
self.CUDA = torch.cuda.is_available()
|
||||
self.OV=False
|
||||
self.DiaTTS=False
|
||||
self.GenericTTS=False
|
||||
self.SentenceTransformer = False
|
||||
self.processor = None
|
||||
|
||||
device_map="cpu"
|
||||
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
||||
@@ -101,7 +83,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
# Parse options from request.Options
|
||||
self.options = {}
|
||||
options = request.Options
|
||||
|
||||
|
||||
# The options are a list of strings in this form optname:optvalue
|
||||
# We are storing all the options in a dict so we can use it later when generating
|
||||
# Example options: ["max_new_tokens:3072", "guidance_scale:3.0", "temperature:1.8", "top_p:0.90", "top_k:45"]
|
||||
@@ -123,7 +105,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
print(f"Parsed options: {self.options}", file=sys.stderr)
|
||||
|
||||
if self.CUDA:
|
||||
from transformers import BitsAndBytesConfig, AutoModelForCausalLM
|
||||
from transformers import BitsAndBytesConfig
|
||||
if request.MainGPU:
|
||||
device_map=request.MainGPU
|
||||
else:
|
||||
@@ -140,40 +122,31 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
quantization = BitsAndBytesConfig(
|
||||
load_in_4bit=False,
|
||||
bnb_4bit_compute_dtype = None,
|
||||
load_in_8bit=True,
|
||||
load_in_8bit=True,
|
||||
)
|
||||
|
||||
try:
|
||||
if request.Type == "AutoModelForCausalLM":
|
||||
if XPU:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from intel_extension_for_transformers.transformers.modeling import AutoModelForCausalLM
|
||||
if XPU and request.Type == "AutoModelForCausalLM":
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from intel_extension_for_transformers.transformers.modeling import AutoModelForCausalLM
|
||||
|
||||
device_map="xpu"
|
||||
compute=torch.float16
|
||||
if request.Quantization == "xpu_4bit":
|
||||
xpu_4bit = True
|
||||
xpu_8bit = False
|
||||
elif request.Quantization == "xpu_8bit":
|
||||
xpu_4bit = False
|
||||
xpu_8bit = True
|
||||
else:
|
||||
xpu_4bit = False
|
||||
xpu_8bit = False
|
||||
self.model = AutoModelForCausalLM.from_pretrained(model_name,
|
||||
trust_remote_code=request.TrustRemoteCode,
|
||||
use_safetensors=True,
|
||||
device_map=device_map,
|
||||
load_in_4bit=xpu_4bit,
|
||||
load_in_8bit=xpu_8bit,
|
||||
torch_dtype=compute)
|
||||
device_map="xpu"
|
||||
compute=torch.float16
|
||||
if request.Quantization == "xpu_4bit":
|
||||
xpu_4bit = True
|
||||
xpu_8bit = False
|
||||
elif request.Quantization == "xpu_8bit":
|
||||
xpu_4bit = False
|
||||
xpu_8bit = True
|
||||
else:
|
||||
self.model = AutoModelForCausalLM.from_pretrained(model_name,
|
||||
trust_remote_code=request.TrustRemoteCode,
|
||||
use_safetensors=True,
|
||||
quantization_config=quantization,
|
||||
device_map=device_map,
|
||||
torch_dtype=compute)
|
||||
xpu_4bit = False
|
||||
xpu_8bit = False
|
||||
self.model = AutoModelForCausalLM.from_pretrained(model_name,
|
||||
trust_remote_code=request.TrustRemoteCode,
|
||||
device_map=device_map,
|
||||
load_in_4bit=xpu_4bit,
|
||||
load_in_8bit=xpu_8bit,
|
||||
torch_dtype=compute)
|
||||
elif request.Type == "OVModelForCausalLM":
|
||||
from optimum.intel.openvino import OVModelForCausalLM
|
||||
from openvino.runtime import Core
|
||||
@@ -185,14 +158,12 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
devices = Core().available_devices
|
||||
if "GPU" in " ".join(devices):
|
||||
device_map="AUTO:GPU"
|
||||
# While working on a fine tuned model, inference may give an inaccuracy and performance drop on GPU if winograd convolutions are selected.
|
||||
# https://docs.openvino.ai/2024/openvino-workflow/running-inference/inference-devices-and-modes/gpu-device.html
|
||||
if "CPU" or "NPU" in device_map:
|
||||
if "-CPU" or "-NPU" not in device_map:
|
||||
ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT"}
|
||||
else:
|
||||
ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT","GPU_DISABLE_WINOGRAD_CONVOLUTION": "YES"}
|
||||
self.model = OVModelForCausalLM.from_pretrained(model_name,
|
||||
self.model = OVModelForCausalLM.from_pretrained(model_name,
|
||||
compile=True,
|
||||
trust_remote_code=request.TrustRemoteCode,
|
||||
ov_config=ovconfig,
|
||||
@@ -209,59 +180,60 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
devices = Core().available_devices
|
||||
if "GPU" in " ".join(devices):
|
||||
device_map="AUTO:GPU"
|
||||
# While working on a fine tuned model, inference may give an inaccuracy and performance drop on GPU if winograd convolutions are selected.
|
||||
# https://docs.openvino.ai/2024/openvino-workflow/running-inference/inference-devices-and-modes/gpu-device.html
|
||||
if "CPU" or "NPU" in device_map:
|
||||
if "-CPU" or "-NPU" not in device_map:
|
||||
ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT"}
|
||||
else:
|
||||
ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT","GPU_DISABLE_WINOGRAD_CONVOLUTION": "YES"}
|
||||
self.model = OVModelForFeatureExtraction.from_pretrained(model_name,
|
||||
self.model = OVModelForFeatureExtraction.from_pretrained(model_name,
|
||||
compile=True,
|
||||
trust_remote_code=request.TrustRemoteCode,
|
||||
ov_config=ovconfig,
|
||||
ov_config=ovconfig,
|
||||
export=True,
|
||||
device=device_map)
|
||||
self.OV = True
|
||||
elif request.Type == "MusicgenForConditionalGeneration":
|
||||
autoTokenizer = False
|
||||
self.processor = AutoProcessor.from_pretrained(model_name)
|
||||
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
|
||||
elif request.Type == "DiaForConditionalGeneration":
|
||||
autoTokenizer = False
|
||||
print("DiaForConditionalGeneration", file=sys.stderr)
|
||||
self.processor = AutoProcessor.from_pretrained(model_name)
|
||||
self.model = DiaForConditionalGeneration.from_pretrained(model_name)
|
||||
if self.CUDA:
|
||||
self.model = self.model.to("cuda")
|
||||
self.processor = self.processor.to("cuda")
|
||||
print("DiaForConditionalGeneration loaded", file=sys.stderr)
|
||||
self.DiaTTS = True
|
||||
elif request.Type == "SentenceTransformer":
|
||||
autoTokenizer = False
|
||||
self.model = SentenceTransformer(model_name, trust_remote_code=request.TrustRemoteCode)
|
||||
self.SentenceTransformer = True
|
||||
elif request.Type == "Mamba":
|
||||
autoTokenizer = False
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.model = MambaForCausalLM.from_pretrained(model_name)
|
||||
else:
|
||||
print("Automodel", file=sys.stderr)
|
||||
self.model = AutoModel.from_pretrained(model_name,
|
||||
trust_remote_code=request.TrustRemoteCode,
|
||||
use_safetensors=True,
|
||||
quantization_config=quantization,
|
||||
device_map=device_map,
|
||||
torch_dtype=compute)
|
||||
# Generic: dynamically resolve model class from transformers
|
||||
model_type = TYPE_ALIASES.get(request.Type, request.Type)
|
||||
ModelClass = AutoModel # default
|
||||
if model_type and hasattr(transformers_module, model_type):
|
||||
ModelClass = getattr(transformers_module, model_type)
|
||||
print(f"Using model class: {model_type}", file=sys.stderr)
|
||||
else:
|
||||
print(f"Using default AutoModel (type={request.Type!r})", file=sys.stderr)
|
||||
|
||||
self.model = ModelClass.from_pretrained(
|
||||
model_name,
|
||||
trust_remote_code=request.TrustRemoteCode,
|
||||
quantization_config=quantization,
|
||||
device_map=device_map,
|
||||
torch_dtype=compute,
|
||||
)
|
||||
|
||||
# Try to load a processor (needed for TTS/audio models)
|
||||
try:
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
model_name,
|
||||
trust_remote_code=request.TrustRemoteCode,
|
||||
)
|
||||
self.GenericTTS = True
|
||||
print(f"Loaded processor for {model_name}", file=sys.stderr)
|
||||
except Exception:
|
||||
self.processor = None
|
||||
|
||||
if request.ContextSize > 0:
|
||||
self.max_tokens = request.ContextSize
|
||||
elif hasattr(self.model, 'config') and hasattr(self.model.config, 'max_position_embeddings'):
|
||||
self.max_tokens = self.model.config.max_position_embeddings
|
||||
else:
|
||||
self.max_tokens = self.options.get("max_new_tokens", 512)
|
||||
|
||||
|
||||
if autoTokenizer:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_safetensors=True)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.XPU = False
|
||||
|
||||
if XPU and self.OV == False:
|
||||
@@ -275,22 +247,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
except Exception as err:
|
||||
print("Error:", err, file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
# Implement your logic here for the LoadModel service
|
||||
# Replace this with your desired response
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
|
||||
def Embedding(self, request, context):
|
||||
"""
|
||||
A gRPC method that calculates embeddings for a given sentence.
|
||||
|
||||
Args:
|
||||
request: An EmbeddingRequest object that contains the request parameters.
|
||||
context: A grpc.ServicerContext object that provides information about the RPC.
|
||||
|
||||
Returns:
|
||||
An EmbeddingResult object that contains the calculated embeddings.
|
||||
"""
|
||||
|
||||
set_seed(request.Seed)
|
||||
# Tokenize input
|
||||
max_length = 512
|
||||
@@ -303,13 +262,13 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr)
|
||||
embeds = self.model.encode(request.Embeddings)
|
||||
else:
|
||||
encoded_input = self.tokenizer(request.Embeddings, padding=True, truncation=True, max_length=max_length, return_tensors="pt")
|
||||
encoded_input = self.tokenizer(request.Embeddings, padding=True, truncation=True, max_length=max_length, return_tensors="pt")
|
||||
|
||||
# Create word embeddings
|
||||
if self.CUDA:
|
||||
encoded_input = encoded_input.to("cuda")
|
||||
|
||||
with torch.no_grad():
|
||||
with torch.no_grad():
|
||||
model_output = self.model(**encoded_input)
|
||||
|
||||
# Pool to get sentence embeddings; i.e. generate one 1024 vector for the entire sentence
|
||||
@@ -317,11 +276,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
embeds = sentence_embeddings[0]
|
||||
return backend_pb2.EmbeddingResult(embeddings=embeds)
|
||||
|
||||
async def _predict(self, request, context, streaming=False):
|
||||
async def _predict(self, request, context, streaming=False):
|
||||
set_seed(request.Seed)
|
||||
if request.TopP < 0 or request.TopP > 1:
|
||||
request.TopP = 1
|
||||
|
||||
|
||||
if request.TopK <= 0:
|
||||
request.TopK = 50
|
||||
|
||||
@@ -334,7 +293,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
request.Temperature == None
|
||||
|
||||
prompt = request.Prompt
|
||||
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
|
||||
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
|
||||
prompt = self.tokenizer.apply_chat_template(request.Messages, tokenize=False, add_generation_prompt=True)
|
||||
|
||||
inputs = self.tokenizer(prompt, return_tensors="pt")
|
||||
@@ -363,10 +322,10 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
skip_prompt=True,
|
||||
skip_special_tokens=True)
|
||||
config=dict(inputs,
|
||||
max_new_tokens=max_tokens,
|
||||
temperature=request.Temperature,
|
||||
max_new_tokens=max_tokens,
|
||||
temperature=request.Temperature,
|
||||
top_p=request.TopP,
|
||||
top_k=request.TopK,
|
||||
top_k=request.TopK,
|
||||
do_sample=sample,
|
||||
attention_mask=inputs["attention_mask"],
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
@@ -387,18 +346,18 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
else:
|
||||
if XPU and self.OV == False:
|
||||
outputs = self.model.generate(inputs["input_ids"],
|
||||
max_new_tokens=max_tokens,
|
||||
temperature=request.Temperature,
|
||||
max_new_tokens=max_tokens,
|
||||
temperature=request.Temperature,
|
||||
top_p=request.TopP,
|
||||
top_k=request.TopK,
|
||||
top_k=request.TopK,
|
||||
do_sample=sample,
|
||||
pad_token=self.tokenizer.eos_token_id)
|
||||
else:
|
||||
outputs = self.model.generate(**inputs,
|
||||
max_new_tokens=max_tokens,
|
||||
temperature=request.Temperature,
|
||||
max_new_tokens=max_tokens,
|
||||
temperature=request.Temperature,
|
||||
top_p=request.TopP,
|
||||
top_k=request.TopK,
|
||||
top_k=request.TopK,
|
||||
do_sample=sample,
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
pad_token_id=self.tokenizer.eos_token_id,
|
||||
@@ -413,31 +372,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
yield backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8'))
|
||||
|
||||
async def Predict(self, request, context):
|
||||
"""
|
||||
Generates text based on the given prompt and sampling parameters.
|
||||
|
||||
Args:
|
||||
request: The predict request.
|
||||
context: The gRPC context.
|
||||
|
||||
Returns:
|
||||
backend_pb2.Reply: The predict result.
|
||||
"""
|
||||
gen = self._predict(request, context, streaming=False)
|
||||
res = await gen.__anext__()
|
||||
return res
|
||||
|
||||
async def PredictStream(self, request, context):
|
||||
"""
|
||||
Generates text based on the given prompt and sampling parameters, and streams the results.
|
||||
|
||||
Args:
|
||||
request: The predict stream request.
|
||||
context: The gRPC context.
|
||||
|
||||
Returns:
|
||||
backend_pb2.Result: The predict stream result.
|
||||
"""
|
||||
iterations = self._predict(request, context, streaming=True)
|
||||
try:
|
||||
async for iteration in iterations:
|
||||
@@ -455,18 +394,19 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
if self.model is None:
|
||||
if model_name == "":
|
||||
return backend_pb2.Result(success=False, message="request.model is required")
|
||||
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
|
||||
# Dynamically resolve model class if configured, otherwise default to MusicgenForConditionalGeneration
|
||||
model_type = self.options.get("model_type", "MusicgenForConditionalGeneration")
|
||||
ModelClass = getattr(transformers_module, model_type)
|
||||
self.model = ModelClass.from_pretrained(model_name)
|
||||
inputs = None
|
||||
if request.text == "":
|
||||
inputs = self.model.get_unconditional_inputs(num_samples=1)
|
||||
elif request.HasField('src'):
|
||||
# TODO SECURITY CODE GOES HERE LOL
|
||||
# WHO KNOWS IF THIS WORKS???
|
||||
sample_rate, wsamples = wavfile.read('path_to_your_file.wav')
|
||||
|
||||
|
||||
if request.HasField('src_divisor'):
|
||||
wsamples = wsamples[: len(wsamples) // request.src_divisor]
|
||||
|
||||
|
||||
inputs = self.processor(
|
||||
audio=wsamples,
|
||||
sampling_rate=sample_rate,
|
||||
@@ -480,7 +420,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
|
||||
if request.HasField('duration'):
|
||||
tokens = int(request.duration * 51.2) # 256 tokens = 5 seconds, therefore 51.2 tokens is one second
|
||||
guidance = self.options.get("guidance_scale", 3.0)
|
||||
@@ -490,92 +430,97 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
if request.HasField('sample'):
|
||||
dosample = request.sample
|
||||
audio_values = self.model.generate(**inputs, do_sample=dosample, guidance_scale=guidance, max_new_tokens=self.max_tokens)
|
||||
print("[transformers-musicgen] SoundGeneration generated!", file=sys.stderr)
|
||||
sampling_rate = self.model.config.audio_encoder.sampling_rate
|
||||
wavfile.write(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy())
|
||||
print("[transformers-musicgen] SoundGeneration saved to", request.dst, file=sys.stderr)
|
||||
print("[transformers-musicgen] SoundGeneration for", file=sys.stderr)
|
||||
print("[transformers-musicgen] SoundGeneration requested tokens", tokens, file=sys.stderr)
|
||||
print("[transformers] SoundGeneration generated!", file=sys.stderr)
|
||||
|
||||
# Save audio output
|
||||
if hasattr(self.processor, 'save_audio'):
|
||||
if hasattr(self.processor, 'batch_decode'):
|
||||
try:
|
||||
audio_values = self.processor.batch_decode(audio_values)
|
||||
except Exception:
|
||||
pass
|
||||
self.processor.save_audio(audio_values, request.dst)
|
||||
else:
|
||||
sampling_rate = self.model.config.audio_encoder.sampling_rate
|
||||
wavfile.write(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy())
|
||||
|
||||
print("[transformers] SoundGeneration saved to", request.dst, file=sys.stderr)
|
||||
print(request, file=sys.stderr)
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
return backend_pb2.Result(success=True)
|
||||
|
||||
|
||||
def CallDiaTTS(self, request, context):
|
||||
"""
|
||||
Generates dialogue audio using the Dia model.
|
||||
|
||||
Args:
|
||||
request: A TTSRequest containing text dialogue and generation parameters
|
||||
context: The gRPC context
|
||||
|
||||
Returns:
|
||||
A Result object indicating success or failure
|
||||
"""
|
||||
try:
|
||||
print("[DiaTTS] generating dialogue audio", file=sys.stderr)
|
||||
|
||||
# Prepare text input - expect dialogue format like [S1] ... [S2] ...
|
||||
text = [request.text]
|
||||
|
||||
# Process the input
|
||||
inputs = self.processor(text=text, padding=True, return_tensors="pt")
|
||||
|
||||
# Generate audio with parameters from options or defaults
|
||||
generation_params = {
|
||||
**inputs,
|
||||
"max_new_tokens": self.max_tokens,
|
||||
"guidance_scale": self.options.get("guidance_scale", 3.0),
|
||||
"temperature": self.options.get("temperature", 1.8),
|
||||
"top_p": self.options.get("top_p", 0.90),
|
||||
"top_k": self.options.get("top_k", 45)
|
||||
}
|
||||
|
||||
outputs = self.model.generate(**generation_params)
|
||||
|
||||
# Decode and save audio
|
||||
outputs = self.processor.batch_decode(outputs)
|
||||
self.processor.save_audio(outputs, request.dst)
|
||||
|
||||
print("[DiaTTS] Generated dialogue audio", file=sys.stderr)
|
||||
print("[DiaTTS] Audio saved to", request.dst, file=sys.stderr)
|
||||
print("[DiaTTS] Dialogue generation done", file=sys.stderr)
|
||||
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
return backend_pb2.Result(success=True)
|
||||
|
||||
|
||||
# The TTS endpoint is older, and provides fewer features, but exists for compatibility reasons
|
||||
def TTS(self, request, context):
|
||||
if self.DiaTTS:
|
||||
print("DiaTTS", file=sys.stderr)
|
||||
return self.CallDiaTTS(request, context)
|
||||
|
||||
model_name = request.model
|
||||
try:
|
||||
if self.processor is None:
|
||||
if model_name == "":
|
||||
return backend_pb2.Result(success=False, message="request.model is required")
|
||||
self.processor = AutoProcessor.from_pretrained(model_name)
|
||||
if self.model is None:
|
||||
if model_name == "":
|
||||
return backend_pb2.Result(success=False, message="request.model is required")
|
||||
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
|
||||
inputs = self.processor(
|
||||
text=[request.text],
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
tokens = self.max_tokens # No good place to set the "length" in TTS, so use 10s as a sane default
|
||||
audio_values = self.model.generate(**inputs, max_new_tokens=tokens)
|
||||
print("[transformers-musicgen] TTS generated!", file=sys.stderr)
|
||||
sampling_rate = self.model.config.audio_encoder.sampling_rate
|
||||
wavfile.write(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy())
|
||||
print("[transformers-musicgen] TTS saved to", request.dst, file=sys.stderr)
|
||||
print("[transformers-musicgen] TTS for", file=sys.stderr)
|
||||
print(request, file=sys.stderr)
|
||||
text = request.text
|
||||
print(f"[transformers] TTS generating for text: {text[:100]}...", file=sys.stderr)
|
||||
|
||||
# Build inputs based on processor capabilities
|
||||
if request.voice and os.path.exists(request.voice):
|
||||
# Voice cloning: use chat template with reference audio
|
||||
chat_template = [{
|
||||
"role": "0",
|
||||
"content": [
|
||||
{"type": "text", "text": text},
|
||||
{"type": "audio", "path": request.voice},
|
||||
],
|
||||
}]
|
||||
inputs = self.processor.apply_chat_template(
|
||||
chat_template, tokenize=True, return_dict=True,
|
||||
).to(self.model.device, self.model.dtype)
|
||||
elif hasattr(self.processor, 'apply_chat_template'):
|
||||
# Models that use chat template format (VibeVoice, CSM, etc.)
|
||||
chat_template = [{"role": "0", "content": [{"type": "text", "text": text}]}]
|
||||
try:
|
||||
inputs = self.processor.apply_chat_template(
|
||||
chat_template, tokenize=True, return_dict=True,
|
||||
).to(self.model.device, self.model.dtype)
|
||||
except Exception:
|
||||
# Fallback if chat template fails (not all processors support it)
|
||||
inputs = self.processor(text=[text], padding=True, return_tensors="pt")
|
||||
if self.CUDA:
|
||||
inputs = inputs.to("cuda")
|
||||
else:
|
||||
# Direct processor call (Musicgen, etc.)
|
||||
inputs = self.processor(text=[text], padding=True, return_tensors="pt")
|
||||
if self.CUDA:
|
||||
inputs = inputs.to("cuda")
|
||||
|
||||
# Build generation kwargs from self.options
|
||||
gen_kwargs = {**inputs, "max_new_tokens": self.max_tokens}
|
||||
for key in ["guidance_scale", "temperature", "top_p", "top_k", "do_sample"]:
|
||||
if key in self.options:
|
||||
gen_kwargs[key] = self.options[key]
|
||||
|
||||
# Add noise scheduler if configured (e.g., for VibeVoice)
|
||||
noise_scheduler_type = self.options.get("noise_scheduler", None)
|
||||
if noise_scheduler_type:
|
||||
import diffusers
|
||||
SchedulerClass = getattr(diffusers, noise_scheduler_type)
|
||||
scheduler_kwargs = {}
|
||||
for key in ["beta_schedule", "prediction_type"]:
|
||||
if key in self.options:
|
||||
scheduler_kwargs[key] = self.options[key]
|
||||
gen_kwargs["noise_scheduler"] = SchedulerClass(**scheduler_kwargs)
|
||||
|
||||
# Generate audio
|
||||
audio = self.model.generate(**gen_kwargs)
|
||||
print("[transformers] TTS generated!", file=sys.stderr)
|
||||
|
||||
# Save audio output
|
||||
if hasattr(self.processor, 'save_audio'):
|
||||
if hasattr(self.processor, 'batch_decode'):
|
||||
try:
|
||||
audio = self.processor.batch_decode(audio)
|
||||
except Exception:
|
||||
pass
|
||||
self.processor.save_audio(audio, request.dst)
|
||||
else:
|
||||
sampling_rate = self.model.config.audio_encoder.sampling_rate
|
||||
wavfile.write(request.dst, rate=sampling_rate, data=audio[0, 0].numpy())
|
||||
|
||||
print("[transformers] TTS saved to", request.dst, file=sys.stderr)
|
||||
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
return backend_pb2.Result(success=True)
|
||||
|
||||
@@ -2,7 +2,9 @@ torch==2.7.1
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
accelerate
|
||||
transformers
|
||||
transformers>=5.0.0
|
||||
bitsandbytes
|
||||
sentence-transformers==5.2.3
|
||||
diffusers
|
||||
soundfile
|
||||
protobuf==6.33.5
|
||||
@@ -2,7 +2,9 @@ torch==2.7.1
|
||||
accelerate
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
transformers
|
||||
transformers>=5.0.0
|
||||
bitsandbytes
|
||||
sentence-transformers==5.2.3
|
||||
diffusers
|
||||
soundfile
|
||||
protobuf==6.33.5
|
||||
@@ -2,7 +2,9 @@
|
||||
torch==2.9.0
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
transformers
|
||||
transformers>=5.0.0
|
||||
bitsandbytes
|
||||
sentence-transformers==5.2.3
|
||||
diffusers
|
||||
soundfile
|
||||
protobuf==6.33.5
|
||||
@@ -1,9 +1,11 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/rocm6.4
|
||||
torch==2.8.0+rocm6.4
|
||||
accelerate
|
||||
transformers
|
||||
transformers>=5.0.0
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
bitsandbytes
|
||||
sentence-transformers==5.2.3
|
||||
diffusers
|
||||
soundfile
|
||||
protobuf==6.33.5
|
||||
@@ -3,7 +3,9 @@ torch
|
||||
optimum[openvino]
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
transformers
|
||||
transformers>=5.0.0
|
||||
bitsandbytes
|
||||
sentence-transformers==5.2.3
|
||||
diffusers
|
||||
soundfile
|
||||
protobuf==6.33.5
|
||||
@@ -2,7 +2,9 @@ torch==2.7.1
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
accelerate
|
||||
transformers
|
||||
transformers>=5.0.0
|
||||
bitsandbytes
|
||||
sentence-transformers==5.2.3
|
||||
diffusers
|
||||
soundfile
|
||||
protobuf==6.33.5
|
||||
|
||||
Reference in New Issue
Block a user