File size: 4,236 Bytes
c92c0ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from .huggingface_models import load_huggingface_model
from .replicate_api_models import load_replicate_model
from .openai_api_models import load_openai_model
from .other_api_models import load_other_model
from .local_models import load_local_model


IMAGE_GENERATION_MODELS = [ 
                            'replicate_SDXL_text2image',
                            'replicate_SD-v3.0_text2image',
                            'replicate_SD-v2.1_text2image',
                            'replicate_SD-v1.5_text2image',
                            'replicate_SDXL-Lightning_text2image',
                            'replicate_Kandinsky-v2.0_text2image',
                            'replicate_Kandinsky-v2.2_text2image',
                            'replicate_Proteus-v0.2_text2image',
                            'replicate_Playground-v2.0_text2image',
                            'replicate_Playground-v2.5_text2image',
                            'replicate_Dreamshaper-xl-turbo_text2image',
                            'replicate_SDXL-Deepcache_text2image',
                            'replicate_Openjourney-v4_text2image',
                            'replicate_LCM-v1.5_text2image',
                            'replicate_Realvisxl-v3.0_text2image',
                            'replicate_Realvisxl-v2.0_text2image',
                            'replicate_Pixart-Sigma_text2image',
                            'replicate_SSD-1b_text2image',
                            'replicate_Open-Dalle-v1.1_text2image',
                            'replicate_Deepfloyd-IF_text2image',
                            'huggingface_SD-turbo_text2image',
                            'huggingface_SDXL-turbo_text2image',
                            'huggingface_Stable-cascade_text2image',
                            'openai_Dalle-2_text2image',
                            'openai_Dalle-3_text2image',
                            'other_Midjourney-v6.0_text2image',
                            'other_Midjourney-v5.0_text2image',
                            "replicate_FLUX.1-schnell_text2image",
                            "replicate_FLUX.1-pro_text2image",
                            "replicate_FLUX.1-dev_text2image",
                            'other_Meissonic_text2image',
                            "replicate_FLUX-1.1-pro_text2image",
                            'replicate_SD-v3.5-large_text2image',
                            'replicate_SD-v3.5-large-turbo_text2image',
                            ]

VIDEO_GENERATION_MODELS = ['replicate_Zeroscope-v2-xl_text2video',
                            'replicate_Animate-Diff_text2video',
                            'replicate_OpenSora_text2video',
                            'replicate_LaVie_text2video',
                            'replicate_VideoCrafter2_text2video',
                            'replicate_Stable-Video-Diffusion_text2video',
                            'other_Runway-Gen3_text2video',
                            'other_Pika-beta_text2video',
                            'other_Pika-v1.0_text2video',
                            'other_Runway-Gen2_text2video',
                            'other_Sora_text2video',
                            'replicate_Cogvideox-5b_text2video',
                            'other_KLing-v1.0_text2video',
                           ]

B2I_MODELS = ['local_MIGC_b2i', 'huggingface_ReCo_b2i']


def load_pipeline(model_name):
    """
    Load a model pipeline based on the model name
    Args:
        model_name (str): The name of the model to load, should be of the form {source}_{name}_{type}
    """
    model_source, model_name, model_type = model_name.split("_")

    if model_source == "replicate":
        pipe = load_replicate_model(model_name, model_type)
    elif model_source == "huggingface":
        pipe = load_huggingface_model(model_name, model_type)
    elif model_source == "openai":
        pipe = load_openai_model(model_name, model_type)
    elif model_source == "other":
        pipe = load_other_model(model_name, model_type)
    elif model_source == "local":
        pipe = load_local_model(model_name, model_type)
    else:
        raise ValueError(f"Model source {model_source} not supported")
    return pipe