From 2f114b22695c78f6f6b76361f3cf4a2d7b7e623b Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Mon, 25 Aug 2025 22:03:39 +0200 Subject: [PATCH] feat(mlx-audio): Add mlx-audio backend Signed-off-by: Ettore Di Giacinto --- .github/workflows/backend.yml | 13 + Makefile | 4 + backend/index.yaml | 22 + backend/python/mlx-audio/Makefile | 23 + backend/python/mlx-audio/backend.py | 460 ++++++++++++++++++ backend/python/mlx-audio/install.sh | 14 + backend/python/mlx-audio/requirements-mps.txt | 1 + backend/python/mlx-audio/requirements.txt | 7 + backend/python/mlx-audio/run.sh | 11 + backend/python/mlx-audio/test.py | 142 ++++++ backend/python/mlx-audio/test.sh | 12 + 11 files changed, 709 insertions(+) create mode 100644 backend/python/mlx-audio/Makefile create mode 100644 backend/python/mlx-audio/backend.py create mode 100755 backend/python/mlx-audio/install.sh create mode 100644 backend/python/mlx-audio/requirements-mps.txt create mode 100644 backend/python/mlx-audio/requirements.txt create mode 100755 backend/python/mlx-audio/run.sh create mode 100644 backend/python/mlx-audio/test.py create mode 100755 backend/python/mlx-audio/test.sh diff --git a/.github/workflows/backend.yml b/.github/workflows/backend.yml index 5a0a6b32e..10e3bda90 100644 --- a/.github/workflows/backend.yml +++ b/.github/workflows/backend.yml @@ -997,6 +997,19 @@ jobs: dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }} quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }} quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }} + mlx-audio-darwin: + uses: ./.github/workflows/backend_build_darwin.yml + with: + backend: "mlx-audio" + build-type: "mps" + go-version: "1.24.x" + tag-suffix: "-metal-darwin-arm64-mlx-audio" + runs-on: "macOS-14" + secrets: + dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }} + dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }} + quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }} + quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }} llama-cpp-darwin: runs-on: macOS-14 strategy: diff --git a/Makefile b/Makefile index 48bf7c9e1..91aba2815 100644 --- a/Makefile +++ b/Makefile @@ -377,6 +377,10 @@ backends/mlx-vlm: BACKEND=mlx-vlm $(MAKE) build-darwin-python-backend ./local-ai backends install "ocifile://$(abspath ./backend-images/mlx-vlm.tar)" +backends/mlx-audio: + BACKEND=mlx-audio $(MAKE) build-darwin-python-backend + ./local-ai backends install "ocifile://$(abspath ./backend-images/mlx-audio.tar)" + backend-images: mkdir -p backend-images diff --git a/backend/index.yaml b/backend/index.yaml index 960cf3aec..cc32ac88e 100644 --- a/backend/index.yaml +++ b/backend/index.yaml @@ -159,6 +159,23 @@ - vision-language - LLM - MLX +- &mlx-audio + name: "mlx-audio" + uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-mlx-audio" + icon: https://avatars.githubusercontent.com/u/102832242?s=200&v=4 + urls: + - https://github.com/Blaizzy/mlx-audio + mirrors: + - localai/localai-backends:latest-metal-darwin-arm64-mlx-audio + license: MIT + description: | + Run Audio Models with MLX + tags: + - audio-to-text + - audio-generation + - text-to-audio + - LLM + - MLX - &rerankers name: "rerankers" alias: "rerankers" @@ -415,6 +432,11 @@ uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-mlx-vlm" mirrors: - localai/localai-backends:master-metal-darwin-arm64-mlx-vlm +- !!merge <<: *mlx-audio + name: "mlx-audio-development" + uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-mlx-audio" + mirrors: + - localai/localai-backends:master-metal-darwin-arm64-mlx-audio - !!merge <<: *kitten-tts name: "kitten-tts-development" uri: "quay.io/go-skynet/local-ai-backends:master-kitten-tts" diff --git a/backend/python/mlx-audio/Makefile b/backend/python/mlx-audio/Makefile new file mode 100644 index 000000000..bb7aabe3a --- /dev/null +++ b/backend/python/mlx-audio/Makefile @@ -0,0 +1,23 @@ +.PHONY: mlx-audio +mlx-audio: + bash install.sh + +.PHONY: run +run: mlx-audio + @echo "Running mlx-audio..." + bash run.sh + @echo "mlx run." + +.PHONY: test +test: mlx-audio + @echo "Testing mlx-audio..." + bash test.sh + @echo "mlx tested." + +.PHONY: protogen-clean +protogen-clean: + $(RM) backend_pb2_grpc.py backend_pb2.py + +.PHONY: clean +clean: protogen-clean + rm -rf venv __pycache__ \ No newline at end of file diff --git a/backend/python/mlx-audio/backend.py b/backend/python/mlx-audio/backend.py new file mode 100644 index 000000000..0ce302404 --- /dev/null +++ b/backend/python/mlx-audio/backend.py @@ -0,0 +1,460 @@ +#!/usr/bin/env python3 +import asyncio +from concurrent import futures +import argparse +import signal +import sys +import os +import shutil +import glob +from typing import List +import time +import tempfile + +import backend_pb2 +import backend_pb2_grpc + +import grpc +from mlx_audio.tts.generate import generate_audio +import soundfile as sf +import numpy as np + +_ONE_DAY_IN_SECONDS = 60 * 60 * 24 + +# If MAX_WORKERS are specified in the environment use it, otherwise default to 1 +MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) + +# Implement the BackendServicer class with the service methods +class BackendServicer(backend_pb2_grpc.BackendServicer): + """ + A gRPC servicer that implements the Backend service defined in backend.proto. + This backend provides TTS (Text-to-Speech) functionality using MLX-Audio. + """ + + def _is_float(self, s): + """Check if a string can be converted to float.""" + try: + float(s) + return True + except ValueError: + return False + + def _is_int(self, s): + """Check if a string can be converted to int.""" + try: + int(s) + return True + except ValueError: + return False + + def Health(self, request, context): + """ + Returns a health check message. + + Args: + request: The health check request. + context: The gRPC context. + + Returns: + backend_pb2.Reply: The health check reply. + """ + return backend_pb2.Reply(message=bytes("OK", 'utf-8')) + + async def LoadModel(self, request, context): + """ + Loads a TTS model using MLX-Audio. + + Args: + request: The load model request. + context: The gRPC context. + + Returns: + backend_pb2.Result: The load model result. + """ + try: + print(f"Loading MLX-Audio TTS model: {request.Model}", file=sys.stderr) + print(f"Request: {request}", file=sys.stderr) + + # Parse options like in the kokoro backend + options = request.Options + self.options = {} + + # The options are a list of strings in this form optname:optvalue + # We store all the options in a dict for later use + for opt in options: + if ":" not in opt: + continue + key, value = opt.split(":", 1) # Split only on first colon to handle values with colons + + # Convert numeric values to appropriate types + if self._is_float(value): + value = float(value) + elif self._is_int(value): + value = int(value) + elif value.lower() in ["true", "false"]: + value = value.lower() == "true" + + self.options[key] = value + + print(f"Options: {self.options}", file=sys.stderr) + + # Store the model path for later use + self.model_path = request.Model + + except Exception as err: + print(f"Error loading MLX-Audio TTS model {err=}, {type(err)=}", file=sys.stderr) + return backend_pb2.Result(success=False, message=f"Error loading MLX-Audio TTS model: {err}") + + print("MLX-Audio TTS model loaded successfully", file=sys.stderr) + return backend_pb2.Result(message="MLX-Audio TTS model loaded successfully", success=True) + + async def Predict(self, request, context): + """ + Generates text based on the given prompt using MLX-Audio TTS. + + Args: + request: The predict request. + context: The gRPC context. + + Returns: + backend_pb2.Reply: The predict result. + """ + try: + # For TTS, we expect the prompt to contain the text to synthesize + if not request.Prompt: + context.set_code(grpc.StatusCode.INVALID_ARGUMENT) + context.set_details("Prompt is required for TTS generation") + return backend_pb2.Reply(message=bytes("", encoding='utf-8')) + + # Get generation parameters + generation_params = self._build_generation_params(request) + + print(f"Generating TTS with MLX-Audio - text: {request.Prompt[:50]}..., params: {generation_params}", file=sys.stderr) + + # Generate audio using MLX-Audio + # Note: MLX-Audio generates files, so we'll create a temporary file and read it back + with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file: + temp_output = tmp_file.name + + try: + # Generate audio using MLX-Audio + generate_audio( + text=request.Prompt, + model_path=self.model_path, + voice=generation_params.get('voice', 'af_heart'), + speed=generation_params.get('speed', 1.0), + lang_code=generation_params.get('lang_code', 'a'), + file_prefix="tts_output", + audio_format="wav", + sample_rate=24000, + join_audio=True, + verbose=False + ) + + # Read the generated audio file + audio_data, sample_rate = sf.read(temp_output) + + # Convert to base64 for response (or handle as needed) + # For now, we'll return a success message + response = f"TTS audio generated successfully. Sample rate: {sample_rate}, Duration: {len(audio_data)/sample_rate:.2f}s" + + return backend_pb2.Reply(message=bytes(response, encoding='utf-8')) + + finally: + # Clean up temporary file + if os.path.exists(temp_output): + os.remove(temp_output) + + except Exception as e: + print(f"Error in MLX-Audio TTS Predict: {e}", file=sys.stderr) + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(f"TTS generation failed: {str(e)}") + return backend_pb2.Reply(message=bytes("", encoding='utf-8')) + + def TTS(self, request, context): + """ + Generates TTS audio from text using MLX-Audio. + + Args: + request: A TTSRequest object containing text, model, destination, voice, and language. + context: A grpc.ServicerContext object that provides information about the RPC. + + Returns: + A Result object indicating success or failure. + """ + try: + print(f"Generating TTS with MLX-Audio - text: {request.text[:50]}..., voice: {request.voice}, language: {request.language}", file=sys.stderr) + + # Get generation parameters + generation_params = self._build_generation_params_from_tts(request) + + # Generate audio using MLX-Audio + generate_audio( + text=request.text, + model_path=request.model if request.model else self.model_path, + voice=generation_params.get('voice', 'af_heart'), + speed=generation_params.get('speed', 1.0), + lang_code=generation_params.get('lang_code', 'a'), + file_prefix="tts_output", + audio_format="wav", + sample_rate=24000, + join_audio=True, + verbose=False + ) + + # The generate_audio function creates files with a specific naming pattern + # We need to find the generated file and move it to the requested destination + generated_files = glob.glob("tts_output_*.wav") + + if generated_files: + # Sort by creation time to get the most recent + generated_files.sort(key=lambda x: os.path.getctime(x), reverse=True) + generated_file = generated_files[0] + + # Move to requested destination if specified + if request.dst: + shutil.move(generated_file, request.dst) + print(f"TTS audio saved to: {request.dst}", file=sys.stderr) + else: + print(f"TTS audio generated: {generated_file}", file=sys.stderr) + + # Clean up other generated files + for file in generated_files[1:]: + try: + os.remove(file) + except: + pass + + return backend_pb2.Result(success=True, message=f"TTS audio generated successfully") + else: + return backend_pb2.Result(success=False, message="No audio file was generated") + + except Exception as e: + print(f"Error in MLX-Audio TTS: {e}", file=sys.stderr) + return backend_pb2.Result(success=False, message=f"TTS generation failed: {str(e)}") + + def Embedding(self, request, context): + """ + A gRPC method that calculates embeddings for a given sentence. + + Note: MLX-Audio doesn't support embeddings directly. This method returns an error. + + Args: + request: An EmbeddingRequest object that contains the request parameters. + context: A grpc.ServicerContext object that provides information about the RPC. + + Returns: + An EmbeddingResult object that contains the calculated embeddings. + """ + print("Embeddings not supported in MLX-Audio backend", file=sys.stderr) + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Embeddings are not supported in the MLX-Audio backend.") + return backend_pb2.EmbeddingResult() + + async def PredictStream(self, request, context): + """ + Generates TTS audio based on the given prompt and streams the results using MLX-Audio. + + Args: + request: The predict stream request. + context: The gRPC context. + + Yields: + backend_pb2.Reply: Streaming TTS results. + """ + try: + # For TTS streaming, we expect the prompt to contain the text to synthesize + if not request.Prompt: + context.set_code(grpc.StatusCode.INVALID_ARGUMENT) + context.set_details("Prompt is required for TTS generation") + yield backend_pb2.Reply(message=bytes("", encoding='utf-8')) + return + + # Get generation parameters + generation_params = self._build_generation_params(request) + + print(f"Streaming TTS with MLX-Audio - text: {request.Prompt[:50]}..., params: {generation_params}", file=sys.stderr) + + # Generate audio using MLX-Audio + with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file: + temp_output = tmp_file.name + + try: + # Generate audio using MLX-Audio + generate_audio( + text=request.Prompt, + model_path=self.model_path, + voice=generation_params.get('voice', 'af_heart'), + speed=generation_params.get('speed', 1.0), + lang_code=generation_params.get('lang_code', 'a'), + file_prefix="tts_stream", + audio_format="wav", + sample_rate=24000, + join_audio=True, + verbose=False + ) + + # Read the generated audio file + audio_data, sample_rate = sf.read(temp_output) + + # For streaming, we'll yield progress updates + # In a real implementation, you might want to stream the audio data itself + yield backend_pb2.Reply(message=bytes(f"TTS generation started. Text length: {len(request.Prompt)}", encoding='utf-8')) + yield backend_pb2.Reply(message=bytes(f"Audio generated. Sample rate: {sample_rate}, Duration: {len(audio_data)/sample_rate:.2f}s", encoding='utf-8')) + + finally: + # Clean up temporary file + if os.path.exists(temp_output): + os.remove(temp_output) + + except Exception as e: + print(f"Error in MLX-Audio TTS PredictStream: {e}", file=sys.stderr) + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(f"TTS streaming generation failed: {str(e)}") + yield backend_pb2.Reply(message=bytes("", encoding='utf-8')) + + def _build_generation_params_from_tts(self, request): + """ + Build generation parameters from TTSRequest for MLX-Audio TTS. + + Args: + request: The TTSRequest object. + + Returns: + dict: Generation parameters for MLX-Audio + """ + # Initialize generation parameters for MLX-Audio TTS + generation_params = { + 'speed': 1.0, + 'voice': 'af_heart', # Default voice + 'lang_code': 'a', # Default language code + } + + # Set voice from request + if request.voice: + generation_params['voice'] = request.voice + + # Set language code from request + if request.language: + generation_params['language'] = request.language + # Map language names to codes if needed + language_map = { + "american_english": "a", + "british_english": "b", + "spanish": "e", + "french": "f", + "hindi": "h", + "italian": "i", + "portuguese": "p", + "japanese": "j", + "mandarin_chinese": "z", + # Also accept direct language codes + "a": "a", "b": "b", "e": "e", "f": "f", "h": "h", "i": "i", "p": "p", "j": "j", "z": "z", + } + lang_code = language_map.get(request.language.lower(), request.language) + generation_params['lang_code'] = lang_code + + # Override with options if available + if hasattr(self, 'options'): + # Speed from options + if 'speed' in self.options: + generation_params['speed'] = self.options['speed'] + + # Voice from options + if 'voice' in self.options: + generation_params['voice'] = self.options['voice'] + + # Language code from options + if 'lang_code' in self.options: + generation_params['lang_code'] = self.options['lang_code'] + + return generation_params + + def _build_generation_params(self, request, default_speed=1.0): + """ + Build generation parameters from request attributes and options for MLX-Audio TTS. + + Args: + request: The gRPC request. + default_speed: Default speed if not specified. + + Returns: + dict: Generation parameters for MLX-Audio + """ + # Initialize generation parameters for MLX-Audio TTS + generation_params = { + 'speed': default_speed, + 'voice': 'af_heart', # Default voice + 'lang_code': 'a', # Default language code + } + + # Extract parameters from request attributes + if hasattr(request, 'Temperature') and request.Temperature > 0: + # Temperature could be mapped to speed variation + generation_params['speed'] = 1.0 + (request.Temperature - 0.5) * 0.5 + + # Override with options if available + if hasattr(self, 'options'): + # Speed from options + if 'speed' in self.options: + generation_params['speed'] = self.options['speed'] + + # Voice from options + if 'voice' in self.options: + generation_params['voice'] = self.options['voice'] + + # Language code from options + if 'lang_code' in self.options: + generation_params['lang_code'] = self.options['lang_code'] + + # Model-specific parameters + param_option_mapping = { + 'temp': 'speed', + 'temperature': 'speed', + 'top_p': 'speed', # Map top_p to speed variation + } + + for option_key, param_key in param_option_mapping.items(): + if option_key in self.options: + if param_key == 'speed': + # Ensure speed is within reasonable bounds + speed_val = float(self.options[option_key]) + if 0.5 <= speed_val <= 2.0: + generation_params[param_key] = speed_val + + return generation_params + +async def serve(address): + # Start asyncio gRPC server + server = grpc.aio.server(migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), + options=[ + ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB + ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB + ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB + ]) + # Add the servicer to the server + backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) + # Bind the server to the address + server.add_insecure_port(address) + + # Gracefully shutdown the server on SIGTERM or SIGINT + loop = asyncio.get_event_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + loop.add_signal_handler( + sig, lambda: asyncio.ensure_future(server.stop(5)) + ) + + # Start the server + await server.start() + print("MLX-Audio TTS Server started. Listening on: " + address, file=sys.stderr) + # Wait for the server to be terminated + await server.wait_for_termination() + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run the MLX-Audio TTS gRPC server.") + parser.add_argument( + "--addr", default="localhost:50051", help="The address to bind the server to." + ) + args = parser.parse_args() + + asyncio.run(serve(args.addr)) diff --git a/backend/python/mlx-audio/install.sh b/backend/python/mlx-audio/install.sh new file mode 100755 index 000000000..b8ee48552 --- /dev/null +++ b/backend/python/mlx-audio/install.sh @@ -0,0 +1,14 @@ +#!/bin/bash +set -e + +USE_PIP=true + +backend_dir=$(dirname $0) + +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +installRequirements diff --git a/backend/python/mlx-audio/requirements-mps.txt b/backend/python/mlx-audio/requirements-mps.txt new file mode 100644 index 000000000..31df2a190 --- /dev/null +++ b/backend/python/mlx-audio/requirements-mps.txt @@ -0,0 +1 @@ +git+https://github.com/Blaizzy/mlx-audio \ No newline at end of file diff --git a/backend/python/mlx-audio/requirements.txt b/backend/python/mlx-audio/requirements.txt new file mode 100644 index 000000000..5f47f0cfd --- /dev/null +++ b/backend/python/mlx-audio/requirements.txt @@ -0,0 +1,7 @@ +grpcio==1.71.0 +protobuf +certifi +setuptools +mlx-audio +soundfile +numpy \ No newline at end of file diff --git a/backend/python/mlx-audio/run.sh b/backend/python/mlx-audio/run.sh new file mode 100755 index 000000000..fc88f97da --- /dev/null +++ b/backend/python/mlx-audio/run.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +backend_dir=$(dirname $0) + +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +startBackend $@ \ No newline at end of file diff --git a/backend/python/mlx-audio/test.py b/backend/python/mlx-audio/test.py new file mode 100644 index 000000000..792cb0648 --- /dev/null +++ b/backend/python/mlx-audio/test.py @@ -0,0 +1,142 @@ +import unittest +import subprocess +import time +import backend_pb2 +import backend_pb2_grpc + +import grpc + +import unittest +import subprocess +import time +import grpc +import backend_pb2_grpc +import backend_pb2 + +class TestBackendServicer(unittest.TestCase): + """ + TestBackendServicer is the class that tests the gRPC service. + + This class contains methods to test the startup and shutdown of the gRPC service. + """ + def setUp(self): + self.service = subprocess.Popen(["python", "backend.py", "--addr", "localhost:50051"]) + time.sleep(10) + + def tearDown(self) -> None: + self.service.terminate() + self.service.wait() + + def test_server_startup(self): + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.Health(backend_pb2.HealthMessage()) + self.assertEqual(response.message, b'OK') + except Exception as err: + print(err) + self.fail("Server failed to start") + finally: + self.tearDown() + def test_load_model(self): + """ + This method tests if the TTS model is loaded successfully + """ + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Kokoro-82M-4bit")) + self.assertTrue(response.success) + self.assertEqual(response.message, "MLX-Audio TTS model loaded successfully") + except Exception as err: + print(err) + self.fail("LoadModel service failed") + finally: + self.tearDown() + + def test_tts_generation(self): + """ + This method tests if TTS audio is generated successfully + """ + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Kokoro-82M-4bit")) + self.assertTrue(response.success) + + # Test TTS generation + tts_req = backend_pb2.TTSRequest( + text="Hello, this is a test of the MLX-Audio TTS system.", + model="mlx-community/Kokoro-82M-4bit", + voice="af_heart", + language="a" + ) + tts_resp = stub.TTS(tts_req) + self.assertTrue(tts_resp.success) + self.assertIn("TTS audio generated successfully", tts_resp.message) + except Exception as err: + print(err) + self.fail("TTS service failed") + finally: + self.tearDown() + + def test_tts_with_options(self): + """ + This method tests if TTS works with various options and parameters + """ + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.LoadModel(backend_pb2.ModelOptions( + Model="mlx-community/Kokoro-82M-4bit", + Options=["voice:af_soft", "speed:1.2", "lang_code:b"] + )) + self.assertTrue(response.success) + + # Test TTS generation with different voice and language + tts_req = backend_pb2.TTSRequest( + text="Hello, this is a test with British English accent.", + model="mlx-community/Kokoro-82M-4bit", + voice="af_soft", + language="b" + ) + tts_resp = stub.TTS(tts_req) + self.assertTrue(tts_resp.success) + self.assertIn("TTS audio generated successfully", tts_resp.message) + except Exception as err: + print(err) + self.fail("TTS with options service failed") + finally: + self.tearDown() + + + def test_tts_multilingual(self): + """ + This method tests if TTS works with different languages + """ + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Kokoro-82M-4bit")) + self.assertTrue(response.success) + + # Test Spanish TTS + tts_req = backend_pb2.TTSRequest( + text="Hola, esto es una prueba del sistema TTS MLX-Audio.", + model="mlx-community/Kokoro-82M-4bit", + voice="af_heart", + language="e" + ) + tts_resp = stub.TTS(tts_req) + self.assertTrue(tts_resp.success) + self.assertIn("TTS audio generated successfully", tts_resp.message) + except Exception as err: + print(err) + self.fail("Multilingual TTS service failed") + finally: + self.tearDown() \ No newline at end of file diff --git a/backend/python/mlx-audio/test.sh b/backend/python/mlx-audio/test.sh new file mode 100755 index 000000000..f31ae54e4 --- /dev/null +++ b/backend/python/mlx-audio/test.sh @@ -0,0 +1,12 @@ +#!/bin/bash +set -e + +backend_dir=$(dirname $0) + +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +runUnittests