jbilcke-hf HF staff commited on
Commit
a1227fd
·
verified ·
1 Parent(s): 16f1cf0

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +28 -0
handler.py CHANGED
@@ -280,6 +280,32 @@ class EndpointHandler:
280
  enable_mmaudio=True,
281
  )
282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  async def process_frames(
284
  self,
285
  frames: torch.Tensor,
@@ -444,6 +470,7 @@ class EndpointHandler:
444
 
445
  # Check if image-to-video generation is requested
446
  if input_image:
 
447
  processed_image = process_input_image(
448
  input_image,
449
  config.width,
@@ -453,6 +480,7 @@ class EndpointHandler:
453
  generation_kwargs["image"] = processed_image
454
  frames = self.image_to_video(**generation_kwargs).frames
455
  else:
 
456
  frames = self.text_to_video(**generation_kwargs).frames
457
 
458
  try:
 
280
  enable_mmaudio=True,
281
  )
282
 
283
+ # Store TeaCache config for each model
284
+ self.text_to_video_teacache = None
285
+ self.image_to_video_teacache = None
286
+
287
+ def _configure_teacache(self, model, config: GenerationConfig):
288
+ """Configure TeaCache for a model based on generation config
289
+
290
+ Args:
291
+ model: The model to configure TeaCache for
292
+ config: Generation configuration
293
+ """
294
+ if config.enable_teacache:
295
+ # Create and enable TeaCache if it should be enabled
296
+ teacache_config = TeaCacheConfig(
297
+ enabled=True,
298
+ rel_l1_thresh=config.teacache_threshold,
299
+ num_inference_steps=config.num_inference_steps
300
+ )
301
+ enable_teacache(model.transformer.__class__, teacache_config)
302
+ logger.info(f"TeaCache enabled with threshold {config.teacache_threshold}")
303
+ else:
304
+ # Disable TeaCache if it was previously enabled
305
+ if hasattr(model.transformer.__class__, 'teacache_config'):
306
+ disable_teacache(model.transformer.__class__)
307
+ logger.info("TeaCache disabled")
308
+
309
  async def process_frames(
310
  self,
311
  frames: torch.Tensor,
 
470
 
471
  # Check if image-to-video generation is requested
472
  if input_image:
473
+ self._configure_teacache(self.image_to_video, config)
474
  processed_image = process_input_image(
475
  input_image,
476
  config.width,
 
480
  generation_kwargs["image"] = processed_image
481
  frames = self.image_to_video(**generation_kwargs).frames
482
  else:
483
+ self._configure_teacache(self.text_to_video, config)
484
  frames = self.text_to_video(**generation_kwargs).frames
485
 
486
  try: