File size: 2,919 Bytes
5bdf5c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b843da
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
from diffusers import (ControlNetModel, 
                       StableDiffusionXLControlNetImg2ImgPipeline, 
                       AutoencoderKL,
                       T2IAdapter,
                       StableDiffusionXLAdapterPipeline,
                       EulerAncestralDiscreteScheduler)

from controlnet_aux.pidi import PidiNetDetector

from PIL import Image
import os


def get_vae(model_name="madebyollin/sdxl-vae-fp16-fix"):
    return AutoencoderKL.from_pretrained(model_name, torch_dtype=torch.float16)

def get_controlnet(model_name="diffusers/controlnet-canny-sdxl-1.0"):
    return ControlNetModel.from_pretrained(model_name, torch_dtype=torch.float16)

def get_adapter(model_name="Adapter/t2iadapter", subfolder="sketch_sdxl_1.0",
                adapter_type="full_adapter_xl"):
    if adapter_type == "full_adapter_xl":
        return T2IAdapter.from_pretrained(model_name, 
                                    subfolder=subfolder, 
                                    torch_dtype=torch.float16, 
                                    adapter_type=adapter_type)

def get_scheduler(model_name, scheduler_type="discrete"):
    if scheduler_type == "discrete":
        return EulerAncestralDiscreteScheduler.from_pretrained(model_name, subfolder="scheduler")


def get_detector(model_name="lllyasviel/Annotators", model_type='pidi'):
    if model_type == 'pidi':
        return PidiNetDetector.from_pretrained(model_name)


def load_lora(pipe, lora_path=None):
    if lora_path != None:
        try:
            lora_dir='./'+'/'.join(lora_path.split("/")[:-1])
            lora_name=lora_path.split("/")[-1]
            pipe.load_lora_weights(lora_dir, weight_name=lora_name)
        except Exception as ex:
            print(ex)
    #return pipe


def get_pipe(vae, model_name, controlnet=None, adapter=None, scheduler=None, lora_path=None):
    if controlnet!=None:
        pipe=StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(model_name,
                                                                          controlnet=controlnet, 
                                                                          vae=vae, 
                                                                          torch_dtype=torch.float16)

        load_lora(pipe, lora_path)                                
        return pipe

    elif adapter != None:
        pipe=StableDiffusionXLAdapterPipeline.from_pretrained(model_name, 
                                                                adapter=adapter, 
                                                                vae=vae,
                                                                scheduler=scheduler,
                                                                torch_dtype=torch.float16, 
                                                                variant="fp16")
        load_lora(pipe, lora_path)
        return pipe