|
import importlib |
|
|
|
from diffusers import DiffusionPipeline |
|
|
|
|
|
def apply_cache_on_transformer(transformer, *args, **kwargs): |
|
transformer_cls_name = transformer.__class__.__name__ |
|
if False: |
|
pass |
|
elif transformer_cls_name.startswith("Flux"): |
|
adapter_name = "flux" |
|
elif transformer_cls_name.startswith("Mochi"): |
|
adapter_name = "mochi" |
|
elif transformer_cls_name.startswith("CogVideoX"): |
|
adapter_name = "cogvideox" |
|
elif transformer_cls_name.startswith("HunyuanVideo"): |
|
adapter_name = "hunyuan_video" |
|
else: |
|
raise ValueError(f"Unknown transformer class name: {transformer_cls_name}") |
|
|
|
adapter_module = importlib.import_module(f".{adapter_name}", __package__) |
|
apply_cache_on_transformer_fn = getattr(adapter_module, "apply_cache_on_transformer") |
|
return apply_cache_on_transformer_fn(transformer, *args, **kwargs) |
|
|
|
|
|
def apply_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs): |
|
assert isinstance(pipe, DiffusionPipeline) |
|
|
|
pipe_cls_name = pipe.__class__.__name__ |
|
if False: |
|
pass |
|
elif pipe_cls_name.startswith("Flux"): |
|
adapter_name = "flux" |
|
elif pipe_cls_name.startswith("Mochi"): |
|
adapter_name = "mochi" |
|
elif pipe_cls_name.startswith("CogVideoX"): |
|
adapter_name = "cogvideox" |
|
elif pipe_cls_name.startswith("HunyuanVideo"): |
|
adapter_name = "hunyuan_video" |
|
else: |
|
raise ValueError(f"Unknown pipeline class name: {pipe_cls_name}") |
|
|
|
adapter_module = importlib.import_module(f".{adapter_name}", __package__) |
|
apply_cache_on_pipe_fn = getattr(adapter_module, "apply_cache_on_pipe") |
|
return apply_cache_on_pipe_fn(pipe, *args, **kwargs) |
|
|