Gainward777 commited on
Commit
5bdf5c2
1 Parent(s): 9d6c0b1

Update sd/utils/utils.py

Browse files
Files changed (1) hide show
  1. sd/utils/utils.py +67 -77
sd/utils/utils.py CHANGED
@@ -1,78 +1,68 @@
1
- import torch
2
- from diffusers import (ControlNetModel,
3
- StableDiffusionXLControlNetImg2ImgPipeline,
4
- AutoencoderKL,
5
- T2IAdapter,
6
- StableDiffusionXLAdapterPipeline,
7
- EulerAncestralDiscreteScheduler)
8
-
9
- from controlnet_aux.pidi import PidiNetDetector
10
-
11
- from PIL import Image
12
- import os
13
-
14
-
15
- #VAE=AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
16
-
17
- #CONTROLNET = ControlNetModel.from_pretrained("diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16)
18
-
19
- #ADAPTER = T2IAdapter.from_pretrained("Adapter/t2iadapter",
20
- #subfolder="sketch_sdxl_1.0",
21
- #torch_dtype=torch.float16,
22
- #adapter_type="full_adapter_xl")
23
-
24
-
25
- def get_vae(model_name="madebyollin/sdxl-vae-fp16-fix"):
26
- return AutoencoderKL.from_pretrained(model_name, torch_dtype=torch.float16)
27
-
28
- def get_controlnet(model_name="diffusers/controlnet-canny-sdxl-1.0"):
29
- return ControlNetModel.from_pretrained(model_name, torch_dtype=torch.float16)
30
-
31
- def get_adapter(model_name="Adapter/t2iadapter", subfolder="sketch_sdxl_1.0",
32
- adapter_type="full_adapter_xl"):
33
- if adapter_type == "full_adapter_xl":
34
- return T2IAdapter.from_pretrained(model_name,
35
- subfolder=subfolder,
36
- torch_dtype=torch.float16,
37
- adapter_type=adapter_type)
38
-
39
- def get_scheduler(model_name, scheduler_type="discrete"):
40
- if scheduler_type == "discrete":
41
- return EulerAncestralDiscreteScheduler.from_pretrained(model_name, subfolder="scheduler")
42
-
43
-
44
- def get_detector(model_name="lllyasviel/Annotators", model_type='pidi'):
45
- if model_type == 'pidi':
46
- return PidiNetDetector.from_pretrained(model_name)
47
-
48
-
49
- def load_lora(pipe, lora_path=None):
50
- if lora_path != None:
51
- try:
52
- lora_dir='./'+'/'.join(lora_path.split("/")[:-1])
53
- lora_name=lora_path.split("/")[-1]
54
- pipe.load_lora_weights(lora_dir, weight_name=lora_name)
55
- except Exception as ex:
56
- print(ex)
57
- #return pipe
58
-
59
-
60
- def get_pipe(vae, model_name, controlnet=None, adapter=None, scheduler=None, lora_path=None):
61
- if controlnet!=None:
62
- pipe=StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(model_name,
63
- controlnet=controlnet,
64
- vae=vae,
65
- torch_dtype=torch.float16)
66
-
67
- load_lora(pipe, lora_path)
68
- return pipe
69
-
70
- elif adapter != None:
71
- pipe=StableDiffusionXLAdapterPipeline.from_pretrained(model_name,
72
- adapter=adapter,
73
- vae=vae,
74
- scheduler=scheduler,
75
- torch_dtype=torch.float16,
76
- variant="fp16")
77
- load_lora(pipe, lora_path)
78
  return pipe
 
1
+ import torch
2
+ from diffusers import (ControlNetModel,
3
+ StableDiffusionXLControlNetImg2ImgPipeline,
4
+ AutoencoderKL,
5
+ T2IAdapter,
6
+ StableDiffusionXLAdapterPipeline,
7
+ EulerAncestralDiscreteScheduler)
8
+
9
+ from controlnet_aux.pidi import PidiNetDetector
10
+
11
+ from PIL import Image
12
+ import os
13
+
14
+
15
+ def get_vae(model_name="madebyollin/sdxl-vae-fp16-fix"):
16
+ return AutoencoderKL.from_pretrained(model_name, torch_dtype=torch.float16)
17
+
18
+ def get_controlnet(model_name="diffusers/controlnet-canny-sdxl-1.0"):
19
+ return ControlNetModel.from_pretrained(model_name, torch_dtype=torch.float16)
20
+
21
+ def get_adapter(model_name="Adapter/t2iadapter", subfolder="sketch_sdxl_1.0",
22
+ adapter_type="full_adapter_xl"):
23
+ if adapter_type == "full_adapter_xl":
24
+ return T2IAdapter.from_pretrained(model_name,
25
+ subfolder=subfolder,
26
+ torch_dtype=torch.float16,
27
+ adapter_type=adapter_type)
28
+
29
+ def get_scheduler(model_name, scheduler_type="discrete"):
30
+ if scheduler_type == "discrete":
31
+ return EulerAncestralDiscreteScheduler.from_pretrained(model_name, subfolder="scheduler")
32
+
33
+
34
+ def get_detector(model_name="lllyasviel/Annotators", model_type='pidi'):
35
+ if model_type == 'pidi':
36
+ return PidiNetDetector.from_pretrained(model_name)
37
+
38
+
39
+ def load_lora(pipe, lora_path=None):
40
+ if lora_path != None:
41
+ try:
42
+ lora_dir='./'+'/'.join(lora_path.split("/")[:-1])
43
+ lora_name=lora_path.split("/")[-1]
44
+ pipe.load_lora_weights(lora_dir, weight_name=lora_name)
45
+ except Exception as ex:
46
+ print(ex)
47
+ #return pipe
48
+
49
+
50
+ def get_pipe(vae, model_name, controlnet=None, adapter=None, scheduler=None, lora_path=None):
51
+ if controlnet!=None:
52
+ pipe=StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(model_name,
53
+ controlnet=controlnet,
54
+ vae=vae,
55
+ torch_dtype=torch.float16)
56
+
57
+ load_lora(pipe, lora_path)
58
+ return pipe
59
+
60
+ elif adapter != None:
61
+ pipe=StableDiffusionXLAdapterPipeline.from_pretrained(model_name,
62
+ adapter=adapter,
63
+ vae=vae,
64
+ scheduler=scheduler,
65
+ torch_dtype=torch.float16,
66
+ variant="fp16")
67
+ load_lora(pipe, lora_path)
 
 
 
 
 
 
 
 
 
 
68
  return pipe