From be2521795504866722dd3ce22b02ff5f5922bdc0 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sun, 22 Mar 2026 00:57:54 +0100 Subject: [PATCH] chore(transformers): bump to >5.0 and generically load models (#9097) * chore(transformers): bump to >5.0 Signed-off-by: Ettore Di Giacinto * chore: refactor to use generic model loading Signed-off-by: Ettore Di Giacinto --------- Signed-off-by: Ettore Di Giacinto --- backend/python/transformers/backend.py | 389 ++++++++---------- .../python/transformers/requirements-cpu.txt | 4 +- .../transformers/requirements-cublas12.txt | 4 +- .../transformers/requirements-cublas13.txt | 4 +- .../transformers/requirements-hipblas.txt | 4 +- .../transformers/requirements-intel.txt | 4 +- .../python/transformers/requirements-mps.txt | 4 +- 7 files changed, 185 insertions(+), 228 deletions(-) diff --git a/backend/python/transformers/backend.py b/backend/python/transformers/backend.py index 450664d6b..54fd9193d 100644 --- a/backend/python/transformers/backend.py +++ b/backend/python/transformers/backend.py @@ -21,11 +21,13 @@ import torch.cuda XPU=os.environ.get("XPU", "0") == "1" -from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria, MambaConfig, MambaForCausalLM -from transformers import AutoProcessor, MusicgenForConditionalGeneration, DiaForConditionalGeneration +import transformers as transformers_module +from transformers import AutoTokenizer, AutoModel, AutoProcessor, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria from scipy.io import wavfile from sentence_transformers import SentenceTransformer +# Backward-compat aliases for model types +TYPE_ALIASES = {"Mamba": "MambaForCausalLM"} _ONE_DAY_IN_SECONDS = 60 * 60 * 24 @@ -52,32 +54,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): This class implements the gRPC methods for the backend service, including Health, LoadModel, and Embedding. """ def Health(self, request, context): - """ - A gRPC method that returns the health status of the backend service. - - Args: - request: A HealthRequest object that contains the request parameters. - context: A grpc.ServicerContext object that provides information about the RPC. - - Returns: - A Reply object that contains the health status of the backend service. - """ return backend_pb2.Reply(message=bytes("OK", 'utf-8')) def LoadModel(self, request, context): - """ - A gRPC method that loads a model into memory. - - Args: - request: A LoadModelRequest object that contains the request parameters. - context: A grpc.ServicerContext object that provides information about the RPC. - - Returns: - A Result object that contains the result of the LoadModel operation. - """ - model_name = request.Model - + # Check to see if the Model exists in the filesystem already. if os.path.exists(request.ModelFile): model_name = request.ModelFile @@ -88,8 +69,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): self.CUDA = torch.cuda.is_available() self.OV=False - self.DiaTTS=False + self.GenericTTS=False self.SentenceTransformer = False + self.processor = None device_map="cpu" mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() @@ -101,7 +83,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): # Parse options from request.Options self.options = {} options = request.Options - + # The options are a list of strings in this form optname:optvalue # We are storing all the options in a dict so we can use it later when generating # Example options: ["max_new_tokens:3072", "guidance_scale:3.0", "temperature:1.8", "top_p:0.90", "top_k:45"] @@ -123,7 +105,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): print(f"Parsed options: {self.options}", file=sys.stderr) if self.CUDA: - from transformers import BitsAndBytesConfig, AutoModelForCausalLM + from transformers import BitsAndBytesConfig if request.MainGPU: device_map=request.MainGPU else: @@ -140,40 +122,31 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): quantization = BitsAndBytesConfig( load_in_4bit=False, bnb_4bit_compute_dtype = None, - load_in_8bit=True, + load_in_8bit=True, ) try: - if request.Type == "AutoModelForCausalLM": - if XPU: - import intel_extension_for_pytorch as ipex - from intel_extension_for_transformers.transformers.modeling import AutoModelForCausalLM + if XPU and request.Type == "AutoModelForCausalLM": + import intel_extension_for_pytorch as ipex + from intel_extension_for_transformers.transformers.modeling import AutoModelForCausalLM - device_map="xpu" - compute=torch.float16 - if request.Quantization == "xpu_4bit": - xpu_4bit = True - xpu_8bit = False - elif request.Quantization == "xpu_8bit": - xpu_4bit = False - xpu_8bit = True - else: - xpu_4bit = False - xpu_8bit = False - self.model = AutoModelForCausalLM.from_pretrained(model_name, - trust_remote_code=request.TrustRemoteCode, - use_safetensors=True, - device_map=device_map, - load_in_4bit=xpu_4bit, - load_in_8bit=xpu_8bit, - torch_dtype=compute) + device_map="xpu" + compute=torch.float16 + if request.Quantization == "xpu_4bit": + xpu_4bit = True + xpu_8bit = False + elif request.Quantization == "xpu_8bit": + xpu_4bit = False + xpu_8bit = True else: - self.model = AutoModelForCausalLM.from_pretrained(model_name, - trust_remote_code=request.TrustRemoteCode, - use_safetensors=True, - quantization_config=quantization, - device_map=device_map, - torch_dtype=compute) + xpu_4bit = False + xpu_8bit = False + self.model = AutoModelForCausalLM.from_pretrained(model_name, + trust_remote_code=request.TrustRemoteCode, + device_map=device_map, + load_in_4bit=xpu_4bit, + load_in_8bit=xpu_8bit, + torch_dtype=compute) elif request.Type == "OVModelForCausalLM": from optimum.intel.openvino import OVModelForCausalLM from openvino.runtime import Core @@ -185,14 +158,12 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): devices = Core().available_devices if "GPU" in " ".join(devices): device_map="AUTO:GPU" - # While working on a fine tuned model, inference may give an inaccuracy and performance drop on GPU if winograd convolutions are selected. - # https://docs.openvino.ai/2024/openvino-workflow/running-inference/inference-devices-and-modes/gpu-device.html if "CPU" or "NPU" in device_map: if "-CPU" or "-NPU" not in device_map: ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT"} else: ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT","GPU_DISABLE_WINOGRAD_CONVOLUTION": "YES"} - self.model = OVModelForCausalLM.from_pretrained(model_name, + self.model = OVModelForCausalLM.from_pretrained(model_name, compile=True, trust_remote_code=request.TrustRemoteCode, ov_config=ovconfig, @@ -209,59 +180,60 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): devices = Core().available_devices if "GPU" in " ".join(devices): device_map="AUTO:GPU" - # While working on a fine tuned model, inference may give an inaccuracy and performance drop on GPU if winograd convolutions are selected. - # https://docs.openvino.ai/2024/openvino-workflow/running-inference/inference-devices-and-modes/gpu-device.html if "CPU" or "NPU" in device_map: if "-CPU" or "-NPU" not in device_map: ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT"} else: ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT","GPU_DISABLE_WINOGRAD_CONVOLUTION": "YES"} - self.model = OVModelForFeatureExtraction.from_pretrained(model_name, + self.model = OVModelForFeatureExtraction.from_pretrained(model_name, compile=True, trust_remote_code=request.TrustRemoteCode, - ov_config=ovconfig, + ov_config=ovconfig, export=True, device=device_map) self.OV = True - elif request.Type == "MusicgenForConditionalGeneration": - autoTokenizer = False - self.processor = AutoProcessor.from_pretrained(model_name) - self.model = MusicgenForConditionalGeneration.from_pretrained(model_name) - elif request.Type == "DiaForConditionalGeneration": - autoTokenizer = False - print("DiaForConditionalGeneration", file=sys.stderr) - self.processor = AutoProcessor.from_pretrained(model_name) - self.model = DiaForConditionalGeneration.from_pretrained(model_name) - if self.CUDA: - self.model = self.model.to("cuda") - self.processor = self.processor.to("cuda") - print("DiaForConditionalGeneration loaded", file=sys.stderr) - self.DiaTTS = True elif request.Type == "SentenceTransformer": autoTokenizer = False self.model = SentenceTransformer(model_name, trust_remote_code=request.TrustRemoteCode) self.SentenceTransformer = True - elif request.Type == "Mamba": - autoTokenizer = False - self.tokenizer = AutoTokenizer.from_pretrained(model_name) - self.model = MambaForCausalLM.from_pretrained(model_name) else: - print("Automodel", file=sys.stderr) - self.model = AutoModel.from_pretrained(model_name, - trust_remote_code=request.TrustRemoteCode, - use_safetensors=True, - quantization_config=quantization, - device_map=device_map, - torch_dtype=compute) + # Generic: dynamically resolve model class from transformers + model_type = TYPE_ALIASES.get(request.Type, request.Type) + ModelClass = AutoModel # default + if model_type and hasattr(transformers_module, model_type): + ModelClass = getattr(transformers_module, model_type) + print(f"Using model class: {model_type}", file=sys.stderr) + else: + print(f"Using default AutoModel (type={request.Type!r})", file=sys.stderr) + + self.model = ModelClass.from_pretrained( + model_name, + trust_remote_code=request.TrustRemoteCode, + quantization_config=quantization, + device_map=device_map, + torch_dtype=compute, + ) + + # Try to load a processor (needed for TTS/audio models) + try: + self.processor = AutoProcessor.from_pretrained( + model_name, + trust_remote_code=request.TrustRemoteCode, + ) + self.GenericTTS = True + print(f"Loaded processor for {model_name}", file=sys.stderr) + except Exception: + self.processor = None + if request.ContextSize > 0: self.max_tokens = request.ContextSize elif hasattr(self.model, 'config') and hasattr(self.model.config, 'max_position_embeddings'): self.max_tokens = self.model.config.max_position_embeddings else: self.max_tokens = self.options.get("max_new_tokens", 512) - + if autoTokenizer: - self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_safetensors=True) + self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.XPU = False if XPU and self.OV == False: @@ -275,22 +247,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): except Exception as err: print("Error:", err, file=sys.stderr) return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") - # Implement your logic here for the LoadModel service - # Replace this with your desired response return backend_pb2.Result(message="Model loaded successfully", success=True) def Embedding(self, request, context): - """ - A gRPC method that calculates embeddings for a given sentence. - - Args: - request: An EmbeddingRequest object that contains the request parameters. - context: A grpc.ServicerContext object that provides information about the RPC. - - Returns: - An EmbeddingResult object that contains the calculated embeddings. - """ - set_seed(request.Seed) # Tokenize input max_length = 512 @@ -303,13 +262,13 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr) embeds = self.model.encode(request.Embeddings) else: - encoded_input = self.tokenizer(request.Embeddings, padding=True, truncation=True, max_length=max_length, return_tensors="pt") + encoded_input = self.tokenizer(request.Embeddings, padding=True, truncation=True, max_length=max_length, return_tensors="pt") # Create word embeddings if self.CUDA: encoded_input = encoded_input.to("cuda") - with torch.no_grad(): + with torch.no_grad(): model_output = self.model(**encoded_input) # Pool to get sentence embeddings; i.e. generate one 1024 vector for the entire sentence @@ -317,11 +276,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): embeds = sentence_embeddings[0] return backend_pb2.EmbeddingResult(embeddings=embeds) - async def _predict(self, request, context, streaming=False): + async def _predict(self, request, context, streaming=False): set_seed(request.Seed) if request.TopP < 0 or request.TopP > 1: request.TopP = 1 - + if request.TopK <= 0: request.TopK = 50 @@ -334,7 +293,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): request.Temperature == None prompt = request.Prompt - if not request.Prompt and request.UseTokenizerTemplate and request.Messages: + if not request.Prompt and request.UseTokenizerTemplate and request.Messages: prompt = self.tokenizer.apply_chat_template(request.Messages, tokenize=False, add_generation_prompt=True) inputs = self.tokenizer(prompt, return_tensors="pt") @@ -363,10 +322,10 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): skip_prompt=True, skip_special_tokens=True) config=dict(inputs, - max_new_tokens=max_tokens, - temperature=request.Temperature, + max_new_tokens=max_tokens, + temperature=request.Temperature, top_p=request.TopP, - top_k=request.TopK, + top_k=request.TopK, do_sample=sample, attention_mask=inputs["attention_mask"], eos_token_id=self.tokenizer.eos_token_id, @@ -387,18 +346,18 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): else: if XPU and self.OV == False: outputs = self.model.generate(inputs["input_ids"], - max_new_tokens=max_tokens, - temperature=request.Temperature, + max_new_tokens=max_tokens, + temperature=request.Temperature, top_p=request.TopP, - top_k=request.TopK, + top_k=request.TopK, do_sample=sample, pad_token=self.tokenizer.eos_token_id) else: outputs = self.model.generate(**inputs, - max_new_tokens=max_tokens, - temperature=request.Temperature, + max_new_tokens=max_tokens, + temperature=request.Temperature, top_p=request.TopP, - top_k=request.TopK, + top_k=request.TopK, do_sample=sample, eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.eos_token_id, @@ -413,31 +372,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): yield backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8')) async def Predict(self, request, context): - """ - Generates text based on the given prompt and sampling parameters. - - Args: - request: The predict request. - context: The gRPC context. - - Returns: - backend_pb2.Reply: The predict result. - """ gen = self._predict(request, context, streaming=False) res = await gen.__anext__() return res async def PredictStream(self, request, context): - """ - Generates text based on the given prompt and sampling parameters, and streams the results. - - Args: - request: The predict stream request. - context: The gRPC context. - - Returns: - backend_pb2.Result: The predict stream result. - """ iterations = self._predict(request, context, streaming=True) try: async for iteration in iterations: @@ -455,18 +394,19 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): if self.model is None: if model_name == "": return backend_pb2.Result(success=False, message="request.model is required") - self.model = MusicgenForConditionalGeneration.from_pretrained(model_name) + # Dynamically resolve model class if configured, otherwise default to MusicgenForConditionalGeneration + model_type = self.options.get("model_type", "MusicgenForConditionalGeneration") + ModelClass = getattr(transformers_module, model_type) + self.model = ModelClass.from_pretrained(model_name) inputs = None if request.text == "": inputs = self.model.get_unconditional_inputs(num_samples=1) elif request.HasField('src'): - # TODO SECURITY CODE GOES HERE LOL - # WHO KNOWS IF THIS WORKS??? sample_rate, wsamples = wavfile.read('path_to_your_file.wav') - + if request.HasField('src_divisor'): wsamples = wsamples[: len(wsamples) // request.src_divisor] - + inputs = self.processor( audio=wsamples, sampling_rate=sample_rate, @@ -480,7 +420,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): padding=True, return_tensors="pt", ) - + if request.HasField('duration'): tokens = int(request.duration * 51.2) # 256 tokens = 5 seconds, therefore 51.2 tokens is one second guidance = self.options.get("guidance_scale", 3.0) @@ -490,92 +430,97 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): if request.HasField('sample'): dosample = request.sample audio_values = self.model.generate(**inputs, do_sample=dosample, guidance_scale=guidance, max_new_tokens=self.max_tokens) - print("[transformers-musicgen] SoundGeneration generated!", file=sys.stderr) - sampling_rate = self.model.config.audio_encoder.sampling_rate - wavfile.write(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy()) - print("[transformers-musicgen] SoundGeneration saved to", request.dst, file=sys.stderr) - print("[transformers-musicgen] SoundGeneration for", file=sys.stderr) - print("[transformers-musicgen] SoundGeneration requested tokens", tokens, file=sys.stderr) + print("[transformers] SoundGeneration generated!", file=sys.stderr) + + # Save audio output + if hasattr(self.processor, 'save_audio'): + if hasattr(self.processor, 'batch_decode'): + try: + audio_values = self.processor.batch_decode(audio_values) + except Exception: + pass + self.processor.save_audio(audio_values, request.dst) + else: + sampling_rate = self.model.config.audio_encoder.sampling_rate + wavfile.write(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy()) + + print("[transformers] SoundGeneration saved to", request.dst, file=sys.stderr) print(request, file=sys.stderr) except Exception as err: return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") return backend_pb2.Result(success=True) - - def CallDiaTTS(self, request, context): - """ - Generates dialogue audio using the Dia model. - - Args: - request: A TTSRequest containing text dialogue and generation parameters - context: The gRPC context - - Returns: - A Result object indicating success or failure - """ - try: - print("[DiaTTS] generating dialogue audio", file=sys.stderr) - - # Prepare text input - expect dialogue format like [S1] ... [S2] ... - text = [request.text] - - # Process the input - inputs = self.processor(text=text, padding=True, return_tensors="pt") - - # Generate audio with parameters from options or defaults - generation_params = { - **inputs, - "max_new_tokens": self.max_tokens, - "guidance_scale": self.options.get("guidance_scale", 3.0), - "temperature": self.options.get("temperature", 1.8), - "top_p": self.options.get("top_p", 0.90), - "top_k": self.options.get("top_k", 45) - } - - outputs = self.model.generate(**generation_params) - - # Decode and save audio - outputs = self.processor.batch_decode(outputs) - self.processor.save_audio(outputs, request.dst) - - print("[DiaTTS] Generated dialogue audio", file=sys.stderr) - print("[DiaTTS] Audio saved to", request.dst, file=sys.stderr) - print("[DiaTTS] Dialogue generation done", file=sys.stderr) - - except Exception as err: - return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") - return backend_pb2.Result(success=True) - - -# The TTS endpoint is older, and provides fewer features, but exists for compatibility reasons def TTS(self, request, context): - if self.DiaTTS: - print("DiaTTS", file=sys.stderr) - return self.CallDiaTTS(request, context) - - model_name = request.model try: - if self.processor is None: - if model_name == "": - return backend_pb2.Result(success=False, message="request.model is required") - self.processor = AutoProcessor.from_pretrained(model_name) - if self.model is None: - if model_name == "": - return backend_pb2.Result(success=False, message="request.model is required") - self.model = MusicgenForConditionalGeneration.from_pretrained(model_name) - inputs = self.processor( - text=[request.text], - padding=True, - return_tensors="pt", - ) - tokens = self.max_tokens # No good place to set the "length" in TTS, so use 10s as a sane default - audio_values = self.model.generate(**inputs, max_new_tokens=tokens) - print("[transformers-musicgen] TTS generated!", file=sys.stderr) - sampling_rate = self.model.config.audio_encoder.sampling_rate - wavfile.write(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy()) - print("[transformers-musicgen] TTS saved to", request.dst, file=sys.stderr) - print("[transformers-musicgen] TTS for", file=sys.stderr) - print(request, file=sys.stderr) + text = request.text + print(f"[transformers] TTS generating for text: {text[:100]}...", file=sys.stderr) + + # Build inputs based on processor capabilities + if request.voice and os.path.exists(request.voice): + # Voice cloning: use chat template with reference audio + chat_template = [{ + "role": "0", + "content": [ + {"type": "text", "text": text}, + {"type": "audio", "path": request.voice}, + ], + }] + inputs = self.processor.apply_chat_template( + chat_template, tokenize=True, return_dict=True, + ).to(self.model.device, self.model.dtype) + elif hasattr(self.processor, 'apply_chat_template'): + # Models that use chat template format (VibeVoice, CSM, etc.) + chat_template = [{"role": "0", "content": [{"type": "text", "text": text}]}] + try: + inputs = self.processor.apply_chat_template( + chat_template, tokenize=True, return_dict=True, + ).to(self.model.device, self.model.dtype) + except Exception: + # Fallback if chat template fails (not all processors support it) + inputs = self.processor(text=[text], padding=True, return_tensors="pt") + if self.CUDA: + inputs = inputs.to("cuda") + else: + # Direct processor call (Musicgen, etc.) + inputs = self.processor(text=[text], padding=True, return_tensors="pt") + if self.CUDA: + inputs = inputs.to("cuda") + + # Build generation kwargs from self.options + gen_kwargs = {**inputs, "max_new_tokens": self.max_tokens} + for key in ["guidance_scale", "temperature", "top_p", "top_k", "do_sample"]: + if key in self.options: + gen_kwargs[key] = self.options[key] + + # Add noise scheduler if configured (e.g., for VibeVoice) + noise_scheduler_type = self.options.get("noise_scheduler", None) + if noise_scheduler_type: + import diffusers + SchedulerClass = getattr(diffusers, noise_scheduler_type) + scheduler_kwargs = {} + for key in ["beta_schedule", "prediction_type"]: + if key in self.options: + scheduler_kwargs[key] = self.options[key] + gen_kwargs["noise_scheduler"] = SchedulerClass(**scheduler_kwargs) + + # Generate audio + audio = self.model.generate(**gen_kwargs) + print("[transformers] TTS generated!", file=sys.stderr) + + # Save audio output + if hasattr(self.processor, 'save_audio'): + if hasattr(self.processor, 'batch_decode'): + try: + audio = self.processor.batch_decode(audio) + except Exception: + pass + self.processor.save_audio(audio, request.dst) + else: + sampling_rate = self.model.config.audio_encoder.sampling_rate + wavfile.write(request.dst, rate=sampling_rate, data=audio[0, 0].numpy()) + + print("[transformers] TTS saved to", request.dst, file=sys.stderr) + except Exception as err: return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") return backend_pb2.Result(success=True) diff --git a/backend/python/transformers/requirements-cpu.txt b/backend/python/transformers/requirements-cpu.txt index e397245e3..0ba6ba995 100644 --- a/backend/python/transformers/requirements-cpu.txt +++ b/backend/python/transformers/requirements-cpu.txt @@ -2,7 +2,9 @@ torch==2.7.1 llvmlite==0.43.0 numba==0.60.0 accelerate -transformers +transformers>=5.0.0 bitsandbytes sentence-transformers==5.2.3 +diffusers +soundfile protobuf==6.33.5 \ No newline at end of file diff --git a/backend/python/transformers/requirements-cublas12.txt b/backend/python/transformers/requirements-cublas12.txt index 07ae834c7..8c79290ef 100644 --- a/backend/python/transformers/requirements-cublas12.txt +++ b/backend/python/transformers/requirements-cublas12.txt @@ -2,7 +2,9 @@ torch==2.7.1 accelerate llvmlite==0.43.0 numba==0.60.0 -transformers +transformers>=5.0.0 bitsandbytes sentence-transformers==5.2.3 +diffusers +soundfile protobuf==6.33.5 \ No newline at end of file diff --git a/backend/python/transformers/requirements-cublas13.txt b/backend/python/transformers/requirements-cublas13.txt index a82220a08..be1489898 100644 --- a/backend/python/transformers/requirements-cublas13.txt +++ b/backend/python/transformers/requirements-cublas13.txt @@ -2,7 +2,9 @@ torch==2.9.0 llvmlite==0.43.0 numba==0.60.0 -transformers +transformers>=5.0.0 bitsandbytes sentence-transformers==5.2.3 +diffusers +soundfile protobuf==6.33.5 \ No newline at end of file diff --git a/backend/python/transformers/requirements-hipblas.txt b/backend/python/transformers/requirements-hipblas.txt index 2c346dfa6..5460b10b7 100644 --- a/backend/python/transformers/requirements-hipblas.txt +++ b/backend/python/transformers/requirements-hipblas.txt @@ -1,9 +1,11 @@ --extra-index-url https://download.pytorch.org/whl/rocm6.4 torch==2.8.0+rocm6.4 accelerate -transformers +transformers>=5.0.0 llvmlite==0.43.0 numba==0.60.0 bitsandbytes sentence-transformers==5.2.3 +diffusers +soundfile protobuf==6.33.5 \ No newline at end of file diff --git a/backend/python/transformers/requirements-intel.txt b/backend/python/transformers/requirements-intel.txt index 61f8d3ba9..a291cb56c 100644 --- a/backend/python/transformers/requirements-intel.txt +++ b/backend/python/transformers/requirements-intel.txt @@ -3,7 +3,9 @@ torch optimum[openvino] llvmlite==0.43.0 numba==0.60.0 -transformers +transformers>=5.0.0 bitsandbytes sentence-transformers==5.2.3 +diffusers +soundfile protobuf==6.33.5 \ No newline at end of file diff --git a/backend/python/transformers/requirements-mps.txt b/backend/python/transformers/requirements-mps.txt index 579fb95c7..cb69b6f42 100644 --- a/backend/python/transformers/requirements-mps.txt +++ b/backend/python/transformers/requirements-mps.txt @@ -2,7 +2,9 @@ torch==2.7.1 llvmlite==0.43.0 numba==0.60.0 accelerate -transformers +transformers>=5.0.0 bitsandbytes sentence-transformers==5.2.3 +diffusers +soundfile protobuf==6.33.5