diff --git a/backend/python/diffusers/backend.py b/backend/python/diffusers/backend.py index cb2e073ac..f26a94b57 100755 --- a/backend/python/diffusers/backend.py +++ b/backend/python/diffusers/backend.py @@ -41,6 +41,14 @@ from optimum.quanto import freeze, qfloat8, quantize from transformers import T5EncoderModel from safetensors.torch import load_file +# Import LTX-2 specific utilities +try: + from diffusers.pipelines.ltx2.export_utils import encode_video as ltx2_encode_video + LTX2_AVAILABLE = True +except ImportError: + LTX2_AVAILABLE = False + ltx2_encode_video = None + _ONE_DAY_IN_SECONDS = 60 * 60 * 24 COMPEL = os.environ.get("COMPEL", "0") == "1" XPU = os.environ.get("XPU", "0") == "1" @@ -290,6 +298,20 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): pipe.enable_model_cpu_offload() return pipe + # LTX2ImageToVideoPipeline - needs img2vid flag, CPU offload, and special handling + if pipeline_type == "LTX2ImageToVideoPipeline": + self.img2vid = True + self.ltx2_pipeline = True + pipe = load_diffusers_pipeline( + class_name="LTX2ImageToVideoPipeline", + model_id=request.Model, + torch_dtype=torchType, + variant=variant + ) + if not DISABLE_CPU_OFFLOAD: + pipe.enable_model_cpu_offload() + return pipe + # ================================================================ # Dynamic pipeline loading - the default path for most pipelines # Uses the dynamic loader to instantiate any pipeline by class name @@ -404,6 +426,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): fromSingleFile = request.Model.startswith("http") or request.Model.startswith("/") or local self.img2vid = False self.txt2vid = False + self.ltx2_pipeline = False # Load pipeline using dynamic loader # Special cases that require custom initialization are handled first @@ -686,7 +709,44 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): print(f"Generating video with {kwargs=}", file=sys.stderr) # Generate video frames based on pipeline type - if self.PipelineType == "WanPipeline": + if self.ltx2_pipeline or self.PipelineType == "LTX2ImageToVideoPipeline": + # LTX-2 image-to-video generation with audio + if not LTX2_AVAILABLE: + return backend_pb2.Result(success=False, message="LTX-2 pipeline requires diffusers.pipelines.ltx2.export_utils") + + # LTX-2 uses 'image' parameter instead of 'start_image' + if request.start_image: + image = load_image(request.start_image) + kwargs["image"] = image + # Remove start_image if it was added + kwargs.pop("start_image", None) + + # LTX-2 uses 'frame_rate' instead of 'fps' + frame_rate = float(fps) + kwargs["frame_rate"] = frame_rate + + # LTX-2 requires output_type="np" and return_dict=False + kwargs["output_type"] = "np" + kwargs["return_dict"] = False + + # Generate video and audio + video, audio = self.pipe(**kwargs) + + # Convert video to uint8 format + video = (video * 255).round().astype("uint8") + video = torch.from_numpy(video) + + # Use LTX-2's encode_video function which handles audio + ltx2_encode_video( + video[0], + fps=frame_rate, + audio=audio[0].float().cpu(), + audio_sample_rate=self.pipe.vocoder.config.output_sampling_rate, + output_path=request.dst, + ) + + return backend_pb2.Result(message="Video generated successfully", success=True) + elif self.PipelineType == "WanPipeline": # WAN2.2 text-to-video generation output = self.pipe(**kwargs) frames = output.frames[0] # WAN2.2 returns frames in this format @@ -727,7 +787,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): else: return backend_pb2.Result(success=False, message=f"Pipeline {self.PipelineType} does not support video generation") - # Export video + # Export video (for non-LTX-2 pipelines) export_to_video(frames, request.dst, fps=fps) return backend_pb2.Result(message="Video generated successfully", success=True) diff --git a/gallery/index.yaml b/gallery/index.yaml index da8c0cd0f..842633e7f 100644 --- a/gallery/index.yaml +++ b/gallery/index.yaml @@ -1247,6 +1247,63 @@ cuda: true pipeline_type: QwenImageEditPipeline enable_parameters: num_inference_steps,image +- <x2 + name: "ltx-2" + url: "github:mudler/LocalAI/gallery/virtual.yaml@master" + urls: + - https://huggingface.co/Lightricks/LTX-2 + license: ltx-2-community-license-agreement + tags: + - diffusers + - gpu + - image-to-video + - video-generation + - audio-video + description: | + **LTX-2** is a DiT-based audio-video foundation model designed to generate synchronized video and audio within a single model. It brings together the core building blocks of modern video generation, with open weights and a focus on practical, local execution. + + **Key Features:** + - **Joint Audio-Video Generation**: Generates synchronized video and audio in a single model + - **Image-to-Video**: Converts static images into dynamic videos with matching audio + - **High Quality**: Produces realistic video with natural motion and synchronized audio + - **Open Weights**: Available under the LTX-2 Community License Agreement + + **Model Details:** + - **Model Type**: Diffusion-based audio-video foundation model + - **Architecture**: DiT (Diffusion Transformer) based + - **Developed by**: Lightricks + - **Paper**: [LTX-2: Efficient Joint Audio-Visual Foundation Model](https://arxiv.org/abs/2601.03233) + + **Usage Tips:** + - Width & height settings must be divisible by 32 + - Frame count must be divisible by 8 + 1 (e.g., 9, 17, 25, 33, 41, 49, 57, 65, 73, 81, 89, 97, 105, 113, 121) + - Recommended settings: width=768, height=512, num_frames=121, frame_rate=24.0 + - For best results, use detailed prompts describing motion and scene dynamics + + **Limitations:** + - This model is not intended or able to provide factual information + - Prompt following is heavily influenced by the prompting-style + - When generating audio without speech, the audio may be of lower quality + + **Citation:** + ```bibtex + @article{hacohen2025ltx2, + title={LTX-2: Efficient Joint Audio-Visual Foundation Model}, + author={HaCohen, Yoav and Brazowski, Benny and Chiprut, Nisan and others}, + journal={arXiv preprint arXiv:2601.03233}, + year={2025} + } + ``` + overrides: + backend: diffusers + low_vram: true + parameters: + model: Lightricks/LTX-2 + diffusers: + cuda: true + pipeline_type: LTX2ImageToVideoPipeline + options: + - torch_dtype:bf16 - &gptoss name: "gpt-oss-20b" url: "github:mudler/LocalAI/gallery/harmony.yaml@master"