Update handler.py
Browse files- handler.py +28 -0
handler.py
CHANGED
@@ -280,6 +280,32 @@ class EndpointHandler:
|
|
280 |
enable_mmaudio=True,
|
281 |
)
|
282 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
283 |
async def process_frames(
|
284 |
self,
|
285 |
frames: torch.Tensor,
|
@@ -444,6 +470,7 @@ class EndpointHandler:
|
|
444 |
|
445 |
# Check if image-to-video generation is requested
|
446 |
if input_image:
|
|
|
447 |
processed_image = process_input_image(
|
448 |
input_image,
|
449 |
config.width,
|
@@ -453,6 +480,7 @@ class EndpointHandler:
|
|
453 |
generation_kwargs["image"] = processed_image
|
454 |
frames = self.image_to_video(**generation_kwargs).frames
|
455 |
else:
|
|
|
456 |
frames = self.text_to_video(**generation_kwargs).frames
|
457 |
|
458 |
try:
|
|
|
280 |
enable_mmaudio=True,
|
281 |
)
|
282 |
|
283 |
+
# Store TeaCache config for each model
|
284 |
+
self.text_to_video_teacache = None
|
285 |
+
self.image_to_video_teacache = None
|
286 |
+
|
287 |
+
def _configure_teacache(self, model, config: GenerationConfig):
|
288 |
+
"""Configure TeaCache for a model based on generation config
|
289 |
+
|
290 |
+
Args:
|
291 |
+
model: The model to configure TeaCache for
|
292 |
+
config: Generation configuration
|
293 |
+
"""
|
294 |
+
if config.enable_teacache:
|
295 |
+
# Create and enable TeaCache if it should be enabled
|
296 |
+
teacache_config = TeaCacheConfig(
|
297 |
+
enabled=True,
|
298 |
+
rel_l1_thresh=config.teacache_threshold,
|
299 |
+
num_inference_steps=config.num_inference_steps
|
300 |
+
)
|
301 |
+
enable_teacache(model.transformer.__class__, teacache_config)
|
302 |
+
logger.info(f"TeaCache enabled with threshold {config.teacache_threshold}")
|
303 |
+
else:
|
304 |
+
# Disable TeaCache if it was previously enabled
|
305 |
+
if hasattr(model.transformer.__class__, 'teacache_config'):
|
306 |
+
disable_teacache(model.transformer.__class__)
|
307 |
+
logger.info("TeaCache disabled")
|
308 |
+
|
309 |
async def process_frames(
|
310 |
self,
|
311 |
frames: torch.Tensor,
|
|
|
470 |
|
471 |
# Check if image-to-video generation is requested
|
472 |
if input_image:
|
473 |
+
self._configure_teacache(self.image_to_video, config)
|
474 |
processed_image = process_input_image(
|
475 |
input_image,
|
476 |
config.width,
|
|
|
480 |
generation_kwargs["image"] = processed_image
|
481 |
frames = self.image_to_video(**generation_kwargs).frames
|
482 |
else:
|
483 |
+
self._configure_teacache(self.text_to_video, config)
|
484 |
frames = self.text_to_video(**generation_kwargs).frames
|
485 |
|
486 |
try:
|