File size: 524 Bytes
5b2ab1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from .controlnet import StableDiffusionControlNet
from .controlnet_inpaint import StableDiffusionControlNetInpaint

DIFFUSION_MODELS = {
    "controlnet": StableDiffusionControlNet,
    "controlnet_inpaint": StableDiffusionControlNetInpaint,
}


def create_diffusion_model(diffusion_model_name: str, **kwargs):
    assert (
        diffusion_model_name in DIFFUSION_MODELS.keys()
    ), "Diffusion model name must be one of " + ", ".join(DIFFUSION_MODELS.keys())

    return DIFFUSION_MODELS[diffusion_model_name](**kwargs)