File size: 1,695 Bytes
9d3fd05 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import importlib
from diffusers import DiffusionPipeline
def apply_cache_on_transformer(transformer, *args, **kwargs):
transformer_cls_name = transformer.__class__.__name__
if False:
pass
elif transformer_cls_name.startswith("Flux"):
adapter_name = "flux"
elif transformer_cls_name.startswith("Mochi"):
adapter_name = "mochi"
elif transformer_cls_name.startswith("CogVideoX"):
adapter_name = "cogvideox"
elif transformer_cls_name.startswith("HunyuanVideo"):
adapter_name = "hunyuan_video"
else:
raise ValueError(f"Unknown transformer class name: {transformer_cls_name}")
adapter_module = importlib.import_module(f".{adapter_name}", __package__)
apply_cache_on_transformer_fn = getattr(adapter_module, "apply_cache_on_transformer")
return apply_cache_on_transformer_fn(transformer, *args, **kwargs)
def apply_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
assert isinstance(pipe, DiffusionPipeline)
pipe_cls_name = pipe.__class__.__name__
if False:
pass
elif pipe_cls_name.startswith("Flux"):
adapter_name = "flux"
elif pipe_cls_name.startswith("Mochi"):
adapter_name = "mochi"
elif pipe_cls_name.startswith("CogVideoX"):
adapter_name = "cogvideox"
elif pipe_cls_name.startswith("HunyuanVideo"):
adapter_name = "hunyuan_video"
else:
raise ValueError(f"Unknown pipeline class name: {pipe_cls_name}")
adapter_module = importlib.import_module(f".{adapter_name}", __package__)
apply_cache_on_pipe_fn = getattr(adapter_module, "apply_cache_on_pipe")
return apply_cache_on_pipe_fn(pipe, *args, **kwargs)
|