File size: 985 Bytes
6af7294 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
from .diffusion_utils import build_pipeline
NAME_TO_MODEL = {
"stable-diffusion-v1-4":
{
"model" : "CompVis/stable-diffusion-v1-4",
"unet" : "CompVis/stable-diffusion-v1-4",
"tokenizer" : "openai/clip-vit-large-patch14",
"text_encoder" : "openai/clip-vit-large-patch14",
},
"stable_diffusion_v2_1":
{
"model" : "stabilityai/stable-diffusion-2-1",
"unet" : "stabilityai/stable-diffusion-2-1",
"tokenizer" : "laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
"text_encoder" : "laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
}
}
def get_model(model_name):
model = NAME_TO_MODEL.get(model_name)
if model is None:
raise ValueError(f"Model name {model_name} not found. Available models: {list(NAME_TO_MODEL.keys())}")
vae, tokenizer, text_encoder, unet = build_pipeline(model["model"], model["tokenizer"], model["text_encoder"], model["unet"])
return vae, tokenizer, text_encoder, unet |