RobertML's picture
Add files using upload-large-folder tool
33e938e verified
raw
history blame
1.7 kB
import importlib
from diffusers import DiffusionPipeline
def apply_cache_on_transformer(transformer, *args, **kwargs):
transformer_cls_name = transformer.__class__.__name__
if False:
pass
elif transformer_cls_name.startswith("Flux"):
adapter_name = "flux"
elif transformer_cls_name.startswith("Mochi"):
adapter_name = "mochi"
elif transformer_cls_name.startswith("CogVideoX"):
adapter_name = "cogvideox"
elif transformer_cls_name.startswith("HunyuanVideo"):
adapter_name = "hunyuan_video"
else:
raise ValueError(f"Unknown transformer class name: {transformer_cls_name}")
adapter_module = importlib.import_module(f".{adapter_name}", __package__)
apply_cache_on_transformer_fn = getattr(adapter_module, "apply_cache_on_transformer")
return apply_cache_on_transformer_fn(transformer, *args, **kwargs)
def apply_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
assert isinstance(pipe, DiffusionPipeline)
pipe_cls_name = pipe.__class__.__name__
if False:
pass
elif pipe_cls_name.startswith("Flux"):
adapter_name = "flux"
elif pipe_cls_name.startswith("Mochi"):
adapter_name = "mochi"
elif pipe_cls_name.startswith("CogVideoX"):
adapter_name = "cogvideox"
elif pipe_cls_name.startswith("HunyuanVideo"):
adapter_name = "hunyuan_video"
else:
raise ValueError(f"Unknown pipeline class name: {pipe_cls_name}")
adapter_module = importlib.import_module(f".{adapter_name}", __package__)
apply_cache_on_pipe_fn = getattr(adapter_module, "apply_cache_on_pipe")
return apply_cache_on_pipe_fn(pipe, *args, **kwargs)