LTX-Video-0.9.1-HFIE / handler.py
jbilcke-hf's picture
jbilcke-hf HF staff
Upload handler.py
132e8c4 verified
raw
history blame
3.59 kB
from typing import Dict, Any, Union, Optional
import torch
from diffusers import LTXPipeline, LTXImageToVideoPipeline
from PIL import Image
import base64
import io
class EndpointHandler:
def __init__(self, path: str = ""):
"""Initialize the LTX Video handler with both text-to-video and image-to-video pipelines.
Args:
path (str): Path to the model weights directory
"""
# Load both pipelines with bfloat16 precision as recommended in docs
self.text_to_video = LTXPipeline.from_pretrained(
path,
torch_dtype=torch.bfloat16
).to("cuda")
self.image_to_video = LTXImageToVideoPipeline.from_pretrained(
path,
torch_dtype=torch.bfloat16
).to("cuda")
# Enable memory optimizations
self.text_to_video.enable_model_cpu_offload()
self.image_to_video.enable_model_cpu_offload()
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Process the input data and generate video using LTX.
Args:
data (Dict[str, Any]): Input data containing:
- prompt (str): Text description for video generation
- image (Optional[str]): Base64 encoded image for image-to-video generation
- num_frames (Optional[int]): Number of frames to generate (default: 24)
- guidance_scale (Optional[float]): Guidance scale (default: 7.5)
- num_inference_steps (Optional[int]): Number of inference steps (default: 50)
Returns:
Dict[str, Any]: Dictionary containing:
- frames: List of base64 encoded frames
"""
# Extract parameters
prompt = data.get("prompt")
if not prompt:
raise ValueError("'prompt' is required in the input data")
# Get optional parameters with defaults
num_frames = data.get("num_frames", 24)
guidance_scale = data.get("guidance_scale", 7.5)
num_inference_steps = data.get("num_inference_steps", 50)
# Check if image is provided for image-to-video generation
image_data = data.get("image")
try:
if image_data:
# Decode base64 image
image_bytes = base64.b64decode(image_data)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# Generate video from image
output = self.image_to_video(
prompt=prompt,
image=image,
num_frames=num_frames,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps
)
else:
# Generate video from text only
output = self.text_to_video(
prompt=prompt,
num_frames=num_frames,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps
)
# Convert frames to base64
frames = []
for frame in output.frames[0]: # First element contains the frames
buffer = io.BytesIO()
frame.save(buffer, format="PNG")
frame_base64 = base64.b64encode(buffer.getvalue()).decode()
frames.append(frame_base64)
return {"frames": frames}
except Exception as e:
raise RuntimeError(f"Error generating video: {str(e)}")