Spaces:
Sleeping
Sleeping
Gainward777
commited on
Commit
•
5bdf5c2
1
Parent(s):
9d6c0b1
Update sd/utils/utils.py
Browse files- sd/utils/utils.py +67 -77
sd/utils/utils.py
CHANGED
@@ -1,78 +1,68 @@
|
|
1 |
-
import torch
|
2 |
-
from diffusers import (ControlNetModel,
|
3 |
-
StableDiffusionXLControlNetImg2ImgPipeline,
|
4 |
-
AutoencoderKL,
|
5 |
-
T2IAdapter,
|
6 |
-
StableDiffusionXLAdapterPipeline,
|
7 |
-
EulerAncestralDiscreteScheduler)
|
8 |
-
|
9 |
-
from controlnet_aux.pidi import PidiNetDetector
|
10 |
-
|
11 |
-
from PIL import Image
|
12 |
-
import os
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
def
|
40 |
-
if
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
load_lora(pipe, lora_path)
|
68 |
-
return pipe
|
69 |
-
|
70 |
-
elif adapter != None:
|
71 |
-
pipe=StableDiffusionXLAdapterPipeline.from_pretrained(model_name,
|
72 |
-
adapter=adapter,
|
73 |
-
vae=vae,
|
74 |
-
scheduler=scheduler,
|
75 |
-
torch_dtype=torch.float16,
|
76 |
-
variant="fp16")
|
77 |
-
load_lora(pipe, lora_path)
|
78 |
return pipe
|
|
|
1 |
+
import torch
|
2 |
+
from diffusers import (ControlNetModel,
|
3 |
+
StableDiffusionXLControlNetImg2ImgPipeline,
|
4 |
+
AutoencoderKL,
|
5 |
+
T2IAdapter,
|
6 |
+
StableDiffusionXLAdapterPipeline,
|
7 |
+
EulerAncestralDiscreteScheduler)
|
8 |
+
|
9 |
+
from controlnet_aux.pidi import PidiNetDetector
|
10 |
+
|
11 |
+
from PIL import Image
|
12 |
+
import os
|
13 |
+
|
14 |
+
|
15 |
+
def get_vae(model_name="madebyollin/sdxl-vae-fp16-fix"):
|
16 |
+
return AutoencoderKL.from_pretrained(model_name, torch_dtype=torch.float16)
|
17 |
+
|
18 |
+
def get_controlnet(model_name="diffusers/controlnet-canny-sdxl-1.0"):
|
19 |
+
return ControlNetModel.from_pretrained(model_name, torch_dtype=torch.float16)
|
20 |
+
|
21 |
+
def get_adapter(model_name="Adapter/t2iadapter", subfolder="sketch_sdxl_1.0",
|
22 |
+
adapter_type="full_adapter_xl"):
|
23 |
+
if adapter_type == "full_adapter_xl":
|
24 |
+
return T2IAdapter.from_pretrained(model_name,
|
25 |
+
subfolder=subfolder,
|
26 |
+
torch_dtype=torch.float16,
|
27 |
+
adapter_type=adapter_type)
|
28 |
+
|
29 |
+
def get_scheduler(model_name, scheduler_type="discrete"):
|
30 |
+
if scheduler_type == "discrete":
|
31 |
+
return EulerAncestralDiscreteScheduler.from_pretrained(model_name, subfolder="scheduler")
|
32 |
+
|
33 |
+
|
34 |
+
def get_detector(model_name="lllyasviel/Annotators", model_type='pidi'):
|
35 |
+
if model_type == 'pidi':
|
36 |
+
return PidiNetDetector.from_pretrained(model_name)
|
37 |
+
|
38 |
+
|
39 |
+
def load_lora(pipe, lora_path=None):
|
40 |
+
if lora_path != None:
|
41 |
+
try:
|
42 |
+
lora_dir='./'+'/'.join(lora_path.split("/")[:-1])
|
43 |
+
lora_name=lora_path.split("/")[-1]
|
44 |
+
pipe.load_lora_weights(lora_dir, weight_name=lora_name)
|
45 |
+
except Exception as ex:
|
46 |
+
print(ex)
|
47 |
+
#return pipe
|
48 |
+
|
49 |
+
|
50 |
+
def get_pipe(vae, model_name, controlnet=None, adapter=None, scheduler=None, lora_path=None):
|
51 |
+
if controlnet!=None:
|
52 |
+
pipe=StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(model_name,
|
53 |
+
controlnet=controlnet,
|
54 |
+
vae=vae,
|
55 |
+
torch_dtype=torch.float16)
|
56 |
+
|
57 |
+
load_lora(pipe, lora_path)
|
58 |
+
return pipe
|
59 |
+
|
60 |
+
elif adapter != None:
|
61 |
+
pipe=StableDiffusionXLAdapterPipeline.from_pretrained(model_name,
|
62 |
+
adapter=adapter,
|
63 |
+
vae=vae,
|
64 |
+
scheduler=scheduler,
|
65 |
+
torch_dtype=torch.float16,
|
66 |
+
variant="fp16")
|
67 |
+
load_lora(pipe, lora_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
return pipe
|