File size: 3,585 Bytes
132e8c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from typing import Dict, Any, Union, Optional
import torch
from diffusers import LTXPipeline, LTXImageToVideoPipeline
from PIL import Image
import base64
import io

class EndpointHandler:
    def __init__(self, path: str = ""):
        """Initialize the LTX Video handler with both text-to-video and image-to-video pipelines.
        
        Args:
            path (str): Path to the model weights directory
        """
        # Load both pipelines with bfloat16 precision as recommended in docs
        self.text_to_video = LTXPipeline.from_pretrained(
            path,
            torch_dtype=torch.bfloat16
        ).to("cuda")
        
        self.image_to_video = LTXImageToVideoPipeline.from_pretrained(
            path,
            torch_dtype=torch.bfloat16
        ).to("cuda")

        # Enable memory optimizations
        self.text_to_video.enable_model_cpu_offload()
        self.image_to_video.enable_model_cpu_offload()

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """Process the input data and generate video using LTX.
        
        Args:
            data (Dict[str, Any]): Input data containing:
                - prompt (str): Text description for video generation
                - image (Optional[str]): Base64 encoded image for image-to-video generation
                - num_frames (Optional[int]): Number of frames to generate (default: 24)
                - guidance_scale (Optional[float]): Guidance scale (default: 7.5)
                - num_inference_steps (Optional[int]): Number of inference steps (default: 50)
        
        Returns:
            Dict[str, Any]: Dictionary containing:
                - frames: List of base64 encoded frames
        """
        # Extract parameters
        prompt = data.get("prompt")
        if not prompt:
            raise ValueError("'prompt' is required in the input data")

        # Get optional parameters with defaults
        num_frames = data.get("num_frames", 24)
        guidance_scale = data.get("guidance_scale", 7.5)
        num_inference_steps = data.get("num_inference_steps", 50)

        # Check if image is provided for image-to-video generation
        image_data = data.get("image")
        
        try:
            if image_data:
                # Decode base64 image
                image_bytes = base64.b64decode(image_data)
                image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
                
                # Generate video from image
                output = self.image_to_video(
                    prompt=prompt,
                    image=image,
                    num_frames=num_frames,
                    guidance_scale=guidance_scale,
                    num_inference_steps=num_inference_steps
                )
            else:
                # Generate video from text only
                output = self.text_to_video(
                    prompt=prompt,
                    num_frames=num_frames,
                    guidance_scale=guidance_scale,
                    num_inference_steps=num_inference_steps
                )

            # Convert frames to base64
            frames = []
            for frame in output.frames[0]:  # First element contains the frames
                buffer = io.BytesIO()
                frame.save(buffer, format="PNG")
                frame_base64 = base64.b64encode(buffer.getvalue()).decode()
                frames.append(frame_base64)

            return {"frames": frames}

        except Exception as e:
            raise RuntimeError(f"Error generating video: {str(e)}")