File size: 1,343 Bytes
91fb4ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from typing import Any, Dict

from .cogvideox import COGVIDEOX_T2V_FULL_FINETUNE_CONFIG, COGVIDEOX_T2V_LORA_CONFIG
from .hunyuan_video import HUNYUAN_VIDEO_T2V_FULL_FINETUNE_CONFIG, HUNYUAN_VIDEO_T2V_LORA_CONFIG
from .ltx_video import LTX_VIDEO_T2V_FULL_FINETUNE_CONFIG, LTX_VIDEO_T2V_LORA_CONFIG


SUPPORTED_MODEL_CONFIGS = {
    "hunyuan_video": {
        "lora": HUNYUAN_VIDEO_T2V_LORA_CONFIG,
        "full-finetune": HUNYUAN_VIDEO_T2V_FULL_FINETUNE_CONFIG,
    },
    "ltx_video": {
        "lora": LTX_VIDEO_T2V_LORA_CONFIG,
        "full-finetune": LTX_VIDEO_T2V_FULL_FINETUNE_CONFIG,
    },
    "cogvideox": {
        "lora": COGVIDEOX_T2V_LORA_CONFIG,
        "full-finetune": COGVIDEOX_T2V_FULL_FINETUNE_CONFIG,
    },
}


def get_config_from_model_name(model_name: str, training_type: str) -> Dict[str, Any]:
    if model_name not in SUPPORTED_MODEL_CONFIGS:
        raise ValueError(
            f"Model {model_name} not supported. Supported models are: {list(SUPPORTED_MODEL_CONFIGS.keys())}"
        )
    if training_type not in SUPPORTED_MODEL_CONFIGS[model_name]:
        raise ValueError(
            f"Training type {training_type} not supported for model {model_name}. Supported training types are: {list(SUPPORTED_MODEL_CONFIGS[model_name].keys())}"
        )
    return SUPPORTED_MODEL_CONFIGS[model_name][training_type]