from typing import Dict, Any, Union, Optional, Tuple import torch from diffusers import LTXPipeline, LTXImageToVideoPipeline from PIL import Image import base64 import io import tempfile import random import numpy as np from moviepy.editor import ImageSequenceClip import os import logging import asyncio from varnish import Varnish # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) ENABLE_CPU_OFFLOAD = True EXPERIMENTAL_STUFF = False random.seed(0) np.random.seed(0) generator = torch.manual_seed(0) # you can notice we don't use device=cuda, for more info see: # https://huggingface.co/docs/diffusers/v0.16.0/en/using-diffusers/reproducibility#gpu varnish = Varnish( enable_mmaudio=False, #mmaudio_config=mmaudio_config ) class EndpointHandler: # Default configuration DEFAULT_FPS = 24 DEFAULT_DURATION = 4 # seconds DEFAULT_NUM_FRAMES = (DEFAULT_DURATION * DEFAULT_FPS) + 1 # 97 frames DEFAULT_NUM_STEPS = 25 DEFAULT_WIDTH = 768 DEFAULT_HEIGHT = 512 # Constraints MAX_WIDTH = 1280 MAX_HEIGHT = 720 MAX_FRAMES = 257 def __init__(self, path: str = ""): """Initialize the LTX Video handler with both text-to-video and image-to-video pipelines. Args: path (str): Path to the model weights directory """ if EXPERIMENTAL_STUFF: torch.backends.cuda.matmul.allow_tf32 = True # Load both pipelines with bfloat16 precision as recommended in docs self.text_to_video = LTXPipeline.from_pretrained( path, torch_dtype=torch.bfloat16 ).to("cuda") self.image_to_video = LTXImageToVideoPipeline.from_pretrained( path, torch_dtype=torch.bfloat16 ).to("cuda") if ENABLE_CPU_OFFLOAD: self.text_to_video.enable_model_cpu_offload() self.image_to_video.enable_model_cpu_offload() self.varnish = Varnish( device="cuda" if torch.cuda.is_available() else "cpu", output_format="mp4", output_codec="h264", output_quality=23, enable_mmaudio=False ) def _validate_and_adjust_resolution(self, width: int, height: int) -> Tuple[int, int]: """Validate and adjust resolution to meet constraints. Args: width (int): Requested width height (int): Requested height Returns: Tuple[int, int]: Adjusted (width, height) """ # Round to nearest multiple of 32 width = round(width / 32) * 32 height = round(height / 32) * 32 # Enforce maximum dimensions width = min(width, self.MAX_WIDTH) height = min(height, self.MAX_HEIGHT) # Enforce minimum dimensions width = max(width, 32) height = max(height, 32) return width, height def _validate_and_adjust_frames(self, num_frames: Optional[int] = None, fps: Optional[int] = None) -> Tuple[int, int]: """Validate and adjust frame count and FPS to meet constraints. Args: num_frames (Optional[int]): Requested number of frames fps (Optional[int]): Requested frames per second Returns: Tuple[int, int]: Adjusted (num_frames, fps) """ # Use defaults if not provided fps = fps or self.DEFAULT_FPS num_frames = num_frames or self.DEFAULT_NUM_FRAMES # Adjust frames to be in format 8k + 1 k = (num_frames - 1) // 8 num_frames = (k * 8) + 1 # Enforce maximum frame count num_frames = min(num_frames, self.MAX_FRAMES) return num_frames, fps async def process_and_encode_video( self, frames: torch.Tensor, fps: int, upscale_factor: int = 0, enable_interpolation: bool = False, interpolation_exp: int = 1 ) -> tuple[str, dict]: """Process video frames using Varnish and return base64 encoded result""" # Process video with Varnish result = await self.varnish( input_data=frames, input_fps=fps, output_fps=fps, enable_upscale=upscale_factor > 1, upscale_factor=upscale_factor, enable_interpolation=enable_interpolation, interpolation_exp=interpolation_exp ) # Get video as data URI video_data_uri = await result.write( output_type="data-uri", output_format="mp4", output_codec="h264", output_quality=23 ) metadata = { "width": result.metadata.width, "height": result.metadata.height, "num_frames": result.metadata.frame_count, "fps": result.metadata.fps, "duration": result.metadata.duration } return video_data_uri, metadata def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """Process the input data and generate video using LTX. Args: data (Dict[str, Any]): Input data containing: - prompt (str): Text description for video generation - image (Optional[str]): Base64 encoded image for image-to-video generation - width (Optional[int]): Video width (default: 768) - height (Optional[int]): Video height (default: 512) - num_frames (Optional[int]): Number of frames (default: 97) - fps (Optional[int]): Frames per second (default: 24) - num_inference_steps (Optional[int]): Number of inference steps (default: 25) - guidance_scale (Optional[float]): Guidance scale (default: 7.5) Returns: Dict[str, Any]: Dictionary containing: - video: video encoded in Base64 (h.264 MP4 video). This is a data-uri (prefixed with "data:"). - content-type: MIME type of the video (right now always "video/mp4") - metadata: Dictionary with actual values used for generation """ prompt = data.get("inputs", None) if not prompt: raise ValueError("No prompt provided in the 'inputs' field") # Get generation parameters width = data.get("width", self.DEFAULT_WIDTH) height = data.get("height", self.DEFAULT_HEIGHT) width, height = self._validate_and_adjust_resolution(width, height) num_frames = data.get("num_frames", self.DEFAULT_NUM_FRAMES) fps = data.get("fps", self.DEFAULT_FPS) num_frames, fps = self._validate_and_adjust_frames(num_frames, fps) # Get post-processing parameters upscale_factor = data.get("upscale_factor", 0) enable_interpolation = data.get("enable_interpolation", False) interpolation_exp = data.get("interpolation_exp", 1) guidance_scale = data.get("guidance_scale", 7.5) num_inference_steps = data.get("num_inference_steps", self.DEFAULT_NUM_STEPS) seed = data.get("seed", -1) seed = random.randint(0, 2**32 - 1) if seed == -1 else int(seed) try: with torch.no_grad(): random.seed(seed) np.random.seed(seed) generator.manual_seed(seed) generation_kwargs = { "prompt": prompt, "height": height, "width": width, "num_frames": num_frames, "guidance_scale": guidance_scale, "num_inference_steps": num_inference_steps, "output_type": "pt", "generator": generator } # Generate frames using appropriate pipeline image_data = data.get("image") if image_data: if image_data.startswith('data:'): image_data = image_data.split(',', 1)[1] image_bytes = base64.b64decode(image_data) image = Image.open(io.BytesIO(image_bytes)).convert("RGB") generation_kwargs["image"] = image frames = self.image_to_video(**generation_kwargs).frames else: frames = self.text_to_video(**generation_kwargs).frames # Process and encode video video_data_uri, metadata = await self.process_and_encode_video( frames=frames, fps=fps, upscale_factor=upscale_factor, enable_interpolation=enable_interpolation, interpolation_exp=interpolation_exp ) # Add generation metadata metadata.update({ "num_inference_steps": num_inference_steps, "seed": seed, "upscale_factor": upscale_factor, "interpolation_enabled": enable_interpolation, "interpolation_exp": interpolation_exp }) return { "video": video_data_uri, "content-type": "video/mp4", "metadata": metadata } except Exception as e: logger.error(f"Error generating video: {str(e)}") raise RuntimeError(f"Error generating video: {str(e)}")