Text-to-Image
PyTorch
majian0318 commited on
Commit
015a181
1 Parent(s): 4651ded

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +628 -3
README.md CHANGED
@@ -1,3 +1,628 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - zh
5
+ - en
6
+ - fr
7
+ - de
8
+ - ja
9
+ - kg
10
+ base_model:
11
+ - stabilityai/stable-diffusion-xl-base-1.0
12
+ pipeline_tag: text-to-image
13
+ ---
14
+
15
+
16
+ ![FLUX.1 [schnell] Grid](./PEA-Diffusion.png)
17
+
18
+
19
+ Text-to-image diffusion models are well-known for their ability to generate realistic images based on textual prompts. However, the existing works have predominantly focused on English, lacking support for non-English text-to-image models. The most commonly used translation methods cannot solve the generation problem related to language culture, while training from scratch on a specific language dataset is prohibitively expensive. In this paper, we are inspired to propose a simple plug-and-play language transfer method based on knowledge distillation. All we need to do is train a lightweight MLP-like parameter-efficient adapter (PEA) with only 6M parameters under teacher knowledge distillation along with a small parallel data corpus. We are surprised to find that freezing the parameters of UNet can still achieve remarkable performance on the language-specific prompt evaluation set, demonstrating that PEA can stimulate the potential generation ability of the original UNet. Additionally, it closely approaches the performance of the English text-to-image model on a general prompt evaluation set. Furthermore, our adapter can be used as a plugin to achieve significant results in downstream tasks in cross-lingual text-to-image generation.
20
+
21
+ # Usage
22
+ We provide examples of adapters for models such as [SDXL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0), [Playground v2.5](https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic), and [stable-cascade](https://huggingface.co/stabilityai/stable-cascade). For SD3, please refer directly to https://huggingface.co/OPPOer/MultilingualSD3-adapter, and for FLUX. 1, please refer to https://huggingface.co/OPPOer/MultilingualFLUX.1-adapter
23
+
24
+
25
+
26
+ ## `SDXL`
27
+ We used the multilingual encoder [Mul-OpenCLIP](https://huggingface.co/laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k).
28
+ As mentioned in the article, you can replace the model here with any SDXL derived model, including sampling acceleration, which can also be directly adapted.
29
+
30
+ ```python
31
+ import os
32
+ import torch
33
+ import torch.nn as nn
34
+
35
+ from PIL import Image
36
+ from diffusers import AutoencoderKL, StableDiffusionXLPipeline,DPMSolverMultistepScheduler
37
+ from diffusers.image_processor import VaeImageProcessor
38
+ from diffusers.models.attention_processor import (
39
+ AttnProcessor2_0,
40
+ LoRAAttnProcessor2_0,
41
+ LoRAXFormersAttnProcessor,
42
+ XFormersAttnProcessor,
43
+ )
44
+
45
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
46
+ import open_clip
47
+
48
+
49
+ def image_grid(imgs, rows, cols):
50
+ assert len(imgs) == rows*cols
51
+
52
+ w, h = imgs[0].size
53
+ grid = Image.new('RGB', size=(cols*w, rows*h))
54
+ grid_w, grid_h = grid.size
55
+
56
+ for i, img in enumerate(imgs):
57
+ grid.paste(img, box=(i%cols*w, i//cols*h))
58
+ return grid
59
+
60
+ class MLP(nn.Module):
61
+ def __init__(self, in_dim, out_dim, hidden_dim,out_dim1, use_residual=True):
62
+ super().__init__()
63
+ if use_residual:
64
+ assert in_dim == out_dim
65
+ self.layernorm = nn.LayerNorm(in_dim)
66
+ self.fc1 = nn.Linear(in_dim, hidden_dim)
67
+ self.fc2 = nn.Linear(hidden_dim, out_dim)
68
+ self.fc3 = nn.Linear(out_dim, out_dim1)
69
+ self.use_residual = use_residual
70
+ self.act_fn = nn.GELU()
71
+
72
+ def forward(self, x):
73
+ residual = x
74
+ x = self.layernorm(x)
75
+ x = self.fc1(x)
76
+ x = self.act_fn(x)
77
+ x = self.fc2(x)
78
+ x2 = self.act_fn(x)
79
+ x2 = self.fc3(x2)
80
+ if self.use_residual:
81
+ x = x + residual
82
+ x1 = torch.mean(x,1)
83
+ return x1,x2
84
+
85
+
86
+ class StableDiffusionTest():
87
+
88
+ def __init__(self, model_id,text_text_encoder_pathpath,proj_path):
89
+ super().__init__()
90
+ self.text_encoder, _, preprocess = open_clip.create_model_and_transforms('xlm-roberta-large-ViT-H-14', pretrained=text_encoder_path)
91
+ self.tokenizer = open_clip.get_tokenizer('xlm-roberta-large-ViT-H-14')
92
+ self.text_encoder.text.output_tokens = True
93
+ self.proj = MLP(1024, 1280, 1024,2048, use_residual=False).to(device,dtype=dtype)
94
+ self.text_encoder = self.text_encoder.to(device)
95
+
96
+ self.vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device)
97
+ scheduler = DPMSolverMultistepScheduler.from_pretrained(model_id, subfolder="scheduler")
98
+ self.pipe = StableDiffusionXLPipeline.from_pretrained(model_id, scheduler=scheduler,torch_dtype=dtype).to(device)
99
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.pipe.vae_scale_factor)
100
+ self.proj.load_state_dict(torch.load(proj_path, map_location="cpu"))
101
+
102
+
103
+ def encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
104
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
105
+
106
+ text_input_ids = self.tokenizer(prompt).to(device,dtype=dtype)
107
+ _,text_embeddings = self.text_encoder.encode_text(text_input_ids)
108
+
109
+ add_text_embeds,text_embeddings_2048 = self.proj(text_embeddings)
110
+
111
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
112
+ bs_embed, seq_len, _ = text_embeddings.shape
113
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
114
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
115
+
116
+ # get unconditional embeddings for classifier free guidance
117
+ if do_classifier_free_guidance:
118
+ uncond_tokens: List[str]
119
+ if negative_prompt is None:
120
+ uncond_tokens = [""] * batch_size
121
+ elif type(prompt) is not type(negative_prompt):
122
+ raise TypeError(
123
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
124
+ f" {type(prompt)}."
125
+ )
126
+ elif isinstance(negative_prompt, str):
127
+ uncond_tokens = [negative_prompt]
128
+ elif batch_size != len(negative_prompt):
129
+ raise ValueError(
130
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
131
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
132
+ " the batch size of `prompt`."
133
+ )
134
+ else:
135
+ uncond_tokens = negative_prompt
136
+
137
+ max_length = text_input_ids.shape[-1]
138
+
139
+ uncond_input_ids = self.tokenizer(uncond_tokens).to(device)
140
+ _,uncond_embeddings = self.text_encoder.encode_text(uncond_input_ids)
141
+
142
+ add_text_embeds_uncond,uncond_embeddings_2048 = self.proj(uncond_embeddings)
143
+
144
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
145
+ seq_len = uncond_embeddings_2048.shape[1]
146
+ uncond_embeddings_2048 = uncond_embeddings_2048.repeat(1, num_images_per_prompt, 1)
147
+ uncond_embeddings_2048 = uncond_embeddings_2048.view(batch_size * num_images_per_prompt, seq_len, -1)
148
+
149
+ text_embeddings_2048 = torch.cat([uncond_embeddings_2048, text_embeddings_2048])
150
+ add_text_embeds = torch.cat([add_text_embeds_uncond, add_text_embeds])
151
+
152
+ return text_embeddings_2048,add_text_embeds
153
+
154
+ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
155
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
156
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
157
+ return add_time_ids
158
+
159
+
160
+ @torch.no_grad()
161
+ def __call__(
162
+ self,
163
+ prompt: Union[str, List[str]],
164
+ height: Optional[int] = 1024,
165
+ width: Optional[int] = 1024,
166
+ num_inference_steps: int = 30,
167
+ guidance_scale: float = 7.5,
168
+ original_size: Optional[Tuple[int, int]] = None,
169
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
170
+ target_size: Optional[Tuple[int, int]] = None,
171
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
172
+ guidance_rescale: float = 0,
173
+ negative_prompt: Optional[Union[str, List[str]]] = None,
174
+ num_images_per_prompt: Optional[int] = 1,
175
+ eta: float = 0.0,
176
+ generator: Optional[torch.Generator] = None,
177
+ latents: Optional[torch.FloatTensor] = None,
178
+ prompt_embeds: Optional[torch.FloatTensor] = None,
179
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
180
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
181
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
182
+ output_type: Optional[str] = "pil",
183
+ return_dict: bool = True,
184
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
185
+ callback_steps: Optional[int] = 1,
186
+ **kwargs,
187
+ ):
188
+ # 0. Default height and width to unet
189
+ height = height or self.pipe.unet.config.sample_size * self.pipe.vae_scale_factor
190
+ width = width or self.pipe.unet.config.sample_size * self.pipe.vae_scale_factor
191
+ original_size = original_size or (height, width)
192
+ target_size = target_size or (height, width)
193
+
194
+ # 1. Check inputs. Raise error if not correct
195
+ # self.pipe.check_inputs(prompt, height, width, callback_steps)
196
+
197
+ # 2. Define call parameters
198
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
199
+ device = self.pipe._execution_device
200
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
201
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
202
+ # corresponds to doing no classifier free guidance.
203
+ do_classifier_free_guidance = guidance_scale > 1.0
204
+
205
+ # 3. Encode input prompt
206
+
207
+ prompt_embeds,add_text_embeds = self.encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt)
208
+ prompt_embeds = prompt_embeds
209
+ add_text_embeds = add_text_embeds
210
+
211
+ # 4. Prepare timesteps
212
+ self.pipe.scheduler.set_timesteps(num_inference_steps, device=device)
213
+ timesteps = self.pipe.scheduler.timesteps
214
+
215
+ # 5. Prepare latent variables
216
+ num_channels_latents = self.pipe.unet.in_channels
217
+ latents = self.pipe.prepare_latents(
218
+ batch_size * num_images_per_prompt,
219
+ num_channels_latents,
220
+ height,
221
+ width,
222
+ prompt_embeds.dtype,
223
+ device,
224
+ generator,
225
+ latents,
226
+ )
227
+
228
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
229
+ extra_step_kwargs = self.pipe.prepare_extra_step_kwargs(generator, eta)
230
+
231
+ add_time_ids = self._get_add_time_ids(original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype)
232
+ if do_classifier_free_guidance:
233
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
234
+
235
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
236
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
237
+
238
+ # 7. Denoising loop
239
+ for i, t in enumerate(self.pipe.progress_bar(timesteps)):
240
+ # expand the latents if we are doing classifier free guidance
241
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
242
+ latent_model_input = self.pipe.scheduler.scale_model_input(latent_model_input, t)
243
+
244
+ # predict the noise residual
245
+ noise_pred = self.pipe.unet(
246
+ latent_model_input,
247
+ t,
248
+ encoder_hidden_states=prompt_embeds,
249
+ cross_attention_kwargs=cross_attention_kwargs,
250
+ added_cond_kwargs=added_cond_kwargs,
251
+ return_dict=False,
252
+ )[0]
253
+
254
+ # noise_pred = self.pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
255
+
256
+ # perform guidance
257
+ if do_classifier_free_guidance:
258
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
259
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
260
+
261
+ # compute the previous noisy sample x_t -> x_t-1
262
+ # latents = self.pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
263
+ latents = self.pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
264
+
265
+ # call the callback, if provided
266
+ if callback is not None and i % callback_steps == 0:
267
+ callback(i, t, latents)
268
+
269
+ self.vae.to(dtype=torch.float32)
270
+
271
+ use_torch_2_0_or_xformers = self.vae.decoder.mid_block.attentions[0].processor in [
272
+ AttnProcessor2_0,
273
+ XFormersAttnProcessor,
274
+ LoRAXFormersAttnProcessor,
275
+ LoRAAttnProcessor2_0,
276
+ ]
277
+ # if xformers or torch_2_0 is used attention block does not need
278
+ # to be in float32 which can save lots of memory
279
+ if not use_torch_2_0_or_xformers:
280
+ self.vae.post_quant_conv.to(latents.dtype)
281
+ self.vae.decoder.conv_in.to(latents.dtype)
282
+ self.vae.decoder.mid_block.to(latents.dtype)
283
+ else:
284
+ latents = latents.float()
285
+
286
+ # 8. Post-processing
287
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
288
+ image = self.image_processor.postprocess(image, output_type="np")
289
+
290
+ # 10. Convert to PIL
291
+ if output_type == "pil":
292
+ image = self.pipe.numpy_to_pil(image)
293
+
294
+ return image
295
+
296
+
297
+ if __name__ == '__main__':
298
+ device = "cuda"
299
+ dtype = torch.float16
300
+
301
+ text_encoder_path = 'laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/open_clip_pytorch_model.bin'
302
+ model_id = "stablediffusionapi/protovision-xl-v6.6"
303
+ proj_path = "OPPOer/PEA-Diffusion/pytorch_model.bin"
304
+
305
+ sdt = StableDiffusionTest(model_id,text_encoder_path,proj_path)
306
+
307
+ batch=2
308
+ height = 1024
309
+ width = 1024
310
+ while True:
311
+ raw_text = input("\nPlease Input Query (stop to exit) >>> ")
312
+ if not raw_text:
313
+ print('Query should not be empty!')
314
+ continue
315
+ if raw_text == "stop":
316
+ break
317
+ images = sdt([raw_text]*batch,height=height,width=width)
318
+ grid = image_grid(images, rows=1, cols=batch)
319
+ grid.save("SDXL.png")
320
+
321
+ ```
322
+
323
+
324
+
325
+
326
+
327
+ ## `Playground v2.5`
328
+ We used the multilingual encoder [Mul-OpenCLIP](https://huggingface.co/laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k)
329
+
330
+ ```python
331
+ import os,sys
332
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
333
+ import sys
334
+ import random
335
+ from tqdm import tqdm
336
+
337
+ import torch
338
+ import torch.nn as nn
339
+ import numpy as np
340
+
341
+ import argparse
342
+ from PIL import Image
343
+ import json
344
+ from diffusers import AutoencoderKL, DiffusionPipeline
345
+ from diffusers.image_processor import VaeImageProcessor
346
+ from diffusers.models.attention_processor import (
347
+ AttnProcessor2_0,
348
+ LoRAAttnProcessor2_0,
349
+ LoRAXFormersAttnProcessor,
350
+ XFormersAttnProcessor,
351
+ )
352
+ import open_clip
353
+
354
+
355
+ def image_grid(imgs, rows, cols):
356
+ assert len(imgs) == rows*cols
357
+
358
+ w, h = imgs[0].size
359
+ grid = Image.new('RGB', size=(cols*w, rows*h))
360
+ grid_w, grid_h = grid.size
361
+
362
+ for i, img in enumerate(imgs):
363
+ grid.paste(img, box=(i%cols*w, i//cols*h))
364
+ return grid
365
+
366
+
367
+ class MLP(nn.Module):
368
+ def __init__(self, in_dim=1024, out_dim=1280, hidden_dim=2048, out_dim1=2048, use_residual=True):
369
+ super().__init__()
370
+ if use_residual:
371
+ assert in_dim == out_dim
372
+ self.layernorm = nn.LayerNorm(in_dim)
373
+ self.projector = nn.Sequential(
374
+ nn.Linear(in_dim, hidden_dim, bias=False),
375
+ nn.GELU(),
376
+ nn.Linear(hidden_dim, hidden_dim, bias=False),
377
+ nn.GELU(),
378
+ nn.Linear(hidden_dim, hidden_dim, bias=False),
379
+ nn.GELU(),
380
+ nn.Linear(hidden_dim, out_dim, bias=False),
381
+ )
382
+ self.fc = nn.Linear(out_dim, out_dim1)
383
+ self.use_residual = use_residual
384
+ def forward(self, x):
385
+ residual = x
386
+ x = self.layernorm(x)
387
+ x = self.projector(x)
388
+ x2 = nn.GELU()(x)
389
+ x2 = self.fc(x2)
390
+ if self.use_residual:
391
+ x = x + residual
392
+ x1 = torch.mean(x,1)
393
+ return x1,x2
394
+
395
+
396
+ class StableDiffusionTest():
397
+ def __init__(self, model_id,text_encoder_path,proj_path):
398
+ super().__init__()
399
+ self.text_encoder, _, preprocess = open_clip.create_model_and_transforms('xlm-roberta-large-ViT-H-14', pretrained=text_encoder_path)
400
+ self.tokenizer = open_clip.get_tokenizer('xlm-roberta-large-ViT-H-14')
401
+ self.text_encoder.text.output_tokens = True
402
+ self.text_encoder = self.text_encoder.to(device,dtype=dtype)
403
+ self.vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device)
404
+
405
+ self.pipe = DiffusionPipeline.from_pretrained(model_id, subfolder="scheduler", torch_dtype=dtype, variant="fp16").to(device)
406
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.pipe.vae_scale_factor)
407
+
408
+ self.proj = MLP(1024, 1280, 2048, 2048, use_residual=False).to(device,dtype=dtype)
409
+ self.proj.load_state_dict(torch.load(proj_path, map_location="cpu"))
410
+
411
+ def encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
412
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
413
+ text_input_ids = self.tokenizer(prompt).to(device)
414
+ _,text_embeddings = self.text_encoder.encode_text(text_input_ids)
415
+ add_text_embeds,text_embeddings_2048 = self.proj(text_embeddings)
416
+
417
+ bs_embed, seq_len, _ = text_embeddings.shape
418
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
419
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
420
+
421
+ if do_classifier_free_guidance:
422
+ uncond_tokens: List[str]
423
+ if negative_prompt is None:
424
+ uncond_tokens = [""] * batch_size
425
+ elif type(prompt) is not type(negative_prompt):
426
+ raise TypeError(
427
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
428
+ f" {type(prompt)}."
429
+ )
430
+ elif isinstance(negative_prompt, str):
431
+ uncond_tokens = [negative_prompt]
432
+ elif batch_size != len(negative_prompt):
433
+ raise ValueError(
434
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
435
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
436
+ " the batch size of `prompt`."
437
+ )
438
+ else:
439
+ uncond_tokens = negative_prompt
440
+
441
+ max_length = text_input_ids.shape[-1]
442
+ uncond_input_ids = self.tokenizer(uncond_tokens).to(device)
443
+ _,uncond_embeddings = self.text_encoder.encode_text(uncond_input_ids)
444
+ add_text_embeds_uncond,uncond_embeddings_2048 = self.proj(uncond_embeddings)
445
+
446
+ seq_len = uncond_embeddings_2048.shape[1]
447
+ uncond_embeddings_2048 = uncond_embeddings_2048.repeat(1, num_images_per_prompt, 1)
448
+ uncond_embeddings_2048 = uncond_embeddings_2048.view(batch_size * num_images_per_prompt, seq_len, -1)
449
+
450
+ text_embeddings_2048 = torch.cat([uncond_embeddings_2048, text_embeddings_2048])
451
+ add_text_embeds = torch.cat([add_text_embeds_uncond, add_text_embeds])
452
+
453
+ return text_embeddings_2048,add_text_embeds
454
+
455
+ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
456
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
457
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
458
+ return add_time_ids
459
+
460
+
461
+ @torch.no_grad()
462
+ def __call__(
463
+ self,
464
+ prompt: Union[str, List[str]],
465
+ height: Optional[int] = 1024,
466
+ width: Optional[int] = 1024,
467
+ num_inference_steps: int = 50,
468
+ guidance_scale: float = 3,
469
+ original_size: Optional[Tuple[int, int]] = None,
470
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
471
+ target_size: Optional[Tuple[int, int]] = None,
472
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
473
+ guidance_rescale: float = 0,
474
+ negative_prompt: Optional[Union[str, List[str]]] = None,
475
+ num_images_per_prompt: Optional[int] = 1,
476
+ eta: float = 0.0,
477
+ generator: Optional[torch.Generator] = None,
478
+ latents: Optional[torch.FloatTensor] = None,
479
+ prompt_embeds: Optional[torch.FloatTensor] = None,
480
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
481
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
482
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
483
+ output_type: Optional[str] = "pil",
484
+ return_dict: bool = True,
485
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
486
+ callback_steps: Optional[int] = 1,
487
+ **kwargs,
488
+ ):
489
+ height = height or self.pipe.unet.config.sample_size * self.pipe.vae_scale_factor
490
+ width = width or self.pipe.unet.config.sample_size * self.pipe.vae_scale_factor
491
+ original_size = original_size or (height, width)
492
+ target_size = target_size or (height, width)
493
+
494
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
495
+ device = self.pipe._execution_device
496
+
497
+ do_classifier_free_guidance = guidance_scale > 1.0
498
+
499
+ prompt_embeds,add_text_embeds = self.encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt)
500
+
501
+ self.pipe.scheduler.set_timesteps(num_inference_steps, device=device)
502
+ timesteps = self.pipe.scheduler.timesteps
503
+ num_channels_latents = self.pipe.unet.in_channels
504
+ latents = self.pipe.prepare_latents(
505
+ batch_size * num_images_per_prompt,
506
+ num_channels_latents,
507
+ height,
508
+ width,
509
+ prompt_embeds.dtype,
510
+ device,
511
+ generator,
512
+ latents,
513
+ )
514
+
515
+ extra_step_kwargs = self.pipe.prepare_extra_step_kwargs(generator, eta)
516
+
517
+ add_time_ids = self._get_add_time_ids(original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype)
518
+ if do_classifier_free_guidance:
519
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
520
+
521
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
522
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
523
+
524
+ for i, t in enumerate(self.pipe.progress_bar(timesteps)):
525
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
526
+ latent_model_input = self.pipe.scheduler.scale_model_input(latent_model_input, t)
527
+
528
+ noise_pred = self.pipe.unet(
529
+ latent_model_input,
530
+ t,
531
+ encoder_hidden_states=prompt_embeds,
532
+ cross_attention_kwargs=cross_attention_kwargs,
533
+ added_cond_kwargs=added_cond_kwargs,
534
+ return_dict=False,
535
+ )[0]
536
+
537
+ if do_classifier_free_guidance:
538
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
539
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
540
+
541
+ latents = self.pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
542
+
543
+ if callback is not None and i % callback_steps == 0:
544
+ callback(i, t, latents)
545
+
546
+ self.vae.to(dtype=torch.float32)
547
+
548
+ use_torch_2_0_or_xformers = self.vae.decoder.mid_block.attentions[0].processor in [
549
+ AttnProcessor2_0,
550
+ XFormersAttnProcessor,
551
+ LoRAXFormersAttnProcessor,
552
+ LoRAAttnProcessor2_0,
553
+ ]
554
+
555
+ if not use_torch_2_0_or_xformers:
556
+ self.vae.post_quant_conv.to(latents.dtype)
557
+ self.vae.decoder.conv_in.to(latents.dtype)
558
+ self.vae.decoder.mid_block.to(latents.dtype)
559
+ else:
560
+ latents = latents.float()
561
+
562
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
563
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
564
+ if has_latents_mean and has_latents_std:
565
+ latents_mean = (
566
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
567
+ )
568
+ latents_std = (
569
+ torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
570
+ )
571
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
572
+ else:
573
+ latents = latents / self.vae.config.scaling_factor
574
+
575
+ image = self.vae.decode(latents, return_dict=False)[0]
576
+ image = self.image_processor.postprocess(image, output_type="np")
577
+
578
+ if output_type == "pil":
579
+ image = self.pipe.numpy_to_pil(image)
580
+
581
+ return image
582
+
583
+
584
+ if __name__ == '__main__':
585
+ device = "cuda"
586
+ dtype = torch.float16
587
+
588
+ model_id = "playgroundai/playground-v2.5-1024px-aesthetic"
589
+ text_encoder_path = 'laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/open_clip_pytorch_model.bin'
590
+ proj_path = "OPPOer/PEA-Diffusion/pytorch_model_pg.bin"
591
+
592
+ sdt = StableDiffusionTest(model_id,text_encoder_path,proj_path)
593
+
594
+ batch=2
595
+ height = 1024
596
+ width = 1024
597
+
598
+ while True:
599
+ raw_text = input("\nPlease Input Query (stop to exit) >>> ")
600
+ if not raw_text:
601
+ print('Query should not be empty!')
602
+ continue
603
+ if raw_text == "stop":
604
+ break
605
+ images = sdt([raw_text]*batch,height=height,width=width)
606
+ grid = image_grid(images, rows=1, cols=batch)
607
+ grid.save("PG.png")
608
+
609
+
610
+ ```
611
+ To learn more check out the [diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux) documentation
612
+
613
+
614
+ # License
615
+ The adapter itself is Apache License 2.0, but it must follow the license of the main model.
616
+
617
+
618
+ # Citation
619
+ ```
620
+ @misc{ma2023peadiffusion,
621
+ title={PEA-Diffusion: Parameter-Efficient Adapter with Knowledge Distillation in non-English Text-to-Image Generation},
622
+ author={Jian Ma and Chen Chen and Qingsong Xie and Haonan Lu},
623
+ year={2023},
624
+ eprint={2311.17086},
625
+ archivePrefix={arXiv},
626
+ primaryClass={cs.CV}
627
+ }
628
+ ```