radames commited on
Commit
ea6542c
·
1 Parent(s): decd923

HyperSDXL with MistoLine

Browse files
server/pipelines/controlnetMistoLineHyperSDXL.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import (
2
+ StableDiffusionXLControlNetImg2ImgPipeline,
3
+ ControlNetModel,
4
+ AutoencoderKL,
5
+ TCDScheduler,
6
+ )
7
+ from compel import Compel, ReturnedEmbeddingsType
8
+ import torch
9
+ from controlnet_aux import AnylineDetector
10
+ from huggingface_hub import hf_hub_download
11
+
12
+ try:
13
+ import intel_extension_for_pytorch as ipex # type: ignore
14
+ except:
15
+ pass
16
+
17
+ import psutil
18
+ from config import Args
19
+ from pydantic import BaseModel, Field
20
+ from PIL import Image
21
+ import math
22
+
23
+ # controlnet_model = "diffusers/controlnet-canny-sdxl-1.0"
24
+ controlnet_model = "TheMistoAI/MistoLine"
25
+ model_id = "stabilityai/stable-diffusion-xl-base-1.0"
26
+ taesd_model = "madebyollin/taesdxl"
27
+
28
+ default_prompt = "Portrait of The Terminator with , glare pose, detailed, intricate, full of colour, cinematic lighting, trending on artstation, 8k, hyperrealistic, focused, extreme details, unreal engine 5 cinematic, masterpiece"
29
+ default_negative_prompt = "blurry, low quality, render, 3D, oversaturated"
30
+ page_content = """
31
+ <h1 class="text-3xl font-bold">Hyper-SDXL Unified + MistoLine</h1>
32
+ <h3 class="text-xl font-bold">Image-to-Image ControlNet</h3>
33
+
34
+ """
35
+
36
+
37
+ class Pipeline:
38
+ class Info(BaseModel):
39
+ name: str = "controlnet+SDXL+Turbo"
40
+ title: str = "SDXL Turbo + Controlnet"
41
+ description: str = "Generates an image from a text prompt"
42
+ input_mode: str = "image"
43
+ page_content: str = page_content
44
+
45
+ class InputParams(BaseModel):
46
+ prompt: str = Field(
47
+ default_prompt,
48
+ title="Prompt",
49
+ field="textarea",
50
+ id="prompt",
51
+ )
52
+ negative_prompt: str = Field(
53
+ default_negative_prompt,
54
+ title="Negative Prompt",
55
+ field="textarea",
56
+ id="negative_prompt",
57
+ hide=True,
58
+ )
59
+ seed: int = Field(
60
+ 2159232, min=0, title="Seed", field="seed", hide=True, id="seed"
61
+ )
62
+ steps: int = Field(
63
+ 2, min=1, max=15, title="Steps", field="range", hide=True, id="steps"
64
+ )
65
+ width: int = Field(
66
+ 1024, min=2, max=15, title="Width", disabled=True, hide=True, id="width"
67
+ )
68
+ height: int = Field(
69
+ 1024, min=2, max=15, title="Height", disabled=True, hide=True, id="height"
70
+ )
71
+ guidance_scale: float = Field(
72
+ 0.0,
73
+ min=0,
74
+ max=10,
75
+ step=0.001,
76
+ title="Guidance Scale",
77
+ field="range",
78
+ hide=True,
79
+ id="guidance_scale",
80
+ )
81
+ strength: float = Field(
82
+ 0.5,
83
+ min=0.25,
84
+ max=1.0,
85
+ step=0.001,
86
+ title="Strength",
87
+ field="range",
88
+ hide=True,
89
+ id="strength",
90
+ )
91
+ eta: float = Field(
92
+ 1.0,
93
+ min=0,
94
+ max=1.0,
95
+ step=0.001,
96
+ title="Eta",
97
+ field="range",
98
+ hide=True,
99
+ id="eta",
100
+ )
101
+ controlnet_scale: float = Field(
102
+ 0.5,
103
+ min=0,
104
+ max=1.0,
105
+ step=0.001,
106
+ title="Controlnet Scale",
107
+ field="range",
108
+ hide=True,
109
+ id="controlnet_scale",
110
+ )
111
+ controlnet_start: float = Field(
112
+ 0.0,
113
+ min=0,
114
+ max=1.0,
115
+ step=0.001,
116
+ title="Controlnet Start",
117
+ field="range",
118
+ hide=True,
119
+ id="controlnet_start",
120
+ )
121
+ controlnet_end: float = Field(
122
+ 1.0,
123
+ min=0,
124
+ max=1.0,
125
+ step=0.001,
126
+ title="Controlnet End",
127
+ field="range",
128
+ hide=True,
129
+ id="controlnet_end",
130
+ )
131
+ guassian_sigma: float = Field(
132
+ 2.0,
133
+ min=0.01,
134
+ max=10.0,
135
+ step=0.001,
136
+ title="(Anyline) Gaussian Sigma",
137
+ field="range",
138
+ hide=True,
139
+ id="guassian_sigma",
140
+ )
141
+ intensity_threshold: float = Field(
142
+ 3,
143
+ min=0,
144
+ max=255,
145
+ step=1,
146
+ title="(Anyline) Intensity Threshold",
147
+ field="range",
148
+ hide=True,
149
+ id="intensity_threshold",
150
+ )
151
+ debug_canny: bool = Field(
152
+ False,
153
+ title="Debug Canny",
154
+ field="checkbox",
155
+ hide=True,
156
+ id="debug_canny",
157
+ )
158
+
159
+ def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
160
+ controlnet_canny = ControlNetModel.from_pretrained(
161
+ controlnet_model,
162
+ torch_dtype=torch_dtype,
163
+ variant="fp16",
164
+ )
165
+ vae = AutoencoderKL.from_pretrained(
166
+ "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch_dtype
167
+ )
168
+
169
+ if args.safety_checker:
170
+ self.pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
171
+ model_id, controlnet=controlnet_canny, vae=vae, torch_dtype=torch_dtype
172
+ )
173
+ else:
174
+ self.pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
175
+ model_id,
176
+ safety_checker=None,
177
+ controlnet=controlnet_canny,
178
+ vae=vae,
179
+ torch_dtype=torch_dtype,
180
+ )
181
+
182
+ self.pipe.load_lora_weights(
183
+ hf_hub_download("ByteDance/Hyper-SD", "Hyper-SDXL-1step-lora.safetensors")
184
+ )
185
+
186
+ self.pipe.scheduler = TCDScheduler.from_config(self.pipe.scheduler.config)
187
+
188
+ self.pipe.fuse_lora()
189
+ self.anyline = AnylineDetector.from_pretrained(
190
+ "TheMistoAI/MistoLine", filename="MTEED.pth", subfolder="Anyline"
191
+ ).to(device)
192
+
193
+ if args.sfast:
194
+ from sfast.compilers.stable_diffusion_pipeline_compiler import (
195
+ compile,
196
+ CompilationConfig,
197
+ )
198
+
199
+ config = CompilationConfig.Default()
200
+ # config.enable_xformers = True
201
+ config.enable_triton = True
202
+ config.enable_cuda_graph = True
203
+ self.pipe = compile(self.pipe, config=config)
204
+
205
+ self.pipe.set_progress_bar_config(disable=True)
206
+ self.pipe.to(device=device)
207
+ if device.type != "mps":
208
+ self.pipe.unet.to(memory_format=torch.channels_last)
209
+
210
+ if args.compel:
211
+ self.pipe.compel_proc = Compel(
212
+ tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2],
213
+ text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2],
214
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
215
+ requires_pooled=[False, True],
216
+ )
217
+
218
+ if args.torch_compile:
219
+ self.pipe.unet = torch.compile(
220
+ self.pipe.unet, mode="reduce-overhead", fullgraph=True
221
+ )
222
+ self.pipe.vae = torch.compile(
223
+ self.pipe.vae, mode="reduce-overhead", fullgraph=True
224
+ )
225
+ self.pipe(
226
+ prompt="warmup",
227
+ image=[Image.new("RGB", (768, 768))],
228
+ control_image=[Image.new("RGB", (768, 768))],
229
+ )
230
+
231
+ def predict(self, params: "Pipeline.InputParams") -> Image.Image:
232
+ generator = torch.manual_seed(params.seed)
233
+
234
+ prompt = params.prompt
235
+ negative_prompt = params.negative_prompt
236
+ prompt_embeds = None
237
+ pooled_prompt_embeds = None
238
+ negative_prompt_embeds = None
239
+ negative_pooled_prompt_embeds = None
240
+ if hasattr(self.pipe, "compel_proc"):
241
+ _prompt_embeds, pooled_prompt_embeds = self.pipe.compel_proc(
242
+ [params.prompt, params.negative_prompt]
243
+ )
244
+ prompt = None
245
+ negative_prompt = None
246
+ prompt_embeds = _prompt_embeds[0:1]
247
+ pooled_prompt_embeds = pooled_prompt_embeds[0:1]
248
+ negative_prompt_embeds = _prompt_embeds[1:2]
249
+ negative_pooled_prompt_embeds = pooled_prompt_embeds[1:2]
250
+
251
+ control_image = self.anyline(
252
+ params.image,
253
+ detect_resolution=1280,
254
+ guassian_sigma=max(0.01, params.guassian_sigma),
255
+ intensity_threshold=params.intensity_threshold,
256
+ )
257
+
258
+ steps = params.steps
259
+ strength = params.strength
260
+ if int(steps * strength) < 1:
261
+ steps = math.ceil(1 / max(0.10, strength))
262
+
263
+ results = self.pipe(
264
+ image=params.image,
265
+ control_image=control_image,
266
+ prompt=prompt,
267
+ negative_prompt=negative_prompt,
268
+ prompt_embeds=prompt_embeds,
269
+ pooled_prompt_embeds=pooled_prompt_embeds,
270
+ negative_prompt_embeds=negative_prompt_embeds,
271
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
272
+ generator=generator,
273
+ strength=strength,
274
+ eta=params.eta,
275
+ num_inference_steps=steps,
276
+ guidance_scale=params.guidance_scale,
277
+ width=params.width,
278
+ height=params.height,
279
+ output_type="pil",
280
+ controlnet_conditioning_scale=params.controlnet_scale,
281
+ control_guidance_start=params.controlnet_start,
282
+ control_guidance_end=params.controlnet_end,
283
+ )
284
+
285
+ nsfw_content_detected = (
286
+ results.nsfw_content_detected[0]
287
+ if "nsfw_content_detected" in results
288
+ else False
289
+ )
290
+ if nsfw_content_detected:
291
+ return None
292
+ result_image = results.images[0]
293
+ if params.debug_canny:
294
+ # paste control_image on top of result_image
295
+ w0, h0 = (200, 200)
296
+ control_image = control_image.resize((w0, h0))
297
+ w1, h1 = result_image.size
298
+ result_image.paste(control_image, (w1 - w0, h1 - h0))
299
+
300
+ return result_image
server/requirements.txt CHANGED
@@ -17,4 +17,5 @@ oneflow @ https://github.com/siliconflow/oneflow_releases/releases/download/comm
17
  onediff @ git+https://github.com/siliconflow/onediff.git@main#egg=onediff ; sys_platform != 'darwin' or platform_machine != 'arm64'
18
  setuptools
19
  mpmath==1.3.0
20
- numpy==1.*
 
 
17
  onediff @ git+https://github.com/siliconflow/onediff.git@main#egg=onediff ; sys_platform != 'darwin' or platform_machine != 'arm64'
18
  setuptools
19
  mpmath==1.3.0
20
+ numpy==1.*
21
+ controlnet-aux