jbilcke-hf HF staff commited on
Commit
be2df75
1 Parent(s): 9840797

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +136 -187
handler.py CHANGED
@@ -1,75 +1,90 @@
1
- from typing import Dict, Any, Union, Optional, Tuple
2
- import torch
3
- from diffusers import LTXPipeline, LTXImageToVideoPipeline
4
- from PIL import Image
5
  import base64
6
  import io
7
- import tempfile
8
  import random
 
9
  import numpy as np
10
- from moviepy.editor import ImageSequenceClip
11
- import os
12
- import logging
13
- import asyncio
14
  from varnish import Varnish
15
 
16
  # Configure logging
17
  logging.basicConfig(level=logging.INFO)
18
  logger = logging.getLogger(__name__)
19
 
20
- ENABLE_CPU_OFFLOAD = True
21
- EXPERIMENTAL_STUFF = False
22
-
23
- random.seed(0)
24
- np.random.seed(0)
25
- generator = torch.manual_seed(0)
26
- # you can notice we don't use device=cuda, for more info see:
27
- # https://huggingface.co/docs/diffusers/v0.16.0/en/using-diffusers/reproducibility#gpu
28
-
29
- varnish = Varnish(
30
- enable_mmaudio=False,
31
- #mmaudio_config=mmaudio_config
32
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  class EndpointHandler:
35
- # Default configuration
36
- DEFAULT_FPS = 24
37
- DEFAULT_DURATION = 4 # seconds
38
- DEFAULT_NUM_FRAMES = (DEFAULT_DURATION * DEFAULT_FPS) + 1 # 97 frames
39
- DEFAULT_NUM_STEPS = 25
40
- DEFAULT_WIDTH = 768
41
- DEFAULT_HEIGHT = 512
42
-
43
- # Constraints
44
- MAX_WIDTH = 1280
45
- MAX_HEIGHT = 720
46
- MAX_FRAMES = 257
47
-
48
 
49
- def __init__(self, path: str = ""):
50
- """Initialize the LTX Video handler with both text-to-video and image-to-video pipelines.
51
-
52
  Args:
53
- path (str): Path to the model weights directory
54
  """
55
- if EXPERIMENTAL_STUFF:
56
- torch.backends.cuda.matmul.allow_tf32 = True
57
 
58
- # Load both pipelines with bfloat16 precision as recommended in docs
59
  self.text_to_video = LTXPipeline.from_pretrained(
60
- path,
61
  torch_dtype=torch.bfloat16
62
  ).to("cuda")
63
 
64
  self.image_to_video = LTXImageToVideoPipeline.from_pretrained(
65
- path,
66
  torch_dtype=torch.bfloat16
67
  ).to("cuda")
68
 
69
- if ENABLE_CPU_OFFLOAD:
70
- self.text_to_video.enable_model_cpu_offload()
71
- self.image_to_video.enable_model_cpu_offload()
72
 
 
73
  self.varnish = Varnish(
74
  device="cuda" if torch.cuda.is_available() else "cpu",
75
  output_format="mp4",
@@ -78,172 +93,115 @@ class EndpointHandler:
78
  enable_mmaudio=False
79
  )
80
 
81
- def _validate_and_adjust_resolution(self, width: int, height: int) -> Tuple[int, int]:
82
- """Validate and adjust resolution to meet constraints.
83
-
84
- Args:
85
- width (int): Requested width
86
- height (int): Requested height
87
-
88
- Returns:
89
- Tuple[int, int]: Adjusted (width, height)
90
- """
91
- # Round to nearest multiple of 32
92
- width = round(width / 32) * 32
93
- height = round(height / 32) * 32
94
-
95
- # Enforce maximum dimensions
96
- width = min(width, self.MAX_WIDTH)
97
- height = min(height, self.MAX_HEIGHT)
98
-
99
- # Enforce minimum dimensions
100
- width = max(width, 32)
101
- height = max(height, 32)
102
-
103
- return width, height
104
-
105
- def _validate_and_adjust_frames(self, num_frames: Optional[int] = None, fps: Optional[int] = None) -> Tuple[int, int]:
106
- """Validate and adjust frame count and FPS to meet constraints.
107
 
108
  Args:
109
- num_frames (Optional[int]): Requested number of frames
110
- fps (Optional[int]): Requested frames per second
111
 
112
  Returns:
113
- Tuple[int, int]: Adjusted (num_frames, fps)
114
  """
115
- # Use defaults if not provided
116
- fps = fps or self.DEFAULT_FPS
117
- num_frames = num_frames or self.DEFAULT_NUM_FRAMES
118
-
119
- # Adjust frames to be in format 8k + 1
120
- k = (num_frames - 1) // 8
121
- num_frames = (k * 8) + 1
122
-
123
- # Enforce maximum frame count
124
- num_frames = min(num_frames, self.MAX_FRAMES)
125
-
126
- return num_frames, fps
127
-
128
- async def process_and_encode_video(
129
- self,
130
- frames: torch.Tensor,
131
- fps: int,
132
- upscale_factor: int = 0,
133
- enable_interpolation: bool = False,
134
- interpolation_exp: int = 1
135
- ) -> tuple[str, dict]:
136
- """Process video frames using Varnish and return base64 encoded result"""
137
-
138
  # Process video with Varnish
139
  result = await self.varnish(
140
  input_data=frames,
141
- input_fps=fps,
142
- output_fps=fps,
143
- enable_upscale=upscale_factor > 1,
144
- upscale_factor=upscale_factor,
145
- enable_interpolation=enable_interpolation,
146
- interpolation_exp=interpolation_exp
147
  )
148
 
149
- # Get video as data URI
150
- video_data_uri = await result.write(
151
  output_type="data-uri",
152
  output_format="mp4",
153
  output_codec="h264",
154
  output_quality=23
155
  )
156
 
 
157
  metadata = {
158
  "width": result.metadata.width,
159
  "height": result.metadata.height,
160
  "num_frames": result.metadata.frame_count,
161
  "fps": result.metadata.fps,
162
- "duration": result.metadata.duration
 
 
 
 
163
  }
164
 
165
- return video_data_uri, metadata
166
 
167
- def _run_async(self, frames: torch.Tensor, fps: int, upscale_factor: int, enable_interpolation: bool, interpolation_exp: int) -> Dict[str, Any]:
168
- """Run asynchronous video processing in a synchronous context"""
169
- loop = asyncio.new_event_loop()
170
- try:
171
- return loop.run_until_complete(
172
- self.process_and_encode_video(
173
- frames=frames,
174
- fps=fps,
175
- upscale_factor=upscale_factor,
176
- enable_interpolation=enable_interpolation,
177
- interpolation_exp=interpolation_exp
178
- )
179
- )
180
- finally:
181
- loop.close()
182
-
183
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
184
- """Process the input data and generate video using LTX.
185
 
186
  Args:
187
- data (Dict[str, Any]): Input data containing:
188
- - prompt (str): Text description for video generation
189
- - image (Optional[str]): Base64 encoded image for image-to-video generation
190
- - width (Optional[int]): Video width (default: 768)
191
- - height (Optional[int]): Video height (default: 512)
192
- - num_frames (Optional[int]): Number of frames (default: 97)
193
- - fps (Optional[int]): Frames per second (default: 24)
194
- - num_inference_steps (Optional[int]): Number of inference steps (default: 25)
195
- - guidance_scale (Optional[float]): Guidance scale (default: 7.5)
196
-
 
 
197
  Returns:
198
- Dict[str, Any]: Dictionary containing:
199
- - video: video encoded in Base64 (h.264 MP4 video). This is a data-uri (prefixed with "data:").
200
- - content-type: MIME type of the video (right now always "video/mp4")
201
- - metadata: Dictionary with actual values used for generation
202
  """
203
-
204
- prompt = data.get("inputs", None)
205
  if not prompt:
206
  raise ValueError("No prompt provided in the 'inputs' field")
207
 
208
- # Get generation parameters
209
- width = data.get("width", self.DEFAULT_WIDTH)
210
- height = data.get("height", self.DEFAULT_HEIGHT)
211
- width, height = self._validate_and_adjust_resolution(width, height)
212
-
213
- num_frames = data.get("num_frames", self.DEFAULT_NUM_FRAMES)
214
- fps = data.get("fps", self.DEFAULT_FPS)
215
- num_frames, fps = self._validate_and_adjust_frames(num_frames, fps)
216
-
217
- # Get post-processing parameters
218
- upscale_factor = data.get("upscale_factor", 0)
219
- enable_interpolation = data.get("enable_interpolation", False)
220
- interpolation_exp = data.get("interpolation_exp", 1)
221
-
222
- guidance_scale = data.get("guidance_scale", 7.5)
223
- num_inference_steps = data.get("num_inference_steps", self.DEFAULT_NUM_STEPS)
224
- seed = data.get("seed", -1)
225
- seed = random.randint(0, 2**32 - 1) if seed == -1 else int(seed)
226
 
227
  try:
228
  with torch.no_grad():
229
- random.seed(seed)
230
- np.random.seed(seed)
231
- generator.manual_seed(seed)
 
232
 
 
233
  generation_kwargs = {
234
  "prompt": prompt,
235
- "height": height,
236
- "width": width,
237
- "num_frames": num_frames,
238
- "guidance_scale": guidance_scale,
239
- "num_inference_steps": num_inference_steps,
240
  "output_type": "pt",
241
  "generator": generator
242
  }
243
 
244
- # Generate frames using appropriate pipeline
245
  image_data = data.get("image")
246
  if image_data:
 
247
  if image_data.startswith('data:'):
248
  image_data = image_data.split(',', 1)[1]
249
  image_bytes = base64.b64decode(image_data)
@@ -253,26 +211,17 @@ class EndpointHandler:
253
  else:
254
  frames = self.text_to_video(**generation_kwargs).frames
255
 
256
- # Process and encode video
257
- video_data_uri, metadata = self._run_async(
258
- frames=frames,
259
- fps=fps,
260
- upscale_factor=upscale_factor,
261
- enable_interpolation=enable_interpolation,
262
- interpolation_exp=interpolation_exp
263
- )
264
-
265
- # Add generation metadata
266
- metadata.update({
267
- "num_inference_steps": num_inference_steps,
268
- "seed": seed,
269
- "upscale_factor": upscale_factor,
270
- "interpolation_enabled": enable_interpolation,
271
- "interpolation_exp": interpolation_exp
272
- })
273
 
274
  return {
275
- "video": video_data_uri,
276
  "content-type": "video/mp4",
277
  "metadata": metadata
278
  }
 
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+ from typing import Dict, Any, Optional, Tuple
4
+ import asyncio
5
  import base64
6
  import io
7
+ import logging
8
  import random
9
+
10
  import numpy as np
11
+ import torch
12
+ from diffusers import LTXPipeline, LTXImageToVideoPipeline
13
+ from PIL import Image
 
14
  from varnish import Varnish
15
 
16
  # Configure logging
17
  logging.basicConfig(level=logging.INFO)
18
  logger = logging.getLogger(__name__)
19
 
20
+ # Constraints
21
+ MAX_WIDTH = 1280
22
+ MAX_HEIGHT = 720
23
+ MAX_FRAMES = 257
24
+
25
+ @dataclass
26
+ class GenerationConfig:
27
+ """Configuration for video generation"""
28
+ width: int = 768
29
+ height: int = 512
30
+ fps: int = 24
31
+ duration_sec: float = 4.0
32
+ num_inference_steps: int = 30
33
+ guidance_scale: float = 7.5
34
+ upscale_factor: float = 2.0
35
+ enable_interpolation: bool = False
36
+ seed: int = -1 # -1 means random seed
37
+
38
+ @property
39
+ def num_frames(self) -> int:
40
+ """Calculate number of frames based on fps and duration"""
41
+ return int(self.duration_sec * self.fps) + 1
42
+
43
+ def validate_and_adjust(self) -> 'GenerationConfig':
44
+ """Validate and adjust parameters to meet constraints"""
45
+ # Round dimensions to nearest multiple of 32
46
+ self.width = max(32, min(MAX_WIDTH, round(self.width / 32) * 32))
47
+ self.height = max(32, min(MAX_HEIGHT, round(self.height / 32) * 32))
48
+
49
+ # Adjust number of frames to be in format 8k + 1
50
+ k = (self.num_frames - 1) // 8
51
+ num_frames = min((k * 8) + 1, MAX_FRAMES)
52
+ self.duration_sec = (num_frames - 1) / self.fps
53
+
54
+ # Set random seed if not specified
55
+ if self.seed == -1:
56
+ self.seed = random.randint(0, 2**32 - 1)
57
+
58
+ return self
59
 
60
  class EndpointHandler:
61
+ """Handles video generation requests using LTX models and Varnish post-processing"""
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
+ def __init__(self, model_path: str = ""):
64
+ """Initialize the handler with LTX models and Varnish
65
+
66
  Args:
67
+ model_path: Path to LTX model weights
68
  """
69
+ # Enable TF32 for potential speedup on Ampere GPUs
70
+ #torch.backends.cuda.matmul.allow_tf32 = True
71
 
72
+ # Initialize models with bfloat16 precision
73
  self.text_to_video = LTXPipeline.from_pretrained(
74
+ model_path,
75
  torch_dtype=torch.bfloat16
76
  ).to("cuda")
77
 
78
  self.image_to_video = LTXImageToVideoPipeline.from_pretrained(
79
+ model_path,
80
  torch_dtype=torch.bfloat16
81
  ).to("cuda")
82
 
83
+ # Enable CPU offload for memory efficiency
84
+ #self.text_to_video.enable_model_cpu_offload()
85
+ #self.image_to_video.enable_model_cpu_offload()
86
 
87
+ # Initialize Varnish for post-processing
88
  self.varnish = Varnish(
89
  device="cuda" if torch.cuda.is_available() else "cpu",
90
  output_format="mp4",
 
93
  enable_mmaudio=False
94
  )
95
 
96
+ async def process_frames(
97
+ self,
98
+ frames: torch.Tensor,
99
+ config: GenerationConfig
100
+ ) -> tuple[str, dict]:
101
+ """Post-process generated frames using Varnish
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  Args:
104
+ frames: Generated video frames tensor
105
+ config: Generation configuration
106
 
107
  Returns:
108
+ Tuple of (video data URI, metadata dictionary)
109
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  # Process video with Varnish
111
  result = await self.varnish(
112
  input_data=frames,
113
+ input_fps=config.fps,
114
+ upscale_factor=config.upscale_factor if config.upscale_factor > 1 else None,
115
+ enable_interpolation=config.enable_interpolation,
116
+ output_fps=config.fps
 
 
117
  )
118
 
119
+ # Convert to data URI
120
+ video_uri = await result.write(
121
  output_type="data-uri",
122
  output_format="mp4",
123
  output_codec="h264",
124
  output_quality=23
125
  )
126
 
127
+ # Collect metadata
128
  metadata = {
129
  "width": result.metadata.width,
130
  "height": result.metadata.height,
131
  "num_frames": result.metadata.frame_count,
132
  "fps": result.metadata.fps,
133
+ "duration": result.metadata.duration,
134
+ "num_inference_steps": config.num_inference_steps,
135
+ "seed": config.seed,
136
+ "upscale_factor": config.upscale_factor,
137
+ "interpolation_enabled": config.enable_interpolation
138
  }
139
 
140
+ return video_uri, metadata
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
143
+ """Process incoming requests for video generation
144
 
145
  Args:
146
+ data: Request data containing:
147
+ - inputs (str): Text prompt or image
148
+ - width (optional): Video width
149
+ - height (optional): Video height
150
+ - fps (optional): Frames per second
151
+ - duration_sec (optional): Video duration
152
+ - num_inference_steps (optional): Inference steps
153
+ - guidance_scale (optional): Guidance scale
154
+ - upscale_factor (optional): Upscaling factor
155
+ - enable_interpolation (optional): Enable frame interpolation
156
+ - seed (optional): Random seed
157
+
158
  Returns:
159
+ Dictionary containing:
160
+ - video: Base64 encoded MP4 data URI
161
+ - content-type: MIME type
162
+ - metadata: Generation metadata
163
  """
164
+ # Extract prompt
165
+ prompt = data.get("inputs")
166
  if not prompt:
167
  raise ValueError("No prompt provided in the 'inputs' field")
168
 
169
+ # Create and validate configuration
170
+ config = GenerationConfig(
171
+ width=data.get("width", GenerationConfig.width),
172
+ height=data.get("height", GenerationConfig.height),
173
+ fps=data.get("fps", GenerationConfig.fps),
174
+ duration_sec=data.get("duration_sec", GenerationConfig.duration_sec),
175
+ num_inference_steps=data.get("num_inference_steps", GenerationConfig.num_inference_steps),
176
+ guidance_scale=data.get("guidance_scale", GenerationConfig.guidance_scale),
177
+ upscale_factor=data.get("upscale_factor", GenerationConfig.upscale_factor),
178
+ enable_interpolation=data.get("enable_interpolation", GenerationConfig.enable_interpolation),
179
+ seed=data.get("seed", GenerationConfig.seed)
180
+ ).validate_and_adjust()
 
 
 
 
 
 
181
 
182
  try:
183
  with torch.no_grad():
184
+ # Set random seeds
185
+ random.seed(config.seed)
186
+ np.random.seed(config.seed)
187
+ generator = torch.manual_seed(config.seed)
188
 
189
+ # Prepare generation parameters
190
  generation_kwargs = {
191
  "prompt": prompt,
192
+ "height": config.height,
193
+ "width": config.width,
194
+ "num_frames": config.num_frames,
195
+ "guidance_scale": config.guidance_scale,
196
+ "num_inference_steps": config.num_inference_steps,
197
  "output_type": "pt",
198
  "generator": generator
199
  }
200
 
201
+ # Check if image-to-video generation is requested
202
  image_data = data.get("image")
203
  if image_data:
204
+ # Process base64 image
205
  if image_data.startswith('data:'):
206
  image_data = image_data.split(',', 1)[1]
207
  image_bytes = base64.b64decode(image_data)
 
211
  else:
212
  frames = self.text_to_video(**generation_kwargs).frames
213
 
214
+ # Post-process frames
215
+ loop = asyncio.new_event_loop()
216
+ try:
217
+ video_uri, metadata = loop.run_until_complete(
218
+ self.process_frames(frames, config)
219
+ )
220
+ finally:
221
+ loop.close()
 
 
 
 
 
 
 
 
 
222
 
223
  return {
224
+ "video": video_uri,
225
  "content-type": "video/mp4",
226
  "metadata": metadata
227
  }