File size: 9,698 Bytes
ef15707
132e8c4
 
 
 
 
1a6f91c
b5d7f4c
1a6f91c
 
 
e349e43
85f39ae
 
e349e43
 
 
 
132e8c4
b5d7f4c
5f47c2b
 
b5d7f4c
 
 
 
 
 
85f39ae
 
 
 
 
132e8c4
ef15707
 
 
 
 
 
 
 
 
 
 
 
 
d35cde0
132e8c4
 
 
 
 
 
d35cde0
 
 
132e8c4
 
 
 
 
 
 
 
 
 
 
d35cde0
 
 
ef15707
f6dd4f3
 
 
 
 
 
 
 
ef15707
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a6f91c
ef15707
1a6f91c
f6dd4f3
 
 
 
 
 
 
 
 
1a6f91c
f6dd4f3
 
 
 
 
 
 
 
 
 
1a6f91c
f6dd4f3
 
 
 
 
 
 
1a6f91c
f6dd4f3
 
 
 
 
 
 
 
 
132e8c4
 
 
 
 
 
 
 
ef15707
 
 
1a6f91c
ef15707
132e8c4
 
 
 
d35cde0
 
ef15707
132e8c4
f6dd4f3
29cace5
132e8c4
d35cde0
132e8c4
f6dd4f3
29cace5
 
ef15707
f6dd4f3
29cace5
 
ef15707
f6dd4f3
 
 
 
 
 
29cace5
 
 
b5d7f4c
e349e43
132e8c4
1a6f91c
b5d7f4c
 
cfe4602
b5d7f4c
ef15707
 
 
 
 
 
 
b5d7f4c
 
ef15707
 
f6dd4f3
ef15707
1a6f91c
29cace5
 
1a6f91c
 
ef15707
f6dd4f3
1a6f91c
f6dd4f3
132e8c4
f6dd4f3
 
 
 
 
 
 
 
1a6f91c
f6dd4f3
 
 
 
 
 
 
 
d35cde0
1a6f91c
d35cde0
f6dd4f3
 
1a6f91c
132e8c4
 
e349e43
1a6f91c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
from typing import Dict, Any, Union, Optional, Tuple
import torch
from diffusers import LTXPipeline, LTXImageToVideoPipeline
from PIL import Image
import base64
import io
import tempfile
import random
import numpy as np
from moviepy.editor import ImageSequenceClip
import os
import logging
import asyncio
from varnish import Varnish

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

ENABLE_CPU_OFFLOAD = True
EXPERIMENTAL_STUFF = False

random.seed(0)
np.random.seed(0)
generator = torch.manual_seed(0)
# you can notice we don't use device=cuda, for more info see:
# https://huggingface.co/docs/diffusers/v0.16.0/en/using-diffusers/reproducibility#gpu

varnish = Varnish(
    enable_mmaudio=False,
    #mmaudio_config=mmaudio_config
)

