import importlib from diffusers import DiffusionPipeline def apply_cache_on_transformer(transformer, *args, **kwargs): transformer_cls_name = transformer.__class__.__name__ if False: pass elif transformer_cls_name.startswith("Flux"): adapter_name = "flux" elif transformer_cls_name.startswith("Mochi"): adapter_name = "mochi" elif transformer_cls_name.startswith("CogVideoX"): adapter_name = "cogvideox" elif transformer_cls_name.startswith("HunyuanVideo"): adapter_name = "hunyuan_video" else: raise ValueError(f"Unknown transformer class name: {transformer_cls_name}") adapter_module = importlib.import_module(f".{adapter_name}", __package__) apply_cache_on_transformer_fn = getattr(adapter_module, "apply_cache_on_transformer") return apply_cache_on_transformer_fn(transformer, *args, **kwargs) def apply_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs): assert isinstance(pipe, DiffusionPipeline) pipe_cls_name = pipe.__class__.__name__ if False: pass elif pipe_cls_name.startswith("Flux"): adapter_name = "flux" elif pipe_cls_name.startswith("Mochi"): adapter_name = "mochi" elif pipe_cls_name.startswith("CogVideoX"): adapter_name = "cogvideox" elif pipe_cls_name.startswith("HunyuanVideo"): adapter_name = "hunyuan_video" else: raise ValueError(f"Unknown pipeline class name: {pipe_cls_name}") adapter_module = importlib.import_module(f".{adapter_name}", __package__) apply_cache_on_pipe_fn = getattr(adapter_module, "apply_cache_on_pipe") return apply_cache_on_pipe_fn(pipe, *args, **kwargs)