diff --git a/backend/python/diffusers/backend.py b/backend/python/diffusers/backend.py index f26a94b57..032af60c4 100755 --- a/backend/python/diffusers/backend.py +++ b/backend/python/diffusers/backend.py @@ -42,12 +42,8 @@ from transformers import T5EncoderModel from safetensors.torch import load_file # Import LTX-2 specific utilities -try: - from diffusers.pipelines.ltx2.export_utils import encode_video as ltx2_encode_video - LTX2_AVAILABLE = True -except ImportError: - LTX2_AVAILABLE = False - ltx2_encode_video = None +from diffusers.pipelines.ltx2.export_utils import encode_video as ltx2_encode_video +from diffusers import LTX2VideoTransformer3DModel, GGUFQuantizationConfig _ONE_DAY_IN_SECONDS = 60 * 60 * 24 COMPEL = os.environ.get("COMPEL", "0") == "1" @@ -302,12 +298,96 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): if pipeline_type == "LTX2ImageToVideoPipeline": self.img2vid = True self.ltx2_pipeline = True - pipe = load_diffusers_pipeline( - class_name="LTX2ImageToVideoPipeline", - model_id=request.Model, - torch_dtype=torchType, - variant=variant - ) + + # Check if loading from single file (GGUF) + if fromSingleFile and LTX2VideoTransformer3DModel is not None: + _, single_file_ext = os.path.splitext(modelFile) + if single_file_ext == ".gguf": + # Load transformer from single GGUF file with quantization + transformer_kwargs = {} + quantization_config = GGUFQuantizationConfig(compute_dtype=torchType) + transformer_kwargs["quantization_config"] = quantization_config + + transformer = LTX2VideoTransformer3DModel.from_single_file( + modelFile, + config=request.Model, # Use request.Model as the config/model_id + subfolder="transformer", + **transformer_kwargs, + ) + + # Load pipeline with custom transformer + pipe = load_diffusers_pipeline( + class_name="LTX2ImageToVideoPipeline", + model_id=request.Model, + transformer=transformer, + torch_dtype=torchType, + ) + else: + # Single file but not GGUF - use standard single file loading + pipe = load_diffusers_pipeline( + class_name="LTX2ImageToVideoPipeline", + model_id=modelFile, + from_single_file=True, + torch_dtype=torchType, + ) + else: + # Standard loading from pretrained + pipe = load_diffusers_pipeline( + class_name="LTX2ImageToVideoPipeline", + model_id=request.Model, + torch_dtype=torchType, + variant=variant + ) + + if not DISABLE_CPU_OFFLOAD: + pipe.enable_model_cpu_offload() + return pipe + + # LTX2Pipeline - text-to-video pipeline, needs txt2vid flag, CPU offload, and special handling + if pipeline_type == "LTX2Pipeline": + self.txt2vid = True + self.ltx2_pipeline = True + + # Check if loading from single file (GGUF) + if fromSingleFile and LTX2VideoTransformer3DModel is not None: + _, single_file_ext = os.path.splitext(modelFile) + if single_file_ext == ".gguf": + # Load transformer from single GGUF file with quantization + transformer_kwargs = {} + quantization_config = GGUFQuantizationConfig(compute_dtype=torchType) + transformer_kwargs["quantization_config"] = quantization_config + + transformer = LTX2VideoTransformer3DModel.from_single_file( + modelFile, + config=request.Model, # Use request.Model as the config/model_id + subfolder="transformer", + **transformer_kwargs, + ) + + # Load pipeline with custom transformer + pipe = load_diffusers_pipeline( + class_name="LTX2Pipeline", + model_id=request.Model, + transformer=transformer, + torch_dtype=torchType, + ) + else: + # Single file but not GGUF - use standard single file loading + pipe = load_diffusers_pipeline( + class_name="LTX2Pipeline", + model_id=modelFile, + from_single_file=True, + torch_dtype=torchType, + ) + else: + # Standard loading from pretrained + pipe = load_diffusers_pipeline( + class_name="LTX2Pipeline", + model_id=request.Model, + torch_dtype=torchType, + variant=variant + ) + if not DISABLE_CPU_OFFLOAD: pipe.enable_model_cpu_offload() return pipe @@ -428,6 +508,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): self.txt2vid = False self.ltx2_pipeline = False + print(f"LoadModel: PipelineType from request: {request.PipelineType}", file=sys.stderr) + # Load pipeline using dynamic loader # Special cases that require custom initialization are handled first self.pipe = self._load_pipeline( @@ -437,6 +519,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): torchType=torchType, variant=variant ) + + print(f"LoadModel: After loading - ltx2_pipeline: {self.ltx2_pipeline}, img2vid: {self.img2vid}, txt2vid: {self.txt2vid}, PipelineType: {self.PipelineType}", file=sys.stderr) if CLIPSKIP and request.CLIPSkip != 0: self.clip_skip = request.CLIPSkip @@ -674,14 +758,20 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): try: prompt = request.prompt if not prompt: + print(f"GenerateVideo: No prompt provided for video generation.", file=sys.stderr) return backend_pb2.Result(success=False, message="No prompt provided for video generation") + # Debug: Print raw request values + print(f"GenerateVideo: Raw request values - num_frames: {request.num_frames}, fps: {request.fps}, cfg_scale: {request.cfg_scale}, step: {request.step}", file=sys.stderr) + # Set default values from request or use defaults num_frames = request.num_frames if request.num_frames > 0 else 81 fps = request.fps if request.fps > 0 else 16 cfg_scale = request.cfg_scale if request.cfg_scale > 0 else 4.0 num_inference_steps = request.step if request.step > 0 else 40 + print(f"GenerateVideo: Using values - num_frames: {num_frames}, fps: {fps}, cfg_scale: {cfg_scale}, num_inference_steps: {num_inference_steps}", file=sys.stderr) + # Prepare generation parameters kwargs = { "prompt": prompt, @@ -707,19 +797,34 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): kwargs["end_image"] = load_image(request.end_image) print(f"Generating video with {kwargs=}", file=sys.stderr) + print(f"GenerateVideo: Pipeline type: {self.PipelineType}, ltx2_pipeline flag: {self.ltx2_pipeline}", file=sys.stderr) # Generate video frames based on pipeline type - if self.ltx2_pipeline or self.PipelineType == "LTX2ImageToVideoPipeline": - # LTX-2 image-to-video generation with audio - if not LTX2_AVAILABLE: - return backend_pb2.Result(success=False, message="LTX-2 pipeline requires diffusers.pipelines.ltx2.export_utils") + if self.ltx2_pipeline or self.PipelineType in ["LTX2Pipeline", "LTX2ImageToVideoPipeline"]: + # LTX-2 generation with audio (supports both text-to-video and image-to-video) + # Determine if this is text-to-video (no image) or image-to-video (has image) + has_image = bool(request.start_image) - # LTX-2 uses 'image' parameter instead of 'start_image' - if request.start_image: - image = load_image(request.start_image) - kwargs["image"] = image - # Remove start_image if it was added - kwargs.pop("start_image", None) + # Remove image-related parameters that might have been added earlier + kwargs.pop("start_image", None) + kwargs.pop("end_image", None) + + # LTX2ImageToVideoPipeline uses 'image' parameter for image-to-video + # LTX2Pipeline (text-to-video) doesn't need an image parameter + if has_image: + # Image-to-video: use 'image' parameter + if self.PipelineType == "LTX2ImageToVideoPipeline": + image = load_image(request.start_image) + kwargs["image"] = image + print(f"LTX-2: Using image-to-video mode with image", file=sys.stderr) + else: + # If pipeline type is LTX2Pipeline but we have an image, we can't do image-to-video + return backend_pb2.Result(success=False, message="LTX2Pipeline does not support image-to-video. Use LTX2ImageToVideoPipeline for image-to-video generation.") + else: + # Text-to-video: no image parameter needed + # Ensure no image-related kwargs are present + kwargs.pop("image", None) + print(f"LTX-2: Using text-to-video mode (no image)", file=sys.stderr) # LTX-2 uses 'frame_rate' instead of 'fps' frame_rate = float(fps) @@ -730,20 +835,45 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): kwargs["return_dict"] = False # Generate video and audio - video, audio = self.pipe(**kwargs) + print(f"LTX-2: Generating with kwargs: {kwargs}", file=sys.stderr) + try: + video, audio = self.pipe(**kwargs) + print(f"LTX-2: Generated video shape: {video.shape}, audio shape: {audio.shape}", file=sys.stderr) + except Exception as e: + print(f"LTX-2: Error during pipe() call: {e}", file=sys.stderr) + traceback.print_exc() + return backend_pb2.Result(success=False, message=f"Error generating video with LTX-2 pipeline: {e}") # Convert video to uint8 format video = (video * 255).round().astype("uint8") video = torch.from_numpy(video) + print(f"LTX-2: Converting video, shape after conversion: {video.shape}", file=sys.stderr) + print(f"LTX-2: Audio sample rate: {self.pipe.vocoder.config.output_sampling_rate}", file=sys.stderr) + print(f"LTX-2: Output path: {request.dst}", file=sys.stderr) + # Use LTX-2's encode_video function which handles audio - ltx2_encode_video( - video[0], - fps=frame_rate, - audio=audio[0].float().cpu(), - audio_sample_rate=self.pipe.vocoder.config.output_sampling_rate, - output_path=request.dst, - ) + try: + ltx2_encode_video( + video[0], + fps=frame_rate, + audio=audio[0].float().cpu(), + audio_sample_rate=self.pipe.vocoder.config.output_sampling_rate, + output_path=request.dst, + ) + # Verify file was created and has content + import os + if os.path.exists(request.dst): + file_size = os.path.getsize(request.dst) + print(f"LTX-2: Video file created successfully, size: {file_size} bytes", file=sys.stderr) + if file_size == 0: + return backend_pb2.Result(success=False, message=f"Video file was created but is empty (0 bytes). Check LTX-2 encode_video function.") + else: + return backend_pb2.Result(success=False, message=f"Video file was not created at {request.dst}") + except Exception as e: + print(f"LTX-2: Error encoding video: {e}", file=sys.stderr) + traceback.print_exc() + return backend_pb2.Result(success=False, message=f"Error encoding video: {e}") return backend_pb2.Result(message="Video generated successfully", success=True) elif self.PipelineType == "WanPipeline": @@ -785,11 +915,23 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): output = self.pipe(**kwargs) frames = output.frames[0] else: + print(f"GenerateVideo: Pipeline {self.PipelineType} does not match any known video pipeline handler", file=sys.stderr) return backend_pb2.Result(success=False, message=f"Pipeline {self.PipelineType} does not support video generation") # Export video (for non-LTX-2 pipelines) + print(f"GenerateVideo: Exporting video to {request.dst} with fps={fps}", file=sys.stderr) export_to_video(frames, request.dst, fps=fps) + # Verify file was created + import os + if os.path.exists(request.dst): + file_size = os.path.getsize(request.dst) + print(f"GenerateVideo: Video file created, size: {file_size} bytes", file=sys.stderr) + if file_size == 0: + return backend_pb2.Result(success=False, message=f"Video file was created but is empty (0 bytes)") + else: + return backend_pb2.Result(success=False, message=f"Video file was not created at {request.dst}") + return backend_pb2.Result(message="Video generated successfully", success=True) except Exception as err: diff --git a/core/http/endpoints/localai/video.go b/core/http/endpoints/localai/video.go index 4ff343af0..da33a0373 100644 --- a/core/http/endpoints/localai/video.go +++ b/core/http/endpoints/localai/video.go @@ -167,6 +167,16 @@ func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi baseURL := middleware.BaseURL(c) + xlog.Debug("VideoEndpoint: Calling VideoGeneration", + "num_frames", input.NumFrames, + "fps", input.FPS, + "cfg_scale", input.CFGScale, + "step", input.Step, + "seed", input.Seed, + "width", width, + "height", height, + "negative_prompt", input.NegativePrompt) + fn, err := backend.VideoGeneration( height, width, diff --git a/core/http/endpoints/openai/video.go b/core/http/endpoints/openai/video.go deleted file mode 100644 index 12c06ffe6..000000000 --- a/core/http/endpoints/openai/video.go +++ /dev/null @@ -1,140 +0,0 @@ -package openai - -import ( - "encoding/json" - "fmt" - "strconv" - "strings" - - "github.com/labstack/echo/v4" - "github.com/mudler/LocalAI/core/config" - "github.com/mudler/LocalAI/core/http/endpoints/localai" - "github.com/mudler/LocalAI/core/http/middleware" - "github.com/mudler/LocalAI/core/schema" - model "github.com/mudler/LocalAI/pkg/model" -) - -func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { - return func(c echo.Context) error { - input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) - if !ok || input == nil { - return echo.ErrBadRequest - } - var raw map[string]interface{} - body := make([]byte, 0) - if c.Request().Body != nil { - c.Request().Body.Read(body) - } - if len(body) > 0 { - _ = json.Unmarshal(body, &raw) - } - // Build VideoRequest using shared mapper - vr := MapOpenAIToVideo(input, raw) - // Place VideoRequest into context so localai.VideoEndpoint can consume it - c.Set(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, vr) - // Delegate to existing localai handler - return localai.VideoEndpoint(cl, ml, appConfig)(c) - } -} - -// VideoEndpoint godoc -// @Summary Generate a video from an OpenAI-compatible request -// @Description Accepts an OpenAI-style request and delegates to the LocalAI video generator -// @Tags openai -// @Accept json -// @Produce json -// @Param request body schema.OpenAIRequest true "OpenAI-style request" -// @Success 200 {object} map[string]interface{} -// @Failure 400 {object} map[string]interface{} -// @Router /v1/videos [post] - -func MapOpenAIToVideo(input *schema.OpenAIRequest, raw map[string]interface{}) *schema.VideoRequest { - vr := &schema.VideoRequest{} - if input == nil { - return vr - } - - if input.Model != "" { - vr.Model = input.Model - } - - // Prompt mapping - switch p := input.Prompt.(type) { - case string: - vr.Prompt = p - case []interface{}: - if len(p) > 0 { - if s, ok := p[0].(string); ok { - vr.Prompt = s - } - } - } - - // Size - size := input.Size - if size == "" && raw != nil { - if v, ok := raw["size"].(string); ok { - size = v - } - } - if size != "" { - parts := strings.SplitN(size, "x", 2) - if len(parts) == 2 { - if wi, err := strconv.Atoi(parts[0]); err == nil { - vr.Width = int32(wi) - } - if hi, err := strconv.Atoi(parts[1]); err == nil { - vr.Height = int32(hi) - } - } - } - - // seconds -> num frames - secondsStr := "" - if raw != nil { - if v, ok := raw["seconds"].(string); ok { - secondsStr = v - } else if v, ok := raw["seconds"].(float64); ok { - secondsStr = fmt.Sprintf("%v", int(v)) - } - } - fps := int32(30) - if raw != nil { - if rawFPS, ok := raw["fps"]; ok { - switch rf := rawFPS.(type) { - case float64: - fps = int32(rf) - case string: - if fi, err := strconv.Atoi(rf); err == nil { - fps = int32(fi) - } - } - } - } - if secondsStr != "" { - if secF, err := strconv.Atoi(secondsStr); err == nil { - vr.FPS = fps - vr.NumFrames = int32(secF) * fps - } - } - - // input_reference - if raw != nil { - if v, ok := raw["input_reference"].(string); ok { - vr.StartImage = v - } - } - - // response format - if input.ResponseFormat != nil { - if rf, ok := input.ResponseFormat.(string); ok { - vr.ResponseFormat = rf - } - } - - if input.Step != 0 { - vr.Step = int32(input.Step) - } - - return vr -} diff --git a/core/http/routes/openai.go b/core/http/routes/openai.go index 2d62859f3..59514339e 100644 --- a/core/http/routes/openai.go +++ b/core/http/routes/openai.go @@ -152,27 +152,6 @@ func RegisterOpenAIRoutes(app *echo.Echo, app.POST("/v1/images/inpainting", inpaintingHandler, imageMiddleware...) app.POST("/images/inpainting", inpaintingHandler, imageMiddleware...) - // videos (OpenAI-compatible endpoints mapped to LocalAI video handler) - videoHandler := openai.VideoEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()) - videoMiddleware := []echo.MiddlewareFunc{ - traceMiddleware, - re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_VIDEO)), - re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }), - func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - if err := re.SetOpenAIRequest(c); err != nil { - return err - } - return next(c) - } - }, - } - - // OpenAI-style create video endpoint - app.POST("/v1/videos", videoHandler, videoMiddleware...) - app.POST("/v1/videos/generations", videoHandler, videoMiddleware...) - app.POST("/videos", videoHandler, videoMiddleware...) - // List models app.GET("/v1/models", openai.ListModelsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())) app.GET("/models", openai.ListModelsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())) diff --git a/core/http/static/video.js b/core/http/static/video.js index bcc2f782a..3d229e0d8 100644 --- a/core/http/static/video.js +++ b/core/http/static/video.js @@ -135,9 +135,9 @@ async function promptVideo() { return; } - // Make API request + // Make API request to LocalAI endpoint try { - const response = await fetch("v1/videos/generations", { + const response = await fetch("video", { method: "POST", headers: { "Content-Type": "application/json", @@ -219,9 +219,13 @@ async function promptVideo() { `; captionDiv.appendChild(detailsDiv); + // Button container + const buttonContainer = document.createElement("div"); + buttonContainer.className = "mt-1.5 flex gap-2"; + // Copy prompt button const copyBtn = document.createElement("button"); - copyBtn.className = "mt-1.5 px-2 py-0.5 text-[10px] bg-[var(--color-primary)] text-white rounded hover:opacity-80"; + copyBtn.className = "px-2 py-0.5 text-[10px] bg-[var(--color-primary)] text-white rounded hover:opacity-80"; copyBtn.innerHTML = 'Copy Prompt'; copyBtn.onclick = () => { navigator.clipboard.writeText(prompt).then(() => { @@ -231,7 +235,18 @@ async function promptVideo() { }, 2000); }); }; - captionDiv.appendChild(copyBtn); + buttonContainer.appendChild(copyBtn); + + // Download video button + const downloadBtn = document.createElement("button"); + downloadBtn.className = "px-2 py-0.5 text-[10px] bg-[var(--color-primary)] text-white rounded hover:opacity-80"; + downloadBtn.innerHTML = 'Download Video'; + downloadBtn.onclick = () => { + downloadVideo(item, downloadBtn); + }; + buttonContainer.appendChild(downloadBtn); + + captionDiv.appendChild(buttonContainer); videoContainer.appendChild(captionDiv); resultDiv.appendChild(videoContainer); @@ -269,6 +284,67 @@ function escapeHtml(text) { return div.innerHTML; } +// Helper function to download video +function downloadVideo(item, button) { + try { + let videoUrl; + let filename = "generated-video.mp4"; + + if (item.url) { + // If we have a URL, use it directly + videoUrl = item.url; + // Extract filename from URL if possible + const urlParts = item.url.split("/"); + if (urlParts.length > 0) { + const lastPart = urlParts[urlParts.length - 1]; + if (lastPart && lastPart.includes(".")) { + filename = lastPart; + } + } + } else if (item.b64_json) { + // Convert base64 to blob + const byteCharacters = atob(item.b64_json); + const byteNumbers = new Array(byteCharacters.length); + for (let i = 0; i < byteCharacters.length; i++) { + byteNumbers[i] = byteCharacters.charCodeAt(i); + } + const byteArray = new Uint8Array(byteNumbers); + const blob = new Blob([byteArray], { type: "video/mp4" }); + videoUrl = URL.createObjectURL(blob); + } else { + console.error("No video data available for download"); + return; + } + + // Create a temporary anchor element to trigger download + const link = document.createElement("a"); + link.href = videoUrl; + link.download = filename; + link.style.display = "none"; + document.body.appendChild(link); + link.click(); + document.body.removeChild(link); + + // Clean up object URL if we created one + if (item.b64_json && videoUrl.startsWith("blob:")) { + setTimeout(() => URL.revokeObjectURL(videoUrl), 100); + } + + // Show feedback + const originalHTML = button.innerHTML; + button.innerHTML = 'Downloaded!'; + setTimeout(() => { + button.innerHTML = originalHTML; + }, 2000); + } catch (error) { + console.error("Error downloading video:", error); + button.innerHTML = 'Error'; + setTimeout(() => { + button.innerHTML = 'Download Video'; + }, 2000); + } +} + // Initialize document.addEventListener("DOMContentLoaded", function() { const input = document.getElementById("input");