File size: 3,585 Bytes
132e8c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
from typing import Dict, Any, Union, Optional
import torch
from diffusers import LTXPipeline, LTXImageToVideoPipeline
from PIL import Image
import base64
import io
class EndpointHandler:
def __init__(self, path: str = ""):
"""Initialize the LTX Video handler with both text-to-video and image-to-video pipelines.
Args:
path (str): Path to the model weights directory
"""
# Load both pipelines with bfloat16 precision as recommended in docs
self.text_to_video = LTXPipeline.from_pretrained(
path,
torch_dtype=torch.bfloat16
).to("cuda")
self.image_to_video = LTXImageToVideoPipeline.from_pretrained(
path,
torch_dtype=torch.bfloat16
).to("cuda")
# Enable memory optimizations
self.text_to_video.enable_model_cpu_offload()
self.image_to_video.enable_model_cpu_offload()
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Process the input data and generate video using LTX.
Args:
data (Dict[str, Any]): Input data containing:
- prompt (str): Text description for video generation
- image (Optional[str]): Base64 encoded image for image-to-video generation
- num_frames (Optional[int]): Number of frames to generate (default: 24)
- guidance_scale (Optional[float]): Guidance scale (default: 7.5)
- num_inference_steps (Optional[int]): Number of inference steps (default: 50)
Returns:
Dict[str, Any]: Dictionary containing:
- frames: List of base64 encoded frames
"""
# Extract parameters
prompt = data.get("prompt")
if not prompt:
raise ValueError("'prompt' is required in the input data")
# Get optional parameters with defaults
num_frames = data.get("num_frames", 24)
guidance_scale = data.get("guidance_scale", 7.5)
num_inference_steps = data.get("num_inference_steps", 50)
# Check if image is provided for image-to-video generation
image_data = data.get("image")
try:
if image_data:
# Decode base64 image
image_bytes = base64.b64decode(image_data)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# Generate video from image
output = self.image_to_video(
prompt=prompt,
image=image,
num_frames=num_frames,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps
)
else:
# Generate video from text only
output = self.text_to_video(
prompt=prompt,
num_frames=num_frames,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps
)
# Convert frames to base64
frames = []
for frame in output.frames[0]: # First element contains the frames
buffer = io.BytesIO()
frame.save(buffer, format="PNG")
frame_base64 = base64.b64encode(buffer.getvalue()).decode()
frames.append(frame_base64)
return {"frames": frames}
except Exception as e:
raise RuntimeError(f"Error generating video: {str(e)}")
|