radames commited on
Commit
6c0021c
·
1 Parent(s): 1ea3019

PCM : Phased Consistency Model controlnet

Browse files
Files changed (1) hide show
  1. server/pipelines/controlnetPCMSD15.py +256 -0
server/pipelines/controlnetPCMSD15.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import (
2
+ StableDiffusionControlNetImg2ImgPipeline,
3
+ ControlNetModel,
4
+ TCDScheduler,
5
+ AutoencoderTiny,
6
+ )
7
+ from compel import Compel
8
+ import torch
9
+ from pipelines.utils.canny_gpu import SobelOperator
10
+
11
+ try:
12
+ import intel_extension_for_pytorch as ipex # type: ignore
13
+ except:
14
+ pass
15
+
16
+ from config import Args
17
+ from pydantic import BaseModel, Field
18
+ from PIL import Image
19
+
20
+ taesd_model = "madebyollin/taesd"
21
+ controlnet_model = "lllyasviel/control_v11p_sd15_canny"
22
+ base_model_id = "runwayml/stable-diffusion-v1-5"
23
+ pcm_base = "wangfuyun/PCM_Weights"
24
+ pcm_lora_ckpts = {
25
+ "2-Step": ["pcm_sd15_smallcfg_2step_converted.safetensors", 2, 0.0],
26
+ "4-Step": ["pcm_sd15_smallcfg_4step_converted.safetensors", 4, 0.0],
27
+ "8-Step": ["pcm_sd15_smallcfg_8step_converted.safetensors", 8, 0.0],
28
+ "16-Step": ["pcm_sd15_smallcfg_16step_converted.safetensors", 16, 0.0],
29
+ "Normal CFG 4-Step": ["pcm_sd15_normalcfg_4step_converted.safetensors", 4, 7.5],
30
+ "Normal CFG 8-Step": ["pcm_sd15_normalcfg_8step_converted.safetensors", 8, 7.5],
31
+ "Normal CFG 16-Step": ["pcm_sd15_normalcfg_16step_converted.safetensors", 16, 7.5],
32
+ }
33
+ default_prompt = "Portrait of The Terminator with , glare pose, detailed, intricate, full of colour, cinematic lighting, trending on artstation, 8k, hyperrealistic, focused, extreme details, unreal engine 5 cinematic, masterpiece"
34
+ page_content = """
35
+
36
+ """
37
+
38
+
39
+ class Pipeline:
40
+ class Info(BaseModel):
41
+ name: str = "controlnet+loras+sd15"
42
+ title: str = "LCM + LoRA + Controlnet"
43
+ description: str = "Generates an image from a text prompt"
44
+ input_mode: str = "image"
45
+ page_content: str = page_content
46
+
47
+ class InputParams(BaseModel):
48
+ prompt: str = Field(
49
+ default_prompt,
50
+ title="Prompt",
51
+ field="textarea",
52
+ id="prompt",
53
+ )
54
+ lora_ckpt_id: str = Field(
55
+ "4-Step",
56
+ title="PCM Base Model",
57
+ values=list(pcm_lora_ckpts.keys()),
58
+ field="select",
59
+ id="lora_ckpt_id",
60
+ )
61
+ seed: int = Field(
62
+ 2159232, min=0, title="Seed", field="seed", hide=True, id="seed"
63
+ )
64
+ width: int = Field(
65
+ 768, min=2, max=15, title="Width", disabled=True, hide=True, id="width"
66
+ )
67
+ height: int = Field(
68
+ 768, min=2, max=15, title="Height", disabled=True, hide=True, id="height"
69
+ )
70
+ strength: float = Field(
71
+ 0.5,
72
+ min=0.25,
73
+ max=1.0,
74
+ step=0.001,
75
+ title="Strength",
76
+ field="range",
77
+ hide=True,
78
+ id="strength",
79
+ )
80
+ controlnet_scale: float = Field(
81
+ 0.8,
82
+ min=0,
83
+ max=1.0,
84
+ step=0.001,
85
+ title="Controlnet Scale",
86
+ field="range",
87
+ hide=True,
88
+ id="controlnet_scale",
89
+ )
90
+ controlnet_start: float = Field(
91
+ 0.0,
92
+ min=0,
93
+ max=1.0,
94
+ step=0.001,
95
+ title="Controlnet Start",
96
+ field="range",
97
+ hide=True,
98
+ id="controlnet_start",
99
+ )
100
+ controlnet_end: float = Field(
101
+ 1.0,
102
+ min=0,
103
+ max=1.0,
104
+ step=0.001,
105
+ title="Controlnet End",
106
+ field="range",
107
+ hide=True,
108
+ id="controlnet_end",
109
+ )
110
+ canny_low_threshold: float = Field(
111
+ 0.31,
112
+ min=0,
113
+ max=1.0,
114
+ step=0.001,
115
+ title="Canny Low Threshold",
116
+ field="range",
117
+ hide=True,
118
+ id="canny_low_threshold",
119
+ )
120
+ canny_high_threshold: float = Field(
121
+ 0.125,
122
+ min=0,
123
+ max=1.0,
124
+ step=0.001,
125
+ title="Canny High Threshold",
126
+ field="range",
127
+ hide=True,
128
+ id="canny_high_threshold",
129
+ )
130
+ debug_canny: bool = Field(
131
+ False,
132
+ title="Debug Canny",
133
+ field="checkbox",
134
+ hide=True,
135
+ id="debug_canny",
136
+ )
137
+
138
+ def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
139
+ controlnet_canny = ControlNetModel.from_pretrained(
140
+ controlnet_model, torch_dtype=torch_dtype
141
+ ).to(device)
142
+
143
+ if args.safety_checker:
144
+ self.pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
145
+ base_model_id,
146
+ controlnet=controlnet_canny,
147
+ )
148
+ else:
149
+ self.pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
150
+ base_model_id,
151
+ safety_checker=None,
152
+ controlnet=controlnet_canny,
153
+ )
154
+
155
+ self.canny_torch = SobelOperator(device=device)
156
+
157
+ self.pipe.scheduler = TCDScheduler(
158
+ num_train_timesteps=1000,
159
+ beta_start=0.00085,
160
+ beta_end=0.012,
161
+ beta_schedule="scaled_linear",
162
+ timestep_spacing="trailing",
163
+ )
164
+
165
+ self.pipe.set_progress_bar_config(disable=True)
166
+ if device.type != "mps":
167
+ self.pipe.unet.to(memory_format=torch.channels_last)
168
+
169
+ if args.taesd:
170
+ self.pipe.vae = AutoencoderTiny.from_pretrained(
171
+ taesd_model, torch_dtype=torch_dtype, use_safetensors=True
172
+ ).to(device)
173
+
174
+ self.loaded_lora = "4-Step"
175
+ self.pipe.load_lora_weights(
176
+ pcm_base,
177
+ weight_name=pcm_lora_ckpts[self.loaded_lora][0],
178
+ subfolder="sd15",
179
+ )
180
+ self.pipe.to(device=device, dtype=torch_dtype).to(device)
181
+ if args.compel:
182
+ self.compel_proc = Compel(
183
+ tokenizer=self.pipe.tokenizer,
184
+ text_encoder=self.pipe.text_encoder,
185
+ truncate_long_prompts=False,
186
+ )
187
+ if args.torch_compile:
188
+ self.pipe.unet = torch.compile(
189
+ self.pipe.unet, mode="reduce-overhead", fullgraph=True
190
+ )
191
+ self.pipe.vae = torch.compile(
192
+ self.pipe.vae, mode="reduce-overhead", fullgraph=True
193
+ )
194
+ self.pipe(
195
+ prompt="warmup",
196
+ image=[Image.new("RGB", (768, 768))],
197
+ control_image=[Image.new("RGB", (768, 768))],
198
+ )
199
+
200
+ def predict(self, params: "Pipeline.InputParams") -> Image.Image:
201
+ generator = torch.manual_seed(params.seed)
202
+ guidance_scale = pcm_lora_ckpts[params.lora_ckpt_id][2]
203
+ steps = pcm_lora_ckpts[params.lora_ckpt_id][1]
204
+
205
+ if self.loaded_lora != params.lora_ckpt_id:
206
+ checkpoint = pcm_lora_ckpts[params.lora_ckpt_id][0]
207
+ self.pipe.load_lora_weights(
208
+ pcm_base,
209
+ weight_name=checkpoint,
210
+ subfolder="sd15",
211
+ )
212
+ self.loaded_lora = params.lora_ckpt_id
213
+
214
+ prompt_embeds = None
215
+ prompt = params.prompt
216
+ if hasattr(self, "compel_proc"):
217
+ prompt_embeds = self.compel_proc(prompt)
218
+ prompt = None
219
+ control_image = self.canny_torch(
220
+ params.image, params.canny_low_threshold, params.canny_high_threshold
221
+ )
222
+ strength = params.strength
223
+
224
+ results = self.pipe(
225
+ image=params.image,
226
+ control_image=control_image,
227
+ prompt=prompt,
228
+ prompt_embeds=prompt_embeds,
229
+ generator=generator,
230
+ strength=strength,
231
+ num_inference_steps=steps,
232
+ guidance_scale=guidance_scale,
233
+ width=params.width,
234
+ height=params.height,
235
+ output_type="pil",
236
+ controlnet_conditioning_scale=params.controlnet_scale,
237
+ control_guidance_start=params.controlnet_start,
238
+ control_guidance_end=params.controlnet_end,
239
+ )
240
+
241
+ nsfw_content_detected = (
242
+ results.nsfw_content_detected[0]
243
+ if "nsfw_content_detected" in results
244
+ else False
245
+ )
246
+ if nsfw_content_detected:
247
+ return None
248
+ result_image = results.images[0]
249
+ if params.debug_canny:
250
+ # paste control_image on top of result_image
251
+ w0, h0 = (200, 200)
252
+ control_image = control_image.resize((w0, h0))
253
+ w1, h1 = result_image.size
254
+ result_image.paste(control_image, (w1 - w0, h1 - h0))
255
+
256
+ return result_image