Update handler.py
Browse files- handler.py +8 -28
handler.py
CHANGED
@@ -16,7 +16,6 @@ from diffusers import LTXPipeline, LTXImageToVideoPipeline
|
|
16 |
from diffusers.hooks import apply_enhance_a_video, EnhanceAVideoConfig
|
17 |
from PIL import Image
|
18 |
|
19 |
-
from teacache import TeaCacheConfig, enable_teacache, disable_teacache
|
20 |
from varnish import Varnish
|
21 |
from varnish.utils import is_truthy, process_input_image
|
22 |
|
@@ -149,12 +148,16 @@ class EndpointHandler:
|
|
149 |
torch_dtype=torch.bfloat16
|
150 |
).to("cuda")
|
151 |
|
|
|
|
|
152 |
else:
|
153 |
# Initialize models with bfloat16 precision
|
154 |
self.text_to_video = LTXPipeline.from_pretrained(
|
155 |
model_path,
|
156 |
torch_dtype=torch.bfloat16
|
157 |
).to("cuda")
|
|
|
|
|
158 |
|
159 |
# Initialize LoRA tracking
|
160 |
self._current_lora_model = None
|
@@ -195,32 +198,11 @@ class EndpointHandler:
|
|
195 |
enable_mmaudio=True,
|
196 |
)
|
197 |
|
198 |
-
#
|
199 |
-
self.text_to_video_teacache =
|
200 |
-
self.image_to_video_teacache =
|
|
|
201 |
|
202 |
-
def _configure_teacache(self, model, config: GenerationConfig):
|
203 |
-
"""Configure TeaCache for a model based on generation config
|
204 |
-
|
205 |
-
Args:
|
206 |
-
model: The model to configure TeaCache for
|
207 |
-
config: Generation configuration
|
208 |
-
"""
|
209 |
-
if config.enable_teacache:
|
210 |
-
# Create and enable TeaCache if it should be enabled
|
211 |
-
teacache_config = TeaCacheConfig(
|
212 |
-
enabled=True,
|
213 |
-
rel_l1_thresh=config.teacache_threshold,
|
214 |
-
num_inference_steps=config.num_inference_steps
|
215 |
-
)
|
216 |
-
enable_teacache(model.transformer.__class__, teacache_config)
|
217 |
-
logger.info(f"TeaCache enabled with threshold {config.teacache_threshold}")
|
218 |
-
else:
|
219 |
-
# Disable TeaCache if it was previously enabled
|
220 |
-
if hasattr(model.transformer.__class__, 'teacache_config'):
|
221 |
-
disable_teacache(model.transformer.__class__)
|
222 |
-
logger.info("TeaCache disabled")
|
223 |
-
|
224 |
async def process_frames(
|
225 |
self,
|
226 |
frames: torch.Tensor,
|
@@ -451,7 +433,6 @@ class EndpointHandler:
|
|
451 |
|
452 |
# Check if image-to-video generation is requested
|
453 |
if support_image_prompt and input_image:
|
454 |
-
self._configure_teacache(self.image_to_video, config)
|
455 |
processed_image = process_input_image(
|
456 |
input_image,
|
457 |
config.width,
|
@@ -463,7 +444,6 @@ class EndpointHandler:
|
|
463 |
# apply_enhance_a_video(self.image_to_video.transformer, enhance_a_video_config)
|
464 |
frames = self.image_to_video(**generation_kwargs).frames
|
465 |
else:
|
466 |
-
self._configure_teacache(self.text_to_video, config)
|
467 |
# disabled (we cannot install the hook multiple times, we would have to uninstall it first or find another way to dynamically enable it, eg. using the weight only)
|
468 |
# apply_enhance_a_video(self.text_to_video.transformer, enhance_a_video_config)
|
469 |
frames = self.text_to_video(**generation_kwargs).frames
|
|
|
16 |
from diffusers.hooks import apply_enhance_a_video, EnhanceAVideoConfig
|
17 |
from PIL import Image
|
18 |
|
|
|
19 |
from varnish import Varnish
|
20 |
from varnish.utils import is_truthy, process_input_image
|
21 |
|
|
|
148 |
torch_dtype=torch.bfloat16
|
149 |
).to("cuda")
|
150 |
|
151 |
+
apply_teacache(self.image_to_video)
|
152 |
+
|
153 |
else:
|
154 |
# Initialize models with bfloat16 precision
|
155 |
self.text_to_video = LTXPipeline.from_pretrained(
|
156 |
model_path,
|
157 |
torch_dtype=torch.bfloat16
|
158 |
).to("cuda")
|
159 |
+
|
160 |
+
apply_teacache(self.text_to_video)
|
161 |
|
162 |
# Initialize LoRA tracking
|
163 |
self._current_lora_model = None
|
|
|
198 |
enable_mmaudio=True,
|
199 |
)
|
200 |
|
201 |
+
# Determine if TeaCache is already installed or not
|
202 |
+
self.text_to_video_teacache = False
|
203 |
+
self.image_to_video_teacache = False
|
204 |
+
|
205 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
async def process_frames(
|
207 |
self,
|
208 |
frames: torch.Tensor,
|
|
|
433 |
|
434 |
# Check if image-to-video generation is requested
|
435 |
if support_image_prompt and input_image:
|
|
|
436 |
processed_image = process_input_image(
|
437 |
input_image,
|
438 |
config.width,
|
|
|
444 |
# apply_enhance_a_video(self.image_to_video.transformer, enhance_a_video_config)
|
445 |
frames = self.image_to_video(**generation_kwargs).frames
|
446 |
else:
|
|
|
447 |
# disabled (we cannot install the hook multiple times, we would have to uninstall it first or find another way to dynamically enable it, eg. using the weight only)
|
448 |
# apply_enhance_a_video(self.text_to_video.transformer, enhance_a_video_config)
|
449 |
frames = self.text_to_video(**generation_kwargs).frames
|