jbilcke-hf HF staff commited on
Commit
9d84818
1 Parent(s): ea52235

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +203 -0
handler.py CHANGED
@@ -104,6 +104,209 @@ class GenerationConfig:
104
 
105
  return self
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  class EndpointHandler:
108
  """Handles video generation requests using LTX models and Varnish post-processing"""
109
 
 
104
 
105
  return self
106
 
107
+ class EndpointHandler:
108
+ """Handles video generation requests using LTX models and Varnish post-processing"""
109
+
110
+ def __init__(self, model_path: str = ""):
111
+ """Initialize the handler with LTX models and Varnish
112
+
113
+ Args:
114
+ model_path: Path to LTX model weights
115
+ """
116
+ # Enable TF32 for potential speedup on Ampere GPUs
117
+ #torch.backends.cuda.matmul.allow_tf32 = True
118
+
119
+ # Initialize models with bfloat16 precision
120
+ self.text_to_video = LTXPipeline.from_pretrained(
121
+ model_path,
122
+ torch_dtype=torch.bfloat16
123
+ ).to("cuda")
124
+
125
+ self.image_to_video = LTXImageToVideoPipeline.from_pretrained(
126
+ model_path,
127
+ torch_dtype=torch.bfloat16
128
+ ).to("cuda")
129
+
130
+ # Enable CPU offload for memory efficiency
131
+ #self.text_to_video.enable_model_cpu_offload()
132
+ #self.image_to_video.enable_model_cpu_offload()
133
+
134
+ # Initialize Varnish for post-processing
135
+ self.varnish = Varnish(
136
+ device="cuda" if torch.cuda.is_available() else "cpu",
137
+ output_format="mp4",
138
+ output_codec="h264",
139
+ output_quality=23,
140
+ enable_mmaudio=False,
141
+ #model_base_dir=os.path.abspath(os.path.join(os.getcwd(), "varnish"))
142
+ model_base_dir="/repository/varnish",
143
+ )
144
+
145
+ async def process_frames(
146
+ self,
147
+ frames: torch.Tensor,
148
+ config: GenerationConfig
149
+ ) -> tuple[str, dict]:
150
+ """Post-process generated frames using Varnish
151
+
152
+ Args:
153
+ frames: Generated video frames tensor
154
+ config: Generation configuration
155
+
156
+ Returns:
157
+ Tuple of (video data URI, metadata dictionary)
158
+ """
159
+ try:
160
+ logger.info(f"Original frames shape: {frames.shape}")
161
+
162
+ # Remove batch dimension if present
163
+ if len(frames.shape) == 5:
164
+ frames = frames.squeeze(0) # Remove batch dimension
165
+
166
+ logger.info(f"Processed frames shape: {frames.shape}")
167
+
168
+ # Process video with Varnish
169
+ result = await self.varnish(
170
+ input_data=frames,
171
+ input_fps=config.fps,
172
+ output_fps=config.fps,
173
+ upscale_factor=config.upscale_factor if config.upscale_factor > 1 else None,
174
+ enable_interpolation=config.enable_interpolation
175
+ )
176
+
177
+ # Convert to data URI
178
+ video_uri = await result.write(
179
+ output_type="data-uri",
180
+ output_format="mp4",
181
+ output_codec="h264",
182
+ output_quality=23
183
+ )
184
+
185
+ # Collect metadata
186
+ metadata = {
187
+ "width": result.metadata.width,
188
+ "height": result.metadata.height,
189
+ "num_frames": result.metadata.frame_count,
190
+ "fps": result.metadata.fps,
191
+ "duration": result.metadata.duration,
192
+ "num_inference_steps": config.num_inference_steps,
193
+ "seed": config.seed,
194
+ "upscale_factor": config.upscale_factor,
195
+ "interpolation_enabled": config.enable_interpolation
196
+ }
197
+
198
+ return video_uri, metadata
199
+
200
+ except Exception as e:
201
+ logger.error(f"Error in process_frames: {str(e)}")
202
+ raise RuntimeError(f"Failed to process frames: {str(e)}")
203
+
204
+ from dataclasses import dataclass
205
+ from pathlib import Path
206
+ import pathlib
207
+ from typing import Dict, Any, Optional, Tuple
208
+ import asyncio
209
+ import base64
210
+ import io
211
+ import pprint
212
+ import logging
213
+ import random
214
+ import traceback
215
+ import os
216
+ import numpy as np
217
+ import torch
218
+ from diffusers import LTXPipeline, LTXImageToVideoPipeline
219
+ from PIL import Image
220
+
221
+ from varnish import Varnish
222
+
223
+ # Configure logging
224
+ logging.basicConfig(level=logging.INFO)
225
+ logger = logging.getLogger(__name__)
226
+
227
+ # Constraints
228
+ MAX_WIDTH = 1280
229
+ MAX_HEIGHT = 720
230
+ MAX_FRAMES = 257
231
+
232
+ # this is only a temporary solution (famous last words)
233
+ def apply_dirty_hack_to_patch_file_extensions_and_bypass_filter(directory):
234
+ """
235
+ Recursively rename all '.wut' files to '.pth' in the given directory
236
+
237
+ Args:
238
+ directory (str): Path to the directory to process
239
+ """
240
+ # Convert the directory path to absolute path
241
+ directory = os.path.abspath(directory)
242
+
243
+ # Walk through directory and its subdirectories
244
+ for root, _, files in os.walk(directory):
245
+ for filename in files:
246
+ if filename.endswith('.wut'):
247
+ # Get full path of the file
248
+ old_path = os.path.join(root, filename)
249
+ # Create new filename by replacing the extension
250
+ new_filename = filename.replace('.wut', '.pth')
251
+ new_path = os.path.join(root, new_filename)
252
+
253
+ try:
254
+ os.rename(old_path, new_path)
255
+ print(f"Renamed: {old_path} -> {new_path}")
256
+ except OSError as e:
257
+ print(f"Error renaming {old_path}: {e}")
258
+
259
+ def print_directory_structure(startpath):
260
+ """Print the directory structure starting from the given path."""
261
+ for root, dirs, files in os.walk(startpath):
262
+ level = root.replace(startpath, '').count(os.sep)
263
+ indent = ' ' * 4 * level
264
+ logger.info(f"{indent}{os.path.basename(root)}/")
265
+ subindent = ' ' * 4 * (level + 1)
266
+ for f in files:
267
+ logger.info(f"{subindent}{f}")
268
+
269
+ logger.info("💡 Applying a dirty hack (patch ""/repository"" to fix file extensions):")
270
+ apply_dirty_hack_to_patch_file_extensions_and_bypass_filter("/repository")
271
+
272
+ logger.info("💡 Printing directory structure of ""/repository"":")
273
+ print_directory_structure("/repository")
274
+
275
+ @dataclass
276
+ class GenerationConfig:
277
+ """Configuration for video generation"""
278
+ width: int = 768
279
+ height: int = 512
280
+ fps: int = 24
281
+ duration_sec: float = 4.0
282
+ num_inference_steps: int = 30
283
+ guidance_scale: float = 7.5
284
+ upscale_factor: float = 2.0
285
+ enable_interpolation: bool = False
286
+ seed: int = -1 # -1 means random seed
287
+
288
+ @property
289
+ def num_frames(self) -> int:
290
+ """Calculate number of frames based on fps and duration"""
291
+ return int(self.duration_sec * self.fps) + 1
292
+
293
+ def validate_and_adjust(self) -> 'GenerationConfig':
294
+ """Validate and adjust parameters to meet constraints"""
295
+ # Round dimensions to nearest multiple of 32
296
+ self.width = max(32, min(MAX_WIDTH, round(self.width / 32) * 32))
297
+ self.height = max(32, min(MAX_HEIGHT, round(self.height / 32) * 32))
298
+
299
+ # Adjust number of frames to be in format 8k + 1
300
+ k = (self.num_frames - 1) // 8
301
+ num_frames = min((k * 8) + 1, MAX_FRAMES)
302
+ self.duration_sec = (num_frames - 1) / self.fps
303
+
304
+ # Set random seed if not specified
305
+ if self.seed == -1:
306
+ self.seed = random.randint(0, 2**32 - 1)
307
+
308
+ return self
309
+
310
  class EndpointHandler:
311
  """Handles video generation requests using LTX models and Varnish post-processing"""
312