jbilcke-hf HF staff commited on
Commit
9ddc16e
·
verified ·
1 Parent(s): 6ecbc3e

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +8 -28
handler.py CHANGED
@@ -16,7 +16,6 @@ from diffusers import LTXPipeline, LTXImageToVideoPipeline
16
  from diffusers.hooks import apply_enhance_a_video, EnhanceAVideoConfig
17
  from PIL import Image
18
 
19
- from teacache import TeaCacheConfig, enable_teacache, disable_teacache
20
  from varnish import Varnish
21
  from varnish.utils import is_truthy, process_input_image
22
 
@@ -149,12 +148,16 @@ class EndpointHandler:
149
  torch_dtype=torch.bfloat16
150
  ).to("cuda")
151
 
 
 
152
  else:
153
  # Initialize models with bfloat16 precision
154
  self.text_to_video = LTXPipeline.from_pretrained(
155
  model_path,
156
  torch_dtype=torch.bfloat16
157
  ).to("cuda")
 
 
158
 
159
  # Initialize LoRA tracking
160
  self._current_lora_model = None
@@ -195,32 +198,11 @@ class EndpointHandler:
195
  enable_mmaudio=True,
196
  )
197
 
198
- # Store TeaCache config for each model
199
- self.text_to_video_teacache = None
200
- self.image_to_video_teacache = None
 
201
 
202
- def _configure_teacache(self, model, config: GenerationConfig):
203
- """Configure TeaCache for a model based on generation config
204
-
205
- Args:
206
- model: The model to configure TeaCache for
207
- config: Generation configuration
208
- """
209
- if config.enable_teacache:
210
- # Create and enable TeaCache if it should be enabled
211
- teacache_config = TeaCacheConfig(
212
- enabled=True,
213
- rel_l1_thresh=config.teacache_threshold,
214
- num_inference_steps=config.num_inference_steps
215
- )
216
- enable_teacache(model.transformer.__class__, teacache_config)
217
- logger.info(f"TeaCache enabled with threshold {config.teacache_threshold}")
218
- else:
219
- # Disable TeaCache if it was previously enabled
220
- if hasattr(model.transformer.__class__, 'teacache_config'):
221
- disable_teacache(model.transformer.__class__)
222
- logger.info("TeaCache disabled")
223
-
224
  async def process_frames(
225
  self,
226
  frames: torch.Tensor,
@@ -451,7 +433,6 @@ class EndpointHandler:
451
 
452
  # Check if image-to-video generation is requested
453
  if support_image_prompt and input_image:
454
- self._configure_teacache(self.image_to_video, config)
455
  processed_image = process_input_image(
456
  input_image,
457
  config.width,
@@ -463,7 +444,6 @@ class EndpointHandler:
463
  # apply_enhance_a_video(self.image_to_video.transformer, enhance_a_video_config)
464
  frames = self.image_to_video(**generation_kwargs).frames
465
  else:
466
- self._configure_teacache(self.text_to_video, config)
467
  # disabled (we cannot install the hook multiple times, we would have to uninstall it first or find another way to dynamically enable it, eg. using the weight only)
468
  # apply_enhance_a_video(self.text_to_video.transformer, enhance_a_video_config)
469
  frames = self.text_to_video(**generation_kwargs).frames
 
16
  from diffusers.hooks import apply_enhance_a_video, EnhanceAVideoConfig
17
  from PIL import Image
18
 
 
19
  from varnish import Varnish
20
  from varnish.utils import is_truthy, process_input_image
21
 
 
148
  torch_dtype=torch.bfloat16
149
  ).to("cuda")
150
 
151
+ apply_teacache(self.image_to_video)
152
+
153
  else:
154
  # Initialize models with bfloat16 precision
155
  self.text_to_video = LTXPipeline.from_pretrained(
156
  model_path,
157
  torch_dtype=torch.bfloat16
158
  ).to("cuda")
159
+
160
+ apply_teacache(self.text_to_video)
161
 
162
  # Initialize LoRA tracking
163
  self._current_lora_model = None
 
198
  enable_mmaudio=True,
199
  )
200
 
201
+ # Determine if TeaCache is already installed or not
202
+ self.text_to_video_teacache = False
203
+ self.image_to_video_teacache = False
204
+
205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  async def process_frames(
207
  self,
208
  frames: torch.Tensor,
 
433
 
434
  # Check if image-to-video generation is requested
435
  if support_image_prompt and input_image:
 
436
  processed_image = process_input_image(
437
  input_image,
438
  config.width,
 
444
  # apply_enhance_a_video(self.image_to_video.transformer, enhance_a_video_config)
445
  frames = self.image_to_video(**generation_kwargs).frames
446
  else:
 
447
  # disabled (we cannot install the hook multiple times, we would have to uninstall it first or find another way to dynamically enable it, eg. using the weight only)
448
  # apply_enhance_a_video(self.text_to_video.transformer, enhance_a_video_config)
449
  frames = self.text_to_video(**generation_kwargs).frames