|
from typing import Dict, Any, Union, Optional, Tuple |
|
import torch |
|
from diffusers import LTXPipeline, LTXImageToVideoPipeline |
|
from PIL import Image |
|
import base64 |
|
import io |
|
import tempfile |
|
import random |
|
import numpy as np |
|
from moviepy.editor import ImageSequenceClip |
|
import os |
|
import logging |
|
import asyncio |
|
from varnish import Varnish |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
ENABLE_CPU_OFFLOAD = True |
|
EXPERIMENTAL_STUFF = False |
|
|
|
random.seed(0) |
|
np.random.seed(0) |
|
generator = torch.manual_seed(0) |
|
|
|
|
|
|
|
varnish = Varnish( |
|
enable_mmaudio=False, |
|
|
|
) |
|
|
|
class EndpointHandler: |
|
|
|
DEFAULT_FPS = 24 |
|
DEFAULT_DURATION = 4 |
|
DEFAULT_NUM_FRAMES = (DEFAULT_DURATION * DEFAULT_FPS) + 1 |
|
DEFAULT_NUM_STEPS = 25 |
|
DEFAULT_WIDTH = 768 |
|
DEFAULT_HEIGHT = 512 |
|
|
|
|
|
MAX_WIDTH = 1280 |
|
MAX_HEIGHT = 720 |
|
MAX_FRAMES = 257 |
|
|
|
|
|
def __init__(self, path: str = ""): |
|
"""Initialize the LTX Video handler with both text-to-video and image-to-video pipelines. |
|
|
|
Args: |
|
path (str): Path to the model weights directory |
|
""" |
|
if EXPERIMENTAL_STUFF: |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
|
|
|
self.text_to_video = LTXPipeline.from_pretrained( |
|
path, |
|
torch_dtype=torch.bfloat16 |
|
).to("cuda") |
|
|
|
self.image_to_video = LTXImageToVideoPipeline.from_pretrained( |
|
path, |
|
torch_dtype=torch.bfloat16 |
|
).to("cuda") |
|
|
|
if ENABLE_CPU_OFFLOAD: |
|
self.text_to_video.enable_model_cpu_offload() |
|
self.image_to_video.enable_model_cpu_offload() |
|
|
|
self.varnish = Varnish( |
|
device="cuda" if torch.cuda.is_available() else "cpu", |
|
output_format="mp4", |
|
output_codec="h264", |
|
output_quality=23, |
|
enable_mmaudio=False |
|
) |
|
|
|
def _validate_and_adjust_resolution(self, width: int, height: int) -> Tuple[int, int]: |
|
"""Validate and adjust resolution to meet constraints. |
|
|
|
Args: |
|
width (int): Requested width |
|
height (int): Requested height |
|
|
|
Returns: |
|
Tuple[int, int]: Adjusted (width, height) |
|
""" |
|
|
|
width = round(width / 32) * 32 |
|
height = round(height / 32) * 32 |
|
|
|
|
|
width = min(width, self.MAX_WIDTH) |
|
height = min(height, self.MAX_HEIGHT) |
|
|
|
|
|
width = max(width, 32) |
|
height = max(height, 32) |
|
|
|
return width, height |
|
|
|
def _validate_and_adjust_frames(self, num_frames: Optional[int] = None, fps: Optional[int] = None) -> Tuple[int, int]: |
|
"""Validate and adjust frame count and FPS to meet constraints. |
|
|
|
Args: |
|
num_frames (Optional[int]): Requested number of frames |
|
fps (Optional[int]): Requested frames per second |
|
|
|
Returns: |
|
Tuple[int, int]: Adjusted (num_frames, fps) |
|
""" |
|
|
|
fps = fps or self.DEFAULT_FPS |
|
num_frames = num_frames or self.DEFAULT_NUM_FRAMES |
|
|
|
|
|
k = (num_frames - 1) // 8 |
|
num_frames = (k * 8) + 1 |
|
|
|
|
|
num_frames = min(num_frames, self.MAX_FRAMES) |
|
|
|
return num_frames, fps |
|
|
|
async def process_and_encode_video( |
|
self, |
|
frames: torch.Tensor, |
|
fps: int, |
|
upscale_factor: int = 0, |
|
enable_interpolation: bool = False, |
|
interpolation_exp: int = 1 |
|
) -> tuple[str, dict]: |
|
"""Process video frames using Varnish and return base64 encoded result""" |
|
|
|
|
|
result = await self.varnish( |
|
input_data=frames, |
|
input_fps=fps, |
|
output_fps=fps, |
|
enable_upscale=upscale_factor > 1, |
|
upscale_factor=upscale_factor, |
|
enable_interpolation=enable_interpolation, |
|
interpolation_exp=interpolation_exp |
|
) |
|
|
|
|
|
video_data_uri = await result.write( |
|
output_type="data-uri", |
|
output_format="mp4", |
|
output_codec="h264", |
|
output_quality=23 |
|
) |
|
|
|
metadata = { |
|
"width": result.metadata.width, |
|
"height": result.metadata.height, |
|
"num_frames": result.metadata.frame_count, |
|
"fps": result.metadata.fps, |
|
"duration": result.metadata.duration |
|
} |
|
|
|
return video_data_uri, metadata |
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
"""Process the input data and generate video using LTX. |
|
|
|
Args: |
|
data (Dict[str, Any]): Input data containing: |
|
- prompt (str): Text description for video generation |
|
- image (Optional[str]): Base64 encoded image for image-to-video generation |
|
- width (Optional[int]): Video width (default: 768) |
|
- height (Optional[int]): Video height (default: 512) |
|
- num_frames (Optional[int]): Number of frames (default: 97) |
|
- fps (Optional[int]): Frames per second (default: 24) |
|
- num_inference_steps (Optional[int]): Number of inference steps (default: 25) |
|
- guidance_scale (Optional[float]): Guidance scale (default: 7.5) |
|
|
|
Returns: |
|
Dict[str, Any]: Dictionary containing: |
|
- video: video encoded in Base64 (h.264 MP4 video). This is a data-uri (prefixed with "data:"). |
|
- content-type: MIME type of the video (right now always "video/mp4") |
|
- metadata: Dictionary with actual values used for generation |
|
""" |
|
|
|
prompt = data.get("inputs", None) |
|
if not prompt: |
|
raise ValueError("No prompt provided in the 'inputs' field") |
|
|
|
|
|
width = data.get("width", self.DEFAULT_WIDTH) |
|
height = data.get("height", self.DEFAULT_HEIGHT) |
|
width, height = self._validate_and_adjust_resolution(width, height) |
|
|
|
num_frames = data.get("num_frames", self.DEFAULT_NUM_FRAMES) |
|
fps = data.get("fps", self.DEFAULT_FPS) |
|
num_frames, fps = self._validate_and_adjust_frames(num_frames, fps) |
|
|
|
|
|
upscale_factor = data.get("upscale_factor", 0) |
|
enable_interpolation = data.get("enable_interpolation", False) |
|
interpolation_exp = data.get("interpolation_exp", 1) |
|
|
|
guidance_scale = data.get("guidance_scale", 7.5) |
|
num_inference_steps = data.get("num_inference_steps", self.DEFAULT_NUM_STEPS) |
|
seed = data.get("seed", -1) |
|
seed = random.randint(0, 2**32 - 1) if seed == -1 else int(seed) |
|
|
|
try: |
|
with torch.no_grad(): |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
generator.manual_seed(seed) |
|
|
|
generation_kwargs = { |
|
"prompt": prompt, |
|
"height": height, |
|
"width": width, |
|
"num_frames": num_frames, |
|
"guidance_scale": guidance_scale, |
|
"num_inference_steps": num_inference_steps, |
|
"output_type": "pt", |
|
"generator": generator |
|
} |
|
|
|
|
|
image_data = data.get("image") |
|
if image_data: |
|
if image_data.startswith('data:'): |
|
image_data = image_data.split(',', 1)[1] |
|
image_bytes = base64.b64decode(image_data) |
|
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
generation_kwargs["image"] = image |
|
frames = self.image_to_video(**generation_kwargs).frames |
|
else: |
|
frames = self.text_to_video(**generation_kwargs).frames |
|
|
|
|
|
video_data_uri, metadata = await self.process_and_encode_video( |
|
frames=frames, |
|
fps=fps, |
|
upscale_factor=upscale_factor, |
|
enable_interpolation=enable_interpolation, |
|
interpolation_exp=interpolation_exp |
|
) |
|
|
|
|
|
metadata.update({ |
|
"num_inference_steps": num_inference_steps, |
|
"seed": seed, |
|
"upscale_factor": upscale_factor, |
|
"interpolation_enabled": enable_interpolation, |
|
"interpolation_exp": interpolation_exp |
|
}) |
|
|
|
return { |
|
"video": video_data_uri, |
|
"content-type": "video/mp4", |
|
"metadata": metadata |
|
} |
|
|
|
except Exception as e: |
|
logger.error(f"Error generating video: {str(e)}") |
|
raise RuntimeError(f"Error generating video: {str(e)}") |