jbilcke-hf's picture
jbilcke-hf HF Staff
initial commit log 🪵🦫
91fb4ef
raw
history blame
1.34 kB
from typing import Any, Dict
from .cogvideox import COGVIDEOX_T2V_FULL_FINETUNE_CONFIG, COGVIDEOX_T2V_LORA_CONFIG
from .hunyuan_video import HUNYUAN_VIDEO_T2V_FULL_FINETUNE_CONFIG, HUNYUAN_VIDEO_T2V_LORA_CONFIG
from .ltx_video import LTX_VIDEO_T2V_FULL_FINETUNE_CONFIG, LTX_VIDEO_T2V_LORA_CONFIG
SUPPORTED_MODEL_CONFIGS = {
"hunyuan_video": {
"lora": HUNYUAN_VIDEO_T2V_LORA_CONFIG,
"full-finetune": HUNYUAN_VIDEO_T2V_FULL_FINETUNE_CONFIG,
},
"ltx_video": {
"lora": LTX_VIDEO_T2V_LORA_CONFIG,
"full-finetune": LTX_VIDEO_T2V_FULL_FINETUNE_CONFIG,
},
"cogvideox": {
"lora": COGVIDEOX_T2V_LORA_CONFIG,
"full-finetune": COGVIDEOX_T2V_FULL_FINETUNE_CONFIG,
},
}
def get_config_from_model_name(model_name: str, training_type: str) -> Dict[str, Any]:
if model_name not in SUPPORTED_MODEL_CONFIGS:
raise ValueError(
f"Model {model_name} not supported. Supported models are: {list(SUPPORTED_MODEL_CONFIGS.keys())}"
)
if training_type not in SUPPORTED_MODEL_CONFIGS[model_name]:
raise ValueError(
f"Training type {training_type} not supported for model {model_name}. Supported training types are: {list(SUPPORTED_MODEL_CONFIGS[model_name].keys())}"
)
return SUPPORTED_MODEL_CONFIGS[model_name][training_type]