from .models import HYVideoDiffusionTransformer, HUNYUAN_VIDEO_CONFIG | |
def load_model(args, in_channels, out_channels, factor_kwargs): | |
"""load hunyuan video model | |
Args: | |
args (dict): model args | |
in_channels (int): input channels number | |
out_channels (int): output channels number | |
factor_kwargs (dict): factor kwargs | |
Returns: | |
model (nn.Module): The hunyuan video model | |
""" | |
if args.model in HUNYUAN_VIDEO_CONFIG.keys(): | |
model = HYVideoDiffusionTransformer( | |
args, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
**HUNYUAN_VIDEO_CONFIG[args.model], | |
**factor_kwargs, | |
) | |
return model | |
else: | |
raise NotImplementedError() | |