class EndpointHandler:
    # Default configuration
    DEFAULT_FPS = 24
    DEFAULT_DURATION = 4  # seconds
    DEFAULT_NUM_FRAMES = (DEFAULT_DURATION * DEFAULT_FPS) + 1  # 97 frames
    DEFAULT_NUM_STEPS = 25
    DEFAULT_WIDTH = 768
    DEFAULT_HEIGHT = 512
    
    # Constraints
    MAX_WIDTH = 1280
    MAX_HEIGHT = 720
    MAX_FRAMES = 257
    
    
    def __init__(self, path: str = ""):
        """Initialize the LTX Video handler with both text-to-video and image-to-video pipelines.
        
        Args:
            path (str): Path to the model weights directory
        """
        if EXPERIMENTAL_STUFF:
            torch.backends.cuda.matmul.allow_tf32 = True
        
        # Load both pipelines with bfloat16 precision as recommended in docs
        self.text_to_video = LTXPipeline.from_pretrained(
            path,
            torch_dtype=torch.bfloat16
        ).to("cuda")
        
        self.image_to_video = LTXImageToVideoPipeline.from_pretrained(
            path,
            torch_dtype=torch.bfloat16
        ).to("cuda")

        if ENABLE_CPU_OFFLOAD:
            self.text_to_video.enable_model_cpu_offload()
            self.image_to_video.enable_model_cpu_offload()

        self.varnish = Varnish(
            device="cuda" if torch.cuda.is_available() else "cpu",
            output_format="mp4",
            output_codec="h264",
            output_quality=23,
            enable_mmaudio=False
        )

    def _validate_and_adjust_resolution(self, width: int, height: int) -> Tuple[int, int]:
        """Validate and adjust resolution to meet constraints.
        
        Args:
            width (int): Requested width
            height (int): Requested height
            
        Returns:
            Tuple[int, int]: Adjusted (width, height)
        """
        # Round to nearest multiple of 32
        width = round(width / 32) * 32
        height = round(height / 32) * 32
        
        # Enforce maximum dimensions
        width = min(width, self.MAX_WIDTH)
        height = min(height, self.MAX_HEIGHT)
        
        # Enforce minimum dimensions
        width = max(width, 32)
        height = max(height, 32)
        
        return width, height

    def _validate_and_adjust_frames(self, num_frames: Optional[int] = None, fps: Optional[int] = None) -> Tuple[int, int]:
        """Validate and adjust frame count and FPS to meet constraints.
        
        Args:
            num_frames (Optional[int]): Requested number of frames
            fps (Optional[int]): Requested frames per second
            
        Returns:
            Tuple[int, int]: Adjusted (num_frames, fps)
        """
        # Use defaults if not provided
        fps = fps or self.DEFAULT_FPS
        num_frames = num_frames or self.DEFAULT_NUM_FRAMES
        
        # Adjust frames to be in format 8k + 1
        k = (num_frames - 1) // 8
        num_frames = (k * 8) + 1
        
        # Enforce maximum frame count
        num_frames = min(num_frames, self.MAX_FRAMES)
        
        return num_frames, fps

    async def process_and_encode_video(
        self,
        frames: torch.Tensor,
        fps: int,
        upscale_factor: int = 0,
        enable_interpolation: bool = False,
        interpolation_exp: int = 1
    ) -> tuple[str, dict]:
        """Process video frames using Varnish and return base64 encoded result"""
        
        # Process video with Varnish
        result = await self.varnish(
            input_data=frames,
            input_fps=fps,
            output_fps=fps,
            enable_upscale=upscale_factor > 1,
            upscale_factor=upscale_factor,
            enable_interpolation=enable_interpolation,
            interpolation_exp=interpolation_exp
        )
        
        # Get video as data URI
        video_data_uri = await result.write(
            output_type="data-uri",
            output_format="mp4",
            output_codec="h264",
            output_quality=23
        )
        
        metadata = {
            "width": result.metadata.width,
            "height": result.metadata.height,
            "num_frames": result.metadata.frame_count,
            "fps": result.metadata.fps,
            "duration": result.metadata.duration
        }
        
        return video_data_uri, metadata

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """Process the input data and generate video using LTX.
        
        Args:
            data (Dict[str, Any]): Input data containing:
                - prompt (str): Text description for video generation
                - image (Optional[str]): Base64 encoded image for image-to-video generation
                - width (Optional[int]): Video width (default: 768)
                - height (Optional[int]): Video height (default: 512)
                - num_frames (Optional[int]): Number of frames (default: 97)
                - fps (Optional[int]): Frames per second (default: 24)
                - num_inference_steps (Optional[int]): Number of inference steps (default: 25)
                - guidance_scale (Optional[float]): Guidance scale (default: 7.5)
        
        Returns:
            Dict[str, Any]: Dictionary containing:
                - video: video encoded in Base64 (h.264 MP4 video). This is a data-uri (prefixed with "data:").
                - content-type: MIME type of the video (right now always "video/mp4")
                - metadata: Dictionary with actual values used for generation
        """

        prompt = data.get("inputs", None)
        if not prompt:
            raise ValueError("No prompt provided in the 'inputs' field")

        # Get generation parameters
        width = data.get("width", self.DEFAULT_WIDTH)
        height = data.get("height", self.DEFAULT_HEIGHT)
        width, height = self._validate_and_adjust_resolution(width, height)
        
        num_frames = data.get("num_frames", self.DEFAULT_NUM_FRAMES)
        fps = data.get("fps", self.DEFAULT_FPS)
        num_frames, fps = self._validate_and_adjust_frames(num_frames, fps)
        
        # Get post-processing parameters
        upscale_factor = data.get("upscale_factor", 0)
        enable_interpolation = data.get("enable_interpolation", False)
        interpolation_exp = data.get("interpolation_exp", 1)
        
        guidance_scale = data.get("guidance_scale", 7.5)
        num_inference_steps = data.get("num_inference_steps", self.DEFAULT_NUM_STEPS)
        seed = data.get("seed", -1)
        seed = random.randint(0, 2**32 - 1) if seed == -1 else int(seed)

        try:
            with torch.no_grad():
                random.seed(seed)
                np.random.seed(seed)
                generator.manual_seed(seed)
                
                generation_kwargs = {
                    "prompt": prompt,
                    "height": height,
                    "width": width,
                    "num_frames": num_frames,
                    "guidance_scale": guidance_scale,
                    "num_inference_steps": num_inference_steps,
                    "output_type": "pt",
                    "generator": generator
                }

                # Generate frames using appropriate pipeline
                image_data = data.get("image")
                if image_data:
                    if image_data.startswith('data:'):
                        image_data = image_data.split(',', 1)[1]
                    image_bytes = base64.b64decode(image_data)
                    image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
                    generation_kwargs["image"] = image
                    frames = self.image_to_video(**generation_kwargs).frames
                else:
                    frames = self.text_to_video(**generation_kwargs).frames

                # Process and encode video
                video_data_uri, metadata = await self.process_and_encode_video(
                    frames=frames,
                    fps=fps,
                    upscale_factor=upscale_factor,
                    enable_interpolation=enable_interpolation,
                    interpolation_exp=interpolation_exp
                )
                
                # Add generation metadata
                metadata.update({
                    "num_inference_steps": num_inference_steps,
                    "seed": seed,
                    "upscale_factor": upscale_factor,
                    "interpolation_enabled": enable_interpolation,
                    "interpolation_exp": interpolation_exp
                })
                
                return {
                    "video": video_data_uri,
                    "content-type": "video/mp4",
                    "metadata": metadata
                }

        except Exception as e:
            logger.error(f"Error generating video: {str(e)}")
            raise RuntimeError(f"Error generating video: {str(e)}")