suvadityamuk commited on
Commit
202b06d
·
1 Parent(s): 73472c7

Upload stable_diff_comp_2.py

Browse files
Files changed (1) hide show
  1. stable_diff_comp_2.py +400 -0
stable_diff_comp_2.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Dict, List, Optional, Union
2
+
3
+ import torch
4
+ from diffusers import (
5
+ AutoencoderKL,
6
+ DDIMScheduler,
7
+ DiffusionPipeline,
8
+ LMSDiscreteScheduler,
9
+ PNDMScheduler,
10
+ StableDiffusionPipeline,
11
+ UNet2DConditionModel,
12
+ )
13
+ from diffusers import DiffusionPipeline, StableDiffusionPipeline
14
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
15
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
16
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
17
+
18
+ pipe1_model_id = "CompVis/stable-diffusion-v1-1"
19
+ pipe2_model_id = "CompVis/stable-diffusion-v1-2"
20
+ pipe3_model_id = "CompVis/stable-diffusion-v1-3"
21
+ pipe4_model_id = "CompVis/stable-diffusion-v1-4"
22
+
23
+
24
+ class StableDiffusionComparisonPipeline(DiffusionPipeline):
25
+ r"""
26
+ Pipeline for parallel comparison of Stable Diffusion v1-v4
27
+ This pipeline inherits from DiffusionPipeline and depends on the use of an Auth Token for
28
+ downloading pre-trained checkpoints from Hugging Face Hub.
29
+ Args:
30
+ pipe1 ('StableDiffusionPipeline' or 'str', optional):
31
+ A Stable Diffusion Pipeline prepared from the SD1.1 Checkpoints on Hugging Face Hub
32
+ pipe2 ('StableDiffusionPipeline' or 'str', optional):
33
+ A Stable Diffusion Pipeline prepared from the SD1.2 Checkpoints on Hugging Face Hub
34
+ pipe3 ('StableDiffusionPipeline' or 'str', optional):
35
+ A Stable Diffusion Pipeline prepared from the SD1.3 Checkpoints on Hugging Face Hub
36
+ pipe4 ('StableDiffusionPipeline' or 'str', optional):
37
+ A Stable Diffusion Pipeline prepared from the SD1.4 Checkpoints on Hugging Face Hub
38
+ """
39
+
40
+ # def _init_(
41
+ # self,
42
+ # sd1_1: Union[StableDiffusionPipeline, str],
43
+ # sd1_2: Union[StableDiffusionPipeline, str],
44
+ # sd1_3: Union[StableDiffusionPipeline, str],
45
+ # sd1_4: Union[StableDiffusionPipeline, str],
46
+ # ):
47
+ def __init__(
48
+ self,
49
+ vae: AutoencoderKL,
50
+ text_encoder: CLIPTextModel,
51
+ tokenizer: CLIPTokenizer,
52
+ unet: UNet2DConditionModel,
53
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
54
+ safety_checker: StableDiffusionSafetyChecker,
55
+ feature_extractor: CLIPFeatureExtractor,
56
+ requires_safety_checker: bool = True,
57
+ ):
58
+ super()._init_()
59
+
60
+ self.pipe1 = StableDiffusionPipeline.from_pretrained(pipe1_model_id)
61
+ self.pipe2 = StableDiffusionPipeline.from_pretrained(pipe2_model_id)
62
+ self.pipe3 = StableDiffusionPipeline.from_pretrained(pipe3_model_id)
63
+ self.pipe4 = StableDiffusionPipeline(
64
+ vae=vae,
65
+ text_encoder=text_encoder,
66
+ tokenizer=tokenizer,
67
+ unet=unet,
68
+ scheduler=scheduler,
69
+ safety_checker=safety_checker,
70
+ feature_extractor=feature_extractor,
71
+ requires_safety_checker=requires_safety_checker
72
+ )
73
+
74
+ self.register_modules(pipeline1=self.pipe1, pipeline2=self.pipe2, pipeline3=self.pipe3, pipeline4=self.pipe4)
75
+
76
+
77
+ # if not isinstance(sd1_1, StableDiffusionPipeline):
78
+ # self.pipe1 = StableDiffusionPipeline.from_pretrained(
79
+ # pipe1_model_id, torch_dtype=torch.float16, revision="fp16", use_auth_token=True
80
+ # )
81
+ # else:
82
+ # self.pipe1 = sd1_1
83
+ # if not isinstance(sd1_2, StableDiffusionPipeline):
84
+ # self.pipe2 = StableDiffusionPipeline.from_pretrained(
85
+ # pipe2_model_id, torch_dtype=torch.float16, revision="fp16", use_auth_token=True
86
+ # )
87
+ # else:
88
+ # self.pipe2 = sd1_2
89
+ # if not isinstance(sd1_3, StableDiffusionPipeline):
90
+ # self.pipe3 = StableDiffusionPipeline.from_pretrained(
91
+ # pipe3_model_id, torch_dtype=torch.float16, revision="fp16", use_auth_token=True
92
+ # )
93
+ # else:
94
+ # self.pipe3 = sd1_3
95
+ # if not isinstance(sd1_4, StableDiffusionPipeline):
96
+ # self.pipe4 = StableDiffusionPipeline.from_pretrained(
97
+ # pipe4_model_id, torch_dtype=torch.float16, revision="fp16", use_auth_token=True
98
+ # )
99
+ # else:
100
+ # self.pipe4 = sd1_4
101
+
102
+ @property
103
+ def layers(self) -> Dict[str, Any]:
104
+ return {k: getattr(self, k) for k in self.config.keys() if not k.startswith("_")}
105
+
106
+ @torch.no_grad()
107
+ def text2img_sd1_1(
108
+ self,
109
+ prompt: Union[str, List[str]],
110
+ height: int = 512,
111
+ width: int = 512,
112
+ num_inference_steps: int = 50,
113
+ guidance_scale: float = 7.5,
114
+ negative_prompt: Optional[Union[str, List[str]]] = None,
115
+ num_images_per_prompt: Optional[int] = 1,
116
+ eta: float = 0.0,
117
+ generator: Optional[torch.Generator] = None,
118
+ latents: Optional[torch.FloatTensor] = None,
119
+ output_type: Optional[str] = "pil",
120
+ return_dict: bool = True,
121
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
122
+ callback_steps: Optional[int] = 1,
123
+ **kwargs,
124
+ ):
125
+ return self.pipe1(
126
+ prompt=prompt,
127
+ height=height,
128
+ width=width,
129
+ num_inference_steps=num_inference_steps,
130
+ guidance_scale=guidance_scale,
131
+ negative_prompt=negative_prompt,
132
+ num_images_per_prompt=num_images_per_prompt,
133
+ eta=eta,
134
+ generator=generator,
135
+ latents=latents,
136
+ output_type=output_type,
137
+ return_dict=return_dict,
138
+ callback=callback,
139
+ callback_steps=callback_steps,
140
+ **kwargs,
141
+ )
142
+
143
+ @torch.no_grad()
144
+ def text2img_sd1_2(
145
+ self,
146
+ prompt: Union[str, List[str]],
147
+ height: int = 512,
148
+ width: int = 512,
149
+ num_inference_steps: int = 50,
150
+ guidance_scale: float = 7.5,
151
+ negative_prompt: Optional[Union[str, List[str]]] = None,
152
+ num_images_per_prompt: Optional[int] = 1,
153
+ eta: float = 0.0,
154
+ generator: Optional[torch.Generator] = None,
155
+ latents: Optional[torch.FloatTensor] = None,
156
+ output_type: Optional[str] = "pil",
157
+ return_dict: bool = True,
158
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
159
+ callback_steps: Optional[int] = 1,
160
+ **kwargs,
161
+ ):
162
+ return self.pipe2(
163
+ prompt=prompt,
164
+ height=height,
165
+ width=width,
166
+ num_inference_steps=num_inference_steps,
167
+ guidance_scale=guidance_scale,
168
+ negative_prompt=negative_prompt,
169
+ num_images_per_prompt=num_images_per_prompt,
170
+ eta=eta,
171
+ generator=generator,
172
+ latents=latents,
173
+ output_type=output_type,
174
+ return_dict=return_dict,
175
+ callback=callback,
176
+ callback_steps=callback_steps,
177
+ **kwargs,
178
+ )
179
+
180
+ @torch.no_grad()
181
+ def text2img_sd1_3(
182
+ self,
183
+ prompt: Union[str, List[str]],
184
+ height: int = 512,
185
+ width: int = 512,
186
+ num_inference_steps: int = 50,
187
+ guidance_scale: float = 7.5,
188
+ negative_prompt: Optional[Union[str, List[str]]] = None,
189
+ num_images_per_prompt: Optional[int] = 1,
190
+ eta: float = 0.0,
191
+ generator: Optional[torch.Generator] = None,
192
+ latents: Optional[torch.FloatTensor] = None,
193
+ output_type: Optional[str] = "pil",
194
+ return_dict: bool = True,
195
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
196
+ callback_steps: Optional[int] = 1,
197
+ **kwargs,
198
+ ):
199
+ return self.pipe3(
200
+ prompt=prompt,
201
+ height=height,
202
+ width=width,
203
+ num_inference_steps=num_inference_steps,
204
+ guidance_scale=guidance_scale,
205
+ negative_prompt=negative_prompt,
206
+ num_images_per_prompt=num_images_per_prompt,
207
+ eta=eta,
208
+ generator=generator,
209
+ latents=latents,
210
+ output_type=output_type,
211
+ return_dict=return_dict,
212
+ callback=callback,
213
+ callback_steps=callback_steps,
214
+ **kwargs,
215
+ )
216
+
217
+ @torch.no_grad()
218
+ def text2img_sd1_4(
219
+ self,
220
+ prompt: Union[str, List[str]],
221
+ height: int = 512,
222
+ width: int = 512,
223
+ num_inference_steps: int = 50,
224
+ guidance_scale: float = 7.5,
225
+ negative_prompt: Optional[Union[str, List[str]]] = None,
226
+ num_images_per_prompt: Optional[int] = 1,
227
+ eta: float = 0.0,
228
+ generator: Optional[torch.Generator] = None,
229
+ latents: Optional[torch.FloatTensor] = None,
230
+ output_type: Optional[str] = "pil",
231
+ return_dict: bool = True,
232
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
233
+ callback_steps: Optional[int] = 1,
234
+ **kwargs,
235
+ ):
236
+ return self.pipe4(
237
+ prompt=prompt,
238
+ height=height,
239
+ width=width,
240
+ num_inference_steps=num_inference_steps,
241
+ guidance_scale=guidance_scale,
242
+ negative_prompt=negative_prompt,
243
+ num_images_per_prompt=num_images_per_prompt,
244
+ eta=eta,
245
+ generator=generator,
246
+ latents=latents,
247
+ output_type=output_type,
248
+ return_dict=return_dict,
249
+ callback=callback,
250
+ callback_steps=callback_steps,
251
+ **kwargs,
252
+ )
253
+
254
+ @torch.no_grad()
255
+ def _call_(
256
+ self,
257
+ prompt: Union[str, List[str]],
258
+ height: int = 512,
259
+ width: int = 512,
260
+ num_inference_steps: int = 50,
261
+ guidance_scale: float = 7.5,
262
+ negative_prompt: Optional[Union[str, List[str]]] = None,
263
+ num_images_per_prompt: Optional[int] = 1,
264
+ eta: float = 0.0,
265
+ generator: Optional[torch.Generator] = None,
266
+ latents: Optional[torch.FloatTensor] = None,
267
+ output_type: Optional[str] = "pil",
268
+ return_dict: bool = True,
269
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
270
+ callback_steps: Optional[int] = 1,
271
+ **kwargs,
272
+ ):
273
+ r"""
274
+ Function invoked when calling the pipeline for generation. This function will generate 4 results as part
275
+ of running all the 4 pipelines for SD1.1-1.4 together in a serial-processing, parallel-invocation fashion.
276
+ Args:
277
+ prompt (`str` or `List[str]`):
278
+ The prompt or prompts to guide the image generation.
279
+ height (`int`, optional, defaults to 512):
280
+ The height in pixels of the generated image.
281
+ width (`int`, optional, defaults to 512):
282
+ The width in pixels of the generated image.
283
+ num_inference_steps (`int`, optional, defaults to 50):
284
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
285
+ expense of slower inference.
286
+ guidance_scale (`float`, optional, defaults to 7.5):
287
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
288
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
289
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
290
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
291
+ usually at the expense of lower image quality.
292
+ eta (`float`, optional, defaults to 0.0):
293
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
294
+ [`schedulers.DDIMScheduler`], will be ignored for others.
295
+ generator (`torch.Generator`, optional):
296
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
297
+ deterministic.
298
+ latents (`torch.FloatTensor`, optional):
299
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
300
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
301
+ tensor will ge generated by sampling using the supplied random `generator`.
302
+ output_type (`str`, optional, defaults to `"pil"`):
303
+ The output format of the generate image. Choose between
304
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
305
+ return_dict (`bool`, optional, defaults to `True`):
306
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
307
+ plain tuple.
308
+ Returns:
309
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
310
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
311
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
312
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
313
+ (nsfw) content, according to the `safety_checker`.
314
+ """
315
+
316
+ device = "cuda" if torch.cuda.is_available() else "cpu"
317
+ self.to(device)
318
+
319
+ # Checks if the height and width are divisible by 8 or not
320
+ if height % 8 != 0 or width % 8 != 0:
321
+ raise ValueError(f"`height` and `width` must be divisible by 8 but are {height} and {width}.")
322
+
323
+ # Get first result from Stable Diffusion Checkpoint v1.1
324
+ res1 = self.text2img_sd1_1(
325
+ prompt=prompt,
326
+ height=height,
327
+ width=width,
328
+ num_inference_steps=num_inference_steps,
329
+ guidance_scale=guidance_scale,
330
+ negative_prompt=negative_prompt,
331
+ num_images_per_prompt=num_images_per_prompt,
332
+ eta=eta,
333
+ generator=generator,
334
+ latents=latents,
335
+ output_type=output_type,
336
+ return_dict=return_dict,
337
+ callback=callback,
338
+ callback_steps=callback_steps,
339
+ **kwargs,
340
+ )
341
+
342
+ # Get first result from Stable Diffusion Checkpoint v1.2
343
+ res2 = self.text2img_sd1_2(
344
+ prompt=prompt,
345
+ height=height,
346
+ width=width,
347
+ num_inference_steps=num_inference_steps,
348
+ guidance_scale=guidance_scale,
349
+ negative_prompt=negative_prompt,
350
+ num_images_per_prompt=num_images_per_prompt,
351
+ eta=eta,
352
+ generator=generator,
353
+ latents=latents,
354
+ output_type=output_type,
355
+ return_dict=return_dict,
356
+ callback=callback,
357
+ callback_steps=callback_steps,
358
+ **kwargs,
359
+ )
360
+
361
+ # Get first result from Stable Diffusion Checkpoint v1.3
362
+ res3 = self.text2img_sd1_3(
363
+ prompt=prompt,
364
+ height=height,
365
+ width=width,
366
+ num_inference_steps=num_inference_steps,
367
+ guidance_scale=guidance_scale,
368
+ negative_prompt=negative_prompt,
369
+ num_images_per_prompt=num_images_per_prompt,
370
+ eta=eta,
371
+ generator=generator,
372
+ latents=latents,
373
+ output_type=output_type,
374
+ return_dict=return_dict,
375
+ callback=callback,
376
+ callback_steps=callback_steps,
377
+ **kwargs,
378
+ )
379
+
380
+ # Get first result from Stable Diffusion Checkpoint v1.4
381
+ res4 = self.text2img_sd1_4(
382
+ prompt=prompt,
383
+ height=height,
384
+ width=width,
385
+ num_inference_steps=num_inference_steps,
386
+ guidance_scale=guidance_scale,
387
+ negative_prompt=negative_prompt,
388
+ num_images_per_prompt=num_images_per_prompt,
389
+ eta=eta,
390
+ generator=generator,
391
+ latents=latents,
392
+ output_type=output_type,
393
+ return_dict=return_dict,
394
+ callback=callback,
395
+ callback_steps=callback_steps,
396
+ **kwargs,
397
+ )
398
+
399
+ # Get all result images into a single list and pass it via StableDiffusionPipelineOutput for final result
400
+ return StableDiffusionPipelineOutput([res1[0], res2[0], res3[0], res4[0]])