Commit
•
f6dd4f3
1
Parent(s):
85f39ae
Update handler.py
Browse files- handler.py +73 -84
handler.py
CHANGED
@@ -70,6 +70,14 @@ class EndpointHandler:
|
|
70 |
self.text_to_video.enable_model_cpu_offload()
|
71 |
self.image_to_video.enable_model_cpu_offload()
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
def _validate_and_adjust_resolution(self, width: int, height: int) -> Tuple[int, int]:
|
74 |
"""Validate and adjust resolution to meet constraints.
|
75 |
|
@@ -117,57 +125,44 @@ class EndpointHandler:
|
|
117 |
|
118 |
return num_frames, fps
|
119 |
|
120 |
-
def
|
121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
logger.info(f"Creating video with {num_frames} frames at {fps} FPS (duration: {duration:.2f} seconds)")
|
134 |
-
|
135 |
-
# Convert tensor to numpy array
|
136 |
-
video_np = frames.squeeze(0).permute(0, 2, 3, 1).cpu().float().numpy()
|
137 |
-
video_np = (video_np * 255).astype(np.uint8)
|
138 |
|
139 |
-
# Get
|
140 |
-
|
141 |
-
|
|
|
|
|
|
|
|
|
142 |
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
# values are: ultrafast, superfast, veryfast, faster, fast, medium, slow, slower, veryslow, placebo
|
153 |
-
#
|
154 |
-
# there is a threads= field, by default None, which can be set to 2, 3, 4 etc..
|
155 |
-
clip.write_videofile(output_path, codec="libx264", audio=False)
|
156 |
-
|
157 |
-
# Read the video file
|
158 |
-
with open(output_path, "rb") as f:
|
159 |
-
video_content = f.read()
|
160 |
-
|
161 |
-
return video_content
|
162 |
-
|
163 |
-
finally:
|
164 |
-
# Cleanup
|
165 |
-
if os.path.exists(output_path):
|
166 |
-
os.remove(output_path)
|
167 |
-
|
168 |
-
# Clear memory
|
169 |
-
del video_np
|
170 |
-
torch.cuda.empty_cache()
|
171 |
|
172 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
173 |
"""Process the input data and generate video using LTX.
|
@@ -189,35 +184,32 @@ class EndpointHandler:
|
|
189 |
- content-type: MIME type of the video (right now always "video/mp4")
|
190 |
- metadata: Dictionary with actual values used for generation
|
191 |
"""
|
192 |
-
|
193 |
prompt = data.get("inputs", None)
|
194 |
if not prompt:
|
195 |
raise ValueError("No prompt provided in the 'inputs' field")
|
196 |
|
197 |
-
# Get
|
198 |
width = data.get("width", self.DEFAULT_WIDTH)
|
199 |
height = data.get("height", self.DEFAULT_HEIGHT)
|
200 |
width, height = self._validate_and_adjust_resolution(width, height)
|
201 |
-
|
202 |
-
# Get and validate frames and FPS
|
203 |
num_frames = data.get("num_frames", self.DEFAULT_NUM_FRAMES)
|
204 |
fps = data.get("fps", self.DEFAULT_FPS)
|
205 |
num_frames, fps = self._validate_and_adjust_frames(num_frames, fps)
|
206 |
-
|
207 |
-
# Get
|
|
|
|
|
|
|
|
|
208 |
guidance_scale = data.get("guidance_scale", 7.5)
|
209 |
num_inference_steps = data.get("num_inference_steps", self.DEFAULT_NUM_STEPS)
|
210 |
-
|
211 |
seed = data.get("seed", -1)
|
212 |
seed = random.randint(0, 2**32 - 1) if seed == -1 else int(seed)
|
213 |
-
|
214 |
-
logger.info(f"Generating video with prompt: '{prompt}'")
|
215 |
-
logger.info(f"Video params: size={width}x{height}, num_frames={num_frames}, fps={fps}")
|
216 |
-
logger.info(f"Generation params: seed={seed}, guidance_scale={guidance_scale}, num_inference_steps={num_inference_steps}")
|
217 |
|
218 |
try:
|
219 |
with torch.no_grad():
|
220 |
-
|
221 |
random.seed(seed)
|
222 |
np.random.seed(seed)
|
223 |
generator.manual_seed(seed)
|
@@ -233,43 +225,40 @@ class EndpointHandler:
|
|
233 |
"generator": generator
|
234 |
}
|
235 |
|
236 |
-
#
|
237 |
image_data = data.get("image")
|
238 |
if image_data:
|
239 |
if image_data.startswith('data:'):
|
240 |
image_data = image_data.split(',', 1)[1]
|
241 |
image_bytes = base64.b64decode(image_data)
|
242 |
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
243 |
-
logger.info("Using image-to-video generation mode")
|
244 |
generation_kwargs["image"] = image
|
245 |
-
|
246 |
else:
|
247 |
-
|
248 |
-
output = self.text_to_video(**generation_kwargs).frames
|
249 |
|
250 |
-
#
|
251 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
252 |
|
253 |
-
#
|
254 |
-
|
255 |
-
|
256 |
-
|
|
|
|
|
|
|
|
|
257 |
|
258 |
-
# Add MP4 data URI prefix
|
259 |
-
video_data_uri = f"data:{content_type};base64,{video_base64}"
|
260 |
-
|
261 |
return {
|
262 |
"video": video_data_uri,
|
263 |
-
"content-type":
|
264 |
-
"metadata":
|
265 |
-
"width": width,
|
266 |
-
"height": height,
|
267 |
-
"num_frames": num_frames,
|
268 |
-
"fps": fps,
|
269 |
-
"duration": num_frames / fps,
|
270 |
-
"num_inference_steps": num_inference_steps,
|
271 |
-
"seed": seed
|
272 |
-
}
|
273 |
}
|
274 |
|
275 |
except Exception as e:
|
|
|
70 |
self.text_to_video.enable_model_cpu_offload()
|
71 |
self.image_to_video.enable_model_cpu_offload()
|
72 |
|
73 |
+
self.varnish = Varnish(
|
74 |
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
75 |
+
output_format="mp4",
|
76 |
+
output_codec="h264",
|
77 |
+
output_quality=23,
|
78 |
+
enable_mmaudio=False
|
79 |
+
)
|
80 |
+
|
81 |
def _validate_and_adjust_resolution(self, width: int, height: int) -> Tuple[int, int]:
|
82 |
"""Validate and adjust resolution to meet constraints.
|
83 |
|
|
|
125 |
|
126 |
return num_frames, fps
|
127 |
|
128 |
+
async def process_and_encode_video(
|
129 |
+
self,
|
130 |
+
frames: torch.Tensor,
|
131 |
+
fps: int,
|
132 |
+
upscale_factor: int = 0,
|
133 |
+
enable_interpolation: bool = False,
|
134 |
+
interpolation_exp: int = 1
|
135 |
+
) -> tuple[str, dict]:
|
136 |
+
"""Process video frames using Varnish and return base64 encoded result"""
|
137 |
|
138 |
+
# Process video with Varnish
|
139 |
+
result = await self.varnish(
|
140 |
+
input_data=frames,
|
141 |
+
input_fps=fps,
|
142 |
+
output_fps=fps,
|
143 |
+
enable_upscale=upscale_factor > 1,
|
144 |
+
upscale_factor=upscale_factor,
|
145 |
+
enable_interpolation=enable_interpolation,
|
146 |
+
interpolation_exp=interpolation_exp
|
147 |
+
)
|
|
|
|
|
|
|
|
|
|
|
148 |
|
149 |
+
# Get video as data URI
|
150 |
+
video_data_uri = await result.write(
|
151 |
+
output_type="data-uri",
|
152 |
+
output_format="mp4",
|
153 |
+
output_codec="h264",
|
154 |
+
output_quality=23
|
155 |
+
)
|
156 |
|
157 |
+
metadata = {
|
158 |
+
"width": result.metadata.width,
|
159 |
+
"height": result.metadata.height,
|
160 |
+
"num_frames": result.metadata.frame_count,
|
161 |
+
"fps": result.metadata.fps,
|
162 |
+
"duration": result.metadata.duration
|
163 |
+
}
|
164 |
+
|
165 |
+
return video_data_uri, metadata
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
|
167 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
168 |
"""Process the input data and generate video using LTX.
|
|
|
184 |
- content-type: MIME type of the video (right now always "video/mp4")
|
185 |
- metadata: Dictionary with actual values used for generation
|
186 |
"""
|
187 |
+
|
188 |
prompt = data.get("inputs", None)
|
189 |
if not prompt:
|
190 |
raise ValueError("No prompt provided in the 'inputs' field")
|
191 |
|
192 |
+
# Get generation parameters
|
193 |
width = data.get("width", self.DEFAULT_WIDTH)
|
194 |
height = data.get("height", self.DEFAULT_HEIGHT)
|
195 |
width, height = self._validate_and_adjust_resolution(width, height)
|
196 |
+
|
|
|
197 |
num_frames = data.get("num_frames", self.DEFAULT_NUM_FRAMES)
|
198 |
fps = data.get("fps", self.DEFAULT_FPS)
|
199 |
num_frames, fps = self._validate_and_adjust_frames(num_frames, fps)
|
200 |
+
|
201 |
+
# Get post-processing parameters
|
202 |
+
upscale_factor = data.get("upscale_factor", 0)
|
203 |
+
enable_interpolation = data.get("enable_interpolation", False)
|
204 |
+
interpolation_exp = data.get("interpolation_exp", 1)
|
205 |
+
|
206 |
guidance_scale = data.get("guidance_scale", 7.5)
|
207 |
num_inference_steps = data.get("num_inference_steps", self.DEFAULT_NUM_STEPS)
|
|
|
208 |
seed = data.get("seed", -1)
|
209 |
seed = random.randint(0, 2**32 - 1) if seed == -1 else int(seed)
|
|
|
|
|
|
|
|
|
210 |
|
211 |
try:
|
212 |
with torch.no_grad():
|
|
|
213 |
random.seed(seed)
|
214 |
np.random.seed(seed)
|
215 |
generator.manual_seed(seed)
|
|
|
225 |
"generator": generator
|
226 |
}
|
227 |
|
228 |
+
# Generate frames using appropriate pipeline
|
229 |
image_data = data.get("image")
|
230 |
if image_data:
|
231 |
if image_data.startswith('data:'):
|
232 |
image_data = image_data.split(',', 1)[1]
|
233 |
image_bytes = base64.b64decode(image_data)
|
234 |
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
|
|
235 |
generation_kwargs["image"] = image
|
236 |
+
frames = self.image_to_video(**generation_kwargs).frames
|
237 |
else:
|
238 |
+
frames = self.text_to_video(**generation_kwargs).frames
|
|
|
239 |
|
240 |
+
# Process and encode video
|
241 |
+
video_data_uri, metadata = await self.process_and_encode_video(
|
242 |
+
frames=frames,
|
243 |
+
fps=fps,
|
244 |
+
upscale_factor=upscale_factor,
|
245 |
+
enable_interpolation=enable_interpolation,
|
246 |
+
interpolation_exp=interpolation_exp
|
247 |
+
)
|
248 |
|
249 |
+
# Add generation metadata
|
250 |
+
metadata.update({
|
251 |
+
"num_inference_steps": num_inference_steps,
|
252 |
+
"seed": seed,
|
253 |
+
"upscale_factor": upscale_factor,
|
254 |
+
"interpolation_enabled": enable_interpolation,
|
255 |
+
"interpolation_exp": interpolation_exp
|
256 |
+
})
|
257 |
|
|
|
|
|
|
|
258 |
return {
|
259 |
"video": video_data_uri,
|
260 |
+
"content-type": "video/mp4",
|
261 |
+
"metadata": metadata
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
}
|
263 |
|
264 |
except Exception as e:
|