radames commited on
Commit
e13bdf0
·
1 Parent(s): ea6542c

Flash Diffusion JasperAI

Browse files
server/pipelines/controlnetDepthFlashSD.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import (
2
+ StableDiffusionControlNetImg2ImgPipeline,
3
+ ControlNetModel,
4
+ AutoencoderTiny,
5
+ LCMScheduler,
6
+ )
7
+ from compel import Compel, ReturnedEmbeddingsType
8
+ import torch
9
+ from transformers import pipeline
10
+
11
+ try:
12
+ import intel_extension_for_pytorch as ipex # type: ignore
13
+ except:
14
+ pass
15
+
16
+ from config import Args
17
+ from pydantic import BaseModel, Field
18
+ from PIL import Image
19
+ import math
20
+
21
+ controlnet_model = "lllyasviel/control_v11f1p_sd15_depth"
22
+ model_id = "runwayml/stable-diffusion-v1-5"
23
+ taesd_model = "madebyollin/taesd"
24
+
25
+ default_prompt = "Portrait of The Terminator with , glare pose, detailed, intricate, full of colour, cinematic lighting, trending on artstation, 8k, hyperrealistic, focused, extreme details, unreal engine 5 cinematic, masterpiece"
26
+ default_negative_prompt = "blurry, low quality, render, 3D, oversaturated"
27
+ page_content = """
28
+ <h1 class="text-3xl font-bold">Flash-SD + Depth</h1>
29
+ <h3 class="text-xl font-bold">Image-to-Image ControlNet</h3>
30
+
31
+ """
32
+
33
+
34
+ class Pipeline:
35
+ class Info(BaseModel):
36
+ name: str = "controlnet+SDXL+Turbo"
37
+ title: str = "SDXL Turbo + Controlnet"
38
+ description: str = "Generates an image from a text prompt"
39
+ input_mode: str = "image"
40
+ page_content: str = page_content
41
+
42
+ class InputParams(BaseModel):
43
+ prompt: str = Field(
44
+ default_prompt,
45
+ title="Prompt",
46
+ field="textarea",
47
+ id="prompt",
48
+ )
49
+ negative_prompt: str = Field(
50
+ default_negative_prompt,
51
+ title="Negative Prompt",
52
+ field="textarea",
53
+ id="negative_prompt",
54
+ hide=True,
55
+ )
56
+ seed: int = Field(
57
+ 2159232, min=0, title="Seed", field="seed", hide=True, id="seed"
58
+ )
59
+ steps: int = Field(
60
+ 2, min=1, max=15, title="Steps", field="range", hide=True, id="steps"
61
+ )
62
+ width: int = Field(
63
+ 512, min=2, max=15, title="Width", disabled=True, hide=True, id="width"
64
+ )
65
+ height: int = Field(
66
+ 512, min=2, max=15, title="Height", disabled=True, hide=True, id="height"
67
+ )
68
+ guidance_scale: float = Field(
69
+ 0.0,
70
+ min=0,
71
+ max=10,
72
+ step=0.001,
73
+ title="Guidance Scale",
74
+ field="range",
75
+ hide=True,
76
+ id="guidance_scale",
77
+ )
78
+ strength: float = Field(
79
+ 0.5,
80
+ min=0.25,
81
+ max=1.0,
82
+ step=0.001,
83
+ title="Strength",
84
+ field="range",
85
+ hide=True,
86
+ id="strength",
87
+ )
88
+ eta: float = Field(
89
+ 1.0,
90
+ min=0,
91
+ max=1.0,
92
+ step=0.001,
93
+ title="Eta",
94
+ field="range",
95
+ hide=True,
96
+ id="eta",
97
+ )
98
+ controlnet_scale: float = Field(
99
+ 0.5,
100
+ min=0,
101
+ max=1.0,
102
+ step=0.001,
103
+ title="Controlnet Scale",
104
+ field="range",
105
+ hide=True,
106
+ id="controlnet_scale",
107
+ )
108
+ controlnet_start: float = Field(
109
+ 0.0,
110
+ min=0,
111
+ max=1.0,
112
+ step=0.001,
113
+ title="Controlnet Start",
114
+ field="range",
115
+ hide=True,
116
+ id="controlnet_start",
117
+ )
118
+ controlnet_end: float = Field(
119
+ 1.0,
120
+ min=0,
121
+ max=1.0,
122
+ step=0.001,
123
+ title="Controlnet End",
124
+ field="range",
125
+ hide=True,
126
+ id="controlnet_end",
127
+ )
128
+ debug_depth: bool = Field(
129
+ False,
130
+ title="Debug Depth",
131
+ field="checkbox",
132
+ hide=True,
133
+ id="debug_depth",
134
+ )
135
+
136
+ def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
137
+ controlnet_depth = ControlNetModel.from_pretrained(
138
+ controlnet_model, torch_dtype=torch_dtype
139
+ )
140
+
141
+ self.depth_estimator = pipeline(
142
+ task="depth-estimation",
143
+ # model="Intel/dpt-swinv2-large-384",
144
+ # model="Intel/dpt-swinv2-base-384",
145
+ model="LiheYoung/depth-anything-small-hf",
146
+ device=device,
147
+ )
148
+
149
+ if args.safety_checker:
150
+ self.pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
151
+ model_id, controlnet=controlnet_depth, torch_dtype=torch_dtype
152
+ )
153
+ else:
154
+ self.pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
155
+ model_id,
156
+ safety_checker=None,
157
+ controlnet=controlnet_depth,
158
+ torch_dtype=torch_dtype,
159
+ )
160
+
161
+ if args.taesd:
162
+ self.pipe.vae = AutoencoderTiny.from_pretrained(
163
+ taesd_model, torch_dtype=torch_dtype
164
+ )
165
+
166
+ self.pipe.load_lora_weights("jasperai/flash-sd")
167
+ self.pipe.fuse_lora()
168
+
169
+ self.pipe.scheduler = LCMScheduler.from_pretrained(
170
+ model_id,
171
+ subfolder="scheduler",
172
+ timestep_spacing="trailing",
173
+ )
174
+ self.pipe.fuse_lora()
175
+
176
+ if args.sfast:
177
+ from sfast.compilers.stable_diffusion_pipeline_compiler import (
178
+ compile,
179
+ CompilationConfig,
180
+ )
181
+
182
+ config = CompilationConfig.Default()
183
+ # config.enable_xformers = True
184
+ config.enable_triton = True
185
+ config.enable_cuda_graph = True
186
+ self.pipe = compile(self.pipe, config=config)
187
+
188
+ self.pipe.set_progress_bar_config(disable=True)
189
+ self.pipe.to(device=device)
190
+ if device.type != "mps":
191
+ self.pipe.unet.to(memory_format=torch.channels_last)
192
+
193
+ if args.compel:
194
+ self.pipe.compel_proc = Compel(
195
+ tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2],
196
+ text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2],
197
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
198
+ requires_pooled=[False, True],
199
+ )
200
+
201
+ if args.torch_compile:
202
+ self.pipe.unet = torch.compile(
203
+ self.pipe.unet, mode="reduce-overhead", fullgraph=True
204
+ )
205
+ self.pipe.vae = torch.compile(
206
+ self.pipe.vae, mode="reduce-overhead", fullgraph=True
207
+ )
208
+ self.pipe(
209
+ prompt="warmup",
210
+ image=[Image.new("RGB", (768, 768))],
211
+ control_image=[Image.new("RGB", (768, 768))],
212
+ )
213
+
214
+ def predict(self, params: "Pipeline.InputParams") -> Image.Image:
215
+ generator = torch.manual_seed(params.seed)
216
+
217
+ prompt = params.prompt
218
+ negative_prompt = params.negative_prompt
219
+ prompt_embeds = None
220
+ pooled_prompt_embeds = None
221
+ negative_prompt_embeds = None
222
+ negative_pooled_prompt_embeds = None
223
+ if hasattr(self.pipe, "compel_proc"):
224
+ _prompt_embeds, pooled_prompt_embeds = self.pipe.compel_proc(
225
+ [params.prompt, params.negative_prompt]
226
+ )
227
+ prompt = None
228
+ negative_prompt = None
229
+ prompt_embeds = _prompt_embeds[0:1]
230
+ pooled_prompt_embeds = pooled_prompt_embeds[0:1]
231
+ negative_prompt_embeds = _prompt_embeds[1:2]
232
+ negative_pooled_prompt_embeds = pooled_prompt_embeds[1:2]
233
+
234
+ control_image = self.depth_estimator(params.image)["depth"]
235
+ steps = params.steps
236
+ strength = params.strength
237
+ if int(steps * strength) < 1:
238
+ steps = math.ceil(1 / max(0.10, strength))
239
+
240
+ results = self.pipe(
241
+ image=params.image,
242
+ control_image=control_image,
243
+ prompt=prompt,
244
+ negative_prompt=negative_prompt,
245
+ prompt_embeds=prompt_embeds,
246
+ pooled_prompt_embeds=pooled_prompt_embeds,
247
+ negative_prompt_embeds=negative_prompt_embeds,
248
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
249
+ generator=generator,
250
+ strength=strength,
251
+ eta=params.eta,
252
+ num_inference_steps=steps,
253
+ guidance_scale=params.guidance_scale,
254
+ width=params.width,
255
+ height=params.height,
256
+ output_type="pil",
257
+ controlnet_conditioning_scale=params.controlnet_scale,
258
+ control_guidance_start=params.controlnet_start,
259
+ control_guidance_end=params.controlnet_end,
260
+ )
261
+
262
+ nsfw_content_detected = (
263
+ results.nsfw_content_detected[0]
264
+ if "nsfw_content_detected" in results
265
+ else False
266
+ )
267
+ if nsfw_content_detected:
268
+ return None
269
+ result_image = results.images[0]
270
+ if params.debug_depth:
271
+ # paste control_image on top of result_image
272
+ w0, h0 = (200, 200)
273
+ control_image = control_image.resize((w0, h0))
274
+ w1, h1 = result_image.size
275
+ result_image.paste(control_image, (w1 - w0, h1 - h0))
276
+
277
+ return result_image
server/pipelines/controlnetFlashSD.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import (
2
+ StableDiffusionControlNetImg2ImgPipeline,
3
+ ControlNetModel,
4
+ AutoencoderTiny,
5
+ LCMScheduler,
6
+ )
7
+ from compel import Compel, ReturnedEmbeddingsType
8
+ import torch
9
+ from pipelines.utils.canny_gpu import SobelOperator
10
+
11
+ try:
12
+ import intel_extension_for_pytorch as ipex # type: ignore
13
+ except:
14
+ pass
15
+
16
+ from config import Args
17
+ from pydantic import BaseModel, Field
18
+ from PIL import Image
19
+ import math
20
+
21
+ controlnet_model = "lllyasviel/control_v11p_sd15_canny"
22
+ model_id = "runwayml/stable-diffusion-v1-5"
23
+ taesd_model = "madebyollin/taesd"
24
+
25
+ default_prompt = "Portrait of The Terminator with , glare pose, detailed, intricate, full of colour, cinematic lighting, trending on artstation, 8k, hyperrealistic, focused, extreme details, unreal engine 5 cinematic, masterpiece"
26
+ default_negative_prompt = "blurry, low quality, render, 3D, oversaturated"
27
+ page_content = """
28
+ <h1 class="text-3xl font-bold">Flash-SD</h1>
29
+ <h3 class="text-xl font-bold">Image-to-Image ControlNet</h3>
30
+
31
+ """
32
+
33
+
34
+ class Pipeline:
35
+ class Info(BaseModel):
36
+ name: str = "controlnet+SDXL+Turbo"
37
+ title: str = "SDXL Turbo + Controlnet"
38
+ description: str = "Generates an image from a text prompt"
39
+ input_mode: str = "image"
40
+ page_content: str = page_content
41
+
42
+ class InputParams(BaseModel):
43
+ prompt: str = Field(
44
+ default_prompt,
45
+ title="Prompt",
46
+ field="textarea",
47
+ id="prompt",
48
+ )
49
+ negative_prompt: str = Field(
50
+ default_negative_prompt,
51
+ title="Negative Prompt",
52
+ field="textarea",
53
+ id="negative_prompt",
54
+ hide=True,
55
+ )
56
+ seed: int = Field(
57
+ 2159232, min=0, title="Seed", field="seed", hide=True, id="seed"
58
+ )
59
+ steps: int = Field(
60
+ 2, min=1, max=15, title="Steps", field="range", hide=True, id="steps"
61
+ )
62
+ width: int = Field(
63
+ 512, min=2, max=15, title="Width", disabled=True, hide=True, id="width"
64
+ )
65
+ height: int = Field(
66
+ 512, min=2, max=15, title="Height", disabled=True, hide=True, id="height"
67
+ )
68
+ strength: float = Field(
69
+ 0.5,
70
+ min=0.25,
71
+ max=1.0,
72
+ step=0.001,
73
+ title="Strength",
74
+ field="range",
75
+ hide=True,
76
+ id="strength",
77
+ )
78
+ controlnet_scale: float = Field(
79
+ 0.5,
80
+ min=0,
81
+ max=1.0,
82
+ step=0.001,
83
+ title="Controlnet Scale",
84
+ field="range",
85
+ hide=True,
86
+ id="controlnet_scale",
87
+ )
88
+ controlnet_start: float = Field(
89
+ 0.0,
90
+ min=0,
91
+ max=1.0,
92
+ step=0.001,
93
+ title="Controlnet Start",
94
+ field="range",
95
+ hide=True,
96
+ id="controlnet_start",
97
+ )
98
+ controlnet_end: float = Field(
99
+ 1.0,
100
+ min=0,
101
+ max=1.0,
102
+ step=0.001,
103
+ title="Controlnet End",
104
+ field="range",
105
+ hide=True,
106
+ id="controlnet_end",
107
+ )
108
+ canny_low_threshold: float = Field(
109
+ 0.31,
110
+ min=0,
111
+ max=1.0,
112
+ step=0.001,
113
+ title="Canny Low Threshold",
114
+ field="range",
115
+ hide=True,
116
+ id="canny_low_threshold",
117
+ )
118
+ canny_high_threshold: float = Field(
119
+ 0.125,
120
+ min=0,
121
+ max=1.0,
122
+ step=0.001,
123
+ title="Canny High Threshold",
124
+ field="range",
125
+ hide=True,
126
+ id="canny_high_threshold",
127
+ )
128
+ debug_canny: bool = Field(
129
+ False,
130
+ title="Debug Canny",
131
+ field="checkbox",
132
+ hide=True,
133
+ id="debug_canny",
134
+ )
135
+
136
+ def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
137
+ controlnet_canny = ControlNetModel.from_pretrained(
138
+ controlnet_model, torch_dtype=torch_dtype
139
+ )
140
+
141
+ if args.safety_checker:
142
+ self.pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
143
+ model_id, controlnet=controlnet_canny, torch_dtype=torch_dtype
144
+ )
145
+ else:
146
+ self.pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
147
+ model_id,
148
+ safety_checker=None,
149
+ controlnet=controlnet_canny,
150
+ torch_dtype=torch_dtype,
151
+ )
152
+
153
+ self.pipe.scheduler = LCMScheduler.from_pretrained(
154
+ model_id,
155
+ subfolder="scheduler",
156
+ timestep_spacing="trailing",
157
+ )
158
+
159
+ if args.taesd:
160
+ self.pipe.vae = AutoencoderTiny.from_pretrained(
161
+ taesd_model, torch_dtype=torch_dtype, use_safetensors=True
162
+ )
163
+ self.pipe.load_lora_weights("jasperai/flash-sd")
164
+ self.pipe.fuse_lora()
165
+
166
+ self.canny_torch = SobelOperator(device=device)
167
+
168
+ if args.sfast:
169
+ from sfast.compilers.stable_diffusion_pipeline_compiler import (
170
+ compile,
171
+ CompilationConfig,
172
+ )
173
+
174
+ config = CompilationConfig.Default()
175
+ # config.enable_xformers = True
176
+ config.enable_triton = True
177
+ config.enable_cuda_graph = True
178
+ self.pipe = compile(self.pipe, config=config)
179
+
180
+ self.pipe.set_progress_bar_config(disable=True)
181
+ self.pipe.to(device=device)
182
+ if device.type != "mps":
183
+ self.pipe.unet.to(memory_format=torch.channels_last)
184
+
185
+ if args.compel:
186
+ self.pipe.compel_proc = Compel(
187
+ tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2],
188
+ text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2],
189
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
190
+ requires_pooled=[False, True],
191
+ )
192
+
193
+ if args.torch_compile:
194
+ self.pipe.unet = torch.compile(
195
+ self.pipe.unet, mode="reduce-overhead", fullgraph=True
196
+ )
197
+ self.pipe.vae = torch.compile(
198
+ self.pipe.vae, mode="reduce-overhead", fullgraph=True
199
+ )
200
+ self.pipe(
201
+ prompt="warmup",
202
+ image=[Image.new("RGB", (768, 768))],
203
+ control_image=[Image.new("RGB", (768, 768))],
204
+ )
205
+
206
+ def predict(self, params: "Pipeline.InputParams") -> Image.Image:
207
+ generator = torch.manual_seed(params.seed)
208
+
209
+ prompt = params.prompt
210
+ negative_prompt = params.negative_prompt
211
+ prompt_embeds = None
212
+ pooled_prompt_embeds = None
213
+ negative_prompt_embeds = None
214
+ negative_pooled_prompt_embeds = None
215
+ if hasattr(self.pipe, "compel_proc"):
216
+ _prompt_embeds, pooled_prompt_embeds = self.pipe.compel_proc(
217
+ [params.prompt, params.negative_prompt]
218
+ )
219
+ prompt = None
220
+ negative_prompt = None
221
+ prompt_embeds = _prompt_embeds[0:1]
222
+ pooled_prompt_embeds = pooled_prompt_embeds[0:1]
223
+ negative_prompt_embeds = _prompt_embeds[1:2]
224
+ negative_pooled_prompt_embeds = pooled_prompt_embeds[1:2]
225
+
226
+ control_image = self.canny_torch(
227
+ params.image, params.canny_low_threshold, params.canny_high_threshold
228
+ )
229
+ steps = params.steps
230
+ strength = params.strength
231
+ if int(steps * strength) < 1:
232
+ steps = math.ceil(1 / max(0.10, strength))
233
+
234
+ results = self.pipe(
235
+ image=params.image,
236
+ control_image=control_image,
237
+ prompt=prompt,
238
+ negative_prompt=negative_prompt,
239
+ prompt_embeds=prompt_embeds,
240
+ pooled_prompt_embeds=pooled_prompt_embeds,
241
+ negative_prompt_embeds=negative_prompt_embeds,
242
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
243
+ generator=generator,
244
+ strength=strength,
245
+ num_inference_steps=steps,
246
+ guidance_scale=0,
247
+ width=params.width,
248
+ height=params.height,
249
+ output_type="pil",
250
+ controlnet_conditioning_scale=params.controlnet_scale,
251
+ control_guidance_start=params.controlnet_start,
252
+ control_guidance_end=params.controlnet_end,
253
+ )
254
+
255
+ nsfw_content_detected = (
256
+ results.nsfw_content_detected[0]
257
+ if "nsfw_content_detected" in results
258
+ else False
259
+ )
260
+ if nsfw_content_detected:
261
+ return None
262
+ result_image = results.images[0]
263
+ if params.debug_canny:
264
+ # paste control_image on top of result_image
265
+ w0, h0 = (200, 200)
266
+ control_image = control_image.resize((w0, h0))
267
+ w1, h1 = result_image.size
268
+ result_image.paste(control_image, (w1 - w0, h1 - h0))
269
+
270
+ return result_image
server/pipelines/controlnetFlashSDXL.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import (
2
+ StableDiffusionXLControlNetImg2ImgPipeline,
3
+ ControlNetModel,
4
+ AutoencoderKL,
5
+ LCMScheduler,
6
+ )
7
+ from compel import Compel, ReturnedEmbeddingsType
8
+ import torch
9
+ from pipelines.utils.canny_gpu import SobelOperator
10
+ from huggingface_hub import hf_hub_download
11
+
12
+ try:
13
+ import intel_extension_for_pytorch as ipex # type: ignore
14
+ except:
15
+ pass
16
+
17
+ from config import Args
18
+ from pydantic import BaseModel, Field
19
+ from PIL import Image
20
+ import math
21
+
22
+ # controlnet_model = "diffusers/controlnet-canny-sdxl-1.0"
23
+ controlnet_model = "xinsir/controlnet-canny-sdxl-1.0"
24
+ model_id = "stabilityai/stable-diffusion-xl-base-1.0"
25
+ taesd_model = "madebyollin/taesdxl"
26
+
27
+ default_prompt = "Portrait of The Terminator with , glare pose, detailed, intricate, full of colour, cinematic lighting, trending on artstation, 8k, hyperrealistic, focused, extreme details, unreal engine 5 cinematic, masterpiece"
28
+ default_negative_prompt = "blurry, low quality, render, 3D, oversaturated"
29
+ page_content = """
30
+ <h1 class="text-3xl font-bold">Flash-SDXL</h1>
31
+ <h3 class="text-xl font-bold">Image-to-Image ControlNet</h3>
32
+
33
+ """
34
+
35
+
36
+ class Pipeline:
37
+ class Info(BaseModel):
38
+ name: str = "controlnet+SDXL+Turbo"
39
+ title: str = "SDXL Turbo + Controlnet"
40
+ description: str = "Generates an image from a text prompt"
41
+ input_mode: str = "image"
42
+ page_content: str = page_content
43
+
44
+ class InputParams(BaseModel):
45
+ prompt: str = Field(
46
+ default_prompt,
47
+ title="Prompt",
48
+ field="textarea",
49
+ id="prompt",
50
+ )
51
+ negative_prompt: str = Field(
52
+ default_negative_prompt,
53
+ title="Negative Prompt",
54
+ field="textarea",
55
+ id="negative_prompt",
56
+ hide=True,
57
+ )
58
+ seed: int = Field(
59
+ 2159232, min=0, title="Seed", field="seed", hide=True, id="seed"
60
+ )
61
+ steps: int = Field(
62
+ 2, min=1, max=15, title="Steps", field="range", hide=True, id="steps"
63
+ )
64
+ width: int = Field(
65
+ 1024, min=2, max=15, title="Width", disabled=True, hide=True, id="width"
66
+ )
67
+ height: int = Field(
68
+ 1024, min=2, max=15, title="Height", disabled=True, hide=True, id="height"
69
+ )
70
+ strength: float = Field(
71
+ 0.5,
72
+ min=0.25,
73
+ max=1.0,
74
+ step=0.001,
75
+ title="Strength",
76
+ field="range",
77
+ hide=True,
78
+ id="strength",
79
+ )
80
+ controlnet_scale: float = Field(
81
+ 0.5,
82
+ min=0,
83
+ max=1.0,
84
+ step=0.001,
85
+ title="Controlnet Scale",
86
+ field="range",
87
+ hide=True,
88
+ id="controlnet_scale",
89
+ )
90
+ controlnet_start: float = Field(
91
+ 0.0,
92
+ min=0,
93
+ max=1.0,
94
+ step=0.001,
95
+ title="Controlnet Start",
96
+ field="range",
97
+ hide=True,
98
+ id="controlnet_start",
99
+ )
100
+ controlnet_end: float = Field(
101
+ 1.0,
102
+ min=0,
103
+ max=1.0,
104
+ step=0.001,
105
+ title="Controlnet End",
106
+ field="range",
107
+ hide=True,
108
+ id="controlnet_end",
109
+ )
110
+ canny_low_threshold: float = Field(
111
+ 0.31,
112
+ min=0,
113
+ max=1.0,
114
+ step=0.001,
115
+ title="Canny Low Threshold",
116
+ field="range",
117
+ hide=True,
118
+ id="canny_low_threshold",
119
+ )
120
+ canny_high_threshold: float = Field(
121
+ 0.125,
122
+ min=0,
123
+ max=1.0,
124
+ step=0.001,
125
+ title="Canny High Threshold",
126
+ field="range",
127
+ hide=True,
128
+ id="canny_high_threshold",
129
+ )
130
+ debug_canny: bool = Field(
131
+ False,
132
+ title="Debug Canny",
133
+ field="checkbox",
134
+ hide=True,
135
+ id="debug_canny",
136
+ )
137
+
138
+ def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
139
+ controlnet_canny = ControlNetModel.from_pretrained(
140
+ controlnet_model, torch_dtype=torch_dtype
141
+ )
142
+ vae = AutoencoderKL.from_pretrained(
143
+ "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch_dtype
144
+ )
145
+
146
+ if args.safety_checker:
147
+ self.pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
148
+ model_id, controlnet=controlnet_canny, vae=vae, torch_dtype=torch_dtype
149
+ )
150
+ else:
151
+ self.pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
152
+ model_id,
153
+ safety_checker=None,
154
+ controlnet=controlnet_canny,
155
+ vae=vae,
156
+ torch_dtype=torch_dtype,
157
+ )
158
+
159
+ self.pipe.scheduler = LCMScheduler.from_pretrained(
160
+ model_id,
161
+ subfolder="scheduler",
162
+ timestep_spacing="trailing",
163
+ )
164
+ self.pipe.load_lora_weights("jasperai/flash-sdxl")
165
+ self.pipe.fuse_lora()
166
+
167
+ self.canny_torch = SobelOperator(device=device)
168
+
169
+ if args.sfast:
170
+ from sfast.compilers.stable_diffusion_pipeline_compiler import (
171
+ compile,
172
+ CompilationConfig,
173
+ )
174
+
175
+ config = CompilationConfig.Default()
176
+ # config.enable_xformers = True
177
+ config.enable_triton = True
178
+ config.enable_cuda_graph = True
179
+ self.pipe = compile(self.pipe, config=config)
180
+
181
+ self.pipe.set_progress_bar_config(disable=True)
182
+ self.pipe.to(device=device)
183
+ if device.type != "mps":
184
+ self.pipe.unet.to(memory_format=torch.channels_last)
185
+
186
+ if args.compel:
187
+ self.pipe.compel_proc = Compel(
188
+ tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2],
189
+ text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2],
190
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
191
+ requires_pooled=[False, True],
192
+ )
193
+
194
+ if args.torch_compile:
195
+ self.pipe.unet = torch.compile(
196
+ self.pipe.unet, mode="reduce-overhead", fullgraph=True
197
+ )
198
+ self.pipe.vae = torch.compile(
199
+ self.pipe.vae, mode="reduce-overhead", fullgraph=True
200
+ )
201
+ self.pipe(
202
+ prompt="warmup",
203
+ image=[Image.new("RGB", (768, 768))],
204
+ control_image=[Image.new("RGB", (768, 768))],
205
+ )
206
+
207
+ def predict(self, params: "Pipeline.InputParams") -> Image.Image:
208
+ generator = torch.manual_seed(params.seed)
209
+
210
+ prompt = params.prompt
211
+ negative_prompt = params.negative_prompt
212
+ prompt_embeds = None
213
+ pooled_prompt_embeds = None
214
+ negative_prompt_embeds = None
215
+ negative_pooled_prompt_embeds = None
216
+ if hasattr(self.pipe, "compel_proc"):
217
+ _prompt_embeds, pooled_prompt_embeds = self.pipe.compel_proc(
218
+ [params.prompt, params.negative_prompt]
219
+ )
220
+ prompt = None
221
+ negative_prompt = None
222
+ prompt_embeds = _prompt_embeds[0:1]
223
+ pooled_prompt_embeds = pooled_prompt_embeds[0:1]
224
+ negative_prompt_embeds = _prompt_embeds[1:2]
225
+ negative_pooled_prompt_embeds = pooled_prompt_embeds[1:2]
226
+
227
+ control_image = self.canny_torch(
228
+ params.image, params.canny_low_threshold, params.canny_high_threshold
229
+ )
230
+ steps = params.steps
231
+ strength = params.strength
232
+ if int(steps * strength) < 1:
233
+ steps = math.ceil(1 / max(0.10, strength))
234
+
235
+ results = self.pipe(
236
+ image=params.image,
237
+ control_image=control_image,
238
+ prompt=prompt,
239
+ negative_prompt=negative_prompt,
240
+ prompt_embeds=prompt_embeds,
241
+ pooled_prompt_embeds=pooled_prompt_embeds,
242
+ negative_prompt_embeds=negative_prompt_embeds,
243
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
244
+ generator=generator,
245
+ strength=strength,
246
+ num_inference_steps=steps,
247
+ guidance_scale=0,
248
+ width=params.width,
249
+ height=params.height,
250
+ output_type="pil",
251
+ controlnet_conditioning_scale=params.controlnet_scale,
252
+ control_guidance_start=params.controlnet_start,
253
+ control_guidance_end=params.controlnet_end,
254
+ )
255
+
256
+ nsfw_content_detected = (
257
+ results.nsfw_content_detected[0]
258
+ if "nsfw_content_detected" in results
259
+ else False
260
+ )
261
+ if nsfw_content_detected:
262
+ return None
263
+ result_image = results.images[0]
264
+ if params.debug_canny:
265
+ # paste control_image on top of result_image
266
+ w0, h0 = (200, 200)
267
+ control_image = control_image.resize((w0, h0))
268
+ w1, h1 = result_image.size
269
+ result_image.paste(control_image, (w1 - w0, h1 - h0))
270
+
271
+ return result_image